Skip to content

Commit

Permalink
Add a sync feature to common, core, and tensor (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexErrant committed Oct 24, 2023
1 parent d021c7d commit 9f2bc59
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 34 deletions.
1 change: 1 addition & 0 deletions burn-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ default = ["std"]

std = ["rand/std"]

wasm-sync = []

[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
Expand Down
20 changes: 10 additions & 10 deletions burn-common/src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::boxed::Box;
use core::marker::PhantomData;

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
#[async_trait::async_trait]
/// Allows to create async reader.
pub trait AsyncReader<T>: Send {
Expand All @@ -15,10 +15,10 @@ pub enum Reader<T> {
Concrete(T),
/// Sync data variant.
Sync(Box<dyn SyncReader<T>>),
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Async data variant.
Async(Box<dyn AsyncReader<T>>),
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Future data variant.
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
}
Expand Down Expand Up @@ -52,7 +52,7 @@ where
}
}

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
#[async_trait::async_trait]
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
where
Expand All @@ -67,7 +67,7 @@ where
}

impl<T> Reader<T> {
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Read the data.
pub async fn read(self) -> T {
match self {
Expand All @@ -78,7 +78,7 @@ impl<T> Reader<T> {
}
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Read the data.
pub fn read(self) -> T {
match self {
Expand All @@ -92,9 +92,9 @@ impl<T> Reader<T> {
match self {
Self::Concrete(data) => Some(data),
Self::Sync(reader) => Some(reader.read()),
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
Self::Async(_func) => return None,
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
Self::Future(_future) => return None,
}
}
Expand All @@ -106,10 +106,10 @@ impl<T> Reader<T> {
O: 'static + Send,
F: 'static + Send,
{
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
return Reader::Async(Box::new(MappedReader::new(self, mapper)));

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
}
}
2 changes: 2 additions & 0 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ dataset-minimal = ["burn-dataset"]
dataset-sqlite = ["burn-dataset/sqlite"]
dataset-sqlite-bundled = ["burn-dataset/sqlite-bundled"]

wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]

# Backend
autodiff = ["burn-autodiff"]

Expand Down
6 changes: 3 additions & 3 deletions burn-core/src/grad_clipping/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl GradientClipping {
clipped_grad.mask_fill(lower_mask, -threshold)
}

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
_grad: Tensor<B, D>,
Expand All @@ -77,7 +77,7 @@ impl GradientClipping {
todo!("Not yet supported on wasm");
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
Expand All @@ -96,7 +96,7 @@ impl GradientClipping {
}
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
let squared = tensor.powf(2.0);
let sum = squared.sum();
Expand Down
13 changes: 6 additions & 7 deletions burn-core/src/record/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
}
}

// #[cfg(not(target_family = "wasm"))]
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Expand Down Expand Up @@ -90,10 +89,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;

fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording float tensors isn't yet supported on wasm.");

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
FloatTensorSerde::new(self.into_data().convert().serialize())
}

Expand All @@ -106,10 +105,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<S>;

fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording int tensors isn't yet supported on wasm.");

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
IntTensorSerde::new(self.into_data().convert().serialize())
}

Expand All @@ -122,10 +121,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde;

fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording bool tensors isn't yet supported on wasm.");

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
BoolTensorSerde::new(self.into_data().serialize())
}

Expand Down
1 change: 1 addition & 0 deletions burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
std = ["rand/std", "half/std"]
benchmark = []
wasm-sync = []

[dependencies]
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }
Expand Down
24 changes: 12 additions & 12 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

use alloc::vec::Vec;

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::format;
#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::string::String;
#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::vec;

use burn_common::{reader::Reader, stub::Mutex};
Expand Down Expand Up @@ -325,25 +325,25 @@ where
Self::new(K::to_device(self.primitive, device))
}

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Returns the data of the current tensor.
pub async fn into_data(self) -> Data<K::Elem, D> {
K::into_data(self.primitive).read().await
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Returns the data of the current tensor.
pub fn into_data(self) -> Data<K::Elem, D> {
K::into_data(self.primitive).read()
}

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Returns the data of the current tensor.
pub async fn to_data(&self) -> Data<K::Elem, D> {
K::into_data(self.primitive.clone()).read().await
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Returns the data of the current tensor without taking ownership.
pub fn to_data(&self) -> Data<K::Elem, D> {
Self::into_data(self.clone())
Expand Down Expand Up @@ -467,7 +467,7 @@ where
K: BasicOps<B>,
<K as BasicOps<B>>::Elem: Debug,
{
#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
#[inline]
fn push_newline_indent(acc: &mut String, indent: usize) {
acc.push('\n');
Expand All @@ -476,7 +476,7 @@ where
}
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn fmt_inner_tensor(
&self,
acc: &mut String,
Expand All @@ -498,7 +498,7 @@ where
}
}

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn fmt_outer_tensor(
&self,
acc: &mut String,
Expand Down Expand Up @@ -533,7 +533,7 @@ where
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
/// * `depth` - The current depth of the tensor dimensions being processed.
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn display_recursive(
&self,
acc: &mut String,
Expand Down Expand Up @@ -644,7 +644,7 @@ where
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "Tensor {{")?;

#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
{
let po = PRINT_OPTS.lock().unwrap();
let mut acc = String::new();
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ where
K: Numeric<B>,
K::Elem: Element,
{
#[cfg(not(target_family = "wasm"))]
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Convert the tensor into a scalar.
///
/// # Panics
Expand All @@ -21,7 +21,7 @@ where
data.value[0]
}

#[cfg(target_family = "wasm")]
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Convert the tensor into a scalar.
///
/// # Panics
Expand Down
3 changes: 3 additions & 0 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ std = ["burn-core/std"]
# Training with full features
train = ["burn-train/default", "autodiff", "dataset"]

# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]

## Include nothing
train-minimal = ["burn-train"]

Expand Down
1 change: 1 addition & 0 deletions examples/mnist-inference-web/build-for-web.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env bash

# Add wasm32 target for compiler.
rustup target add wasm32-unknown-unknown
Expand Down
1 change: 1 addition & 0 deletions examples/mnist-inference-web/run-server.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env bash

# Opening index.html file directly by a browser does not work because of
# the security restrictions by the browser. Viewing the HTML file will fail with
Expand Down

0 comments on commit 9f2bc59

Please sign in to comment.