From 9f2bc599b880f069cc5f1ede21c69554adf3ee1e Mon Sep 17 00:00:00 2001 From: Alex Errant <109672176+AlexErrant@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:32:01 -0500 Subject: [PATCH] Add a `sync` feature to common, core, and tensor (#893) --- burn-common/Cargo.toml | 1 + burn-common/src/reader.rs | 20 ++++++++-------- burn-core/Cargo.toml | 2 ++ burn-core/src/grad_clipping/base.rs | 6 ++--- burn-core/src/record/tensor.rs | 13 +++++----- burn-tensor/Cargo.toml | 1 + burn-tensor/src/tensor/api/base.rs | 24 +++++++++---------- burn-tensor/src/tensor/api/numeric.rs | 4 ++-- burn/Cargo.toml | 3 +++ examples/mnist-inference-web/build-for-web.sh | 1 + examples/mnist-inference-web/run-server.sh | 1 + 11 files changed, 42 insertions(+), 34 deletions(-) diff --git a/burn-common/Cargo.toml b/burn-common/Cargo.toml index 04c41adbb3..431e2de48c 100644 --- a/burn-common/Cargo.toml +++ b/burn-common/Cargo.toml @@ -15,6 +15,7 @@ default = ["std"] std = ["rand/std"] +wasm-sync = [] [target.'cfg(target_family = "wasm")'.dependencies] async-trait = { workspace = true } diff --git a/burn-common/src/reader.rs b/burn-common/src/reader.rs index 7aa32a8f44..91f4492c1a 100644 --- a/burn-common/src/reader.rs +++ b/burn-common/src/reader.rs @@ -1,7 +1,7 @@ use alloc::boxed::Box; use core::marker::PhantomData; -#[cfg(target_family = "wasm")] +#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] #[async_trait::async_trait] /// Allows to create async reader. pub trait AsyncReader: Send { @@ -15,10 +15,10 @@ pub enum Reader { Concrete(T), /// Sync data variant. Sync(Box>), - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Async data variant. Async(Box>), - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Future data variant. Future(core::pin::Pin + Send>>), } @@ -52,7 +52,7 @@ where } } -#[cfg(target_family = "wasm")] +#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] #[async_trait::async_trait] impl AsyncReader for MappedReader where @@ -67,7 +67,7 @@ where } impl Reader { - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Read the data. pub async fn read(self) -> T { match self { @@ -78,7 +78,7 @@ impl Reader { } } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] /// Read the data. pub fn read(self) -> T { match self { @@ -92,9 +92,9 @@ impl Reader { match self { Self::Concrete(data) => Some(data), Self::Sync(reader) => Some(reader.read()), - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] Self::Async(_func) => return None, - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] Self::Future(_future) => return None, } } @@ -106,10 +106,10 @@ impl Reader { O: 'static + Send, F: 'static + Send, { - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] return Reader::Async(Box::new(MappedReader::new(self, mapper))); - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] Reader::Sync(Box::new(MappedReader::new(self, mapper))) } } diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index dd043b7e83..c78f56fbb2 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -29,6 +29,8 @@ dataset-minimal = ["burn-dataset"] dataset-sqlite = ["burn-dataset/sqlite"] dataset-sqlite-bundled = ["burn-dataset/sqlite-bundled"] +wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"] + # Backend autodiff = ["burn-autodiff"] diff --git a/burn-core/src/grad_clipping/base.rs b/burn-core/src/grad_clipping/base.rs index fabc055d55..91a6be069d 100644 --- a/burn-core/src/grad_clipping/base.rs +++ b/burn-core/src/grad_clipping/base.rs @@ -68,7 +68,7 @@ impl GradientClipping { clipped_grad.mask_fill(lower_mask, -threshold) } - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] fn clip_by_norm( &self, _grad: Tensor, @@ -77,7 +77,7 @@ impl GradientClipping { todo!("Not yet supported on wasm"); } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn clip_by_norm( &self, grad: Tensor, @@ -96,7 +96,7 @@ impl GradientClipping { } } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn l2_norm(tensor: Tensor) -> Tensor { let squared = tensor.powf(2.0); let sum = squared.sum(); diff --git a/burn-core/src/record/tensor.rs b/burn-core/src/record/tensor.rs index 110459f1fa..70badf2169 100644 --- a/burn-core/src/record/tensor.rs +++ b/burn-core/src/record/tensor.rs @@ -44,7 +44,6 @@ impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde { } } -// #[cfg(not(target_family = "wasm"))] impl Serialize for IntTensorSerde { fn serialize(&self, serializer: Se) -> Result where @@ -90,10 +89,10 @@ impl Record for Tensor { type Item = FloatTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] todo!("Recording float tensors isn't yet supported on wasm."); - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] FloatTensorSerde::new(self.into_data().convert().serialize()) } @@ -106,10 +105,10 @@ impl Record for Tensor { type Item = IntTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] todo!("Recording int tensors isn't yet supported on wasm."); - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] IntTensorSerde::new(self.into_data().convert().serialize()) } @@ -122,10 +121,10 @@ impl Record for Tensor { type Item = BoolTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] todo!("Recording bool tensors isn't yet supported on wasm."); - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] BoolTensorSerde::new(self.into_data().serialize()) } diff --git a/burn-tensor/Cargo.toml b/burn-tensor/Cargo.toml index 3c5c59a66d..495e4264d2 100644 --- a/burn-tensor/Cargo.toml +++ b/burn-tensor/Cargo.toml @@ -16,6 +16,7 @@ experimental-named-tensor = [] export_tests = ["burn-tensor-testgen"] std = ["rand/std", "half/std"] benchmark = [] +wasm-sync = [] [dependencies] burn-common = { path = "../burn-common", version = "0.10.0", default-features = false } diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 1f960505d0..3a12b834e3 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -2,11 +2,11 @@ use alloc::vec::Vec; -#[cfg(not(target_family = "wasm"))] +#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::format; -#[cfg(not(target_family = "wasm"))] +#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::string::String; -#[cfg(not(target_family = "wasm"))] +#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::vec; use burn_common::{reader::Reader, stub::Mutex}; @@ -325,25 +325,25 @@ where Self::new(K::to_device(self.primitive, device)) } - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Returns the data of the current tensor. pub async fn into_data(self) -> Data { K::into_data(self.primitive).read().await } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] /// Returns the data of the current tensor. pub fn into_data(self) -> Data { K::into_data(self.primitive).read() } - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Returns the data of the current tensor. pub async fn to_data(&self) -> Data { K::into_data(self.primitive.clone()).read().await } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] /// Returns the data of the current tensor without taking ownership. pub fn to_data(&self) -> Data { Self::into_data(self.clone()) @@ -467,7 +467,7 @@ where K: BasicOps, >::Elem: Debug, { - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] #[inline] fn push_newline_indent(acc: &mut String, indent: usize) { acc.push('\n'); @@ -476,7 +476,7 @@ where } } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn fmt_inner_tensor( &self, acc: &mut String, @@ -498,7 +498,7 @@ where } } - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn fmt_outer_tensor( &self, acc: &mut String, @@ -533,7 +533,7 @@ where /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. /// * `depth` - The current depth of the tensor dimensions being processed. /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn display_recursive( &self, acc: &mut String, @@ -644,7 +644,7 @@ where fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { writeln!(f, "Tensor {{")?; - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] { let po = PRINT_OPTS.lock().unwrap(); let mut acc = String::new(); diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index 786316ce71..fe3d064120 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -9,7 +9,7 @@ where K: Numeric, K::Elem: Element, { - #[cfg(not(target_family = "wasm"))] + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] /// Convert the tensor into a scalar. /// /// # Panics @@ -21,7 +21,7 @@ where data.value[0] } - #[cfg(target_family = "wasm")] + #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Convert the tensor into a scalar. /// /// # Panics diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 663d4785fa..238652a05e 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -18,6 +18,9 @@ std = ["burn-core/std"] # Training with full features train = ["burn-train/default", "autodiff", "dataset"] +# Useful when targeting WASM and not using WGPU. +wasm-sync = ["burn-core/wasm-sync"] + ## Include nothing train-minimal = ["burn-train"] diff --git a/examples/mnist-inference-web/build-for-web.sh b/examples/mnist-inference-web/build-for-web.sh index 99166342ae..7fb9ea1d19 100755 --- a/examples/mnist-inference-web/build-for-web.sh +++ b/examples/mnist-inference-web/build-for-web.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env bash # Add wasm32 target for compiler. rustup target add wasm32-unknown-unknown diff --git a/examples/mnist-inference-web/run-server.sh b/examples/mnist-inference-web/run-server.sh index bc8e94dcd7..0ce038f1d1 100755 --- a/examples/mnist-inference-web/run-server.sh +++ b/examples/mnist-inference-web/run-server.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env bash # Opening index.html file directly by a browser does not work because of # the security restrictions by the browser. Viewing the HTML file will fail with