From 2e060dc1bec4fb40765bc2ae4ce8352d02d42fdd Mon Sep 17 00:00:00 2001 From: Piotr Stankiewicz Date: Thu, 14 Mar 2024 11:34:20 +0100 Subject: [PATCH 1/2] burn-wgpu: Move any use of wgpu into a separate module Currently the burn-wgpu crate is hardcoded to use the wgpu WebGPU implementaion. It would be nice to be able to use other WebGPU implementations as new features may land at different times, and there may be potential performance gains to unlock. So, separate any use of the wgpu crate into a separate module, in preparation for adding the ability to target other WebGPU implementaions. Signed-off-by: Piotr Stankiewicz --- crates/burn-wgpu/src/compute/mod.rs | 2 + crates/burn-wgpu/src/compute/server.rs | 93 +++----- crates/burn-wgpu/src/compute/storage.rs | 29 +-- crates/burn-wgpu/src/compute/wgpu_api_shim.rs | 212 ++++++++++++++++++ crates/burn-wgpu/src/runtime.rs | 156 +------------ 5 files changed, 274 insertions(+), 218 deletions(-) create mode 100644 crates/burn-wgpu/src/compute/wgpu_api_shim.rs diff --git a/crates/burn-wgpu/src/compute/mod.rs b/crates/burn-wgpu/src/compute/mod.rs index 4139c3868f..8b996dcafc 100644 --- a/crates/burn-wgpu/src/compute/mod.rs +++ b/crates/burn-wgpu/src/compute/mod.rs @@ -1,5 +1,7 @@ mod server; mod storage; +mod wgpu_api_shim; pub use server::*; pub use storage::*; +pub use wgpu_api_shim::*; diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 4d95d87ad6..b6e8b2f620 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -1,4 +1,11 @@ use super::WgpuStorage; +use crate::compute::{ + webgpu_device_poll, webgpu_read_buffer, WebGPUBindGroup, WebGPUBindGroupDescriptor, + WebGPUBindGroupEntry, WebGPUBuffer, WebGPUBufferDescriptor, WebGPUCommandEncoder, + WebGPUCommandEncoderDescriptor, WebGPUComputePassDescriptor, WebGPUComputePipeline, + WebGPUComputePipelineDescriptor, WebGPUDevice, WebGPUQueue, WebGPUShaderModuleDescriptor, + WebGPUShaderSource, COPY_DST, MAP_READ, +}; use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ memory_management::MemoryManagement, @@ -7,19 +14,15 @@ use burn_compute::{ use burn_jit::compute::{JitAutotuneKey, JitKernel, Kernel, WorkGroup}; use burn_tensor::Reader; use hashbrown::HashMap; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, -}; /// Wgpu compute server. #[derive(Debug)] pub struct WgpuServer> { memory_management: MM, - device: Arc, - queue: wgpu::Queue, - encoder: CommandEncoder, - pipelines: HashMap>, + device: Arc, + queue: WebGPUQueue, + encoder: WebGPUCommandEncoder, + pipelines: HashMap>, tasks_max: usize, tasks_count: usize, } @@ -31,11 +34,11 @@ where /// Create a new server. pub fn new( memory_management: MM, - device: Arc, - queue: wgpu::Queue, + device: Arc, + queue: WebGPUQueue, tasks_max: usize, ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + let encoder = device.create_command_encoder(&WebGPUCommandEncoderDescriptor { label: Some("Command Encoder"), }); @@ -53,7 +56,7 @@ where fn submit(&mut self) { let mut new_encoder = self .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + .create_command_encoder(&WebGPUCommandEncoderDescriptor { label: None }); core::mem::swap(&mut new_encoder, &mut self.encoder); self.queue.submit(Some(new_encoder.finish())); @@ -71,7 +74,7 @@ where ) { let mut compute = self .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { + .begin_compute_pass(&WebGPUComputePassDescriptor { label: None, timestamp_writes: None, }); @@ -83,7 +86,7 @@ where self.tasks_count += 1; } - fn pipeline(&mut self, kernel: Kernel) -> Arc { + fn pipeline(&mut self, kernel: Kernel) -> Arc { let kernel_id = kernel.id(); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); @@ -96,15 +99,17 @@ where pipeline } - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); + fn compile_source(&self, source: &str) -> Arc { + let module = self + .device + .create_shader_module(WebGPUShaderModuleDescriptor { + label: None, + source: WebGPUShaderSource::Wgsl(Cow::Borrowed(source)), + }); Arc::new( self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + .create_compute_pipeline(&WebGPUComputePipelineDescriptor { label: None, layout: None, module: &module, @@ -117,10 +122,10 @@ where let resource = self.memory_management.get(handle.memory); let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + let buffer_dest = self.device.create_buffer(&WebGPUBufferDescriptor { label: None, size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + usage: MAP_READ | COPY_DST, mapped_at_creation: false, }); @@ -141,43 +146,12 @@ where #[derive(new)] struct BufferReader { - buffer: wgpu::Buffer, + buffer: WebGPUBuffer, } impl BufferReader { - #[cfg(target_family = "wasm")] - async fn read(self, device: alloc::sync::Arc) -> Vec { - self.read_async(&device).await - } - - #[cfg(not(target_family = "wasm"))] - fn read(self, device: &wgpu::Device) -> Vec { - pollster::block_on(self.read_async(device)) - } - - async fn read_async(&self, device: &wgpu::Device) -> Vec { - let buffer_slice = self.buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver.receive().await; - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - self.buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } + fn read(&self, device: &WebGPUDevice) -> Vec { + webgpu_read_buffer(&self.buffer, device) } } @@ -225,7 +199,6 @@ where resource.offset(), buffer_src.size(), ); - self.tasks_count += 1; handle } @@ -247,13 +220,13 @@ where let entries = memory_handles .iter() .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { + .map(|(i, buffer)| WebGPUBindGroupEntry { binding: i as u32, resource: buffer.as_binding(), }) .collect::>(); - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + let bind_group = self.device.create_bind_group(&WebGPUBindGroupDescriptor { label: None, layout: &group_layout, entries: &entries, @@ -268,6 +241,6 @@ where fn sync(&mut self) { self.submit(); - self.device.poll(wgpu::Maintain::Wait); + webgpu_device_poll(&self.device); } } diff --git a/crates/burn-wgpu/src/compute/storage.rs b/crates/burn-wgpu/src/compute/storage.rs index 12988b1352..0ab02727b5 100644 --- a/crates/burn-wgpu/src/compute/storage.rs +++ b/crates/burn-wgpu/src/compute/storage.rs @@ -1,12 +1,16 @@ +use crate::compute::{ + WebGPUBindingResource, WebGPUBuffer, WebGPUBufferAddress, WebGPUBufferBinding, + WebGPUBufferDescriptor, WebGPUBufferSize, WebGPUDevice, COPY_DST, COPY_SRC, STORAGE, +}; use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use hashbrown::HashMap; use std::{num::NonZeroU64, sync::Arc}; /// Buffer storage for wgpu. pub struct WgpuStorage { - memory: HashMap>, + memory: HashMap>, deallocations: Vec, - device: Arc, + device: Arc, } impl core::fmt::Debug for WgpuStorage { @@ -19,23 +23,23 @@ impl core::fmt::Debug for WgpuStorage { #[derive(new, Debug)] pub struct WgpuResource { /// The wgpu buffer. - pub buffer: Arc, + pub buffer: Arc, /// How the resource is used. pub kind: WgpuResourceKind, } impl WgpuResource { /// Return the binding view of the buffer. - pub fn as_binding(&self) -> wgpu::BindingResource { + pub fn as_binding(&self) -> WebGPUBindingResource { let binding = match &self.kind { WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { + WgpuResourceKind::Slice(offs, size) => WebGPUBufferBinding { buffer: &self.buffer, offset: *offs, size: Some(*size), }, }; - wgpu::BindingResource::Buffer(binding) + WebGPUBindingResource::Buffer(binding) } /// Return the buffer size. @@ -61,13 +65,13 @@ pub enum WgpuResourceKind { /// Represents an entire buffer. Full, /// A slice over a buffer. - Slice(wgpu::BufferAddress, wgpu::BufferSize), + Slice(WebGPUBufferAddress, WebGPUBufferSize), } /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { + /// Create a new storage on the given [device](WebGPUDevice). + pub fn new(device: Arc) -> Self { Self { memory: HashMap::new(), deallocations: Vec::new(), @@ -104,12 +108,11 @@ impl ComputeStorage for WgpuStorage { fn alloc(&mut self, size: usize) -> StorageHandle { let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + + let buffer = Arc::new(self.device.create_buffer(&WebGPUBufferDescriptor { label: None, size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, + usage: COPY_DST | STORAGE | COPY_SRC, mapped_at_creation: false, })); diff --git a/crates/burn-wgpu/src/compute/wgpu_api_shim.rs b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs new file mode 100644 index 0000000000..12f14a1cb8 --- /dev/null +++ b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs @@ -0,0 +1,212 @@ +use crate::{GraphicsApi, WgpuDevice}; + +pub type WebGPUAdapter = wgpu::Adapter; +pub type WebGPUAdapterInfo = wgpu::AdapterInfo; +pub type WebGPUBindGroup = wgpu::BindGroup; +pub type WebGPUBindGroupDescriptor<'a> = wgpu::BindGroupDescriptor<'a>; +pub type WebGPUBindGroupEntry<'a> = wgpu::BindGroupEntry<'a>; +pub type WebGPUBindingResource<'a> = wgpu::BindingResource<'a>; +pub type WebGPUBuffer = wgpu::Buffer; +pub type WebGPUBufferAddress = wgpu::BufferAddress; +pub type WebGPUBufferBinding<'a> = wgpu::BufferBinding<'a>; +pub type WebGPUBufferSize = wgpu::BufferSize; +pub type WebGPUBufferDescriptor<'a> = wgpu::BufferDescriptor<'a>; + +pub type WebGPUBufferUsages = wgpu::BufferUsages; +pub const MAP_READ: WebGPUBufferUsages = wgpu::BufferUsages::MAP_READ; +pub const COPY_SRC: WebGPUBufferUsages = wgpu::BufferUsages::COPY_SRC; +pub const COPY_DST: WebGPUBufferUsages = wgpu::BufferUsages::COPY_DST; +pub const STORAGE: WebGPUBufferUsages = wgpu::BufferUsages::STORAGE; + +pub type WebGPUCommandEncoder = wgpu::CommandEncoder; +pub type WebGPUCommandEncoderDescriptor<'a> = wgpu::CommandEncoderDescriptor<'a>; +pub type WebGPUComputePassDescriptor<'a> = wgpu::ComputePassDescriptor<'a>; +pub type WebGPUComputePipeline = wgpu::ComputePipeline; +pub type WebGPUComputePipelineDescriptor<'a> = wgpu::ComputePipelineDescriptor<'a>; +pub type WebGPUDevice = wgpu::Device; +pub type WebGPUQueue = wgpu::Queue; +pub type WebGPUShaderModuleDescriptor<'a> = wgpu::ShaderModuleDescriptor<'a>; +pub type WebGPUShaderSource<'a> = wgpu::ShaderSource<'a>; + +pub async fn webgpu_select_device(adapter: &WebGPUAdapter) -> (WebGPUDevice, WebGPUQueue) { + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + features: wgpu::Features::empty(), + limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue) +} + +#[cfg(target_family = "wasm")] +async fn webgpu_select_adapter(_device: &WgpuDevice) -> WebGPUAdapter { + let instance = wgpu::Instance::default(); + + instance + .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) + .await + .unwrap() +} + +#[cfg(not(target_family = "wasm"))] +pub fn webgpu_select_adapter(device: &WgpuDevice) -> WebGPUAdapter { + use wgpu::DeviceType; + + let instance = wgpu::Instance::default(); + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); + + instance + .enumerate_adapters(G::backend().into()) + .for_each(|adapter| { + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } + }); + + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } + + adapters.remove(num) + } + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; + + adapters + .into_iter() + .chain(adapters_other) + .for_each(|adapter| { + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } + }); + + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } + } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); + + adapter +} + +pub fn webgpu_read_buffer(buffer: &WebGPUBuffer, device: &WebGPUDevice) -> Vec { + pollster::block_on(webgpu_read_buffer_async(buffer, device)) +} + +async fn webgpu_read_buffer_async(buffer: &WebGPUBuffer, device: &WebGPUDevice) -> Vec { + let buffer_slice = buffer.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + device.poll(wgpu::Maintain::Wait); + + let result = receiver.receive().await; + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + buffer.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } +} + +pub fn webgpu_device_poll(device: &WebGPUDevice) { + device.poll(wgpu::Maintain::Wait); +} diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index e3192f2238..ae0b63445e 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -1,6 +1,9 @@ use crate::{ compiler::wgsl, - compute::{WgpuServer, WgpuStorage}, + compute::{ + webgpu_select_adapter, webgpu_select_device, WebGPUAdapter, WebGPUAdapterInfo, + WebGPUDevice, WebGPUQueue, WgpuServer, WgpuStorage, + }, GraphicsApi, WgpuDevice, }; use alloc::sync::Arc; @@ -15,7 +18,6 @@ use burn_compute::{ use burn_jit::Runtime; use burn_tensor::backend::{DeviceId, DeviceOps}; use std::marker::PhantomData; -use wgpu::{AdapterInfo, DeviceDescriptor}; /// Runtime that uses the [wgpu] crate with the wgsl compiler. /// @@ -135,164 +137,28 @@ async fn create_client( /// Select the wgpu device and queue based on the provided [device](WgpuDevice). pub async fn select_device( device: &WgpuDevice, -) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { +) -> (WebGPUDevice, WebGPUQueue, WebGPUAdapterInfo) { #[cfg(target_family = "wasm")] let adapter = select_adapter::(device).await; #[cfg(not(target_family = "wasm"))] let adapter = select_adapter::(device); - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - required_features: wgpu::Features::empty(), - required_limits: limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); + let (device, queue) = webgpu_select_device(&adapter).await; (device, queue, adapter.get_info()) } -fn tuner_device_id(info: AdapterInfo) -> String { +fn tuner_device_id(info: WebGPUAdapterInfo) -> String { format!("wgpu-{}-{}", info.device, info.backend.to_str()) } #[cfg(target_family = "wasm")] -async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); - - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() +async fn select_adapter(_device: &WgpuDevice) -> WebGPUAdapter { + webgpu_select_adapter::(device) } #[cfg(not(target_family = "wasm"))] -fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - use wgpu::DeviceType; - - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); - - instance - .enumerate_adapters(G::backend().into()) - .into_iter() - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; - - if is_same_type { - adapters.push(adapter); - } - }); - - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } - - return adapters_other.remove(num); - } - - adapters.remove(num) - } - - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) - .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } - } - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - - adapter +fn select_adapter(device: &WgpuDevice) -> WebGPUAdapter { + webgpu_select_adapter::(device) } From 3212b1a807e86c3a3da9e4ad7282aa2887a339e2 Mon Sep 17 00:00:00 2001 From: Piotr Stankiewicz Date: Thu, 14 Mar 2024 12:46:20 +0100 Subject: [PATCH 2/2] burn-wgpu: Allow using Dawn instead of/along wgpu Dawn, Google's WebGPU impelementation, currently supports feattures which wgpu does not. For example using 16 bit floats in shaders. So, add the ability to build the burn-wgpu backend against Dawn. Signed-off-by: Piotr Stankiewicz --- .gitmodules | 3 + Cargo.lock | 123 +++ _typos.toml | 1 + crates/burn-core/Cargo.toml | 2 +- crates/burn-jit/Cargo.toml | 1 + crates/burn-jit/src/kernel/base.rs | 4 +- crates/burn-wgpu/Cargo.toml | 9 +- crates/burn-wgpu/README.md | 6 +- crates/burn-wgpu/build.rs | 236 +++++ crates/burn-wgpu/dawn | 1 + crates/burn-wgpu/dawn.h | 1 + crates/burn-wgpu/src/compiler/wgsl/base.rs | 4 +- crates/burn-wgpu/src/compute/dawn_api_shim.rs | 820 ++++++++++++++++++ .../src/compute/dawn_native_bindings.rs | 6 + crates/burn-wgpu/src/compute/mod.rs | 12 +- crates/burn-wgpu/src/compute/server.rs | 109 +-- crates/burn-wgpu/src/compute/storage.rs | 68 +- crates/burn-wgpu/src/compute/webgpu_api.rs | 200 +++++ crates/burn-wgpu/src/compute/wgpu_api_shim.rs | 609 +++++++++---- crates/burn-wgpu/src/lib.rs | 69 +- crates/burn-wgpu/src/runtime.rs | 74 +- .../examples/custom-wgpu-kernel.rs | 4 +- examples/custom-wgpu-kernel/src/backward.rs | 6 +- examples/custom-wgpu-kernel/src/forward.rs | 6 +- examples/image-classification-web/src/web.rs | 4 +- 25 files changed, 2046 insertions(+), 332 deletions(-) create mode 100644 .gitmodules create mode 100644 crates/burn-wgpu/build.rs create mode 160000 crates/burn-wgpu/dawn create mode 100644 crates/burn-wgpu/dawn.h create mode 100644 crates/burn-wgpu/src/compute/dawn_api_shim.rs create mode 100644 crates/burn-wgpu/src/compute/dawn_native_bindings.rs create mode 100644 crates/burn-wgpu/src/compute/webgpu_api.rs diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..f5d883e911 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "crates/burn-wgpu/dawn"] + path = crates/burn-wgpu/dawn + url = https://dawn.googlesource.com/dawn diff --git a/Cargo.lock b/Cargo.lock index b0fe41112a..4b76664d76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,6 +314,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.5.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.60", + "which", +] + [[package]] name = "bindgen_cuda" version = "0.1.5" @@ -712,14 +735,17 @@ dependencies = [ name = "burn-wgpu" version = "0.14.0" dependencies = [ + "bindgen", "burn-common", "burn-compute", "burn-fusion", "burn-jit", "burn-tensor", "bytemuck", + "cmake", "derive-new", "futures-intrusive", + "git2", "hashbrown 0.14.5", "log", "pollster", @@ -862,6 +888,15 @@ dependencies = [ "once_cell", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-expr" version = "0.15.8" @@ -908,6 +943,17 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" +dependencies = [ + "glob", + "libc", + "libloading 0.8.3", +] + [[package]] name = "clap" version = "3.2.25" @@ -2025,6 +2071,21 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "git2" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "232e6a7bfe35766bf715e55a88b39a700596c0ccfd88cd3680b4cdb40d66ef70" +dependencies = [ + "bitflags 2.5.0", + "libc", + "libgit2-sys", + "log", + "openssl-probe", + "openssl-sys", + "url", +] + [[package]] name = "github-device-flow" version = "0.2.0" @@ -2714,6 +2775,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "lebe" version = "0.5.2" @@ -2737,6 +2804,20 @@ dependencies = [ "once_cell", ] +[[package]] +name = "libgit2-sys" +version = "0.16.2+1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8" +dependencies = [ + "cc", + "libc", + "libssh2-sys", + "libz-sys", + "openssl-sys", + "pkg-config", +] + [[package]] name = "libloading" version = "0.7.4" @@ -2784,6 +2865,32 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libssh2-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dc8a030b787e2119a731f1951d6a773e2280c660f8ec4b0f5e1505a386e71ee" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "libz-sys" +version = "1.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e143b5e666b2695d28f6bca6497720813f699c9602dd7f5cac91008b8ada7f9" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -3616,6 +3723,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "prettyplease" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +dependencies = [ + "proc-macro2", + "syn 2.0.60", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -4551,6 +4668,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook" version = "0.3.17" diff --git a/_typos.toml b/_typos.toml index 0d002ff695..a5b9a183bb 100644 --- a/_typos.toml +++ b/_typos.toml @@ -5,4 +5,5 @@ extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"] extend-exclude = [ "assets/ModuleSerialization.xml", "examples/image-classification-web/src/model/label.txt", + "crates/burn-wgpu/dawn", ] diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 7dcbf46014..f0b6767d24 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -84,7 +84,7 @@ ndarray = ["burn-ndarray"] tch = ["burn-tch"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -wgpu = ["burn-wgpu"] +wgpu = ["burn-wgpu/wgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index d0f727e5d3..0657401b3d 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -25,6 +25,7 @@ export_tests = [ "burn-ndarray", "fusion", ] +dawn = [] [dependencies] burn-common = { path = "../burn-common", version = "0.14.0" } diff --git a/crates/burn-jit/src/kernel/base.rs b/crates/burn-jit/src/kernel/base.rs index 54c2350016..db2e36d67d 100644 --- a/crates/burn-jit/src/kernel/base.rs +++ b/crates/burn-jit/src/kernel/base.rs @@ -1,8 +1,8 @@ use crate::{compute::WorkGroup, gpu::ComputeShader}; -#[cfg(target_family = "wasm")] +#[cfg(any(target_family = "wasm", feature = "dawn"))] pub(crate) const WORKGROUP_DEFAULT: usize = 16; -#[cfg(not(target_family = "wasm"))] +#[cfg(all(not(target_family = "wasm"), not(feature = "dawn")))] pub(crate) const WORKGROUP_DEFAULT: usize = 32; /// Dynamic jit kernel to create a [compute shader](ComputeShader). diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 039a112589..69a2458b6a 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,12 +11,14 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-wgpu" version.workspace = true [features] -default = ["fusion", "burn-jit/default"] +default = ["fusion", "burn-jit/default", "wgpu"] fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] template = ["burn-jit/template"] doc = ["burn-jit/doc"] std = ["burn-jit/std"] +dawn = ["burn-jit/dawn", "dep:bindgen", "dep:cmake", "dep:git2"] +wgpu = [] [dependencies] burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false } @@ -34,6 +36,11 @@ futures-intrusive = { workspace = true } derive-new = { workspace = true } hashbrown = { workspace = true } +[build-dependencies] +bindgen = { version = "0.69.4", optional = true } +cmake = { version = "0.1", optional = true } +git2 = { version = "0.18.2", optional = true } + [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [ "export_tests", diff --git a/crates/burn-wgpu/README.md b/crates/burn-wgpu/README.md index 4a17176bbc..00c9aa36cd 100644 --- a/crates/burn-wgpu/README.md +++ b/crates/burn-wgpu/README.md @@ -6,7 +6,7 @@ [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-wgpu/blob/master/README.md) This crate provides a WGPU backend for [Burn](https://github.com/tracel-ai/burn) using the -[wgpu](https://github.com/gfx-rs/wgpu). +[wgpu](https://github.com/gfx-rs/wgpu) or [Dawn](https://dawn.googlesource.com/dawn/). The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU. @@ -39,3 +39,7 @@ You can set `BURN_WGPU_MAX_TASKS` to a positive integer that determines how many | OpenGL | No | Yes | Yes | Yes | Yes | Yes | Yes | No | | WebGpu | No | Yes | No | No | No | No | No | Yes | | Dx11/Dx12 | No | Yes | No | No | Yes | No | No | No | + +## Building with the `dawn` backend enabled + +This crate can be built using Dawn as the backing WebGPU implementation. To do this enable the `dawn` feature. Note that Dawn requires `python3`and `ninja` (https://ninja-build.org/) to build and may take a non-negligible time to compile. diff --git a/crates/burn-wgpu/build.rs b/crates/burn-wgpu/build.rs new file mode 100644 index 0000000000..de09fe829d --- /dev/null +++ b/crates/burn-wgpu/build.rs @@ -0,0 +1,236 @@ +#[cfg(all(feature = "dawn", not(any(target_os = "macos", target_os = "linux"))))] +compile_error!("The 'dawn' backend currently only builds on macos."); + +fn main() { + #[cfg(feature = "dawn")] + link_and_bind_dawn(); +} + +#[cfg(feature = "dawn")] +fn link_and_bind_dawn() { + use bindgen::builder; + use std::env; + use std::path::PathBuf; + + let src_dir = env::current_dir().unwrap(); + let dawn_src_dir = src_dir.join("dawn"); + + let repo = match git2::Repository::open("../..") { + Ok(repo) => repo, + Err(err) => panic!("failed to open repo: {err}"), + }; + let mut submodules = match repo.submodules() { + Ok(submodules) => submodules, + Err(err) => panic!("failed to list git submodules: {err}"), + }; + for submodule in submodules.iter_mut() { + if submodule.name().unwrap().ends_with("dawn") { + match submodule.update(true, None) { + Ok(_) => (), + Err(_) => { + // if the working directory is empty, but the module is present in .git/modules + // update will fail, so try to manually reinit the submodule (more context in + // https://github.com/libgit2/libgit2/issues/3820) + match submodule.init(false) { + Ok(_) => (), + Err(err) => panic!("failed to init the dawn submodule: {err}"), + }; + match submodule.repo_init(false) { + Ok(_) => (), + Err(err) => panic!("failed to initi the dawn submodule repo: {err}"), + }; + match submodule.clone(None) { + Ok(_) => (), + Err(err) => panic!("failed to clone the dawn submodule: {err}"), + }; + match submodule.sync() { + Ok(_) => (), + Err(err) => panic!("failed to sync the dawn submodule: {err}"), + } + } + } + } + } + + let _ = std::process::Command::new("python3") + .current_dir(dawn_src_dir.clone()) + .arg("tools/fetch_dawn_dependencies.py") + .arg("--use-test-deps") + .spawn() + .expect("failed to fetch Dawn dependencies") + .wait(); + + let dst = cmake::Config::new(dawn_src_dir.clone()) + .profile("Release") + .generator("Ninja") // would be nice to use make, but the Dawn build has a tendency to quietly + // fail when generating code, leading to confusing errors + .build_target("webgpu_dawn") + .build(); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let dawn_build_dir = dst.join("build"); + let dawn_build_dir = dawn_build_dir.display(); + let dawn_src_dir = dawn_src_dir.display(); + + let bindings = builder() + .header("dawn.h") + .clang_args([ + "-x", + "c++", + "-I", + std::format!("{dawn_build_dir}/gen/include/").as_str(), + "-I", + std::format!("{dawn_src_dir}/include/").as_str(), + "--std=c++17", + ]) + .allowlist_function(".*GetProcs.*") + .allowlist_function(".*SetProcs.*") + .allowlist_function("wgpu.*") + .allowlist_file(".*webgpu.h") + .layout_tests(false) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .generate() + .expect("Unable to generate Dawn bindings"); + + bindings + .write_to_file(out_path.join("dawn_native_bindings_gen.rs")) + .expect("Couldn't write Dawn bindings!"); + + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/common"); + println!("cargo:rustc-link-lib=static=dawn_common"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/native"); + println!("cargo:rustc-link-lib=static=dawn_native"); + println!("cargo:rustc-link-lib=static=webgpu_dawn"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/platform"); + println!("cargo:rustc-link-lib=static=dawn_platform"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/tint"); + println!("cargo:rustc-link-lib=static=tint_api"); + println!("cargo:rustc-link-lib=static=tint_api_common"); + println!("cargo:rustc-link-lib=static=tint_api_options"); + println!("cargo:rustc-link-lib=static=tint_lang_core"); + println!("cargo:rustc-link-lib=static=tint_lang_core_constant"); + println!("cargo:rustc-link-lib=static=tint_lang_core_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_core_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_core_ir_transform"); + println!("cargo:rustc-link-lib=static=tint_lang_core_type"); + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_raise"); + } + println!("cargo:rustc-link-lib=static=tint_lang_hlsl_writer_common"); + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-lib=static=tint_lang_msl"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_raise"); + } + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-lib=static=tint_lang_spirv"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_ast_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_ast_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_common"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_type"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_raise"); + } + println!("cargo:rustc-link-lib=static=tint_lang_wgsl"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ast"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ast_transform"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_common"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_features"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_helpers"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_inspector"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_program"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_program_to_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_resolver"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_sem"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_ir_to_program"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_syntax_tree_printer"); + println!("cargo:rustc-link-lib=static=tint_utils_containers"); + println!("cargo:rustc-link-lib=static=tint_utils_debug"); + println!("cargo:rustc-link-lib=static=tint_utils_diagnostic"); + println!("cargo:rustc-link-lib=static=tint_utils_generator"); + println!("cargo:rustc-link-lib=static=tint_utils_ice"); + println!("cargo:rustc-link-lib=static=tint_utils_id"); + println!("cargo:rustc-link-lib=static=tint_utils_macros"); + println!("cargo:rustc-link-lib=static=tint_utils_math"); + println!("cargo:rustc-link-lib=static=tint_utils_memory"); + println!("cargo:rustc-link-lib=static=tint_utils_reflection"); + println!("cargo:rustc-link-lib=static=tint_utils_result"); + println!("cargo:rustc-link-lib=static=tint_utils_rtti"); + println!("cargo:rustc-link-lib=static=tint_utils_strconv"); + println!("cargo:rustc-link-lib=static=tint_utils_symbol"); + println!("cargo:rustc-link-lib=static=tint_utils_text"); + println!("cargo:rustc-link-lib=static=tint_utils_traits"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/strings"); + println!("cargo:rustc-link-lib=static=absl_str_format_internal"); + println!("cargo:rustc-link-lib=static=absl_strings"); + println!("cargo:rustc-link-lib=static=absl_strings_internal"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/base"); + println!("cargo:rustc-link-lib=static=absl_base"); + println!("cargo:rustc-link-lib=static=absl_spinlock_wait"); + println!("cargo:rustc-link-lib=static=absl_throw_delegate"); + println!("cargo:rustc-link-lib=static=absl_raw_logging_internal"); + println!("cargo:rustc-link-lib=static=absl_log_severity"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/numeric"); + println!("cargo:rustc-link-lib=static=absl_int128"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/hash"); + println!("cargo:rustc-link-lib=static=absl_city"); + println!("cargo:rustc-link-lib=static=absl_hash"); + println!("cargo:rustc-link-lib=static=absl_low_level_hash"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/container"); + println!("cargo:rustc-link-lib=static=absl_hashtablez_sampler"); + println!("cargo:rustc-link-lib=static=absl_raw_hash_set"); + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/spirv-tools/source"); + println!("cargo:rustc-link-lib=static=SPIRV-Tools"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/spirv-tools/source/opt"); + println!("cargo:rustc-link-lib=static=SPIRV-Tools-opt"); + + // Has to go at the end of the list, otherwise the linker will complain + // about missing c++ symbols. + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-lib=framework=CoreFoundation"); + println!("cargo:rustc-link-lib=framework=IOKit"); + println!("cargo:rustc-link-lib=framework=IOSurface"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=QuartzCore"); + println!("cargo:rustc-link-lib=framework=Cocoa"); + // Has to go at the end of the list, otherwise the linker will complain + // about missing c++ symbols. + println!("cargo:rustc-link-lib=dylib=c++"); + } +} diff --git a/crates/burn-wgpu/dawn b/crates/burn-wgpu/dawn new file mode 160000 index 0000000000..fb97d04c0c --- /dev/null +++ b/crates/burn-wgpu/dawn @@ -0,0 +1 @@ +Subproject commit fb97d04c0c2e2307dc11a9d9de4eab607af111f9 diff --git a/crates/burn-wgpu/dawn.h b/crates/burn-wgpu/dawn.h new file mode 100644 index 0000000000..4a29d37346 --- /dev/null +++ b/crates/burn-wgpu/dawn.h @@ -0,0 +1 @@ +#include "dawn/webgpu.h" diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index a268b4d166..dddc0bc414 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -218,7 +218,9 @@ impl Display for Variable { } Variable::ConstantScalar(number, elem) => match elem { Elem::F32 => f.write_fmt(format_args!("{number}f")), - Elem::I32 => f.write_fmt(format_args!("{number}i")), + // Dawn seems to get tripped up by the 'i' suffix, while wgpu is happy + // with or without it, so emit the literal without it. + Elem::I32 => f.write_fmt(format_args!("{number}")), Elem::U32 => f.write_fmt(format_args!("{number}u")), Elem::Bool => f.write_fmt(format_args!("bool({number})")), }, diff --git a/crates/burn-wgpu/src/compute/dawn_api_shim.rs b/crates/burn-wgpu/src/compute/dawn_api_shim.rs new file mode 100644 index 0000000000..31c8bfa47d --- /dev/null +++ b/crates/burn-wgpu/src/compute/dawn_api_shim.rs @@ -0,0 +1,820 @@ +#![allow(missing_docs)] + +use crate::compute::{ + dawn_native_bindings::*, webgpu_api::*, WgpuServer, WgpuStorage, +}; +use crate::{create_client, GraphicsApi, RuntimeOptions, WgpuDevice}; +use alloc::sync::Arc; +use burn_compute::{ + channel::MutexComputeChannel, client::ComputeClient, memory_management::SimpleMemoryManagement, + ComputeRuntime, +}; +use burn_jit::compute::WorkGroup; +use std::num::NonZeroU64; + +#[derive(Debug)] +pub struct DawnApi {} + +#[derive(Debug)] +pub struct DawnAdapter { + adapter: WGPUAdapter, +} + +impl Adapter for DawnAdapter { + fn get_info(&self) -> DawnAdapterInfo { + let mut adapter_info = WGPUAdapterProperties { + nextInChain: std::ptr::null_mut::(), + vendorID: 0, + vendorName: std::ptr::null(), + architecture: std::ptr::null(), + deviceID: 0, + name: std::ptr::null(), + driverDescription: std::ptr::null(), + adapterType: 0, + backendType: 0, + compatibilityMode: 0, + }; + unsafe { + wgpuAdapterGetProperties(self.adapter, &mut adapter_info); + } + DawnAdapterInfo { adapter_info } + } +} + +#[derive(Debug)] +pub struct DawnAdapterInfo { + adapter_info: WGPUAdapterProperties, +} + +impl AdapterInfo for DawnAdapterInfo { + fn backend(&self) -> DawnBackend { + DawnBackend::from_u32(self.adapter_info.backendType) + } + + fn device(&self) -> DeviceId { + self.adapter_info.deviceID + } +} + +#[derive(Debug)] +pub enum DawnBackend { + Undefined = WGPUBackendType_WGPUBackendType_Undefined as isize, + Null = WGPUBackendType_WGPUBackendType_Null as isize, + WebGPU = WGPUBackendType_WGPUBackendType_WebGPU as isize, + D3D11 = WGPUBackendType_WGPUBackendType_D3D11 as isize, + D3D12 = WGPUBackendType_WGPUBackendType_D3D12 as isize, + Metal = WGPUBackendType_WGPUBackendType_Metal as isize, + Vulkan = WGPUBackendType_WGPUBackendType_Vulkan as isize, + OpenGL = WGPUBackendType_WGPUBackendType_OpenGL as isize, + OpenGLES = WGPUBackendType_WGPUBackendType_OpenGLES as isize, +} + +impl DawnBackend { + #[allow(non_upper_case_globals)] + fn from_u32(val: u32) -> DawnBackend { + match val { + WGPUBackendType_WGPUBackendType_Undefined => DawnBackend::Undefined, + WGPUBackendType_WGPUBackendType_Null => DawnBackend::Null, + WGPUBackendType_WGPUBackendType_WebGPU => DawnBackend::WebGPU, + WGPUBackendType_WGPUBackendType_D3D11 => DawnBackend::D3D11, + WGPUBackendType_WGPUBackendType_D3D12 => DawnBackend::D3D12, + WGPUBackendType_WGPUBackendType_Metal => DawnBackend::Metal, + WGPUBackendType_WGPUBackendType_Vulkan => DawnBackend::Vulkan, + WGPUBackendType_WGPUBackendType_OpenGL => DawnBackend::OpenGL, + WGPUBackendType_WGPUBackendType_OpenGLES => DawnBackend::OpenGLES, + _ => panic!("Unknown Dawn backend type: {}", val), + } + } +} + +impl core::convert::AsRef for DawnBackend { + fn as_ref(&self) -> &'static str { + match self { + DawnBackend::Undefined => "undefined", + DawnBackend::Null => "null", + DawnBackend::WebGPU => "webgpu", + DawnBackend::D3D11 => "dx11", + DawnBackend::D3D12 => "dx12", + DawnBackend::Metal => "metal", + DawnBackend::Vulkan => "vulkan", + DawnBackend::OpenGL => "opengl", + DawnBackend::OpenGLES => "opengles", + } + } +} + +#[derive(Debug)] +pub struct DawnBindGroup { + bind_group: WGPUBindGroup, +} +unsafe impl Send for DawnBindGroup {} +impl BindGroup for DawnBindGroup {} + +#[derive(Debug)] +pub struct DawnBindGroupLayout { + layout: WGPUBindGroupLayout, +} +impl BindGroupLayout for DawnBindGroupLayout {} + +#[derive(Debug)] +pub struct DawnBuffer { + buffer: WGPUBuffer, + size: u64, +} +unsafe impl Send for DawnBuffer {} +unsafe impl Sync for DawnBuffer {} + +impl Buffer for DawnBuffer { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, DawnBuffer> { + BufferBinding { + buffer: self, + offset: 0, + size: Some(NonZeroU64::new((*self).size).unwrap()), + } + } + + fn destroy(&self) { + unsafe { + wgpuBufferDestroy((*self).buffer.into()); + } + } + + async fn read(&self, device: &DawnDevice) -> Vec { + let mut read_data = BufferReadData { + read_done: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + unsafe { + let data_ptr = std::mem::transmute::<*mut BufferReadData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(read_data), + ); + let mut sz = (*self).size; + if sz % 4 != 0 { + sz += 2; + } + wgpuBufferMapAsync( + (*self).buffer.into(), + WGPUMapMode_WGPUMapMode_Read, + 0, + sz as usize, + Some(buffer_reader_cb), + data_ptr, + ); + + let mut read_done = read_data.read_done.lock().unwrap(); + let should_process = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true)); + let spt = should_process.clone(); + let instance = DawnInstance { + instance: wgpuAdapterGetInstance(wgpuDeviceGetAdapter((*device).device)), + }; + let handle = std::thread::spawn(move || { + let inst = instance; + while spt.load(std::sync::atomic::Ordering::Relaxed) { + wgpuInstanceProcessEvents(inst.instance); + std::thread::sleep(std::time::Duration::from_micros(10)); + } + }); + while !*read_done { + let res = read_data + .cv + .wait_timeout(read_done, std::time::Duration::from_micros(100)) + .unwrap(); + read_done = res.0; + } + should_process.store(false, std::sync::atomic::Ordering::Relaxed); + handle.join().unwrap(); + + let mpd_rng = + wgpuBufferGetConstMappedRange((*self).buffer.into(), 0, (*self).size as usize); + let slice = std::slice::from_raw_parts(mpd_rng as *const u8, (*self).size as usize); + slice.to_vec() + } + } + + fn size(&self) -> u64 { + (*self).size + } +} + +pub type DawnBufferUsages = u32; +pub const MAP_READ: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_MapRead; +pub const COPY_SRC: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_CopySrc; +pub const COPY_DST: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_CopyDst; +pub const STORAGE: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_Storage; + +#[derive(Debug)] +pub struct DawnCommandBuffer { + buffer: WGPUCommandBuffer, +} +impl CommandBuffer for DawnCommandBuffer {} + +#[derive(Debug)] +pub struct DawnCommandEncoder { + encoder: WGPUCommandEncoder, +} +unsafe impl Send for DawnCommandEncoder {} +unsafe impl Sync for DawnCommandEncoder {} + +impl CommandEncoder + for DawnCommandEncoder +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: DawnBindGroup, + work_group: WorkGroup, + ) { + let label = match desc.label { + Some(name) => name, + None => "", + }; + let pass_desc = WGPUComputePassDescriptor { + nextInChain: std::ptr::null(), + label: std::ffi::CString::new(label).unwrap().into_raw(), + timestampWrites: std::ptr::null(), + }; + let pass: WGPUComputePassEncoder; + unsafe { + pass = wgpuCommandEncoderBeginComputePass(self.encoder.into(), &pass_desc); + } + unsafe { + wgpuComputePassEncoderSetPipeline(pass, pipeline.pipeline.into()); + wgpuComputePassEncoderSetBindGroup( + pass, + 0, + bind_group.bind_group.into(), + 0, + (&[]).as_ptr(), + ); + wgpuComputePassEncoderDispatchWorkgroups( + pass, + work_group.x, + work_group.y, + work_group.z, + ); + } + unsafe { + wgpuComputePassEncoderEnd(pass); + } + } + + fn copy_buffer_to_buffer( + &mut self, + src: &DawnBuffer, + src_offset: u64, + dest: &DawnBuffer, + dest_offset: u64, + size: u64, + ) { + unsafe { + wgpuCommandEncoderCopyBufferToBuffer( + (*self).encoder.into(), + (*src).buffer.into(), + src_offset, + (*dest).buffer.into(), + dest_offset, + size, + ); + } + } + + fn finish(self) -> DawnCommandBuffer { + let cmd_buf_desc = WGPUCommandBufferDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + }; + let cmd_buf: WGPUCommandBuffer; + unsafe { + cmd_buf = wgpuCommandEncoderFinish(self.encoder.into(), &cmd_buf_desc); + } + DawnCommandBuffer { buffer: cmd_buf } + } +} + +#[derive(Debug)] +pub struct DawnComputePipeline { + pipeline: WGPUComputePipeline, +} +unsafe impl Send for DawnComputePipeline {} +unsafe impl Sync for DawnComputePipeline {} + +impl ComputePipeline for DawnComputePipeline { + fn get_bind_group_layout(&self, id: u32) -> DawnBindGroupLayout { + let layout: WGPUBindGroupLayout; + unsafe { + layout = wgpuComputePipelineGetBindGroupLayout((*self).pipeline.into(), id); + } + DawnBindGroupLayout { layout: layout } + } +} + +#[derive(Debug)] +pub struct DawnDevice { + device: WGPUDevice, +} +unsafe impl Send for DawnDevice {} +unsafe impl Sync for DawnDevice {} + +impl + Device< + DawnBindGroup, + DawnBindGroupLayout, + DawnBuffer, + DawnCommandEncoder, + DawnComputePipeline, + DawnPipelineLayout, + DawnShaderModule, + > for DawnDevice +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, DawnBindGroupLayout, DawnBuffer>, + ) -> DawnBindGroup { + let entries = (*desc) + .entries + .iter() + .map(|entry| { + let resource = match &entry.resource { + BindingResource::Buffer(res) => res, + }; + WGPUBindGroupEntry { + nextInChain: std::ptr::null(), + binding: entry.binding, + buffer: resource.buffer.buffer.into(), + offset: resource.offset, + size: resource.size.unwrap().get(), + sampler: std::ptr::null_mut(), + textureView: std::ptr::null_mut(), + } + }) + .collect::>(); + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let bg_desc = WGPUBindGroupDescriptor { + nextInChain: std::ptr::null(), + label: label, + layout: (*desc).layout.layout, + entryCount: entries.len(), + entries: entries.as_ptr(), + }; + let bind_group: WGPUBindGroup; + unsafe { + bind_group = wgpuDeviceCreateBindGroup((*self).device.into(), &bg_desc); + } + DawnBindGroup { + bind_group: bind_group, + } + } + + fn create_buffer(&self, desc: &BufferDescriptor) -> DawnBuffer { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let buf_desc = WGPUBufferDescriptor { + nextInChain: std::ptr::null(), + label: label, + usage: (*desc).usage, + size: (*desc).size, + mappedAtCreation: (*desc).mapped_at_creation as u32, + }; + let buffer: WGPUBuffer; + unsafe { + buffer = wgpuDeviceCreateBuffer((*self).device.into(), &buf_desc); + } + DawnBuffer { + buffer: buffer, + size: (*desc).size, + } + } + + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> DawnBuffer { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let buf_desc = WGPUBufferDescriptor { + nextInChain: std::ptr::null(), + label: label, + usage: (*desc).usage, + size: (*desc).contents.len() as u64, + mappedAtCreation: 1, + }; + let buffer: WGPUBuffer; + unsafe { + buffer = wgpuDeviceCreateBuffer((*self).device.into(), &buf_desc); + let data = wgpuBufferGetMappedRange(buffer, 0, (*desc).contents.len()); + let src_ptr = &(*desc).contents[0] as *const u8; + std::ptr::copy_nonoverlapping(src_ptr, data as *mut u8, (*desc).contents.len()); + wgpuBufferUnmap(buffer); + } + DawnBuffer { + buffer: buffer, + size: (*desc).contents.len() as u64, + } + } + + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> DawnCommandEncoder { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let encoder_desc = WGPUCommandEncoderDescriptor { + nextInChain: std::ptr::null(), + label: label, + }; + let encoder: WGPUCommandEncoder; + unsafe { + encoder = wgpuDeviceCreateCommandEncoder((*self).device.into(), &encoder_desc); + } + DawnCommandEncoder { encoder: encoder } + } + + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> DawnComputePipeline { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let layout = match desc.layout { + None => std::ptr::null_mut(), + Some(layout) => layout.layout, + }; + let pip_desc = WGPUComputePipelineDescriptor { + nextInChain: std::ptr::null(), + label: label, + layout: layout, + compute: WGPUProgrammableStageDescriptor { + nextInChain: std::ptr::null(), + module: (*(*desc).module).module, + entryPoint: std::ffi::CString::new((*desc).entry_point) + .unwrap() + .into_raw(), + constantCount: 0, + constants: std::ptr::null(), + }, + }; + let pipeline: WGPUComputePipeline; + unsafe { + pipeline = wgpuDeviceCreateComputePipeline((*self).device.into(), &pip_desc); + } + DawnComputePipeline { pipeline: pipeline } + } + + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> DawnShaderModule { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let src = match &desc.source { + ShaderSource::Wgsl(source) => source.to_string(), + }; + let wgsl_desc = WGPUShaderModuleWGSLDescriptor { + chain: WGPUChainedStruct { + next: std::ptr::null(), + sType: WGPUSType_WGPUSType_ShaderModuleWGSLDescriptor, + }, + code: std::ffi::CString::new(src).unwrap().into_raw(), + }; + let module: WGPUShaderModule; + unsafe { + let sh_desc = WGPUShaderModuleDescriptor { + nextInChain: std::mem::transmute::< + *const WGPUShaderModuleWGSLDescriptor, + *const WGPUChainedStruct, + >(&wgsl_desc), + label: label, + }; + module = wgpuDeviceCreateShaderModule((*self).device.into(), &sh_desc); + } + DawnShaderModule { module: module } + } +} + +#[derive(Debug)] +pub struct DawnInstance { + instance: WGPUInstance, +} +unsafe impl Send for DawnInstance {} + +#[derive(Debug)] +pub struct DawnPipelineLayout { + layout: WGPUPipelineLayout, +} +impl PipelineLayout for DawnPipelineLayout {} + +#[derive(Debug)] +pub struct DawnQueue { + queue: WGPUQueue, +} +unsafe impl Send for DawnQueue {} + +impl Queue for DawnQueue { + fn submit(&self, buf: Option) { + match buf { + None => (), + Some(buf) => unsafe { + wgpuQueueSubmit((*self).queue.into(), 1, std::ptr::addr_of!(buf.buffer)); + }, + }; + } + + fn write_buffer(&self, buffer: &DawnBuffer, offset: u64, data: &[u8]) { + unsafe { + let data_ptr = + std::mem::transmute::<*const u8, *const std::os::raw::c_void>(data.as_ptr()); + let mut sz = data.len(); + if sz % 4 != 0 { + sz += 2; + } + wgpuQueueWriteBuffer( + (*self).queue.into(), + (*buffer).buffer.into(), + offset, + data_ptr, + sz, + ); + } + } +} + +#[derive(Debug)] +pub struct DawnShaderModule { + module: WGPUShaderModule, +} +impl ShaderModule for DawnShaderModule {} + +/// The compute instance is shared across all [dawn runtimes](WgpuRuntime). +static RUNTIME: ComputeRuntime> = + ComputeRuntime::new(); + +type Server = WgpuServer>>; + +impl WebGPUApi for DawnApi { + type Adapter = DawnAdapter; + type AdapterInfo = DawnAdapterInfo; + type Backend = DawnBackend; + type BindGroup = DawnBindGroup; + type BindGroupLayout = DawnBindGroupLayout; + type Buffer = DawnBuffer; + type CommandBuffer = DawnCommandBuffer; + type CommandEncoder = DawnCommandEncoder; + type ComputePipeline = DawnComputePipeline; + type Device = DawnDevice; + type PipelineLayout = DawnPipelineLayout; + type Queue = DawnQueue; + type ShaderModule = DawnShaderModule; + + const MAP_READ: u32 = MAP_READ; + const COPY_SRC: u32 = COPY_SRC; + const COPY_DST: u32 = COPY_DST; + const STORAGE: u32 = STORAGE; + + type Server = WgpuServer>>; + type Channel = MutexComputeChannel>>>; + + fn client(device: &WgpuDevice) -> ComputeClient { + RUNTIME.client(device, move || { + pollster::block_on(create_client::( + device, + RuntimeOptions::default(), + )) + }) + } + + async fn select_device(adapter: &DawnAdapter) -> (DawnDevice, DawnQueue) { + let mut req_data = DevRequestData { + device: std::ptr::null::() as WGPUDevice, + is_set: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + let desc = WGPUDeviceDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + requiredFeatureCount: 1, + requiredFeatures: &WGPUFeatureName_WGPUFeatureName_ShaderF16, + requiredLimits: std::ptr::null(), + defaultQueue: WGPUQueueDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + }, + deviceLostCallback: None, + deviceLostCallbackInfo: WGPUDeviceLostCallbackInfo { + nextInChain: std::ptr::null(), + mode: 0, + callback: None, + userdata: std::ptr::null_mut(), + }, + deviceLostUserdata: std::ptr::null_mut(), + uncapturedErrorCallbackInfo: WGPUUncapturedErrorCallbackInfo { + nextInChain: std::ptr::null(), + callback: None, + userdata: std::ptr::null_mut(), + }, + }; + + unsafe { + let data_ptr = std::mem::transmute::<*mut DevRequestData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(req_data), + ); + wgpuAdapterRequestDevice( + (*adapter).adapter.into(), + &desc, + Some(request_device_cb), + data_ptr, + ); + } + + let mut is_set = req_data.is_set.lock().unwrap(); + while !*is_set { + is_set = req_data.cv.wait(is_set).unwrap(); + } + + unsafe { + wgpuDeviceSetUncapturedErrorCallback( + req_data.device, + Some(device_error_callback), + std::ptr::null_mut(), + ); + wgpuDeviceSetLoggingCallback( + req_data.device, + Some(device_logging_callback), + std::ptr::null_mut(), + ); + } + + let dev = DawnDevice { + device: req_data.device, + }; + let queue: WGPUQueue; + unsafe { + queue = wgpuDeviceGetQueue(dev.device.into()); + } + (dev, DawnQueue { queue }) + } + + fn select_adapter(_: &WgpuDevice) -> DawnAdapter { + let instance: WGPUInstance; + let instance_desc = WGPUInstanceDescriptor { + nextInChain: std::ptr::null(), + features: WGPUInstanceFeatures { + nextInChain: std::ptr::null(), + timedWaitAnyEnable: 0, + timedWaitAnyMaxCount: 0, + }, + }; + unsafe { + instance = wgpuCreateInstance(&instance_desc); + } + let mut req_data = AdapterRequestData { + adapter: std::ptr::null::() as WGPUAdapter, + is_set: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + unsafe { + let data_ptr = std::mem::transmute::<*mut AdapterRequestData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(req_data), + ); + wgpuInstanceRequestAdapter( + instance, + std::ptr::null(), + Some(request_adapter_cb), + data_ptr, + ); + } + + let mut is_set = req_data.is_set.lock().unwrap(); + while !*is_set { + is_set = req_data.cv.wait(is_set).unwrap(); + } + + DawnAdapter { + adapter: req_data.adapter, + } + } + + fn device_poll(device: &DawnDevice) { + let instance: WGPUInstance; + let dev = (*device).device; + unsafe { + instance = wgpuAdapterGetInstance(wgpuDeviceGetAdapter(dev.into())); + wgpuInstanceProcessEvents(instance.into()); + wgpuDeviceTick(dev.into()); + } + } + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = pollster::block_on(create_client::(&device, options)); + + RUNTIME.register(&device, client) + } + + async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = create_client::(&device, options).await; + + RUNTIME.register(&device, client) + } +} + +#[allow(non_upper_case_globals)] +extern "C" fn device_error_callback( + type_: WGPUErrorType, + message: *const ::std::os::raw::c_char, + _userdata: *mut ::std::os::raw::c_void, +) { + let type_str = match type_ { + WGPUErrorType_WGPUErrorType_Validation => "Validation", + WGPUErrorType_WGPUErrorType_OutOfMemory => "Out of memory", + WGPUErrorType_WGPUErrorType_Internal => "Internal", + WGPUErrorType_WGPUErrorType_Unknown => "Unknown", + WGPUErrorType_WGPUErrorType_DeviceLost => "Device lost", + _ => "", + }; + unsafe { + let msg_str = std::ffi::CStr::from_ptr(message).to_str().unwrap(); + println!("{} error: {}", type_str, msg_str); + } +} + +extern "C" fn device_logging_callback( + _type_: WGPULoggingType, + message: *const ::std::os::raw::c_char, + _userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let msg_str = std::ffi::CStr::from_ptr(message).to_str().unwrap(); + println!("Device log: {}", msg_str); + } +} + +extern "C" fn request_device_cb( + _status: WGPURequestDeviceStatus, + device: WGPUDevice, + _message: *const ::std::os::raw::c_char, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let req_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut DevRequestData>(userdata); + (*req_data).device = device; + let mut is_set = (*req_data).is_set.lock().unwrap(); + *is_set = true; + (*req_data).cv.notify_one(); + } +} + +#[repr(C)] +struct DevRequestData { + device: WGPUDevice, + is_set: std::sync::Mutex, + cv: std::sync::Condvar, +} + +extern "C" fn request_adapter_cb( + _status: WGPURequestAdapterStatus, + adapter: WGPUAdapter, + _message: *const ::std::os::raw::c_char, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let req_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut AdapterRequestData>(userdata); + (*req_data).adapter = adapter; + let mut is_set = (*req_data).is_set.lock().unwrap(); + *is_set = true; + (*req_data).cv.notify_one(); + } +} + +#[repr(C)] +struct AdapterRequestData { + adapter: WGPUAdapter, + is_set: std::sync::Mutex, + cv: std::sync::Condvar, +} + +#[repr(C)] +struct BufferReadData { + read_done: std::sync::Mutex, + cv: std::sync::Condvar, +} + +unsafe extern "C" fn buffer_reader_cb( + _status: WGPUBufferMapAsyncStatus, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let read_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut BufferReadData>(userdata); + let mut read_done = (*read_data).read_done.lock().unwrap(); + (*read_done) = true; + (*read_data).cv.notify_one(); + } +} diff --git a/crates/burn-wgpu/src/compute/dawn_native_bindings.rs b/crates/burn-wgpu/src/compute/dawn_native_bindings.rs new file mode 100644 index 0000000000..5f3a03ae25 --- /dev/null +++ b/crates/burn-wgpu/src/compute/dawn_native_bindings.rs @@ -0,0 +1,6 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +include!(concat!(env!("OUT_DIR"), "/dawn_native_bindings_gen.rs")); diff --git a/crates/burn-wgpu/src/compute/mod.rs b/crates/burn-wgpu/src/compute/mod.rs index 8b996dcafc..9382e3d30b 100644 --- a/crates/burn-wgpu/src/compute/mod.rs +++ b/crates/burn-wgpu/src/compute/mod.rs @@ -1,7 +1,17 @@ +#[cfg(feature = "dawn")] +mod dawn_api_shim; +#[cfg(feature = "dawn")] +mod dawn_native_bindings; mod server; mod storage; -mod wgpu_api_shim; +mod webgpu_api; +#[cfg(feature = "wgpu")] +pub mod wgpu_api_shim; +#[cfg(feature = "dawn")] +pub use dawn_api_shim::*; pub use server::*; pub use storage::*; +pub use webgpu_api::*; +#[cfg(feature = "wgpu")] pub use wgpu_api_shim::*; diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index b6e8b2f620..fef053f47d 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -1,11 +1,5 @@ use super::WgpuStorage; -use crate::compute::{ - webgpu_device_poll, webgpu_read_buffer, WebGPUBindGroup, WebGPUBindGroupDescriptor, - WebGPUBindGroupEntry, WebGPUBuffer, WebGPUBufferDescriptor, WebGPUCommandEncoder, - WebGPUCommandEncoderDescriptor, WebGPUComputePassDescriptor, WebGPUComputePipeline, - WebGPUComputePipelineDescriptor, WebGPUDevice, WebGPUQueue, WebGPUShaderModuleDescriptor, - WebGPUShaderSource, COPY_DST, MAP_READ, -}; +use crate::compute::webgpu_api::*; use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ memory_management::MemoryManagement, @@ -17,28 +11,29 @@ use hashbrown::HashMap; /// Wgpu compute server. #[derive(Debug)] -pub struct WgpuServer> { +pub struct WgpuServer>> { memory_management: MM, - device: Arc, - queue: WebGPUQueue, - encoder: WebGPUCommandEncoder, - pipelines: HashMap>, + device: Arc, + queue: W::Queue, + encoder: W::CommandEncoder, + pipelines: HashMap>, tasks_max: usize, tasks_count: usize, } -impl WgpuServer +impl WgpuServer where - MM: MemoryManagement, + W: WebGPUApi, + MM: MemoryManagement>, { /// Create a new server. pub fn new( memory_management: MM, - device: Arc, - queue: WebGPUQueue, + device: Arc, + queue: W::Queue, tasks_max: usize, ) -> Self { - let encoder = device.create_command_encoder(&WebGPUCommandEncoderDescriptor { + let encoder = device.create_command_encoder(&CommandEncoderDescriptor { label: Some("Command Encoder"), }); @@ -56,7 +51,7 @@ where fn submit(&mut self) { let mut new_encoder = self .device - .create_command_encoder(&WebGPUCommandEncoderDescriptor { label: None }); + .create_command_encoder(&CommandEncoderDescriptor { label: None }); core::mem::swap(&mut new_encoder, &mut self.encoder); self.queue.submit(Some(new_encoder.finish())); @@ -68,25 +63,24 @@ where fn register_compute( &mut self, - pipeline: Arc, - bind_group: BindGroup, + pipeline: Arc, + bind_group: W::BindGroup, work_group: WorkGroup, ) { - let mut compute = self + self .encoder - .begin_compute_pass(&WebGPUComputePassDescriptor { - label: None, - timestamp_writes: None, - }); - - compute.set_pipeline(&pipeline); - compute.set_bind_group(0, &bind_group, &[]); - compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + .dispatch_compute_pass(&ComputePassDescriptor { + label: None, + }, + pipeline, + bind_group, + work_group, + ); self.tasks_count += 1; } - fn pipeline(&mut self, kernel: Kernel) -> Arc { + fn pipeline(&mut self, kernel: Kernel) -> Arc { let kernel_id = kernel.id(); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); @@ -99,17 +93,15 @@ where pipeline } - fn compile_source(&self, source: &str) -> Arc { - let module = self - .device - .create_shader_module(WebGPUShaderModuleDescriptor { - label: None, - source: WebGPUShaderSource::Wgsl(Cow::Borrowed(source)), - }); + fn compile_source(&self, source: &str) -> Arc { + let module = self.device.create_shader_module(&ShaderModuleDescriptor { + label: None, + source: ShaderSource::Wgsl(Cow::Borrowed(source)), + }); Arc::new( self.device - .create_compute_pipeline(&WebGPUComputePipelineDescriptor { + .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &module, @@ -118,14 +110,14 @@ where ) } - fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader { + fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader { let resource = self.memory_management.get(handle.memory); let size = resource.size(); - let buffer_dest = self.device.create_buffer(&WebGPUBufferDescriptor { + let buffer_dest = self.device.create_buffer(&BufferDescriptor { label: None, size, - usage: MAP_READ | COPY_DST, + usage: W::MAP_READ | W::COPY_DST, mapped_at_creation: false, }); @@ -145,22 +137,32 @@ where } #[derive(new)] -struct BufferReader { - buffer: WebGPUBuffer, +struct BufferReader { + buffer: W::Buffer, } -impl BufferReader { - fn read(&self, device: &WebGPUDevice) -> Vec { - webgpu_read_buffer(&self.buffer, device) +impl BufferReader +where + W: WebGPUApi, +{ + #[cfg(target_family = "wasm")] + async fn read(self, device: alloc::sync::Arc) -> Vec { + self.buffer.read(&device).await + } + + #[cfg(not(target_family = "wasm"))] + fn read(self, device: &W::Device) -> Vec { + pollster::block_on(self.buffer.read(device)) } } -impl ComputeServer for WgpuServer +impl ComputeServer for WgpuServer where - MM: MemoryManagement, + W: WebGPUApi, + MM: MemoryManagement>, { type Kernel = Kernel; - type Storage = WgpuStorage; + type Storage = WgpuStorage; type MemoryManagement = MM; type AutotuneKey = JitAutotuneKey; @@ -187,7 +189,7 @@ where let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { label: Some("Buffer Src"), contents: data, - usage: wgpu::BufferUsages::COPY_SRC, + usage: W::COPY_SRC, })); let resource = self.memory_management.get(binding.memory); @@ -199,6 +201,7 @@ where resource.offset(), buffer_src.size(), ); + self.tasks_count += 1; handle } @@ -220,13 +223,13 @@ where let entries = memory_handles .iter() .enumerate() - .map(|(i, buffer)| WebGPUBindGroupEntry { + .map(|(i, buffer)| BindGroupEntry { binding: i as u32, resource: buffer.as_binding(), }) .collect::>(); - let bind_group = self.device.create_bind_group(&WebGPUBindGroupDescriptor { + let bind_group = self.device.create_bind_group(&BindGroupDescriptor { label: None, layout: &group_layout, entries: &entries, @@ -241,6 +244,6 @@ where fn sync(&mut self) { self.submit(); - webgpu_device_poll(&self.device); + W::device_poll(&self.device); } } diff --git a/crates/burn-wgpu/src/compute/storage.rs b/crates/burn-wgpu/src/compute/storage.rs index 0ab02727b5..9fa35c5ce4 100644 --- a/crates/burn-wgpu/src/compute/storage.rs +++ b/crates/burn-wgpu/src/compute/storage.rs @@ -1,19 +1,19 @@ -use crate::compute::{ - WebGPUBindingResource, WebGPUBuffer, WebGPUBufferAddress, WebGPUBufferBinding, - WebGPUBufferDescriptor, WebGPUBufferSize, WebGPUDevice, COPY_DST, COPY_SRC, STORAGE, -}; +use crate::compute::{BindingResource, Buffer, BufferBinding, BufferDescriptor, Device, WebGPUApi}; use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use hashbrown::HashMap; use std::{num::NonZeroU64, sync::Arc}; /// Buffer storage for wgpu. -pub struct WgpuStorage { - memory: HashMap>, +pub struct WgpuStorage { + memory: HashMap>, deallocations: Vec, - device: Arc, + device: Arc, } -impl core::fmt::Debug for WgpuStorage { +impl core::fmt::Debug for WgpuStorage +where + W: WebGPUApi, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) } @@ -21,32 +21,35 @@ impl core::fmt::Debug for WgpuStorage { /// The memory resource that can be allocated for wgpu. #[derive(new, Debug)] -pub struct WgpuResource { +pub struct WgpuResource { /// The wgpu buffer. - pub buffer: Arc, + pub buffer: Arc, /// How the resource is used. pub kind: WgpuResourceKind, } -impl WgpuResource { +impl WgpuResource +where + W: WebGPUApi, +{ /// Return the binding view of the buffer. - pub fn as_binding(&self) -> WebGPUBindingResource { + pub fn as_binding(&self) -> BindingResource<'_, W::Buffer> { let binding = match &self.kind { WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => WebGPUBufferBinding { - buffer: &self.buffer, - offset: *offs, + WgpuResourceKind::Slice { offset, size } => BufferBinding::<'_> { + buffer: self.buffer.as_ref(), + offset: *offset, size: Some(*size), }, }; - WebGPUBindingResource::Buffer(binding) + BindingResource::Buffer(binding) } /// Return the buffer size. pub fn size(&self) -> u64 { match self.kind { WgpuResourceKind::Full => self.buffer.size(), - WgpuResourceKind::Slice(_, size) => size.get(), + WgpuResourceKind::Slice { offset: _, size } => size.get(), } } @@ -54,7 +57,7 @@ impl WgpuResource { pub fn offset(&self) -> u64 { match self.kind { WgpuResourceKind::Full => 0, - WgpuResourceKind::Slice(offset, _) => offset, + WgpuResourceKind::Slice { offset, size: _ } => offset, } } } @@ -65,13 +68,16 @@ pub enum WgpuResourceKind { /// Represents an entire buffer. Full, /// A slice over a buffer. - Slice(WebGPUBufferAddress, WebGPUBufferSize), + Slice { offset: u64, size: NonZeroU64 }, } /// Keeps actual wgpu buffer references in a hashmap with ids as key. -impl WgpuStorage { +impl WgpuStorage +where + W: WebGPUApi, +{ /// Create a new storage on the given [device](WebGPUDevice). - pub fn new(device: Arc) -> Self { + pub fn new(device: Arc) -> Self { Self { memory: HashMap::new(), deallocations: Vec::new(), @@ -89,8 +95,11 @@ impl WgpuStorage { } } -impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; +impl ComputeStorage for WgpuStorage +where + W: WebGPUApi, +{ + type Resource = WgpuResource; fn get(&mut self, handle: &StorageHandle) -> Self::Resource { let buffer = self.memory.get(&handle.id).unwrap(); @@ -101,7 +110,10 @@ impl ComputeStorage for WgpuStorage { } StorageUtilization::Slice { offset, size } => WgpuResource::new( buffer.clone(), - WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), + WgpuResourceKind::Slice { + offset: offset as u64, + size: NonZeroU64::new(size as u64).unwrap(), + }, ), } } @@ -109,14 +121,14 @@ impl ComputeStorage for WgpuStorage { fn alloc(&mut self, size: usize) -> StorageHandle { let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&WebGPUBufferDescriptor { + let buffer = self.device.create_buffer(&BufferDescriptor { label: None, size: size as u64, - usage: COPY_DST | STORAGE | COPY_SRC, + usage: W::COPY_DST | W::STORAGE | W::COPY_SRC, mapped_at_creation: false, - })); + }); - self.memory.insert(id.clone(), buffer); + self.memory.insert(id.clone(), buffer.into()); StorageHandle::new(id, StorageUtilization::Full(size)) } diff --git a/crates/burn-wgpu/src/compute/webgpu_api.rs b/crates/burn-wgpu/src/compute/webgpu_api.rs new file mode 100644 index 0000000000..28e535218a --- /dev/null +++ b/crates/burn-wgpu/src/compute/webgpu_api.rs @@ -0,0 +1,200 @@ +use crate::{GraphicsApi, RuntimeOptions, WgpuDevice}; +use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; +use burn_jit::compute::WorkGroup; +use std::borrow::Cow; +use alloc::sync::Arc; + +pub trait Adapter: core::fmt::Debug { + fn get_info(&self) -> AdapterInfo; +} + +pub trait AdapterInfo: core::fmt::Debug { + fn backend(&self) -> Backend; + fn device(&self) -> DeviceId; +} + +pub trait BindGroup: Send + core::fmt::Debug {} + +pub struct BindGroupDescriptor<'a, BindGroupLayout, Buffer> { + pub label: Option<&'a str>, + pub layout: &'a BindGroupLayout, + pub entries: &'a Vec>, +} + +pub struct BindGroupEntry<'a, Buffer> { + pub binding: u32, + pub resource: BindingResource<'a, Buffer>, +} + +pub trait BindGroupLayout: core::fmt::Debug {} + +pub enum BindingResource<'a, Buffer> { + Buffer(BufferBinding<'a, Buffer>), +} + +pub trait Buffer: Send + Sync + core::fmt::Debug { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, Buffer>; + fn destroy(&self); + #[allow(async_fn_in_trait)] + async fn read(&self, device: &Device) -> Vec; + fn size(&self) -> u64; +} + +pub struct BufferBinding<'a, Buffer> { + pub buffer: &'a Buffer, + pub offset: u64, + pub size: Option, +} + +pub struct BufferDescriptor<'a> { + pub label: Option<&'a str>, + pub size: u64, + pub usage: u32, + pub mapped_at_creation: bool, +} + +pub struct BufferInitDescriptor<'a> { + pub label: Option<&'a str>, + pub contents: &'a [u8], + pub usage: u32, +} + +pub trait CommandBuffer: core::fmt::Debug {} + +pub struct CommandEncoderDescriptor<'a> { + pub label: Option<&'a str>, +} + +pub trait CommandEncoder: + Send + Sync + core::fmt::Debug +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: BindGroup, + work_group: WorkGroup, + ); + fn copy_buffer_to_buffer( + &mut self, + src: &Buffer, + src_offset: u64, + dst: &Buffer, + dst_offset: u64, + size: u64, + ); + fn finish(self) -> CommandBuffer; +} + +pub struct ComputePassDescriptor<'a> { + pub label: Option<&'a str>, +} + +pub trait ComputePipeline: Send + Sync + core::fmt::Debug { + fn get_bind_group_layout(&self, id: u32) -> BindGroupLayout; +} + +pub struct ComputePipelineDescriptor<'a, PipelineLayout, ShaderModule> { + pub label: Option<&'a str>, + pub layout: Option<&'a PipelineLayout>, + pub module: &'a ShaderModule, + pub entry_point: &'a str, +} + +pub trait Device< + BindGroup, + BindGroupLayout, + Buffer, + CommandEncoder, + ComputePipeline, + PipelineLayout, + ShaderModule, +>: Send + Sync + core::fmt::Debug +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, BindGroupLayout, Buffer>, + ) -> BindGroup; + fn create_buffer(&self, desc: &BufferDescriptor) -> Buffer; + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> Buffer; + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> CommandEncoder; + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> ComputePipeline; + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> ShaderModule; +} + +pub type DeviceId = u32; + +pub trait PipelineLayout: core::fmt::Debug {} + +pub trait Queue: Send + core::fmt::Debug { + fn submit(&self, buf: Option); + fn write_buffer(&self, buf: &Buffer, offset: u64, data: &[u8]); +} + +pub enum ShaderSource<'a> { + Wgsl(Cow<'a, str>), +} + +pub trait ShaderModule: core::fmt::Debug {} + +pub struct ShaderModuleDescriptor<'a> { + pub label: Option<&'a str>, + pub source: ShaderSource<'a>, +} + +pub trait WebGPUApi: Send + Sync + core::fmt::Debug + 'static { + type Adapter: Adapter; + type AdapterInfo: AdapterInfo; + type Backend: core::convert::AsRef; + type BindGroup: BindGroup; + type BindGroupLayout: BindGroupLayout; + type Buffer: Buffer; + type CommandBuffer: CommandBuffer; + type CommandEncoder: CommandEncoder< + Self::BindGroup, + Self::Buffer, + Self::CommandBuffer, + Self::ComputePipeline, + >; + type ComputePipeline: ComputePipeline; + type Device: Device< + Self::BindGroup, + Self::BindGroupLayout, + Self::Buffer, + Self::CommandEncoder, + Self::ComputePipeline, + Self::PipelineLayout, + Self::ShaderModule, + >; + type PipelineLayout: PipelineLayout; + type Queue: Queue; + type ShaderModule: ShaderModule; + + const MAP_READ: u32; + const COPY_SRC: u32; + const COPY_DST: u32; + const STORAGE: u32; + + type Server: ComputeServer< + Kernel = burn_jit::compute::Kernel, + AutotuneKey = burn_jit::compute::JitAutotuneKey, + >; + type Channel: ComputeChannel; + + fn client(device: &WgpuDevice) -> ComputeClient; + #[allow(async_fn_in_trait)] + async fn select_device(adapter: &Self::Adapter) -> (Self::Device, Self::Queue); + #[allow(async_fn_in_trait)] + #[cfg(target_family = "wasm")] + async fn select_adapter(device: &WgpuDevice) -> Self::Adapter; + #[cfg(not(target_family = "wasm"))] + fn select_adapter(device: &WgpuDevice) -> Self::Adapter; + fn device_poll(device: &Self::Device); + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions); + #[allow(async_fn_in_trait)] + async fn init_async(device: &WgpuDevice, options: RuntimeOptions); +} diff --git a/crates/burn-wgpu/src/compute/wgpu_api_shim.rs b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs index 12f14a1cb8..29e9656784 100644 --- a/crates/burn-wgpu/src/compute/wgpu_api_shim.rs +++ b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs @@ -1,212 +1,455 @@ -use crate::{GraphicsApi, WgpuDevice}; - -pub type WebGPUAdapter = wgpu::Adapter; -pub type WebGPUAdapterInfo = wgpu::AdapterInfo; -pub type WebGPUBindGroup = wgpu::BindGroup; -pub type WebGPUBindGroupDescriptor<'a> = wgpu::BindGroupDescriptor<'a>; -pub type WebGPUBindGroupEntry<'a> = wgpu::BindGroupEntry<'a>; -pub type WebGPUBindingResource<'a> = wgpu::BindingResource<'a>; -pub type WebGPUBuffer = wgpu::Buffer; -pub type WebGPUBufferAddress = wgpu::BufferAddress; -pub type WebGPUBufferBinding<'a> = wgpu::BufferBinding<'a>; -pub type WebGPUBufferSize = wgpu::BufferSize; -pub type WebGPUBufferDescriptor<'a> = wgpu::BufferDescriptor<'a>; - -pub type WebGPUBufferUsages = wgpu::BufferUsages; -pub const MAP_READ: WebGPUBufferUsages = wgpu::BufferUsages::MAP_READ; -pub const COPY_SRC: WebGPUBufferUsages = wgpu::BufferUsages::COPY_SRC; -pub const COPY_DST: WebGPUBufferUsages = wgpu::BufferUsages::COPY_DST; -pub const STORAGE: WebGPUBufferUsages = wgpu::BufferUsages::STORAGE; - -pub type WebGPUCommandEncoder = wgpu::CommandEncoder; -pub type WebGPUCommandEncoderDescriptor<'a> = wgpu::CommandEncoderDescriptor<'a>; -pub type WebGPUComputePassDescriptor<'a> = wgpu::ComputePassDescriptor<'a>; -pub type WebGPUComputePipeline = wgpu::ComputePipeline; -pub type WebGPUComputePipelineDescriptor<'a> = wgpu::ComputePipelineDescriptor<'a>; -pub type WebGPUDevice = wgpu::Device; -pub type WebGPUQueue = wgpu::Queue; -pub type WebGPUShaderModuleDescriptor<'a> = wgpu::ShaderModuleDescriptor<'a>; -pub type WebGPUShaderSource<'a> = wgpu::ShaderSource<'a>; - -pub async fn webgpu_select_device(adapter: &WebGPUAdapter) -> (WebGPUDevice, WebGPUQueue) { - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &wgpu::DeviceDescriptor { - label: None, - features: wgpu::Features::empty(), - limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); +use crate::{ + compute::{webgpu_api::*, WgpuServer, WgpuStorage}, + create_client, GraphicsApi, RuntimeOptions, WgpuDevice, +}; +use alloc::sync::Arc; +use burn_compute::{ + channel::MutexComputeChannel, client::ComputeClient, memory_management::SimpleMemoryManagement, + ComputeRuntime, +}; +use burn_jit::compute::WorkGroup; - (device, queue) +#[derive(Debug)] +pub struct WgpuApi {} + +pub struct WgpuBackend { + backend: wgpu::Backend, } -#[cfg(target_family = "wasm")] -async fn webgpu_select_adapter(_device: &WgpuDevice) -> WebGPUAdapter { - let instance = wgpu::Instance::default(); +impl Adapter for wgpu::Adapter { + fn get_info(&self) -> wgpu::AdapterInfo { + wgpu::Adapter::get_info(self) + } +} - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() +impl AdapterInfo for wgpu::AdapterInfo { + fn backend(&self) -> WgpuBackend { + WgpuBackend { + backend: self.backend, + } + } + + fn device(&self) -> DeviceId { + self.device + } } -#[cfg(not(target_family = "wasm"))] -pub fn webgpu_select_adapter(device: &WgpuDevice) -> WebGPUAdapter { - use wgpu::DeviceType; +impl core::convert::AsRef for WgpuBackend { + fn as_ref(&self) -> &str { + wgpu::Backend::to_str(self.backend) + } +} - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); +impl BindGroup for wgpu::BindGroup {} - instance - .enumerate_adapters(G::backend().into()) - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; +impl BindGroupLayout for wgpu::BindGroupLayout {} - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } +impl Buffer for wgpu::Buffer { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, wgpu::Buffer> { + let binding = wgpu::Buffer::as_entire_buffer_binding(self); + BufferBinding { + buffer: binding.buffer, + offset: binding.offset, + size: binding.size, + } + } - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; + fn destroy(&self) { + wgpu::Buffer::destroy(self) + } - if is_same_type { - adapters.push(adapter); - } + async fn read(&self, device: &wgpu::Device) -> Vec { + let buffer_slice = self.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") }); - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } else { - return adapters_other.remove(num); - } - } + device.poll(wgpu::Maintain::Wait); - adapters.remove(num) - } - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) - .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } + let result = receiver.receive().await; + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + self.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) } - }; + } + + fn size(&self) -> u64 { + wgpu::Buffer::size(self) + } +} - log::info!("Using adapter {:?}", adapter.get_info()); +impl CommandBuffer for wgpu::CommandBuffer {} + +impl CommandEncoder + for wgpu::CommandEncoder +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: wgpu::BindGroup, + work_group: WorkGroup, + ) { + let mut compute = self.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: desc.label, + timestamp_writes: None, + }); - adapter + compute.set_pipeline(&pipeline); + compute.set_bind_group(0, &bind_group, &[]); + compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + } + + fn copy_buffer_to_buffer( + &mut self, + src: &wgpu::Buffer, + src_offset: u64, + dst: &wgpu::Buffer, + dst_offset: u64, + size: u64, + ) { + wgpu::CommandEncoder::copy_buffer_to_buffer(self, src, src_offset, dst, dst_offset, size) + } + + fn finish(self) -> wgpu::CommandBuffer { + wgpu::CommandEncoder::finish(self) + } } -pub fn webgpu_read_buffer(buffer: &WebGPUBuffer, device: &WebGPUDevice) -> Vec { - pollster::block_on(webgpu_read_buffer_async(buffer, device)) +impl ComputePipeline for wgpu::ComputePipeline { + fn get_bind_group_layout(&self, id: u32) -> wgpu::BindGroupLayout { + wgpu::ComputePipeline::get_bind_group_layout(self, id) + } } -async fn webgpu_read_buffer_async(buffer: &WebGPUBuffer, device: &WebGPUDevice) -> Vec { - let buffer_slice = buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); +impl + Device< + wgpu::BindGroup, + wgpu::BindGroupLayout, + wgpu::Buffer, + wgpu::CommandEncoder, + wgpu::ComputePipeline, + wgpu::PipelineLayout, + wgpu::ShaderModule, + > for wgpu::Device +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, wgpu::BindGroupLayout, wgpu::Buffer>, + ) -> wgpu::BindGroup { + let entries = desc + .entries + .iter() + .map(|entry| { + let BindingResource::Buffer(resource) = &entry.resource; + wgpu::BindGroupEntry { + binding: entry.binding, + resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding { + buffer: resource.buffer, + offset: resource.offset, + size: resource.size, + }), + } + }) + .collect::>(); + + wgpu::Device::create_bind_group( + self, + &wgpu::BindGroupDescriptor { + label: desc.label, + layout: desc.layout, + entries: &entries, + }, + ) + } + + fn create_buffer(&self, desc: &BufferDescriptor) -> wgpu::Buffer { + wgpu::Device::create_buffer( + self, + &wgpu::BufferDescriptor { + label: desc.label, + size: desc.size, + usage: wgpu::BufferUsages::from_bits(desc.usage).unwrap(), + mapped_at_creation: desc.mapped_at_creation, + }, + ) + } + + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> wgpu::Buffer { + wgpu::util::DeviceExt::create_buffer_init( + self, + &wgpu::util::BufferInitDescriptor { + label: desc.label, + contents: desc.contents, + usage: wgpu::BufferUsages::from_bits(desc.usage).unwrap(), + }, + ) + } + + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> wgpu::CommandEncoder { + wgpu::Device::create_command_encoder( + self, + &wgpu::CommandEncoderDescriptor { label: desc.label }, + ) + } + + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> wgpu::ComputePipeline { + wgpu::Device::create_compute_pipeline( + self, + &wgpu::ComputePipelineDescriptor { + label: desc.label, + layout: desc.layout, + module: desc.module, + entry_point: desc.entry_point, + }, + ) + } - device.poll(wgpu::Maintain::Wait); + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> wgpu::ShaderModule { + let source = match &desc.source { + ShaderSource::Wgsl(source) => source.to_string(), + }; + wgpu::Device::create_shader_module( + self, + wgpu::ShaderModuleDescriptor { + label: desc.label, + source: wgpu::ShaderSource::Wgsl(source.into()), + }, + ) + } +} - let result = receiver.receive().await; +impl PipelineLayout for wgpu::PipelineLayout {} - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); +impl Queue for wgpu::Queue { + fn submit(&self, buf: Option) { + wgpu::Queue::submit(self, buf); + } - drop(data); - buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) + fn write_buffer(&self, buf: &wgpu::Buffer, offset: u64, data: &[u8]) { + wgpu::Queue::write_buffer(self, buf, offset, data) } } -pub fn webgpu_device_poll(device: &WebGPUDevice) { - device.poll(wgpu::Maintain::Wait); +impl ShaderModule for wgpu::ShaderModule {} + +/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). +static RUNTIME: ComputeRuntime> = + ComputeRuntime::new(); + +type Server = WgpuServer>>; + +impl WebGPUApi for WgpuApi { + type Adapter = wgpu::Adapter; + type AdapterInfo = wgpu::AdapterInfo; + type Backend = WgpuBackend; + type BindGroup = wgpu::BindGroup; + type BindGroupLayout = wgpu::BindGroupLayout; + type Buffer = wgpu::Buffer; + type CommandBuffer = wgpu::CommandBuffer; + type CommandEncoder = wgpu::CommandEncoder; + type ComputePipeline = wgpu::ComputePipeline; + type Device = wgpu::Device; + type PipelineLayout = wgpu::PipelineLayout; + type Queue = wgpu::Queue; + type ShaderModule = wgpu::ShaderModule; + + const MAP_READ: u32 = wgpu::BufferUsages::MAP_READ.bits(); + const COPY_SRC: u32 = wgpu::BufferUsages::COPY_SRC.bits(); + const COPY_DST: u32 = wgpu::BufferUsages::COPY_DST.bits(); + const STORAGE: u32 = wgpu::BufferUsages::STORAGE.bits(); + + type Server = WgpuServer>>; + type Channel = MutexComputeChannel>>>; + + fn client(device: &WgpuDevice) -> ComputeClient { + RUNTIME.client(device, move || { + pollster::block_on(create_client::( + device, + RuntimeOptions::default(), + )) + }) + } + + async fn select_device(adapter: &wgpu::Adapter) -> (wgpu::Device, wgpu::Queue) { + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + required_features: wgpu::Features::empty(), + required_limits: limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue) + } + + #[cfg(target_family = "wasm")] + async fn select_adapter(_device: &WgpuDevice) -> Self::Adapter { + let instance = wgpu::Instance::default(); + + instance + .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) + .await + .unwrap() + } + + #[cfg(not(target_family = "wasm"))] + fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { + use wgpu::DeviceType; + + let instance = wgpu::Instance::default(); + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); + + instance + .enumerate_adapters(G::backend().into()) + .into_iter() + .for_each(|adapter| { + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } + }); + + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } + + adapters.remove(num) + } + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; + + adapters + .into_iter() + .chain(adapters_other) + .for_each(|adapter| { + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } + }); + + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } + } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); + + adapter + } + + fn device_poll(device: &Self::Device) { + device.poll(wgpu::Maintain::Wait); + } + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = pollster::block_on(create_client::(&device, options)); + + RUNTIME.register(&device, client) + } + + async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = create_client::(&device, options).await; + + RUNTIME.register(&device, client) + } } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 59a246840d..c159e211af 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -27,7 +27,14 @@ pub use runtime::*; pub use burn_jit::compute::WorkGroup; pub use burn_jit::{tensor::JitTensor, JitBackend}; -#[cfg(feature = "fusion")] +pub use crate::compute::WebGPUApi; + +#[cfg(feature = "dawn")] +pub use crate::compute::DawnApi; +#[cfg(feature = "wgpu")] +pub use crate::compute::WgpuApi; + +#[cfg(all(feature = "fusion", feature = "wgpu"))] /// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: @@ -45,9 +52,9 @@ pub use burn_jit::{tensor::JitTensor, JitBackend}; /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. pub type Wgpu = - burn_fusion::Fusion, F, I>>; + burn_fusion::Fusion, F, I>>; -#[cfg(not(feature = "fusion"))] +#[cfg(all(not(feature = "fusion"), feature = "wgpu"))] /// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: @@ -64,13 +71,61 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend, F, I>; +pub type Wgpu = JitBackend, F, I>; + +#[cfg(all(feature = "fusion", feature = "dawn"))] +/// Tensor backend that uses Dawn for executing GPU compute shaders. +/// +/// This backend can target multiple graphics APIs, including: +/// - [Vulkan] on Linux, Windows, and Android. +/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android. +/// - [DirectX 12](crate::Dx12) on Windows. +/// - [Metal] on Apple hardware. +/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes. +/// +/// # Notes +/// +/// This version of the Dawn backend uses [burn_fusion] to compile and optimize streams of tensor +/// operations for improved performance. +/// +/// You can disable the `fusion` feature flag to remove that functionality, which might be +/// necessary on `wasm` for now. +pub type Dawn = + burn_fusion::Fusion, F, I>>; + +#[cfg(all(not(feature = "fusion"), feature = "dawn"))] +/// Tensor backend that uses Dawn for executing GPU compute shaders. +/// +/// This backend can target multiple graphics APIs, including: +/// - [Vulkan] on Linux, Windows, and Android. +/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android. +/// - [DirectX 12](crate::Dx12) on Windows. +/// - [Metal] on Apple hardware. +/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes. +/// +/// # Notes +/// +/// This version of the Dawn backend doesn't use [burn_fusion] to compile and optimize streams of tensor +/// operations. +/// +/// You can enable the `fusion` feature flag to add that functionality, which might improve +/// performance. +pub type Dawn = JitBackend, F, I>; + +#[cfg(all(test, feature = "wgpu"))] +mod tests_wgpu { + use super::*; + + pub type TestRuntime = crate::WgpuRuntime; + + burn_jit::testgen_all!(); +} -#[cfg(test)] -mod tests { +#[cfg(all(test, feature = "dawn"))] +mod tests_dawn { use super::*; - pub type TestRuntime = crate::WgpuRuntime; + pub type TestRuntime = crate::WgpuRuntime; burn_jit::testgen_all!(); } diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index ae0b63445e..ae85b01af5 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -1,9 +1,6 @@ use crate::{ compiler::wgsl, - compute::{ - webgpu_select_adapter, webgpu_select_device, WebGPUAdapter, WebGPUAdapterInfo, - WebGPUDevice, WebGPUQueue, WgpuServer, WgpuStorage, - }, + compute::{Adapter, AdapterInfo, WebGPUApi, WgpuServer, WgpuStorage}, GraphicsApi, WgpuDevice, }; use alloc::sync::Arc; @@ -13,7 +10,6 @@ use burn_compute::{ client::ComputeClient, memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, tune::Tuner, - ComputeRuntime, }; use burn_jit::Runtime; use burn_tensor::backend::{DeviceId, DeviceOps}; @@ -23,27 +19,20 @@ use std::marker::PhantomData; /// /// The [graphics api](GraphicsApi) type is passed as generic. #[derive(Debug)] -pub struct WgpuRuntime { +pub struct WgpuRuntime { + _w: PhantomData, _g: PhantomData, } -/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). -static RUNTIME: ComputeRuntime> = - ComputeRuntime::new(); - -type Server = WgpuServer>; - -impl Runtime for WgpuRuntime { +impl Runtime for WgpuRuntime { type Compiler = wgsl::WgslCompiler; - type Server = WgpuServer>; + type Server = W::Server; - type Channel = MutexComputeChannel>>; + type Channel = W::Channel; type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - pollster::block_on(create_client::(device, RuntimeOptions::default())) - }) + W::client::(device) } fn name() -> &'static str { @@ -93,29 +82,26 @@ impl Default for RuntimeOptions { } /// Init the client sync, useful to configure the runtime options. -pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { - let device = Arc::new(device); - let client = pollster::block_on(create_client::(&device, options)); - - RUNTIME.register(&device, client) +pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + W::init_sync::(device, options) } /// Init the client async, necessary for wasm. -pub async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { - let device = Arc::new(device); - let client = create_client::(&device, options).await; - - RUNTIME.register(&device, client) +pub async fn init_async( + device: &WgpuDevice, + options: RuntimeOptions, +) { + W::init_async::(device, options).await } -async fn create_client( +pub async fn create_client( device: &WgpuDevice, options: RuntimeOptions, ) -> ComputeClient< - WgpuServer>, - MutexComputeChannel>>, + WgpuServer>>, + MutexComputeChannel>>>, > { - let (device_wgpu, queue, info) = select_device::(device).await; + let (device_wgpu, queue, info) = select_device::(device).await; log::info!( "Created wgpu compute server on device {:?} => {:?}", @@ -130,35 +116,35 @@ async fn create_client( let server = WgpuServer::new(memory_management, device, queue, options.tasks_max); let channel = MutexComputeChannel::new(server); - let tuner_device_id = tuner_device_id(info); + let tuner_device_id = tuner_device_id::(info); ComputeClient::new(channel, Arc::new(RwLock::new(Tuner::new(&tuner_device_id)))) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice). -pub async fn select_device( +pub async fn select_device( device: &WgpuDevice, -) -> (WebGPUDevice, WebGPUQueue, WebGPUAdapterInfo) { +) -> (W::Device, W::Queue, W::AdapterInfo) { #[cfg(target_family = "wasm")] - let adapter = select_adapter::(device).await; + let adapter = select_adapter::(device).await; #[cfg(not(target_family = "wasm"))] - let adapter = select_adapter::(device); + let adapter = select_adapter::(device); - let (device, queue) = webgpu_select_device(&adapter).await; + let (device, queue) = W::select_device(&adapter).await; (device, queue, adapter.get_info()) } -fn tuner_device_id(info: WebGPUAdapterInfo) -> String { - format!("wgpu-{}-{}", info.device, info.backend.to_str()) +fn tuner_device_id(info: W::AdapterInfo) -> String { + format!("wgpu-{}-{}", info.device(), info.backend().as_ref()) } #[cfg(target_family = "wasm")] -async fn select_adapter(_device: &WgpuDevice) -> WebGPUAdapter { - webgpu_select_adapter::(device) +async fn select_adapter(_device: &WgpuDevice) -> W::Adapter { + W::select_adapter::(device) } #[cfg(not(target_family = "wasm"))] -fn select_adapter(device: &WgpuDevice) -> WebGPUAdapter { - webgpu_select_adapter::(device) +fn select_adapter(device: &WgpuDevice) -> W::Adapter { + W::select_adapter::(device) } diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 960fc7647d..69e5ba5e27 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -1,5 +1,5 @@ use burn::{ - backend::wgpu::{AutoGraphicsApi, WgpuRuntime}, + backend::wgpu::{AutoGraphicsApi, WgpuApi, WgpuRuntime}, tensor::{Distribution, Tensor}, }; use custom_wgpu_kernel::{ @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend, f32, i32>; + type MyBackend = burn::backend::wgpu::JitBackend, f32, i32>; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index 5a2a03129b..7c2c1c2860 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -9,13 +9,13 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WgpuRuntime}, + wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WebGPUApi, WgpuRuntime}, }, tensor::Shape, }; -impl AutodiffBackend - for Autodiff, F, I>> +impl AutodiffBackend + for Autodiff, F, I>> { } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index f23b54c6fe..f4861772fb 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -4,8 +4,8 @@ use super::Backend; use burn::{ backend::wgpu::{ build_info, into_contiguous, kernel_wgsl, FloatElement, GraphicsApi, IntElement, - JitBackend, JitTensor, Kernel, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, - WorkGroup, WorkgroupSize, + JitBackend, JitTensor, Kernel, KernelSource, SourceKernel, SourceTemplate, WebGPUApi, + WgpuRuntime, WorkGroup, WorkgroupSize, }, tensor::Shape, }; @@ -37,7 +37,7 @@ impl KernelSource for FusedMatmulAddRelu { } /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend, F, I> { +impl Backend for JitBackend, F, I> { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 6fc7d3f8f5..b0c366f4a4 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -11,7 +11,7 @@ use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as S use burn::{backend::NdArray, prelude::*, tensor::activation::softmax}; use burn_candle::Candle; -use burn_wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn_wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuApi, WgpuDevice}; use serde::Serialize; use wasm_bindgen::prelude::*; @@ -106,7 +106,7 @@ impl ImageClassifier { log::info!("Loading the model to the Wgpu backend"); let start = Instant::now(); let device = WgpuDevice::default(); - init_async::(&device, Default::default()).await; + init_async::(&device, Default::default()).await; self.model = ModelType::WithWgpuBackend(Model::new(&device)); let duration = start.elapsed(); log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);