Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Element for bool #1878

Merged
merged 10 commits into from
Jun 14, 2024
2 changes: 0 additions & 2 deletions crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use burn_tensor::Element;
use ndarray::LinalgScalar;
use num_traits::One;
use num_traits::Signed;

#[cfg(not(feature = "std"))]
Expand All @@ -20,7 +19,6 @@ where
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ One
laggui marked this conversation as resolved.
Show resolved Hide resolved
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement
Expand Down
12 changes: 7 additions & 5 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,14 @@ where
tensor
.array
.mapv(|x| {
if x > E::zero() {
E::one()
} else if x < E::zero() {
-E::one()
let zero = 0.elem();
let one = 1.elem::<E>();
laggui marked this conversation as resolved.
Show resolved Hide resolved
if x > zero {
one
} else if x < zero {
-one
} else {
E::zero()
zero
}
})
.into_shared(),
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-ndarray/src/ops/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ macro_rules! keepdim {
}};
}

use burn_tensor::ElementConversion;
pub(crate) use keepdim;
use ndarray::Axis;

Expand Down Expand Up @@ -63,7 +64,7 @@ pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
.into_shared();

NdArrayTensor { array }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use serde::{Serialize, Serializer};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::Element;
use crate::{backend::Backend, check, Bool, Data, DataSerialize, Float, Int, Shape, TensorKind};

/// A tensor with a given backend, shape and data type.
Expand Down Expand Up @@ -1159,7 +1160,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: 'static + Copy;
type Elem: Element;

/// Creates an empty tensor with the given shape.
///
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::{
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
};
use num_traits::Zero;

impl<B, const D: usize, K> Tensor<B, D, K>
where
Expand Down Expand Up @@ -656,7 +655,7 @@ where
///
/// A boolean tensor with the same shape as the input tensor.
pub fn bool(self) -> Tensor<B, D, Bool> {
K::not_equal_elem::<D>(self.primitive, K::Elem::zero())
K::not_equal_elem::<D>(self.primitive, 0.elem())
}

/// Create a random tensor of the given shape on the given device where each element is
Expand Down
12 changes: 0 additions & 12 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,6 @@ impl<E: Element> DataSerialize<E> {
}
}

impl<const D: usize> Data<bool, D> {
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();

Data {
value,
shape: self.shape,
}
}
}

impl<E: Element, const D: usize> Data<E, D> {
/// Populates the data with random values.
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution, rng: &mut R) -> Self {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use core::cmp::Ordering;

use crate::Distribution;
use crate::{cast::ToPrimitive, Distribution};
use half::{bf16, f16};
use num_traits::{identities::Zero, One, ToPrimitive};
use rand::RngCore;
use serde::{Deserialize, Serialize};

/// Element trait for tensor.
pub trait Element:
ToPrimitive
+ Zero
+ One
+ ElementRandom
+ ElementConversion
+ ElementPrecision
Expand Down Expand Up @@ -223,6 +220,17 @@ make_element!(
dtype DType::BF16
);

make_element!(
ty bool Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap() != 0,
random |distribution: Distribution, rng: &mut R| {
let sample: u8 = distribution.sampler(rng).sample();
bool::from_elem(sample)
},
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
dtype DType::Bool
);

#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
Expand Down
Loading
Loading