diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..f5d883e911 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "crates/burn-wgpu/dawn"] + path = crates/burn-wgpu/dawn + url = https://dawn.googlesource.com/dawn diff --git a/Cargo.lock b/Cargo.lock index b0fe41112a..4b76664d76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,6 +314,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.5.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.60", + "which", +] + [[package]] name = "bindgen_cuda" version = "0.1.5" @@ -712,14 +735,17 @@ dependencies = [ name = "burn-wgpu" version = "0.14.0" dependencies = [ + "bindgen", "burn-common", "burn-compute", "burn-fusion", "burn-jit", "burn-tensor", "bytemuck", + "cmake", "derive-new", "futures-intrusive", + "git2", "hashbrown 0.14.5", "log", "pollster", @@ -862,6 +888,15 @@ dependencies = [ "once_cell", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-expr" version = "0.15.8" @@ -908,6 +943,17 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" +dependencies = [ + "glob", + "libc", + "libloading 0.8.3", +] + [[package]] name = "clap" version = "3.2.25" @@ -2025,6 +2071,21 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "git2" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "232e6a7bfe35766bf715e55a88b39a700596c0ccfd88cd3680b4cdb40d66ef70" +dependencies = [ + "bitflags 2.5.0", + "libc", + "libgit2-sys", + "log", + "openssl-probe", + "openssl-sys", + "url", +] + [[package]] name = "github-device-flow" version = "0.2.0" @@ -2714,6 +2775,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "lebe" version = "0.5.2" @@ -2737,6 +2804,20 @@ dependencies = [ "once_cell", ] +[[package]] +name = "libgit2-sys" +version = "0.16.2+1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8" +dependencies = [ + "cc", + "libc", + "libssh2-sys", + "libz-sys", + "openssl-sys", + "pkg-config", +] + [[package]] name = "libloading" version = "0.7.4" @@ -2784,6 +2865,32 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libssh2-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dc8a030b787e2119a731f1951d6a773e2280c660f8ec4b0f5e1505a386e71ee" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "libz-sys" +version = "1.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e143b5e666b2695d28f6bca6497720813f699c9602dd7f5cac91008b8ada7f9" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -3616,6 +3723,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "prettyplease" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +dependencies = [ + "proc-macro2", + "syn 2.0.60", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -4551,6 +4668,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook" version = "0.3.17" diff --git a/_typos.toml b/_typos.toml index 0d002ff695..a5b9a183bb 100644 --- a/_typos.toml +++ b/_typos.toml @@ -5,4 +5,5 @@ extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"] extend-exclude = [ "assets/ModuleSerialization.xml", "examples/image-classification-web/src/model/label.txt", + "crates/burn-wgpu/dawn", ] diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 7dcbf46014..f0b6767d24 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -84,7 +84,7 @@ ndarray = ["burn-ndarray"] tch = ["burn-tch"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -wgpu = ["burn-wgpu"] +wgpu = ["burn-wgpu/wgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index d0f727e5d3..0657401b3d 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -25,6 +25,7 @@ export_tests = [ "burn-ndarray", "fusion", ] +dawn = [] [dependencies] burn-common = { path = "../burn-common", version = "0.14.0" } diff --git a/crates/burn-jit/src/kernel/base.rs b/crates/burn-jit/src/kernel/base.rs index 54c2350016..db2e36d67d 100644 --- a/crates/burn-jit/src/kernel/base.rs +++ b/crates/burn-jit/src/kernel/base.rs @@ -1,8 +1,8 @@ use crate::{compute::WorkGroup, gpu::ComputeShader}; -#[cfg(target_family = "wasm")] +#[cfg(any(target_family = "wasm", feature = "dawn"))] pub(crate) const WORKGROUP_DEFAULT: usize = 16; -#[cfg(not(target_family = "wasm"))] +#[cfg(all(not(target_family = "wasm"), not(feature = "dawn")))] pub(crate) const WORKGROUP_DEFAULT: usize = 32; /// Dynamic jit kernel to create a [compute shader](ComputeShader). diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 039a112589..69a2458b6a 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,12 +11,14 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-wgpu" version.workspace = true [features] -default = ["fusion", "burn-jit/default"] +default = ["fusion", "burn-jit/default", "wgpu"] fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] template = ["burn-jit/template"] doc = ["burn-jit/doc"] std = ["burn-jit/std"] +dawn = ["burn-jit/dawn", "dep:bindgen", "dep:cmake", "dep:git2"] +wgpu = [] [dependencies] burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false } @@ -34,6 +36,11 @@ futures-intrusive = { workspace = true } derive-new = { workspace = true } hashbrown = { workspace = true } +[build-dependencies] +bindgen = { version = "0.69.4", optional = true } +cmake = { version = "0.1", optional = true } +git2 = { version = "0.18.2", optional = true } + [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [ "export_tests", diff --git a/crates/burn-wgpu/README.md b/crates/burn-wgpu/README.md index 4a17176bbc..00c9aa36cd 100644 --- a/crates/burn-wgpu/README.md +++ b/crates/burn-wgpu/README.md @@ -6,7 +6,7 @@ [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-wgpu/blob/master/README.md) This crate provides a WGPU backend for [Burn](https://github.com/tracel-ai/burn) using the -[wgpu](https://github.com/gfx-rs/wgpu). +[wgpu](https://github.com/gfx-rs/wgpu) or [Dawn](https://dawn.googlesource.com/dawn/). The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU. @@ -39,3 +39,7 @@ You can set `BURN_WGPU_MAX_TASKS` to a positive integer that determines how many | OpenGL | No | Yes | Yes | Yes | Yes | Yes | Yes | No | | WebGpu | No | Yes | No | No | No | No | No | Yes | | Dx11/Dx12 | No | Yes | No | No | Yes | No | No | No | + +## Building with the `dawn` backend enabled + +This crate can be built using Dawn as the backing WebGPU implementation. To do this enable the `dawn` feature. Note that Dawn requires `python3`and `ninja` (https://ninja-build.org/) to build and may take a non-negligible time to compile. diff --git a/crates/burn-wgpu/build.rs b/crates/burn-wgpu/build.rs new file mode 100644 index 0000000000..de09fe829d --- /dev/null +++ b/crates/burn-wgpu/build.rs @@ -0,0 +1,236 @@ +#[cfg(all(feature = "dawn", not(any(target_os = "macos", target_os = "linux"))))] +compile_error!("The 'dawn' backend currently only builds on macos."); + +fn main() { + #[cfg(feature = "dawn")] + link_and_bind_dawn(); +} + +#[cfg(feature = "dawn")] +fn link_and_bind_dawn() { + use bindgen::builder; + use std::env; + use std::path::PathBuf; + + let src_dir = env::current_dir().unwrap(); + let dawn_src_dir = src_dir.join("dawn"); + + let repo = match git2::Repository::open("../..") { + Ok(repo) => repo, + Err(err) => panic!("failed to open repo: {err}"), + }; + let mut submodules = match repo.submodules() { + Ok(submodules) => submodules, + Err(err) => panic!("failed to list git submodules: {err}"), + }; + for submodule in submodules.iter_mut() { + if submodule.name().unwrap().ends_with("dawn") { + match submodule.update(true, None) { + Ok(_) => (), + Err(_) => { + // if the working directory is empty, but the module is present in .git/modules + // update will fail, so try to manually reinit the submodule (more context in + // https://github.com/libgit2/libgit2/issues/3820) + match submodule.init(false) { + Ok(_) => (), + Err(err) => panic!("failed to init the dawn submodule: {err}"), + }; + match submodule.repo_init(false) { + Ok(_) => (), + Err(err) => panic!("failed to initi the dawn submodule repo: {err}"), + }; + match submodule.clone(None) { + Ok(_) => (), + Err(err) => panic!("failed to clone the dawn submodule: {err}"), + }; + match submodule.sync() { + Ok(_) => (), + Err(err) => panic!("failed to sync the dawn submodule: {err}"), + } + } + } + } + } + + let _ = std::process::Command::new("python3") + .current_dir(dawn_src_dir.clone()) + .arg("tools/fetch_dawn_dependencies.py") + .arg("--use-test-deps") + .spawn() + .expect("failed to fetch Dawn dependencies") + .wait(); + + let dst = cmake::Config::new(dawn_src_dir.clone()) + .profile("Release") + .generator("Ninja") // would be nice to use make, but the Dawn build has a tendency to quietly + // fail when generating code, leading to confusing errors + .build_target("webgpu_dawn") + .build(); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let dawn_build_dir = dst.join("build"); + let dawn_build_dir = dawn_build_dir.display(); + let dawn_src_dir = dawn_src_dir.display(); + + let bindings = builder() + .header("dawn.h") + .clang_args([ + "-x", + "c++", + "-I", + std::format!("{dawn_build_dir}/gen/include/").as_str(), + "-I", + std::format!("{dawn_src_dir}/include/").as_str(), + "--std=c++17", + ]) + .allowlist_function(".*GetProcs.*") + .allowlist_function(".*SetProcs.*") + .allowlist_function("wgpu.*") + .allowlist_file(".*webgpu.h") + .layout_tests(false) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .generate() + .expect("Unable to generate Dawn bindings"); + + bindings + .write_to_file(out_path.join("dawn_native_bindings_gen.rs")) + .expect("Couldn't write Dawn bindings!"); + + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/common"); + println!("cargo:rustc-link-lib=static=dawn_common"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/native"); + println!("cargo:rustc-link-lib=static=dawn_native"); + println!("cargo:rustc-link-lib=static=webgpu_dawn"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/dawn/platform"); + println!("cargo:rustc-link-lib=static=dawn_platform"); + println!("cargo:rustc-link-search={dawn_build_dir}/src/tint"); + println!("cargo:rustc-link-lib=static=tint_api"); + println!("cargo:rustc-link-lib=static=tint_api_common"); + println!("cargo:rustc-link-lib=static=tint_api_options"); + println!("cargo:rustc-link-lib=static=tint_lang_core"); + println!("cargo:rustc-link-lib=static=tint_lang_core_constant"); + println!("cargo:rustc-link-lib=static=tint_lang_core_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_core_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_core_ir_transform"); + println!("cargo:rustc-link-lib=static=tint_lang_core_type"); + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_glsl_writer_raise"); + } + println!("cargo:rustc-link-lib=static=tint_lang_hlsl_writer_common"); + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-lib=static=tint_lang_msl"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_msl_writer_raise"); + } + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-lib=static=tint_lang_spirv"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_ast_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_ast_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_common"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_reader_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_type"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_ast_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_common"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_spirv_writer_raise"); + } + println!("cargo:rustc-link-lib=static=tint_lang_wgsl"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ast"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ast_transform"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_common"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_features"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_helpers"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_inspector"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_intrinsic"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_program"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_lower"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_parser"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_reader_program_to_ir"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_resolver"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_sem"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_ast_printer"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_ir_to_program"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_raise"); + println!("cargo:rustc-link-lib=static=tint_lang_wgsl_writer_syntax_tree_printer"); + println!("cargo:rustc-link-lib=static=tint_utils_containers"); + println!("cargo:rustc-link-lib=static=tint_utils_debug"); + println!("cargo:rustc-link-lib=static=tint_utils_diagnostic"); + println!("cargo:rustc-link-lib=static=tint_utils_generator"); + println!("cargo:rustc-link-lib=static=tint_utils_ice"); + println!("cargo:rustc-link-lib=static=tint_utils_id"); + println!("cargo:rustc-link-lib=static=tint_utils_macros"); + println!("cargo:rustc-link-lib=static=tint_utils_math"); + println!("cargo:rustc-link-lib=static=tint_utils_memory"); + println!("cargo:rustc-link-lib=static=tint_utils_reflection"); + println!("cargo:rustc-link-lib=static=tint_utils_result"); + println!("cargo:rustc-link-lib=static=tint_utils_rtti"); + println!("cargo:rustc-link-lib=static=tint_utils_strconv"); + println!("cargo:rustc-link-lib=static=tint_utils_symbol"); + println!("cargo:rustc-link-lib=static=tint_utils_text"); + println!("cargo:rustc-link-lib=static=tint_utils_traits"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/strings"); + println!("cargo:rustc-link-lib=static=absl_str_format_internal"); + println!("cargo:rustc-link-lib=static=absl_strings"); + println!("cargo:rustc-link-lib=static=absl_strings_internal"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/base"); + println!("cargo:rustc-link-lib=static=absl_base"); + println!("cargo:rustc-link-lib=static=absl_spinlock_wait"); + println!("cargo:rustc-link-lib=static=absl_throw_delegate"); + println!("cargo:rustc-link-lib=static=absl_raw_logging_internal"); + println!("cargo:rustc-link-lib=static=absl_log_severity"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/numeric"); + println!("cargo:rustc-link-lib=static=absl_int128"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/hash"); + println!("cargo:rustc-link-lib=static=absl_city"); + println!("cargo:rustc-link-lib=static=absl_hash"); + println!("cargo:rustc-link-lib=static=absl_low_level_hash"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/abseil/absl/container"); + println!("cargo:rustc-link-lib=static=absl_hashtablez_sampler"); + println!("cargo:rustc-link-lib=static=absl_raw_hash_set"); + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/spirv-tools/source"); + println!("cargo:rustc-link-lib=static=SPIRV-Tools"); + println!("cargo:rustc-link-search={dawn_build_dir}/third_party/spirv-tools/source/opt"); + println!("cargo:rustc-link-lib=static=SPIRV-Tools-opt"); + + // Has to go at the end of the list, otherwise the linker will complain + // about missing c++ symbols. + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-lib=framework=CoreFoundation"); + println!("cargo:rustc-link-lib=framework=IOKit"); + println!("cargo:rustc-link-lib=framework=IOSurface"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=QuartzCore"); + println!("cargo:rustc-link-lib=framework=Cocoa"); + // Has to go at the end of the list, otherwise the linker will complain + // about missing c++ symbols. + println!("cargo:rustc-link-lib=dylib=c++"); + } +} diff --git a/crates/burn-wgpu/dawn b/crates/burn-wgpu/dawn new file mode 160000 index 0000000000..fb97d04c0c --- /dev/null +++ b/crates/burn-wgpu/dawn @@ -0,0 +1 @@ +Subproject commit fb97d04c0c2e2307dc11a9d9de4eab607af111f9 diff --git a/crates/burn-wgpu/dawn.h b/crates/burn-wgpu/dawn.h new file mode 100644 index 0000000000..4a29d37346 --- /dev/null +++ b/crates/burn-wgpu/dawn.h @@ -0,0 +1 @@ +#include "dawn/webgpu.h" diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index a268b4d166..dddc0bc414 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -218,7 +218,9 @@ impl Display for Variable { } Variable::ConstantScalar(number, elem) => match elem { Elem::F32 => f.write_fmt(format_args!("{number}f")), - Elem::I32 => f.write_fmt(format_args!("{number}i")), + // Dawn seems to get tripped up by the 'i' suffix, while wgpu is happy + // with or without it, so emit the literal without it. + Elem::I32 => f.write_fmt(format_args!("{number}")), Elem::U32 => f.write_fmt(format_args!("{number}u")), Elem::Bool => f.write_fmt(format_args!("bool({number})")), }, diff --git a/crates/burn-wgpu/src/compute/dawn_api_shim.rs b/crates/burn-wgpu/src/compute/dawn_api_shim.rs new file mode 100644 index 0000000000..31c8bfa47d --- /dev/null +++ b/crates/burn-wgpu/src/compute/dawn_api_shim.rs @@ -0,0 +1,820 @@ +#![allow(missing_docs)] + +use crate::compute::{ + dawn_native_bindings::*, webgpu_api::*, WgpuServer, WgpuStorage, +}; +use crate::{create_client, GraphicsApi, RuntimeOptions, WgpuDevice}; +use alloc::sync::Arc; +use burn_compute::{ + channel::MutexComputeChannel, client::ComputeClient, memory_management::SimpleMemoryManagement, + ComputeRuntime, +}; +use burn_jit::compute::WorkGroup; +use std::num::NonZeroU64; + +#[derive(Debug)] +pub struct DawnApi {} + +#[derive(Debug)] +pub struct DawnAdapter { + adapter: WGPUAdapter, +} + +impl Adapter for DawnAdapter { + fn get_info(&self) -> DawnAdapterInfo { + let mut adapter_info = WGPUAdapterProperties { + nextInChain: std::ptr::null_mut::(), + vendorID: 0, + vendorName: std::ptr::null(), + architecture: std::ptr::null(), + deviceID: 0, + name: std::ptr::null(), + driverDescription: std::ptr::null(), + adapterType: 0, + backendType: 0, + compatibilityMode: 0, + }; + unsafe { + wgpuAdapterGetProperties(self.adapter, &mut adapter_info); + } + DawnAdapterInfo { adapter_info } + } +} + +#[derive(Debug)] +pub struct DawnAdapterInfo { + adapter_info: WGPUAdapterProperties, +} + +impl AdapterInfo for DawnAdapterInfo { + fn backend(&self) -> DawnBackend { + DawnBackend::from_u32(self.adapter_info.backendType) + } + + fn device(&self) -> DeviceId { + self.adapter_info.deviceID + } +} + +#[derive(Debug)] +pub enum DawnBackend { + Undefined = WGPUBackendType_WGPUBackendType_Undefined as isize, + Null = WGPUBackendType_WGPUBackendType_Null as isize, + WebGPU = WGPUBackendType_WGPUBackendType_WebGPU as isize, + D3D11 = WGPUBackendType_WGPUBackendType_D3D11 as isize, + D3D12 = WGPUBackendType_WGPUBackendType_D3D12 as isize, + Metal = WGPUBackendType_WGPUBackendType_Metal as isize, + Vulkan = WGPUBackendType_WGPUBackendType_Vulkan as isize, + OpenGL = WGPUBackendType_WGPUBackendType_OpenGL as isize, + OpenGLES = WGPUBackendType_WGPUBackendType_OpenGLES as isize, +} + +impl DawnBackend { + #[allow(non_upper_case_globals)] + fn from_u32(val: u32) -> DawnBackend { + match val { + WGPUBackendType_WGPUBackendType_Undefined => DawnBackend::Undefined, + WGPUBackendType_WGPUBackendType_Null => DawnBackend::Null, + WGPUBackendType_WGPUBackendType_WebGPU => DawnBackend::WebGPU, + WGPUBackendType_WGPUBackendType_D3D11 => DawnBackend::D3D11, + WGPUBackendType_WGPUBackendType_D3D12 => DawnBackend::D3D12, + WGPUBackendType_WGPUBackendType_Metal => DawnBackend::Metal, + WGPUBackendType_WGPUBackendType_Vulkan => DawnBackend::Vulkan, + WGPUBackendType_WGPUBackendType_OpenGL => DawnBackend::OpenGL, + WGPUBackendType_WGPUBackendType_OpenGLES => DawnBackend::OpenGLES, + _ => panic!("Unknown Dawn backend type: {}", val), + } + } +} + +impl core::convert::AsRef for DawnBackend { + fn as_ref(&self) -> &'static str { + match self { + DawnBackend::Undefined => "undefined", + DawnBackend::Null => "null", + DawnBackend::WebGPU => "webgpu", + DawnBackend::D3D11 => "dx11", + DawnBackend::D3D12 => "dx12", + DawnBackend::Metal => "metal", + DawnBackend::Vulkan => "vulkan", + DawnBackend::OpenGL => "opengl", + DawnBackend::OpenGLES => "opengles", + } + } +} + +#[derive(Debug)] +pub struct DawnBindGroup { + bind_group: WGPUBindGroup, +} +unsafe impl Send for DawnBindGroup {} +impl BindGroup for DawnBindGroup {} + +#[derive(Debug)] +pub struct DawnBindGroupLayout { + layout: WGPUBindGroupLayout, +} +impl BindGroupLayout for DawnBindGroupLayout {} + +#[derive(Debug)] +pub struct DawnBuffer { + buffer: WGPUBuffer, + size: u64, +} +unsafe impl Send for DawnBuffer {} +unsafe impl Sync for DawnBuffer {} + +impl Buffer for DawnBuffer { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, DawnBuffer> { + BufferBinding { + buffer: self, + offset: 0, + size: Some(NonZeroU64::new((*self).size).unwrap()), + } + } + + fn destroy(&self) { + unsafe { + wgpuBufferDestroy((*self).buffer.into()); + } + } + + async fn read(&self, device: &DawnDevice) -> Vec { + let mut read_data = BufferReadData { + read_done: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + unsafe { + let data_ptr = std::mem::transmute::<*mut BufferReadData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(read_data), + ); + let mut sz = (*self).size; + if sz % 4 != 0 { + sz += 2; + } + wgpuBufferMapAsync( + (*self).buffer.into(), + WGPUMapMode_WGPUMapMode_Read, + 0, + sz as usize, + Some(buffer_reader_cb), + data_ptr, + ); + + let mut read_done = read_data.read_done.lock().unwrap(); + let should_process = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true)); + let spt = should_process.clone(); + let instance = DawnInstance { + instance: wgpuAdapterGetInstance(wgpuDeviceGetAdapter((*device).device)), + }; + let handle = std::thread::spawn(move || { + let inst = instance; + while spt.load(std::sync::atomic::Ordering::Relaxed) { + wgpuInstanceProcessEvents(inst.instance); + std::thread::sleep(std::time::Duration::from_micros(10)); + } + }); + while !*read_done { + let res = read_data + .cv + .wait_timeout(read_done, std::time::Duration::from_micros(100)) + .unwrap(); + read_done = res.0; + } + should_process.store(false, std::sync::atomic::Ordering::Relaxed); + handle.join().unwrap(); + + let mpd_rng = + wgpuBufferGetConstMappedRange((*self).buffer.into(), 0, (*self).size as usize); + let slice = std::slice::from_raw_parts(mpd_rng as *const u8, (*self).size as usize); + slice.to_vec() + } + } + + fn size(&self) -> u64 { + (*self).size + } +} + +pub type DawnBufferUsages = u32; +pub const MAP_READ: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_MapRead; +pub const COPY_SRC: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_CopySrc; +pub const COPY_DST: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_CopyDst; +pub const STORAGE: DawnBufferUsages = WGPUBufferUsage_WGPUBufferUsage_Storage; + +#[derive(Debug)] +pub struct DawnCommandBuffer { + buffer: WGPUCommandBuffer, +} +impl CommandBuffer for DawnCommandBuffer {} + +#[derive(Debug)] +pub struct DawnCommandEncoder { + encoder: WGPUCommandEncoder, +} +unsafe impl Send for DawnCommandEncoder {} +unsafe impl Sync for DawnCommandEncoder {} + +impl CommandEncoder + for DawnCommandEncoder +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: DawnBindGroup, + work_group: WorkGroup, + ) { + let label = match desc.label { + Some(name) => name, + None => "", + }; + let pass_desc = WGPUComputePassDescriptor { + nextInChain: std::ptr::null(), + label: std::ffi::CString::new(label).unwrap().into_raw(), + timestampWrites: std::ptr::null(), + }; + let pass: WGPUComputePassEncoder; + unsafe { + pass = wgpuCommandEncoderBeginComputePass(self.encoder.into(), &pass_desc); + } + unsafe { + wgpuComputePassEncoderSetPipeline(pass, pipeline.pipeline.into()); + wgpuComputePassEncoderSetBindGroup( + pass, + 0, + bind_group.bind_group.into(), + 0, + (&[]).as_ptr(), + ); + wgpuComputePassEncoderDispatchWorkgroups( + pass, + work_group.x, + work_group.y, + work_group.z, + ); + } + unsafe { + wgpuComputePassEncoderEnd(pass); + } + } + + fn copy_buffer_to_buffer( + &mut self, + src: &DawnBuffer, + src_offset: u64, + dest: &DawnBuffer, + dest_offset: u64, + size: u64, + ) { + unsafe { + wgpuCommandEncoderCopyBufferToBuffer( + (*self).encoder.into(), + (*src).buffer.into(), + src_offset, + (*dest).buffer.into(), + dest_offset, + size, + ); + } + } + + fn finish(self) -> DawnCommandBuffer { + let cmd_buf_desc = WGPUCommandBufferDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + }; + let cmd_buf: WGPUCommandBuffer; + unsafe { + cmd_buf = wgpuCommandEncoderFinish(self.encoder.into(), &cmd_buf_desc); + } + DawnCommandBuffer { buffer: cmd_buf } + } +} + +#[derive(Debug)] +pub struct DawnComputePipeline { + pipeline: WGPUComputePipeline, +} +unsafe impl Send for DawnComputePipeline {} +unsafe impl Sync for DawnComputePipeline {} + +impl ComputePipeline for DawnComputePipeline { + fn get_bind_group_layout(&self, id: u32) -> DawnBindGroupLayout { + let layout: WGPUBindGroupLayout; + unsafe { + layout = wgpuComputePipelineGetBindGroupLayout((*self).pipeline.into(), id); + } + DawnBindGroupLayout { layout: layout } + } +} + +#[derive(Debug)] +pub struct DawnDevice { + device: WGPUDevice, +} +unsafe impl Send for DawnDevice {} +unsafe impl Sync for DawnDevice {} + +impl + Device< + DawnBindGroup, + DawnBindGroupLayout, + DawnBuffer, + DawnCommandEncoder, + DawnComputePipeline, + DawnPipelineLayout, + DawnShaderModule, + > for DawnDevice +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, DawnBindGroupLayout, DawnBuffer>, + ) -> DawnBindGroup { + let entries = (*desc) + .entries + .iter() + .map(|entry| { + let resource = match &entry.resource { + BindingResource::Buffer(res) => res, + }; + WGPUBindGroupEntry { + nextInChain: std::ptr::null(), + binding: entry.binding, + buffer: resource.buffer.buffer.into(), + offset: resource.offset, + size: resource.size.unwrap().get(), + sampler: std::ptr::null_mut(), + textureView: std::ptr::null_mut(), + } + }) + .collect::>(); + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let bg_desc = WGPUBindGroupDescriptor { + nextInChain: std::ptr::null(), + label: label, + layout: (*desc).layout.layout, + entryCount: entries.len(), + entries: entries.as_ptr(), + }; + let bind_group: WGPUBindGroup; + unsafe { + bind_group = wgpuDeviceCreateBindGroup((*self).device.into(), &bg_desc); + } + DawnBindGroup { + bind_group: bind_group, + } + } + + fn create_buffer(&self, desc: &BufferDescriptor) -> DawnBuffer { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let buf_desc = WGPUBufferDescriptor { + nextInChain: std::ptr::null(), + label: label, + usage: (*desc).usage, + size: (*desc).size, + mappedAtCreation: (*desc).mapped_at_creation as u32, + }; + let buffer: WGPUBuffer; + unsafe { + buffer = wgpuDeviceCreateBuffer((*self).device.into(), &buf_desc); + } + DawnBuffer { + buffer: buffer, + size: (*desc).size, + } + } + + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> DawnBuffer { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let buf_desc = WGPUBufferDescriptor { + nextInChain: std::ptr::null(), + label: label, + usage: (*desc).usage, + size: (*desc).contents.len() as u64, + mappedAtCreation: 1, + }; + let buffer: WGPUBuffer; + unsafe { + buffer = wgpuDeviceCreateBuffer((*self).device.into(), &buf_desc); + let data = wgpuBufferGetMappedRange(buffer, 0, (*desc).contents.len()); + let src_ptr = &(*desc).contents[0] as *const u8; + std::ptr::copy_nonoverlapping(src_ptr, data as *mut u8, (*desc).contents.len()); + wgpuBufferUnmap(buffer); + } + DawnBuffer { + buffer: buffer, + size: (*desc).contents.len() as u64, + } + } + + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> DawnCommandEncoder { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let encoder_desc = WGPUCommandEncoderDescriptor { + nextInChain: std::ptr::null(), + label: label, + }; + let encoder: WGPUCommandEncoder; + unsafe { + encoder = wgpuDeviceCreateCommandEncoder((*self).device.into(), &encoder_desc); + } + DawnCommandEncoder { encoder: encoder } + } + + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> DawnComputePipeline { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let layout = match desc.layout { + None => std::ptr::null_mut(), + Some(layout) => layout.layout, + }; + let pip_desc = WGPUComputePipelineDescriptor { + nextInChain: std::ptr::null(), + label: label, + layout: layout, + compute: WGPUProgrammableStageDescriptor { + nextInChain: std::ptr::null(), + module: (*(*desc).module).module, + entryPoint: std::ffi::CString::new((*desc).entry_point) + .unwrap() + .into_raw(), + constantCount: 0, + constants: std::ptr::null(), + }, + }; + let pipeline: WGPUComputePipeline; + unsafe { + pipeline = wgpuDeviceCreateComputePipeline((*self).device.into(), &pip_desc); + } + DawnComputePipeline { pipeline: pipeline } + } + + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> DawnShaderModule { + let label = match desc.label { + None => std::ptr::null(), + Some(name) => std::ffi::CString::new(name).unwrap().into_raw(), + }; + let src = match &desc.source { + ShaderSource::Wgsl(source) => source.to_string(), + }; + let wgsl_desc = WGPUShaderModuleWGSLDescriptor { + chain: WGPUChainedStruct { + next: std::ptr::null(), + sType: WGPUSType_WGPUSType_ShaderModuleWGSLDescriptor, + }, + code: std::ffi::CString::new(src).unwrap().into_raw(), + }; + let module: WGPUShaderModule; + unsafe { + let sh_desc = WGPUShaderModuleDescriptor { + nextInChain: std::mem::transmute::< + *const WGPUShaderModuleWGSLDescriptor, + *const WGPUChainedStruct, + >(&wgsl_desc), + label: label, + }; + module = wgpuDeviceCreateShaderModule((*self).device.into(), &sh_desc); + } + DawnShaderModule { module: module } + } +} + +#[derive(Debug)] +pub struct DawnInstance { + instance: WGPUInstance, +} +unsafe impl Send for DawnInstance {} + +#[derive(Debug)] +pub struct DawnPipelineLayout { + layout: WGPUPipelineLayout, +} +impl PipelineLayout for DawnPipelineLayout {} + +#[derive(Debug)] +pub struct DawnQueue { + queue: WGPUQueue, +} +unsafe impl Send for DawnQueue {} + +impl Queue for DawnQueue { + fn submit(&self, buf: Option) { + match buf { + None => (), + Some(buf) => unsafe { + wgpuQueueSubmit((*self).queue.into(), 1, std::ptr::addr_of!(buf.buffer)); + }, + }; + } + + fn write_buffer(&self, buffer: &DawnBuffer, offset: u64, data: &[u8]) { + unsafe { + let data_ptr = + std::mem::transmute::<*const u8, *const std::os::raw::c_void>(data.as_ptr()); + let mut sz = data.len(); + if sz % 4 != 0 { + sz += 2; + } + wgpuQueueWriteBuffer( + (*self).queue.into(), + (*buffer).buffer.into(), + offset, + data_ptr, + sz, + ); + } + } +} + +#[derive(Debug)] +pub struct DawnShaderModule { + module: WGPUShaderModule, +} +impl ShaderModule for DawnShaderModule {} + +/// The compute instance is shared across all [dawn runtimes](WgpuRuntime). +static RUNTIME: ComputeRuntime> = + ComputeRuntime::new(); + +type Server = WgpuServer>>; + +impl WebGPUApi for DawnApi { + type Adapter = DawnAdapter; + type AdapterInfo = DawnAdapterInfo; + type Backend = DawnBackend; + type BindGroup = DawnBindGroup; + type BindGroupLayout = DawnBindGroupLayout; + type Buffer = DawnBuffer; + type CommandBuffer = DawnCommandBuffer; + type CommandEncoder = DawnCommandEncoder; + type ComputePipeline = DawnComputePipeline; + type Device = DawnDevice; + type PipelineLayout = DawnPipelineLayout; + type Queue = DawnQueue; + type ShaderModule = DawnShaderModule; + + const MAP_READ: u32 = MAP_READ; + const COPY_SRC: u32 = COPY_SRC; + const COPY_DST: u32 = COPY_DST; + const STORAGE: u32 = STORAGE; + + type Server = WgpuServer>>; + type Channel = MutexComputeChannel>>>; + + fn client(device: &WgpuDevice) -> ComputeClient { + RUNTIME.client(device, move || { + pollster::block_on(create_client::( + device, + RuntimeOptions::default(), + )) + }) + } + + async fn select_device(adapter: &DawnAdapter) -> (DawnDevice, DawnQueue) { + let mut req_data = DevRequestData { + device: std::ptr::null::() as WGPUDevice, + is_set: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + let desc = WGPUDeviceDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + requiredFeatureCount: 1, + requiredFeatures: &WGPUFeatureName_WGPUFeatureName_ShaderF16, + requiredLimits: std::ptr::null(), + defaultQueue: WGPUQueueDescriptor { + nextInChain: std::ptr::null(), + label: std::ptr::null(), + }, + deviceLostCallback: None, + deviceLostCallbackInfo: WGPUDeviceLostCallbackInfo { + nextInChain: std::ptr::null(), + mode: 0, + callback: None, + userdata: std::ptr::null_mut(), + }, + deviceLostUserdata: std::ptr::null_mut(), + uncapturedErrorCallbackInfo: WGPUUncapturedErrorCallbackInfo { + nextInChain: std::ptr::null(), + callback: None, + userdata: std::ptr::null_mut(), + }, + }; + + unsafe { + let data_ptr = std::mem::transmute::<*mut DevRequestData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(req_data), + ); + wgpuAdapterRequestDevice( + (*adapter).adapter.into(), + &desc, + Some(request_device_cb), + data_ptr, + ); + } + + let mut is_set = req_data.is_set.lock().unwrap(); + while !*is_set { + is_set = req_data.cv.wait(is_set).unwrap(); + } + + unsafe { + wgpuDeviceSetUncapturedErrorCallback( + req_data.device, + Some(device_error_callback), + std::ptr::null_mut(), + ); + wgpuDeviceSetLoggingCallback( + req_data.device, + Some(device_logging_callback), + std::ptr::null_mut(), + ); + } + + let dev = DawnDevice { + device: req_data.device, + }; + let queue: WGPUQueue; + unsafe { + queue = wgpuDeviceGetQueue(dev.device.into()); + } + (dev, DawnQueue { queue }) + } + + fn select_adapter(_: &WgpuDevice) -> DawnAdapter { + let instance: WGPUInstance; + let instance_desc = WGPUInstanceDescriptor { + nextInChain: std::ptr::null(), + features: WGPUInstanceFeatures { + nextInChain: std::ptr::null(), + timedWaitAnyEnable: 0, + timedWaitAnyMaxCount: 0, + }, + }; + unsafe { + instance = wgpuCreateInstance(&instance_desc); + } + let mut req_data = AdapterRequestData { + adapter: std::ptr::null::() as WGPUAdapter, + is_set: std::sync::Mutex::new(false), + cv: std::sync::Condvar::new(), + }; + unsafe { + let data_ptr = std::mem::transmute::<*mut AdapterRequestData, *mut std::os::raw::c_void>( + std::ptr::addr_of_mut!(req_data), + ); + wgpuInstanceRequestAdapter( + instance, + std::ptr::null(), + Some(request_adapter_cb), + data_ptr, + ); + } + + let mut is_set = req_data.is_set.lock().unwrap(); + while !*is_set { + is_set = req_data.cv.wait(is_set).unwrap(); + } + + DawnAdapter { + adapter: req_data.adapter, + } + } + + fn device_poll(device: &DawnDevice) { + let instance: WGPUInstance; + let dev = (*device).device; + unsafe { + instance = wgpuAdapterGetInstance(wgpuDeviceGetAdapter(dev.into())); + wgpuInstanceProcessEvents(instance.into()); + wgpuDeviceTick(dev.into()); + } + } + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = pollster::block_on(create_client::(&device, options)); + + RUNTIME.register(&device, client) + } + + async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = create_client::(&device, options).await; + + RUNTIME.register(&device, client) + } +} + +#[allow(non_upper_case_globals)] +extern "C" fn device_error_callback( + type_: WGPUErrorType, + message: *const ::std::os::raw::c_char, + _userdata: *mut ::std::os::raw::c_void, +) { + let type_str = match type_ { + WGPUErrorType_WGPUErrorType_Validation => "Validation", + WGPUErrorType_WGPUErrorType_OutOfMemory => "Out of memory", + WGPUErrorType_WGPUErrorType_Internal => "Internal", + WGPUErrorType_WGPUErrorType_Unknown => "Unknown", + WGPUErrorType_WGPUErrorType_DeviceLost => "Device lost", + _ => "", + }; + unsafe { + let msg_str = std::ffi::CStr::from_ptr(message).to_str().unwrap(); + println!("{} error: {}", type_str, msg_str); + } +} + +extern "C" fn device_logging_callback( + _type_: WGPULoggingType, + message: *const ::std::os::raw::c_char, + _userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let msg_str = std::ffi::CStr::from_ptr(message).to_str().unwrap(); + println!("Device log: {}", msg_str); + } +} + +extern "C" fn request_device_cb( + _status: WGPURequestDeviceStatus, + device: WGPUDevice, + _message: *const ::std::os::raw::c_char, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let req_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut DevRequestData>(userdata); + (*req_data).device = device; + let mut is_set = (*req_data).is_set.lock().unwrap(); + *is_set = true; + (*req_data).cv.notify_one(); + } +} + +#[repr(C)] +struct DevRequestData { + device: WGPUDevice, + is_set: std::sync::Mutex, + cv: std::sync::Condvar, +} + +extern "C" fn request_adapter_cb( + _status: WGPURequestAdapterStatus, + adapter: WGPUAdapter, + _message: *const ::std::os::raw::c_char, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let req_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut AdapterRequestData>(userdata); + (*req_data).adapter = adapter; + let mut is_set = (*req_data).is_set.lock().unwrap(); + *is_set = true; + (*req_data).cv.notify_one(); + } +} + +#[repr(C)] +struct AdapterRequestData { + adapter: WGPUAdapter, + is_set: std::sync::Mutex, + cv: std::sync::Condvar, +} + +#[repr(C)] +struct BufferReadData { + read_done: std::sync::Mutex, + cv: std::sync::Condvar, +} + +unsafe extern "C" fn buffer_reader_cb( + _status: WGPUBufferMapAsyncStatus, + userdata: *mut ::std::os::raw::c_void, +) { + unsafe { + let read_data = + std::mem::transmute::<*mut std::os::raw::c_void, *mut BufferReadData>(userdata); + let mut read_done = (*read_data).read_done.lock().unwrap(); + (*read_done) = true; + (*read_data).cv.notify_one(); + } +} diff --git a/crates/burn-wgpu/src/compute/dawn_native_bindings.rs b/crates/burn-wgpu/src/compute/dawn_native_bindings.rs new file mode 100644 index 0000000000..5f3a03ae25 --- /dev/null +++ b/crates/burn-wgpu/src/compute/dawn_native_bindings.rs @@ -0,0 +1,6 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +include!(concat!(env!("OUT_DIR"), "/dawn_native_bindings_gen.rs")); diff --git a/crates/burn-wgpu/src/compute/mod.rs b/crates/burn-wgpu/src/compute/mod.rs index 4139c3868f..9382e3d30b 100644 --- a/crates/burn-wgpu/src/compute/mod.rs +++ b/crates/burn-wgpu/src/compute/mod.rs @@ -1,5 +1,17 @@ +#[cfg(feature = "dawn")] +mod dawn_api_shim; +#[cfg(feature = "dawn")] +mod dawn_native_bindings; mod server; mod storage; +mod webgpu_api; +#[cfg(feature = "wgpu")] +pub mod wgpu_api_shim; +#[cfg(feature = "dawn")] +pub use dawn_api_shim::*; pub use server::*; pub use storage::*; +pub use webgpu_api::*; +#[cfg(feature = "wgpu")] +pub use wgpu_api_shim::*; diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 4d95d87ad6..fef053f47d 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -1,4 +1,5 @@ use super::WgpuStorage; +use crate::compute::webgpu_api::*; use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ memory_management::MemoryManagement, @@ -7,35 +8,32 @@ use burn_compute::{ use burn_jit::compute::{JitAutotuneKey, JitKernel, Kernel, WorkGroup}; use burn_tensor::Reader; use hashbrown::HashMap; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, -}; /// Wgpu compute server. #[derive(Debug)] -pub struct WgpuServer> { +pub struct WgpuServer>> { memory_management: MM, - device: Arc, - queue: wgpu::Queue, - encoder: CommandEncoder, - pipelines: HashMap>, + device: Arc, + queue: W::Queue, + encoder: W::CommandEncoder, + pipelines: HashMap>, tasks_max: usize, tasks_count: usize, } -impl WgpuServer +impl WgpuServer where - MM: MemoryManagement, + W: WebGPUApi, + MM: MemoryManagement>, { /// Create a new server. pub fn new( memory_management: MM, - device: Arc, - queue: wgpu::Queue, + device: Arc, + queue: W::Queue, tasks_max: usize, ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + let encoder = device.create_command_encoder(&CommandEncoderDescriptor { label: Some("Command Encoder"), }); @@ -53,7 +51,7 @@ where fn submit(&mut self) { let mut new_encoder = self .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + .create_command_encoder(&CommandEncoderDescriptor { label: None }); core::mem::swap(&mut new_encoder, &mut self.encoder); self.queue.submit(Some(new_encoder.finish())); @@ -65,25 +63,24 @@ where fn register_compute( &mut self, - pipeline: Arc, - bind_group: BindGroup, + pipeline: Arc, + bind_group: W::BindGroup, work_group: WorkGroup, ) { - let mut compute = self + self .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { - label: None, - timestamp_writes: None, - }); - - compute.set_pipeline(&pipeline); - compute.set_bind_group(0, &bind_group, &[]); - compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + .dispatch_compute_pass(&ComputePassDescriptor { + label: None, + }, + pipeline, + bind_group, + work_group, + ); self.tasks_count += 1; } - fn pipeline(&mut self, kernel: Kernel) -> Arc { + fn pipeline(&mut self, kernel: Kernel) -> Arc { let kernel_id = kernel.id(); if let Some(pipeline) = self.pipelines.get(&kernel_id) { return pipeline.clone(); @@ -96,15 +93,15 @@ where pipeline } - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { + fn compile_source(&self, source: &str) -> Arc { + let module = self.device.create_shader_module(&ShaderModuleDescriptor { label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + source: ShaderSource::Wgsl(Cow::Borrowed(source)), }); Arc::new( self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &module, @@ -113,14 +110,14 @@ where ) } - fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader { + fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader { let resource = self.memory_management.get(handle.memory); let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + let buffer_dest = self.device.create_buffer(&BufferDescriptor { label: None, size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + usage: W::MAP_READ | W::COPY_DST, mapped_at_creation: false, }); @@ -140,53 +137,32 @@ where } #[derive(new)] -struct BufferReader { - buffer: wgpu::Buffer, +struct BufferReader { + buffer: W::Buffer, } -impl BufferReader { +impl BufferReader +where + W: WebGPUApi, +{ #[cfg(target_family = "wasm")] - async fn read(self, device: alloc::sync::Arc) -> Vec { - self.read_async(&device).await + async fn read(self, device: alloc::sync::Arc) -> Vec { + self.buffer.read(&device).await } #[cfg(not(target_family = "wasm"))] - fn read(self, device: &wgpu::Device) -> Vec { - pollster::block_on(self.read_async(device)) - } - - async fn read_async(&self, device: &wgpu::Device) -> Vec { - let buffer_slice = self.buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver.receive().await; - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - self.buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } + fn read(self, device: &W::Device) -> Vec { + pollster::block_on(self.buffer.read(device)) } } -impl ComputeServer for WgpuServer +impl ComputeServer for WgpuServer where - MM: MemoryManagement, + W: WebGPUApi, + MM: MemoryManagement>, { type Kernel = Kernel; - type Storage = WgpuStorage; + type Storage = WgpuStorage; type MemoryManagement = MM; type AutotuneKey = JitAutotuneKey; @@ -213,7 +189,7 @@ where let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { label: Some("Buffer Src"), contents: data, - usage: wgpu::BufferUsages::COPY_SRC, + usage: W::COPY_SRC, })); let resource = self.memory_management.get(binding.memory); @@ -247,13 +223,13 @@ where let entries = memory_handles .iter() .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { + .map(|(i, buffer)| BindGroupEntry { binding: i as u32, resource: buffer.as_binding(), }) .collect::>(); - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + let bind_group = self.device.create_bind_group(&BindGroupDescriptor { label: None, layout: &group_layout, entries: &entries, @@ -268,6 +244,6 @@ where fn sync(&mut self) { self.submit(); - self.device.poll(wgpu::Maintain::Wait); + W::device_poll(&self.device); } } diff --git a/crates/burn-wgpu/src/compute/storage.rs b/crates/burn-wgpu/src/compute/storage.rs index 12988b1352..9fa35c5ce4 100644 --- a/crates/burn-wgpu/src/compute/storage.rs +++ b/crates/burn-wgpu/src/compute/storage.rs @@ -1,15 +1,19 @@ +use crate::compute::{BindingResource, Buffer, BufferBinding, BufferDescriptor, Device, WebGPUApi}; use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use hashbrown::HashMap; use std::{num::NonZeroU64, sync::Arc}; /// Buffer storage for wgpu. -pub struct WgpuStorage { - memory: HashMap>, +pub struct WgpuStorage { + memory: HashMap>, deallocations: Vec, - device: Arc, + device: Arc, } -impl core::fmt::Debug for WgpuStorage { +impl core::fmt::Debug for WgpuStorage +where + W: WebGPUApi, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) } @@ -17,32 +21,35 @@ impl core::fmt::Debug for WgpuStorage { /// The memory resource that can be allocated for wgpu. #[derive(new, Debug)] -pub struct WgpuResource { +pub struct WgpuResource { /// The wgpu buffer. - pub buffer: Arc, + pub buffer: Arc, /// How the resource is used. pub kind: WgpuResourceKind, } -impl WgpuResource { +impl WgpuResource +where + W: WebGPUApi, +{ /// Return the binding view of the buffer. - pub fn as_binding(&self) -> wgpu::BindingResource { + pub fn as_binding(&self) -> BindingResource<'_, W::Buffer> { let binding = match &self.kind { WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { - buffer: &self.buffer, - offset: *offs, + WgpuResourceKind::Slice { offset, size } => BufferBinding::<'_> { + buffer: self.buffer.as_ref(), + offset: *offset, size: Some(*size), }, }; - wgpu::BindingResource::Buffer(binding) + BindingResource::Buffer(binding) } /// Return the buffer size. pub fn size(&self) -> u64 { match self.kind { WgpuResourceKind::Full => self.buffer.size(), - WgpuResourceKind::Slice(_, size) => size.get(), + WgpuResourceKind::Slice { offset: _, size } => size.get(), } } @@ -50,7 +57,7 @@ impl WgpuResource { pub fn offset(&self) -> u64 { match self.kind { WgpuResourceKind::Full => 0, - WgpuResourceKind::Slice(offset, _) => offset, + WgpuResourceKind::Slice { offset, size: _ } => offset, } } } @@ -61,13 +68,16 @@ pub enum WgpuResourceKind { /// Represents an entire buffer. Full, /// A slice over a buffer. - Slice(wgpu::BufferAddress, wgpu::BufferSize), + Slice { offset: u64, size: NonZeroU64 }, } /// Keeps actual wgpu buffer references in a hashmap with ids as key. -impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { +impl WgpuStorage +where + W: WebGPUApi, +{ + /// Create a new storage on the given [device](WebGPUDevice). + pub fn new(device: Arc) -> Self { Self { memory: HashMap::new(), deallocations: Vec::new(), @@ -85,8 +95,11 @@ impl WgpuStorage { } } -impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; +impl ComputeStorage for WgpuStorage +where + W: WebGPUApi, +{ + type Resource = WgpuResource; fn get(&mut self, handle: &StorageHandle) -> Self::Resource { let buffer = self.memory.get(&handle.id).unwrap(); @@ -97,23 +110,25 @@ impl ComputeStorage for WgpuStorage { } StorageUtilization::Slice { offset, size } => WgpuResource::new( buffer.clone(), - WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), + WgpuResourceKind::Slice { + offset: offset as u64, + size: NonZeroU64::new(size as u64).unwrap(), + }, ), } } fn alloc(&mut self, size: usize) -> StorageHandle { let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + + let buffer = self.device.create_buffer(&BufferDescriptor { label: None, size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, + usage: W::COPY_DST | W::STORAGE | W::COPY_SRC, mapped_at_creation: false, - })); + }); - self.memory.insert(id.clone(), buffer); + self.memory.insert(id.clone(), buffer.into()); StorageHandle::new(id, StorageUtilization::Full(size)) } diff --git a/crates/burn-wgpu/src/compute/webgpu_api.rs b/crates/burn-wgpu/src/compute/webgpu_api.rs new file mode 100644 index 0000000000..28e535218a --- /dev/null +++ b/crates/burn-wgpu/src/compute/webgpu_api.rs @@ -0,0 +1,200 @@ +use crate::{GraphicsApi, RuntimeOptions, WgpuDevice}; +use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; +use burn_jit::compute::WorkGroup; +use std::borrow::Cow; +use alloc::sync::Arc; + +pub trait Adapter: core::fmt::Debug { + fn get_info(&self) -> AdapterInfo; +} + +pub trait AdapterInfo: core::fmt::Debug { + fn backend(&self) -> Backend; + fn device(&self) -> DeviceId; +} + +pub trait BindGroup: Send + core::fmt::Debug {} + +pub struct BindGroupDescriptor<'a, BindGroupLayout, Buffer> { + pub label: Option<&'a str>, + pub layout: &'a BindGroupLayout, + pub entries: &'a Vec>, +} + +pub struct BindGroupEntry<'a, Buffer> { + pub binding: u32, + pub resource: BindingResource<'a, Buffer>, +} + +pub trait BindGroupLayout: core::fmt::Debug {} + +pub enum BindingResource<'a, Buffer> { + Buffer(BufferBinding<'a, Buffer>), +} + +pub trait Buffer: Send + Sync + core::fmt::Debug { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, Buffer>; + fn destroy(&self); + #[allow(async_fn_in_trait)] + async fn read(&self, device: &Device) -> Vec; + fn size(&self) -> u64; +} + +pub struct BufferBinding<'a, Buffer> { + pub buffer: &'a Buffer, + pub offset: u64, + pub size: Option, +} + +pub struct BufferDescriptor<'a> { + pub label: Option<&'a str>, + pub size: u64, + pub usage: u32, + pub mapped_at_creation: bool, +} + +pub struct BufferInitDescriptor<'a> { + pub label: Option<&'a str>, + pub contents: &'a [u8], + pub usage: u32, +} + +pub trait CommandBuffer: core::fmt::Debug {} + +pub struct CommandEncoderDescriptor<'a> { + pub label: Option<&'a str>, +} + +pub trait CommandEncoder: + Send + Sync + core::fmt::Debug +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: BindGroup, + work_group: WorkGroup, + ); + fn copy_buffer_to_buffer( + &mut self, + src: &Buffer, + src_offset: u64, + dst: &Buffer, + dst_offset: u64, + size: u64, + ); + fn finish(self) -> CommandBuffer; +} + +pub struct ComputePassDescriptor<'a> { + pub label: Option<&'a str>, +} + +pub trait ComputePipeline: Send + Sync + core::fmt::Debug { + fn get_bind_group_layout(&self, id: u32) -> BindGroupLayout; +} + +pub struct ComputePipelineDescriptor<'a, PipelineLayout, ShaderModule> { + pub label: Option<&'a str>, + pub layout: Option<&'a PipelineLayout>, + pub module: &'a ShaderModule, + pub entry_point: &'a str, +} + +pub trait Device< + BindGroup, + BindGroupLayout, + Buffer, + CommandEncoder, + ComputePipeline, + PipelineLayout, + ShaderModule, +>: Send + Sync + core::fmt::Debug +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, BindGroupLayout, Buffer>, + ) -> BindGroup; + fn create_buffer(&self, desc: &BufferDescriptor) -> Buffer; + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> Buffer; + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> CommandEncoder; + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> ComputePipeline; + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> ShaderModule; +} + +pub type DeviceId = u32; + +pub trait PipelineLayout: core::fmt::Debug {} + +pub trait Queue: Send + core::fmt::Debug { + fn submit(&self, buf: Option); + fn write_buffer(&self, buf: &Buffer, offset: u64, data: &[u8]); +} + +pub enum ShaderSource<'a> { + Wgsl(Cow<'a, str>), +} + +pub trait ShaderModule: core::fmt::Debug {} + +pub struct ShaderModuleDescriptor<'a> { + pub label: Option<&'a str>, + pub source: ShaderSource<'a>, +} + +pub trait WebGPUApi: Send + Sync + core::fmt::Debug + 'static { + type Adapter: Adapter; + type AdapterInfo: AdapterInfo; + type Backend: core::convert::AsRef; + type BindGroup: BindGroup; + type BindGroupLayout: BindGroupLayout; + type Buffer: Buffer; + type CommandBuffer: CommandBuffer; + type CommandEncoder: CommandEncoder< + Self::BindGroup, + Self::Buffer, + Self::CommandBuffer, + Self::ComputePipeline, + >; + type ComputePipeline: ComputePipeline; + type Device: Device< + Self::BindGroup, + Self::BindGroupLayout, + Self::Buffer, + Self::CommandEncoder, + Self::ComputePipeline, + Self::PipelineLayout, + Self::ShaderModule, + >; + type PipelineLayout: PipelineLayout; + type Queue: Queue; + type ShaderModule: ShaderModule; + + const MAP_READ: u32; + const COPY_SRC: u32; + const COPY_DST: u32; + const STORAGE: u32; + + type Server: ComputeServer< + Kernel = burn_jit::compute::Kernel, + AutotuneKey = burn_jit::compute::JitAutotuneKey, + >; + type Channel: ComputeChannel; + + fn client(device: &WgpuDevice) -> ComputeClient; + #[allow(async_fn_in_trait)] + async fn select_device(adapter: &Self::Adapter) -> (Self::Device, Self::Queue); + #[allow(async_fn_in_trait)] + #[cfg(target_family = "wasm")] + async fn select_adapter(device: &WgpuDevice) -> Self::Adapter; + #[cfg(not(target_family = "wasm"))] + fn select_adapter(device: &WgpuDevice) -> Self::Adapter; + fn device_poll(device: &Self::Device); + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions); + #[allow(async_fn_in_trait)] + async fn init_async(device: &WgpuDevice, options: RuntimeOptions); +} diff --git a/crates/burn-wgpu/src/compute/wgpu_api_shim.rs b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs new file mode 100644 index 0000000000..29e9656784 --- /dev/null +++ b/crates/burn-wgpu/src/compute/wgpu_api_shim.rs @@ -0,0 +1,455 @@ +use crate::{ + compute::{webgpu_api::*, WgpuServer, WgpuStorage}, + create_client, GraphicsApi, RuntimeOptions, WgpuDevice, +}; +use alloc::sync::Arc; +use burn_compute::{ + channel::MutexComputeChannel, client::ComputeClient, memory_management::SimpleMemoryManagement, + ComputeRuntime, +}; +use burn_jit::compute::WorkGroup; + +#[derive(Debug)] +pub struct WgpuApi {} + +pub struct WgpuBackend { + backend: wgpu::Backend, +} + +impl Adapter for wgpu::Adapter { + fn get_info(&self) -> wgpu::AdapterInfo { + wgpu::Adapter::get_info(self) + } +} + +impl AdapterInfo for wgpu::AdapterInfo { + fn backend(&self) -> WgpuBackend { + WgpuBackend { + backend: self.backend, + } + } + + fn device(&self) -> DeviceId { + self.device + } +} + +impl core::convert::AsRef for WgpuBackend { + fn as_ref(&self) -> &str { + wgpu::Backend::to_str(self.backend) + } +} + +impl BindGroup for wgpu::BindGroup {} + +impl BindGroupLayout for wgpu::BindGroupLayout {} + +impl Buffer for wgpu::Buffer { + fn as_entire_buffer_binding(&self) -> BufferBinding<'_, wgpu::Buffer> { + let binding = wgpu::Buffer::as_entire_buffer_binding(self); + BufferBinding { + buffer: binding.buffer, + offset: binding.offset, + size: binding.size, + } + } + + fn destroy(&self) { + wgpu::Buffer::destroy(self) + } + + async fn read(&self, device: &wgpu::Device) -> Vec { + let buffer_slice = self.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + device.poll(wgpu::Maintain::Wait); + + let result = receiver.receive().await; + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + self.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } + } + + fn size(&self) -> u64 { + wgpu::Buffer::size(self) + } +} + +impl CommandBuffer for wgpu::CommandBuffer {} + +impl CommandEncoder + for wgpu::CommandEncoder +{ + fn dispatch_compute_pass( + &mut self, + desc: &ComputePassDescriptor, + pipeline: Arc, + bind_group: wgpu::BindGroup, + work_group: WorkGroup, + ) { + let mut compute = self.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: desc.label, + timestamp_writes: None, + }); + + compute.set_pipeline(&pipeline); + compute.set_bind_group(0, &bind_group, &[]); + compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + } + + fn copy_buffer_to_buffer( + &mut self, + src: &wgpu::Buffer, + src_offset: u64, + dst: &wgpu::Buffer, + dst_offset: u64, + size: u64, + ) { + wgpu::CommandEncoder::copy_buffer_to_buffer(self, src, src_offset, dst, dst_offset, size) + } + + fn finish(self) -> wgpu::CommandBuffer { + wgpu::CommandEncoder::finish(self) + } +} + +impl ComputePipeline for wgpu::ComputePipeline { + fn get_bind_group_layout(&self, id: u32) -> wgpu::BindGroupLayout { + wgpu::ComputePipeline::get_bind_group_layout(self, id) + } +} + +impl + Device< + wgpu::BindGroup, + wgpu::BindGroupLayout, + wgpu::Buffer, + wgpu::CommandEncoder, + wgpu::ComputePipeline, + wgpu::PipelineLayout, + wgpu::ShaderModule, + > for wgpu::Device +{ + fn create_bind_group( + &self, + desc: &BindGroupDescriptor<'_, wgpu::BindGroupLayout, wgpu::Buffer>, + ) -> wgpu::BindGroup { + let entries = desc + .entries + .iter() + .map(|entry| { + let BindingResource::Buffer(resource) = &entry.resource; + wgpu::BindGroupEntry { + binding: entry.binding, + resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding { + buffer: resource.buffer, + offset: resource.offset, + size: resource.size, + }), + } + }) + .collect::>(); + + wgpu::Device::create_bind_group( + self, + &wgpu::BindGroupDescriptor { + label: desc.label, + layout: desc.layout, + entries: &entries, + }, + ) + } + + fn create_buffer(&self, desc: &BufferDescriptor) -> wgpu::Buffer { + wgpu::Device::create_buffer( + self, + &wgpu::BufferDescriptor { + label: desc.label, + size: desc.size, + usage: wgpu::BufferUsages::from_bits(desc.usage).unwrap(), + mapped_at_creation: desc.mapped_at_creation, + }, + ) + } + + fn create_buffer_init(&self, desc: &BufferInitDescriptor) -> wgpu::Buffer { + wgpu::util::DeviceExt::create_buffer_init( + self, + &wgpu::util::BufferInitDescriptor { + label: desc.label, + contents: desc.contents, + usage: wgpu::BufferUsages::from_bits(desc.usage).unwrap(), + }, + ) + } + + fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> wgpu::CommandEncoder { + wgpu::Device::create_command_encoder( + self, + &wgpu::CommandEncoderDescriptor { label: desc.label }, + ) + } + + fn create_compute_pipeline( + &self, + desc: &ComputePipelineDescriptor, + ) -> wgpu::ComputePipeline { + wgpu::Device::create_compute_pipeline( + self, + &wgpu::ComputePipelineDescriptor { + label: desc.label, + layout: desc.layout, + module: desc.module, + entry_point: desc.entry_point, + }, + ) + } + + fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> wgpu::ShaderModule { + let source = match &desc.source { + ShaderSource::Wgsl(source) => source.to_string(), + }; + wgpu::Device::create_shader_module( + self, + wgpu::ShaderModuleDescriptor { + label: desc.label, + source: wgpu::ShaderSource::Wgsl(source.into()), + }, + ) + } +} + +impl PipelineLayout for wgpu::PipelineLayout {} + +impl Queue for wgpu::Queue { + fn submit(&self, buf: Option) { + wgpu::Queue::submit(self, buf); + } + + fn write_buffer(&self, buf: &wgpu::Buffer, offset: u64, data: &[u8]) { + wgpu::Queue::write_buffer(self, buf, offset, data) + } +} + +impl ShaderModule for wgpu::ShaderModule {} + +/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). +static RUNTIME: ComputeRuntime> = + ComputeRuntime::new(); + +type Server = WgpuServer>>; + +impl WebGPUApi for WgpuApi { + type Adapter = wgpu::Adapter; + type AdapterInfo = wgpu::AdapterInfo; + type Backend = WgpuBackend; + type BindGroup = wgpu::BindGroup; + type BindGroupLayout = wgpu::BindGroupLayout; + type Buffer = wgpu::Buffer; + type CommandBuffer = wgpu::CommandBuffer; + type CommandEncoder = wgpu::CommandEncoder; + type ComputePipeline = wgpu::ComputePipeline; + type Device = wgpu::Device; + type PipelineLayout = wgpu::PipelineLayout; + type Queue = wgpu::Queue; + type ShaderModule = wgpu::ShaderModule; + + const MAP_READ: u32 = wgpu::BufferUsages::MAP_READ.bits(); + const COPY_SRC: u32 = wgpu::BufferUsages::COPY_SRC.bits(); + const COPY_DST: u32 = wgpu::BufferUsages::COPY_DST.bits(); + const STORAGE: u32 = wgpu::BufferUsages::STORAGE.bits(); + + type Server = WgpuServer>>; + type Channel = MutexComputeChannel>>>; + + fn client(device: &WgpuDevice) -> ComputeClient { + RUNTIME.client(device, move || { + pollster::block_on(create_client::( + device, + RuntimeOptions::default(), + )) + }) + } + + async fn select_device(adapter: &wgpu::Adapter) -> (wgpu::Device, wgpu::Queue) { + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + required_features: wgpu::Features::empty(), + required_limits: limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue) + } + + #[cfg(target_family = "wasm")] + async fn select_adapter(_device: &WgpuDevice) -> Self::Adapter { + let instance = wgpu::Instance::default(); + + instance + .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) + .await + .unwrap() + } + + #[cfg(not(target_family = "wasm"))] + fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { + use wgpu::DeviceType; + + let instance = wgpu::Instance::default(); + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); + + instance + .enumerate_adapters(G::backend().into()) + .into_iter() + .for_each(|adapter| { + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } + }); + + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } + + adapters.remove(num) + } + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; + + adapters + .into_iter() + .chain(adapters_other) + .for_each(|adapter| { + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } + }); + + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } + } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); + + adapter + } + + fn device_poll(device: &Self::Device) { + device.poll(wgpu::Maintain::Wait); + } + + fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = pollster::block_on(create_client::(&device, options)); + + RUNTIME.register(&device, client) + } + + async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { + let device = Arc::new(device); + let client = create_client::(&device, options).await; + + RUNTIME.register(&device, client) + } +} diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 59a246840d..c159e211af 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -27,7 +27,14 @@ pub use runtime::*; pub use burn_jit::compute::WorkGroup; pub use burn_jit::{tensor::JitTensor, JitBackend}; -#[cfg(feature = "fusion")] +pub use crate::compute::WebGPUApi; + +#[cfg(feature = "dawn")] +pub use crate::compute::DawnApi; +#[cfg(feature = "wgpu")] +pub use crate::compute::WgpuApi; + +#[cfg(all(feature = "fusion", feature = "wgpu"))] /// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: @@ -45,9 +52,9 @@ pub use burn_jit::{tensor::JitTensor, JitBackend}; /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. pub type Wgpu = - burn_fusion::Fusion, F, I>>; + burn_fusion::Fusion, F, I>>; -#[cfg(not(feature = "fusion"))] +#[cfg(all(not(feature = "fusion"), feature = "wgpu"))] /// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: @@ -64,13 +71,61 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend, F, I>; +pub type Wgpu = JitBackend, F, I>; + +#[cfg(all(feature = "fusion", feature = "dawn"))] +/// Tensor backend that uses Dawn for executing GPU compute shaders. +/// +/// This backend can target multiple graphics APIs, including: +/// - [Vulkan] on Linux, Windows, and Android. +/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android. +/// - [DirectX 12](crate::Dx12) on Windows. +/// - [Metal] on Apple hardware. +/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes. +/// +/// # Notes +/// +/// This version of the Dawn backend uses [burn_fusion] to compile and optimize streams of tensor +/// operations for improved performance. +/// +/// You can disable the `fusion` feature flag to remove that functionality, which might be +/// necessary on `wasm` for now. +pub type Dawn = + burn_fusion::Fusion, F, I>>; + +#[cfg(all(not(feature = "fusion"), feature = "dawn"))] +/// Tensor backend that uses Dawn for executing GPU compute shaders. +/// +/// This backend can target multiple graphics APIs, including: +/// - [Vulkan] on Linux, Windows, and Android. +/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android. +/// - [DirectX 12](crate::Dx12) on Windows. +/// - [Metal] on Apple hardware. +/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes. +/// +/// # Notes +/// +/// This version of the Dawn backend doesn't use [burn_fusion] to compile and optimize streams of tensor +/// operations. +/// +/// You can enable the `fusion` feature flag to add that functionality, which might improve +/// performance. +pub type Dawn = JitBackend, F, I>; + +#[cfg(all(test, feature = "wgpu"))] +mod tests_wgpu { + use super::*; + + pub type TestRuntime = crate::WgpuRuntime; + + burn_jit::testgen_all!(); +} -#[cfg(test)] -mod tests { +#[cfg(all(test, feature = "dawn"))] +mod tests_dawn { use super::*; - pub type TestRuntime = crate::WgpuRuntime; + pub type TestRuntime = crate::WgpuRuntime; burn_jit::testgen_all!(); } diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index e3192f2238..ae85b01af5 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -1,6 +1,6 @@ use crate::{ compiler::wgsl, - compute::{WgpuServer, WgpuStorage}, + compute::{Adapter, AdapterInfo, WebGPUApi, WgpuServer, WgpuStorage}, GraphicsApi, WgpuDevice, }; use alloc::sync::Arc; @@ -10,38 +10,29 @@ use burn_compute::{ client::ComputeClient, memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, tune::Tuner, - ComputeRuntime, }; use burn_jit::Runtime; use burn_tensor::backend::{DeviceId, DeviceOps}; use std::marker::PhantomData; -use wgpu::{AdapterInfo, DeviceDescriptor}; /// Runtime that uses the [wgpu] crate with the wgsl compiler. /// /// The [graphics api](GraphicsApi) type is passed as generic. #[derive(Debug)] -pub struct WgpuRuntime { +pub struct WgpuRuntime { + _w: PhantomData, _g: PhantomData, } -/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). -static RUNTIME: ComputeRuntime> = - ComputeRuntime::new(); - -type Server = WgpuServer>; - -impl Runtime for WgpuRuntime { +impl Runtime for WgpuRuntime { type Compiler = wgsl::WgslCompiler; - type Server = WgpuServer>; + type Server = W::Server; - type Channel = MutexComputeChannel>>; + type Channel = W::Channel; type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - pollster::block_on(create_client::(device, RuntimeOptions::default())) - }) + W::client::(device) } fn name() -> &'static str { @@ -91,29 +82,26 @@ impl Default for RuntimeOptions { } /// Init the client sync, useful to configure the runtime options. -pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { - let device = Arc::new(device); - let client = pollster::block_on(create_client::(&device, options)); - - RUNTIME.register(&device, client) +pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { + W::init_sync::(device, options) } /// Init the client async, necessary for wasm. -pub async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { - let device = Arc::new(device); - let client = create_client::(&device, options).await; - - RUNTIME.register(&device, client) +pub async fn init_async( + device: &WgpuDevice, + options: RuntimeOptions, +) { + W::init_async::(device, options).await } -async fn create_client( +pub async fn create_client( device: &WgpuDevice, options: RuntimeOptions, ) -> ComputeClient< - WgpuServer>, - MutexComputeChannel>>, + WgpuServer>>, + MutexComputeChannel>>>, > { - let (device_wgpu, queue, info) = select_device::(device).await; + let (device_wgpu, queue, info) = select_device::(device).await; log::info!( "Created wgpu compute server on device {:?} => {:?}", @@ -128,171 +116,35 @@ async fn create_client( let server = WgpuServer::new(memory_management, device, queue, options.tasks_max); let channel = MutexComputeChannel::new(server); - let tuner_device_id = tuner_device_id(info); + let tuner_device_id = tuner_device_id::(info); ComputeClient::new(channel, Arc::new(RwLock::new(Tuner::new(&tuner_device_id)))) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice). -pub async fn select_device( +pub async fn select_device( device: &WgpuDevice, -) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { +) -> (W::Device, W::Queue, W::AdapterInfo) { #[cfg(target_family = "wasm")] - let adapter = select_adapter::(device).await; + let adapter = select_adapter::(device).await; #[cfg(not(target_family = "wasm"))] - let adapter = select_adapter::(device); - - let limits = adapter.limits(); + let adapter = select_adapter::(device); - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - required_features: wgpu::Features::empty(), - required_limits: limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); + let (device, queue) = W::select_device(&adapter).await; (device, queue, adapter.get_info()) } -fn tuner_device_id(info: AdapterInfo) -> String { - format!("wgpu-{}-{}", info.device, info.backend.to_str()) +fn tuner_device_id(info: W::AdapterInfo) -> String { + format!("wgpu-{}-{}", info.device(), info.backend().as_ref()) } #[cfg(target_family = "wasm")] -async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); - - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() +async fn select_adapter(_device: &WgpuDevice) -> W::Adapter { + W::select_adapter::(device) } #[cfg(not(target_family = "wasm"))] -fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - use wgpu::DeviceType; - - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); - - instance - .enumerate_adapters(G::backend().into()) - .into_iter() - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; - - if is_same_type { - adapters.push(adapter); - } - }); - - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } - - return adapters_other.remove(num); - } - - adapters.remove(num) - } - - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) - .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } - } - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - - adapter +fn select_adapter(device: &WgpuDevice) -> W::Adapter { + W::select_adapter::(device) } diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 960fc7647d..69e5ba5e27 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -1,5 +1,5 @@ use burn::{ - backend::wgpu::{AutoGraphicsApi, WgpuRuntime}, + backend::wgpu::{AutoGraphicsApi, WgpuApi, WgpuRuntime}, tensor::{Distribution, Tensor}, }; use custom_wgpu_kernel::{ @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend, f32, i32>; + type MyBackend = burn::backend::wgpu::JitBackend, f32, i32>; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index 5a2a03129b..7c2c1c2860 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -9,13 +9,13 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WgpuRuntime}, + wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WebGPUApi, WgpuRuntime}, }, tensor::Shape, }; -impl AutodiffBackend - for Autodiff, F, I>> +impl AutodiffBackend + for Autodiff, F, I>> { } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index f23b54c6fe..f4861772fb 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -4,8 +4,8 @@ use super::Backend; use burn::{ backend::wgpu::{ build_info, into_contiguous, kernel_wgsl, FloatElement, GraphicsApi, IntElement, - JitBackend, JitTensor, Kernel, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, - WorkGroup, WorkgroupSize, + JitBackend, JitTensor, Kernel, KernelSource, SourceKernel, SourceTemplate, WebGPUApi, + WgpuRuntime, WorkGroup, WorkgroupSize, }, tensor::Shape, }; @@ -37,7 +37,7 @@ impl KernelSource for FusedMatmulAddRelu { } /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend, F, I> { +impl Backend for JitBackend, F, I> { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 6fc7d3f8f5..b0c366f4a4 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -11,7 +11,7 @@ use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as S use burn::{backend::NdArray, prelude::*, tensor::activation::softmax}; use burn_candle::Candle; -use burn_wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn_wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuApi, WgpuDevice}; use serde::Serialize; use wasm_bindgen::prelude::*; @@ -106,7 +106,7 @@ impl ImageClassifier { log::info!("Loading the model to the Wgpu backend"); let start = Instant::now(); let device = WgpuDevice::default(); - init_async::(&device, Default::default()).await; + init_async::(&device, Default::default()).await; self.model = ModelType::WithWgpuBackend(Model::new(&device)); let duration = start.elapsed(); log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);