Skip to content

Commit

Permalink
Add option to flush queue instead of waiting for completion. (#1864)
Browse files Browse the repository at this point in the history
* Make sync_type an option on sync instead of adding submit
  • Loading branch information
ArthurBrussee committed Jun 13, 2024
1 parent 71bd5ef commit c873d87
Show file tree
Hide file tree
Showing 29 changed files with 168 additions and 108 deletions.
7 changes: 5 additions & 2 deletions backend-comparison/benches/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use burn::{
Distribution, Tensor,
},
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
config: nn::LstmConfig,
Expand Down Expand Up @@ -47,7 +50,7 @@ impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/binary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct BinaryBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
Expand Down Expand Up @@ -31,7 +34,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait);
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct Conv2dBenchmark<B: Backend> {
input_shape: Shape<4>,
Expand Down Expand Up @@ -48,7 +51,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use burn::tensor::{
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct ConvTranspose2dBenchmark<B: Backend> {
input_shape: Shape<4>,
Expand Down Expand Up @@ -49,7 +52,7 @@ impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
3 changes: 2 additions & 1 deletion backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use backend_comparison::persistence::save;
use burn::backend::Autodiff;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use core::f64::consts::SQRT_2;
use derive_new::new;

Expand Down Expand Up @@ -68,7 +69,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}

fn num_samples(&self) -> usize {
Expand Down
9 changes: 6 additions & 3 deletions backend-comparison/benches/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -29,7 +32,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down Expand Up @@ -66,7 +69,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
3 changes: 2 additions & 1 deletion backend-comparison/benches/load_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use burn::tensor::backend::Backend;
use burn::tensor::Device;
use burn::{config::Config, module::Module, nn};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::sync_type::SyncType;
use derive_new::new;

#[derive(Module, Debug)]
Expand Down Expand Up @@ -93,7 +94,7 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -37,7 +40,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/max_pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct MaxPool2dBenchmark<B: Backend> {
shape: Shape<4>,
Expand Down Expand Up @@ -37,7 +40,7 @@ impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
7 changes: 5 additions & 2 deletions backend-comparison/benches/unary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};
use derive_new::new;

#[derive(new)]
Expand Down Expand Up @@ -30,7 +33,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
}

fn sync(&self) {
B::sync(&self.device)
B::sync(&self.device, SyncType::Wait)
}
}

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
tensor::AutodiffTensor,
AutodiffBridge,
};
use burn_common::sync_type::SyncType;
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::marker::PhantomData;

Expand Down Expand Up @@ -43,8 +44,8 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
B::seed(seed)
}

fn sync(device: &B::Device) {
B::sync(device);
fn sync(device: &B::Device, sync_type: SyncType) {
B::sync(device, sync_type)
}
}

Expand Down
35 changes: 20 additions & 15 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::marker::PhantomData;

use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
backend::{Backend, DeviceId, DeviceOps, SyncType},
Device,
};
use candle_core::DeviceLocation;
Expand Down Expand Up @@ -105,20 +105,25 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
panic!("Manual seed not supported by Candle. ")
}

fn sync(device: &Device<Self>) {
let device: candle_core::Device = (*device).into();

match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (*device).into();

match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
}
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
SyncType::Flush => (), // Nothhing to flush.
};
}
}
3 changes: 3 additions & 0 deletions crates/burn-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub mod benchmark;
/// notation.
pub mod reader;

/// Synchronization type module, used both by ComputeServer and Backends.
pub mod sync_type;

extern crate alloc;

/// Network utilities.
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-common/src/sync_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/// What kind of synchronization to use.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncType {
/// Submit all outstanding tasks to the task queue if any.
Flush,
/// Submit all tasks to the task queue and wait for all of them to complete.
Wait,
}
6 changes: 3 additions & 3 deletions crates/burn-compute/src/channel/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
storage::ComputeStorage,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use burn_common::{reader::Reader, sync_type::SyncType};

/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
Expand All @@ -26,6 +26,6 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
/// Executes the `kernel` over the given `bindings`.
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);

/// Wait for the completion of every task in the server.
fn sync(&self);
/// Perform some synchronization of commands on the server.
fn sync(&self, sync_type: SyncType);
}
5 changes: 3 additions & 2 deletions crates/burn-compute/src/channel/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
use burn_common::sync_type::SyncType;

/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
///
Expand Down Expand Up @@ -68,8 +69,8 @@ where
.execute(kernel_description, bindings)
}

fn sync(&self) {
self.server.borrow_mut().sync()
fn sync(&self, sync_type: SyncType) {
self.server.borrow_mut().sync(sync_type)
}
}

Expand Down
17 changes: 9 additions & 8 deletions crates/burn-compute/src/channel/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
thread,
};

use burn_common::reader::Reader;
use burn_common::{reader::Reader, sync_type::SyncType};

use super::ComputeChannel;
use crate::{
Expand Down Expand Up @@ -44,7 +44,7 @@ where
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
Sync(Callback<()>),
Sync(SyncType, Callback<()>),
}

impl<Server> MpscComputeChannel<Server>
Expand Down Expand Up @@ -77,8 +77,8 @@ where
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings);
}
Message::Sync(callback) => {
server.sync();
Message::Sync(sync_type, callback) => {
server.sync(sync_type);
callback.send(()).unwrap();
}
};
Expand Down Expand Up @@ -157,11 +157,12 @@ where
.unwrap()
}

fn sync(&self) {
fn sync(&self, sync_type: SyncType) {
let (callback, response) = mpsc::channel();

self.state.sender.send(Message::Sync(callback)).unwrap();

self.state
.sender
.send(Message::Sync(sync_type, callback))
.unwrap();
self.response(response)
}
}
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-compute/src/channel/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
use burn_common::sync_type::SyncType;
use spin::Mutex;

/// The MutexComputeChannel ensures thread-safety by locking the server
Expand Down Expand Up @@ -59,7 +60,7 @@ where
self.server.lock().execute(kernel, handles)
}

fn sync(&self) {
self.server.lock().sync()
fn sync(&self, sync_type: SyncType) {
self.server.lock().sync(sync_type)
}
}
Loading

0 comments on commit c873d87

Please sign in to comment.