From c873d87ac815c1c1d9e695474c7948909b7b0112 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Thu, 13 Jun 2024 14:56:08 +0100 Subject: [PATCH] Add option to flush queue instead of waiting for completion. (#1864) * Make sync_type an option on sync instead of adding submit --- backend-comparison/benches/autodiff.rs | 7 ++- backend-comparison/benches/binary.rs | 7 ++- backend-comparison/benches/conv2d.rs | 7 ++- .../benches/conv_transpose2d.rs | 7 ++- backend-comparison/benches/custom_gelu.rs | 3 +- backend-comparison/benches/data.rs | 9 ++-- backend-comparison/benches/load_record.rs | 3 +- backend-comparison/benches/matmul.rs | 7 ++- backend-comparison/benches/max_pool2d.rs | 7 ++- backend-comparison/benches/unary.rs | 7 ++- crates/burn-autodiff/src/backend.rs | 5 +- crates/burn-candle/src/backend.rs | 35 ++++++++------ crates/burn-common/src/lib.rs | 3 ++ crates/burn-common/src/sync_type.rs | 8 ++++ crates/burn-compute/src/channel/base.rs | 6 +-- crates/burn-compute/src/channel/cell.rs | 5 +- crates/burn-compute/src/channel/mpsc.rs | 17 +++---- crates/burn-compute/src/channel/mutex.rs | 5 +- crates/burn-compute/src/client.rs | 6 +-- crates/burn-compute/src/server.rs | 4 +- .../burn-compute/src/tune/tune_benchmark.rs | 4 +- crates/burn-compute/tests/dummy/server.rs | 4 +- crates/burn-compute/tests/integration_test.rs | 6 ++- crates/burn-cuda/src/compute/server.rs | 14 ++++-- crates/burn-fusion/src/backend.rs | 6 +-- crates/burn-jit/src/backend.rs | 6 +-- crates/burn-tch/src/backend.rs | 29 ++++++------ crates/burn-tensor/src/tensor/backend/base.rs | 3 +- crates/burn-wgpu/src/compute/server.rs | 46 +++++++++---------- 29 files changed, 168 insertions(+), 108 deletions(-) create mode 100644 crates/burn-common/src/sync_type.rs diff --git a/backend-comparison/benches/autodiff.rs b/backend-comparison/benches/autodiff.rs index 0d14af026c..9c5d130ed1 100644 --- a/backend-comparison/benches/autodiff.rs +++ b/backend-comparison/benches/autodiff.rs @@ -7,7 +7,10 @@ use burn::{ Distribution, Tensor, }, }; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; pub struct AutodiffOverheadBenchmark { config: nn::LstmConfig, @@ -47,7 +50,7 @@ impl Benchmark for AutodiffOverheadBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/binary.rs b/backend-comparison/benches/binary.rs index 16db7a1853..b54974ed87 100644 --- a/backend-comparison/benches/binary.rs +++ b/backend-comparison/benches/binary.rs @@ -1,6 +1,9 @@ use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; pub struct BinaryBenchmark { shape: Shape, @@ -31,7 +34,7 @@ impl Benchmark for BinaryBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait); } } diff --git a/backend-comparison/benches/conv2d.rs b/backend-comparison/benches/conv2d.rs index 343cecc543..527eaf96b2 100644 --- a/backend-comparison/benches/conv2d.rs +++ b/backend-comparison/benches/conv2d.rs @@ -2,7 +2,10 @@ use backend_comparison::persistence::save; use burn::tensor::{ backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor, }; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; pub struct Conv2dBenchmark { input_shape: Shape<4>, @@ -48,7 +51,7 @@ impl Benchmark for Conv2dBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/conv_transpose2d.rs b/backend-comparison/benches/conv_transpose2d.rs index 13ad6e1c44..f3879b159a 100644 --- a/backend-comparison/benches/conv_transpose2d.rs +++ b/backend-comparison/benches/conv_transpose2d.rs @@ -3,7 +3,10 @@ use burn::tensor::{ backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape, Tensor, }; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; pub struct ConvTranspose2dBenchmark { input_shape: Shape<4>, @@ -49,7 +52,7 @@ impl Benchmark for ConvTranspose2dBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/custom_gelu.rs b/backend-comparison/benches/custom_gelu.rs index 8f42eb11ea..f7333c4e43 100644 --- a/backend-comparison/benches/custom_gelu.rs +++ b/backend-comparison/benches/custom_gelu.rs @@ -2,6 +2,7 @@ use backend_comparison::persistence::save; use burn::backend::Autodiff; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::sync_type::SyncType; use core::f64::consts::SQRT_2; use derive_new::new; @@ -68,7 +69,7 @@ impl Benchmark for CustomGeluBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } fn num_samples(&self) -> usize { diff --git a/backend-comparison/benches/data.rs b/backend-comparison/benches/data.rs index afe625bdd0..49c9a06e0b 100644 --- a/backend-comparison/benches/data.rs +++ b/backend-comparison/benches/data.rs @@ -1,6 +1,9 @@ use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; use derive_new::new; #[derive(new)] @@ -29,7 +32,7 @@ impl Benchmark for ToDataBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } @@ -66,7 +69,7 @@ impl Benchmark for FromDataBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/load_record.rs b/backend-comparison/benches/load_record.rs index 5dfe2e641e..5ab2d1d64b 100644 --- a/backend-comparison/benches/load_record.rs +++ b/backend-comparison/benches/load_record.rs @@ -3,6 +3,7 @@ use burn::tensor::backend::Backend; use burn::tensor::Device; use burn::{config::Config, module::Module, nn}; use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::sync_type::SyncType; use derive_new::new; #[derive(Module, Debug)] @@ -93,7 +94,7 @@ impl Benchmark for LoadRecordBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 260be58caf..15941a61bb 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -1,6 +1,9 @@ use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; use derive_new::new; #[derive(new)] @@ -37,7 +40,7 @@ impl Benchmark for MatmulBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/max_pool2d.rs b/backend-comparison/benches/max_pool2d.rs index 80e2ffab2b..40ade42fc4 100644 --- a/backend-comparison/benches/max_pool2d.rs +++ b/backend-comparison/benches/max_pool2d.rs @@ -1,6 +1,9 @@ use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor}; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; pub struct MaxPool2dBenchmark { shape: Shape<4>, @@ -37,7 +40,7 @@ impl Benchmark for MaxPool2dBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/backend-comparison/benches/unary.rs b/backend-comparison/benches/unary.rs index 98fa89c850..c381371b66 100644 --- a/backend-comparison/benches/unary.rs +++ b/backend-comparison/benches/unary.rs @@ -1,6 +1,9 @@ use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; -use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_common::{ + benchmark::{run_benchmark, Benchmark}, + sync_type::SyncType, +}; use derive_new::new; #[derive(new)] @@ -30,7 +33,7 @@ impl Benchmark for UnaryBenchmark { } fn sync(&self) { - B::sync(&self.device) + B::sync(&self.device, SyncType::Wait) } } diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 0445cb784b..e5d887a8bc 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -5,6 +5,7 @@ use crate::{ tensor::AutodiffTensor, AutodiffBridge, }; +use burn_common::sync_type::SyncType; use burn_tensor::backend::{AutodiffBackend, Backend}; use core::marker::PhantomData; @@ -43,8 +44,8 @@ impl Backend for Autodiff { B::seed(seed) } - fn sync(device: &B::Device) { - B::sync(device); + fn sync(device: &B::Device, sync_type: SyncType) { + B::sync(device, sync_type) } } diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index f66a28bb29..8ed8b95d43 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use burn_tensor::{ - backend::{Backend, DeviceId, DeviceOps}, + backend::{Backend, DeviceId, DeviceOps, SyncType}, Device, }; use candle_core::DeviceLocation; @@ -105,20 +105,25 @@ impl Backend for Candle { panic!("Manual seed not supported by Candle. ") } - fn sync(device: &Device) { - let device: candle_core::Device = (*device).into(); - - match device { - candle_core::Device::Cpu => (), - candle_core::Device::Cuda(device) => { - #[cfg(feature = "cuda")] - device.synchronize().unwrap(); + fn sync(device: &Device, sync_type: SyncType) { + match sync_type { + SyncType::Wait => { + let device: candle_core::Device = (*device).into(); + + match device { + candle_core::Device::Cpu => (), + candle_core::Device::Cuda(device) => { + #[cfg(feature = "cuda")] + device.synchronize().unwrap(); + } + candle_core::Device::Metal(device) => { + // For some reason, device.wait_until_completed() does not seem to work, + // and neither does writing and reading a value with into_data + panic!("Device synchronization unavailable with Metal device on Candle backend") + } + } } - candle_core::Device::Metal(device) => { - // For some reason, device.wait_until_completed() does not seem to work, - // and neither does writing and reading a value with into_data - panic!("Device synchronization unavailable with Metal device on Candle backend") - } - } + SyncType::Flush => (), // Nothhing to flush. + }; } } diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs index 4a69c89be5..b5f5338983 100644 --- a/crates/burn-common/src/lib.rs +++ b/crates/burn-common/src/lib.rs @@ -25,6 +25,9 @@ pub mod benchmark; /// notation. pub mod reader; +/// Synchronization type module, used both by ComputeServer and Backends. +pub mod sync_type; + extern crate alloc; /// Network utilities. diff --git a/crates/burn-common/src/sync_type.rs b/crates/burn-common/src/sync_type.rs new file mode 100644 index 0000000000..afe484819d --- /dev/null +++ b/crates/burn-common/src/sync_type.rs @@ -0,0 +1,8 @@ +/// What kind of synchronization to use. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SyncType { + /// Submit all outstanding tasks to the task queue if any. + Flush, + /// Submit all tasks to the task queue and wait for all of them to complete. + Wait, +} diff --git a/crates/burn-compute/src/channel/base.rs b/crates/burn-compute/src/channel/base.rs index b11bdbb915..d74a23f9b3 100644 --- a/crates/burn-compute/src/channel/base.rs +++ b/crates/burn-compute/src/channel/base.rs @@ -3,7 +3,7 @@ use crate::{ storage::ComputeStorage, }; use alloc::vec::Vec; -use burn_common::reader::Reader; +use burn_common::{reader::Reader, sync_type::SyncType}; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety @@ -26,6 +26,6 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send /// Executes the `kernel` over the given `bindings`. fn execute(&self, kernel: Server::Kernel, bindings: Vec>); - /// Wait for the completion of every task in the server. - fn sync(&self); + /// Perform some synchronization of commands on the server. + fn sync(&self, sync_type: SyncType); } diff --git a/crates/burn-compute/src/channel/cell.rs b/crates/burn-compute/src/channel/cell.rs index c41f895fdd..4d809694d4 100644 --- a/crates/burn-compute/src/channel/cell.rs +++ b/crates/burn-compute/src/channel/cell.rs @@ -4,6 +4,7 @@ use crate::storage::ComputeStorage; use alloc::sync::Arc; use alloc::vec::Vec; use burn_common::reader::Reader; +use burn_common::sync_type::SyncType; /// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability. /// @@ -68,8 +69,8 @@ where .execute(kernel_description, bindings) } - fn sync(&self) { - self.server.borrow_mut().sync() + fn sync(&self, sync_type: SyncType) { + self.server.borrow_mut().sync(sync_type) } } diff --git a/crates/burn-compute/src/channel/mpsc.rs b/crates/burn-compute/src/channel/mpsc.rs index 6ca9e96a70..b0af33e018 100644 --- a/crates/burn-compute/src/channel/mpsc.rs +++ b/crates/burn-compute/src/channel/mpsc.rs @@ -3,7 +3,7 @@ use std::{ thread, }; -use burn_common::reader::Reader; +use burn_common::{reader::Reader, sync_type::SyncType}; use super::ComputeChannel; use crate::{ @@ -44,7 +44,7 @@ where Create(Vec, Callback>), Empty(usize, Callback>), ExecuteKernel(Server::Kernel, Vec>), - Sync(Callback<()>), + Sync(SyncType, Callback<()>), } impl MpscComputeChannel @@ -77,8 +77,8 @@ where Message::ExecuteKernel(kernel, bindings) => { server.execute(kernel, bindings); } - Message::Sync(callback) => { - server.sync(); + Message::Sync(sync_type, callback) => { + server.sync(sync_type); callback.send(()).unwrap(); } }; @@ -157,11 +157,12 @@ where .unwrap() } - fn sync(&self) { + fn sync(&self, sync_type: SyncType) { let (callback, response) = mpsc::channel(); - - self.state.sender.send(Message::Sync(callback)).unwrap(); - + self.state + .sender + .send(Message::Sync(sync_type, callback)) + .unwrap(); self.response(response) } } diff --git a/crates/burn-compute/src/channel/mutex.rs b/crates/burn-compute/src/channel/mutex.rs index e6db609040..141cfca7be 100644 --- a/crates/burn-compute/src/channel/mutex.rs +++ b/crates/burn-compute/src/channel/mutex.rs @@ -4,6 +4,7 @@ use crate::storage::ComputeStorage; use alloc::sync::Arc; use alloc::vec::Vec; use burn_common::reader::Reader; +use burn_common::sync_type::SyncType; use spin::Mutex; /// The MutexComputeChannel ensures thread-safety by locking the server @@ -59,7 +60,7 @@ where self.server.lock().execute(kernel, handles) } - fn sync(&self) { - self.server.lock().sync() + fn sync(&self, sync_type: SyncType) { + self.server.lock().sync(sync_type) } } diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs index b8ea924340..cb1ec1a8ba 100644 --- a/crates/burn-compute/src/client.rs +++ b/crates/burn-compute/src/client.rs @@ -6,8 +6,8 @@ use crate::{ }; use alloc::vec::Vec; use alloc::{boxed::Box, sync::Arc}; -use burn_common::reader::Reader; use burn_common::stub::RwLock; +use burn_common::{reader::Reader, sync_type::SyncType}; /// The ComputeClient is the entry point to require tasks from the ComputeServer. /// It should be obtained for a specific device via the Compute struct. @@ -69,8 +69,8 @@ where } /// Wait for the completion of every task in the server. - pub fn sync(&self) { - self.channel.sync() + pub fn sync(&self, sync_type: SyncType) { + self.channel.sync(sync_type) } /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs index aa0360815f..cb495b34a8 100644 --- a/crates/burn-compute/src/server.rs +++ b/crates/burn-compute/src/server.rs @@ -4,7 +4,7 @@ use crate::{ tune::AutotuneKey, }; use alloc::vec::Vec; -use burn_common::reader::Reader; +use burn_common::{reader::Reader, sync_type::SyncType}; use core::fmt::Debug; /// The compute server is responsible for handling resources and computations over resources. @@ -46,7 +46,7 @@ where fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>); /// Wait for the completion of every task in the server. - fn sync(&mut self); + fn sync(&mut self, command: SyncType); } /// Server handle containing the [memory handle](MemoryManagement::Handle). diff --git a/crates/burn-compute/src/tune/tune_benchmark.rs b/crates/burn-compute/src/tune/tune_benchmark.rs index 53a03558e9..e0fd6a2493 100644 --- a/crates/burn-compute/src/tune/tune_benchmark.rs +++ b/crates/burn-compute/src/tune/tune_benchmark.rs @@ -1,4 +1,5 @@ use burn_common::benchmark::Benchmark; +use burn_common::sync_type::SyncType; use crate::channel::ComputeChannel; use crate::client::ComputeClient; @@ -41,6 +42,7 @@ impl> Benchmark for TuneBenchmark { } fn sync(&self) { - self.client.sync(); + // For benchmarks - we need to wait for all tasks to complete before returning. + self.client.sync(SyncType::Wait); } } diff --git a/crates/burn-compute/tests/dummy/server.rs b/crates/burn-compute/tests/dummy/server.rs index c49d2e0679..452c1a5550 100644 --- a/crates/burn-compute/tests/dummy/server.rs +++ b/crates/burn-compute/tests/dummy/server.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use burn_common::reader::Reader; +use burn_common::{reader::Reader, sync_type::SyncType}; use burn_compute::{ memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, @@ -62,7 +62,7 @@ where kernel.compute(&mut resources); } - fn sync(&mut self) { + fn sync(&mut self, _: SyncType) { // Nothing to do with dummy backend. } } diff --git a/crates/burn-compute/tests/integration_test.rs b/crates/burn-compute/tests/integration_test.rs index 9d051f0cc8..7d4a14f67a 100644 --- a/crates/burn-compute/tests/integration_test.rs +++ b/crates/burn-compute/tests/integration_test.rs @@ -246,6 +246,8 @@ fn autotune_cache_different_keys_return_a_cache_miss() { #[serial] #[cfg(feature = "std")] fn autotune_cache_different_checksums_return_a_cache_miss() { + use burn_common::sync_type::SyncType; + type Runtime = ComputeRuntime; let runtime = Runtime::new(); let client = runtime.client(&DummyDevice, dummy::init_client); @@ -260,7 +262,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() { let cache_test_autotune_kernel_1 = dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); client.autotune_execute(Box::new(cache_test_autotune_kernel_1)); - client.sync(); + client.sync(SyncType::Wait); // we use a second compute client in order to have freshly initialized autotune cache // and test invalidation of the cache when the checksum of the operation set is @@ -278,7 +280,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() { dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); cache_test_autotune_kernel_2.generate_random_checksum = true; client.autotune_execute(Box::new(cache_test_autotune_kernel_2)); - client.sync(); + client.sync(SyncType::Wait); let obtained_resource = client.read(out_2.binding()); diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs index c49cbb22ec..7218794c86 100644 --- a/crates/burn-cuda/src/compute/server.rs +++ b/crates/burn-cuda/src/compute/server.rs @@ -7,6 +7,7 @@ use burn_compute::{ use burn_cube::ir::CubeDim; use burn_cube::prelude::*; use burn_jit::JitAutotuneKey; +use burn_tensor::backend::SyncType; use cudarc::driver::sys::CUctx_st; use cudarc::driver::sys::CUfunc_st; use std::collections::HashMap; @@ -110,9 +111,16 @@ impl> ComputeServer for CudaServer { // self.memory_management.storage().perform_deallocations(); } - fn sync(&mut self) { - let ctx = self.get_context(); - ctx.sync(); + fn sync(&mut self, sync_type: SyncType) { + match sync_type { + // Synchronize the stream if waiting. + SyncType::Wait => { + let ctx = self.get_context(); + ctx.sync(); + } + // Nothing to do - all tasks are already submitted to the stream. + SyncType::Flush => (), + } } fn get_resource( diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index e5b8cb1a66..e794c9f2e1 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -2,7 +2,7 @@ use crate::{ client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ - backend::{Backend, DeviceOps}, + backend::{Backend, DeviceOps, SyncType}, ops::FloatTensor, repr::{OperationDescription, ReprBackend}, Device, @@ -45,10 +45,10 @@ impl Backend for Fusion { B::seed(seed); } - fn sync(device: &Self::Device) { + fn sync(device: &Self::Device, sync_type: SyncType) { let client = CLIENTS.client::(&device.clone()); client.drain(); - B::sync(device) + B::sync(device, sync_type); } fn ad_enabled() -> bool { diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 8cd0631fc8..4a058bb25a 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -2,7 +2,7 @@ use crate::{ tensor::JitTensor, FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge, }; use burn_compute::server::ComputeServer; -use burn_tensor::backend::Backend; +use burn_tensor::backend::{Backend, SyncType}; use rand::{rngs::StdRng, SeedableRng}; use std::{marker::PhantomData, sync::Mutex}; @@ -48,9 +48,9 @@ where false } - fn sync(device: &Self::Device) { + fn sync(device: &Self::Device, sync_type: SyncType) { let client = R::client(device); - client.sync(); + client.sync(sync_type); } } diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index bf436702f9..c2bb584aa6 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -2,7 +2,7 @@ use crate::PrecisionBridge; use super::element::TchElement; use super::TchTensor; -use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; +use burn_tensor::backend::{Backend, DeviceId, DeviceOps, SyncType}; use burn_tensor::ops::IntTensorOps; use burn_tensor::{Int, Tensor}; @@ -114,19 +114,20 @@ impl Backend for LibTorch { "tch".to_string() } - fn sync(device: &Self::Device) { - match device { - LibTorchDevice::Cpu => (), - LibTorchDevice::Cuda(index) => { - tch::Cuda::synchronize(*index as i64); - } - _ => { - // When there is no explicit way to synchronize, we write and read one value to sync - Tensor::::from_primitive(>::int_zeros( - [1].into(), - device, - )) - .into_data(); + fn sync(device: &Self::Device, sync_type: SyncType) { + if sync_type == SyncType::Wait { + match device { + LibTorchDevice::Cpu => (), + LibTorchDevice::Cuda(index) => { + tch::Cuda::synchronize(*index as i64); + } + _ => { + // When there is no explicit way to synchronize, we write and read one value to sync + Tensor::::from_primitive( + >::int_zeros([1].into(), device), + ) + .into_data(); + } } } } diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index 182235df1d..0964fd602f 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -1,4 +1,5 @@ use alloc::string::String; +pub use burn_common::sync_type::SyncType; use crate::ops::*; use crate::tensor::Element; @@ -96,7 +97,7 @@ pub trait Backend: fn seed(seed: u64); /// Sync the backend, ensure that all computation are finished. - fn sync(_device: &Self::Device) {} + fn sync(_device: &Self::Device, _sync_type: SyncType) {} } /// Trait that allows a backend to support autodiff. diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 7c52fa6e76..dd9fc302bd 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -8,7 +8,7 @@ use burn_compute::{ }; use burn_cube::prelude::*; use burn_jit::JitAutotuneKey; -use burn_tensor::Reader; +use burn_tensor::{backend::SyncType, Reader}; use hashbrown::HashMap; use wgpu::{ util::{BufferInitDescriptor, DeviceExt, StagingBelt}, @@ -60,23 +60,6 @@ where } } - fn submit(&mut self) { - self.staging_belt.finish(); - - let mut new_encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, &mut self.encoder); - - self.queue.submit(Some(new_encoder.finish())); - self.tasks_count = 0; - - // Cleanup allocations and deallocations. - self.memory_management.storage().perform_deallocations(); - - self.staging_belt.recall(); - } - fn register_compute( &mut self, pipeline: Arc, @@ -150,7 +133,7 @@ where ); self.tasks_count += 1; - self.submit(); + self.sync(SyncType::Flush); BufferReader::new(buffer_dest) } @@ -304,12 +287,29 @@ where self.register_compute(pipeline, bind_group, work_group); if self.tasks_count >= self.tasks_max { - self.submit(); + self.sync(SyncType::Flush); } } - fn sync(&mut self) { - self.submit(); - self.device.poll(wgpu::Maintain::Wait); + fn sync(&mut self, sync_type: SyncType) { + // Flush commands to the queue. + self.staging_belt.finish(); + + let mut new_encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + core::mem::swap(&mut new_encoder, &mut self.encoder); + + self.queue.submit(Some(new_encoder.finish())); + self.tasks_count = 0; + + // Cleanup allocations and deallocations. + self.memory_management.storage().perform_deallocations(); + + self.staging_belt.recall(); + + if sync_type == SyncType::Wait { + self.device.poll(wgpu::Maintain::Wait); + } } }