Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Element for bool #1878

Merged
merged 10 commits into from
Jun 14, 2024
2 changes: 0 additions & 2 deletions crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use burn_tensor::Element;
use ndarray::LinalgScalar;
use num_traits::One;
use num_traits::Signed;

#[cfg(not(feature = "std"))]
Expand All @@ -20,7 +19,6 @@ where
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ One
laggui marked this conversation as resolved.
Show resolved Hide resolved
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,12 @@ where
tensor
.array
.mapv(|x| {
if x > E::zero() {
E::one()
} else if x < E::zero() {
-E::one()
if x > <E as burn_tensor::identities::Zero>::zero() {
<E as burn_tensor::identities::One>::one()
} else if x < <E as burn_tensor::identities::Zero>::zero() {
-<E as burn_tensor::identities::One>::one()
} else {
E::zero()
<E as burn_tensor::identities::Zero>::zero()
}
})
.into_shared(),
Expand Down
6 changes: 5 additions & 1 deletion crates/burn-ndarray/src/ops/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
.fold_axis(
Axis(dim),
<E as burn_tensor::identities::One>::one(),
|acc, &x| acc.mul(x.elem()),
)
.into_shared();

NdArrayTensor { array }
Expand Down
5 changes: 2 additions & 3 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use alloc::vec::Vec;
use crate::alloc::borrow::ToOwned;

use crate::{
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
backend::Backend, check, check::TensorCheck, identities::Zero, BasicOps, Bool, Distribution,
Element, ElementConversion, Float, Int, Shape, Tensor, TensorKind,
};
use num_traits::Zero;

impl<B, const D: usize, K> Tensor<B, D, K>
where
Expand Down
12 changes: 0 additions & 12 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,6 @@ impl<E: Element> DataSerialize<E> {
}
}

impl<const D: usize> Data<bool, D> {
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();

Data {
value,
shape: self.shape,
}
}
}

impl<E: Element, const D: usize> Data<E, D> {
/// Populates the data with random values.
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution, rng: &mut R) -> Self {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use core::cmp::Ordering;

use crate::Distribution;
use crate::{
cast::ToPrimitive,
identities::{One, Zero},
Distribution,
};
use half::{bf16, f16};
use num_traits::{identities::Zero, One, ToPrimitive};
use rand::RngCore;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -223,6 +226,17 @@ make_element!(
dtype DType::BF16
);

make_element!(
ty bool Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap() != 0,
random |distribution: Distribution, rng: &mut R| {
let sample: u8 = distribution.sampler(rng).sample();
bool::from_elem(sample)
},
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
dtype DType::Bool
);

#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
Expand Down
Loading
Loading