From 47a81270e11f66a9102bb0112644412cb8cd9f3e Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Fri, 14 Jun 2024 13:12:14 +0100 Subject: [PATCH] Make autodiff compile on wasm (#1889) --- crates/burn-autodiff/src/ops/bool_tensor.rs | 2 ++ crates/burn-autodiff/src/ops/int_tensor.rs | 3 +++ crates/burn-autodiff/src/ops/tensor.rs | 8 +++++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index 3ecc39ce10..4a4e256882 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -121,10 +121,12 @@ impl BoolTensorOps for Autodiff { B::bool_flip(tensor, axes) } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn bool_argwhere(tensor: BoolTensor) -> IntTensor { B::bool_argwhere(tensor) } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn bool_nonzero(tensor: BoolTensor) -> Vec> { B::bool_nonzero(tensor) } diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 5e3e44a285..c1e24f3419 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -383,6 +383,7 @@ impl IntTensorOps for Autodiff { B::int_expand(tensor, shape) } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort( tensor: IntTensor, dim: usize, @@ -391,6 +392,7 @@ impl IntTensorOps for Autodiff { B::int_sort(tensor, dim, descending) } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort_with_indices( tensor: IntTensor, dim: usize, @@ -399,6 +401,7 @@ impl IntTensorOps for Autodiff { B::int_sort_with_indices(tensor, dim, descending) } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_argsort( tensor: IntTensor, dim: usize, diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 720ebc983e..3b7d941264 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -21,7 +21,6 @@ use burn_tensor::{ }; use super::maxmin::MaxMinDim; -use super::sort::SortDim; impl FloatTensorOps for Autodiff { fn float_from_data( @@ -2369,12 +2368,13 @@ impl FloatTensorOps for Autodiff } } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort( tensor: FloatTensor, dim: usize, descending: bool, ) -> FloatTensor { - match SortDim + match super::sort::SortDim .prepare::([tensor.node]) .compute_bound() .stateful() @@ -2391,12 +2391,13 @@ impl FloatTensorOps for Autodiff } } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort_with_indices( tensor: FloatTensor, dim: usize, descending: bool, ) -> (FloatTensor, IntTensor) { - match SortDim + match super::sort::SortDim .prepare::([tensor.node]) .compute_bound() .stateful() @@ -2419,6 +2420,7 @@ impl FloatTensorOps for Autodiff } } + #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_argsort( tensor: FloatTensor, dim: usize,