Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Element for bool #1878

Merged
merged 10 commits into from
Jun 14, 2024
6 changes: 2 additions & 4 deletions crates/burn-jit/src/kernel/pool/max_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {

fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
cpa!(scope, max_val = max_initial);
max_val
}
Expand Down Expand Up @@ -68,8 +67,7 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {

fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
cpa!(scope, max_val = max_initial);
let max_index = scope.create_local(Elem::UInt);
(max_val, max_index)
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/kernel/reduce/naive/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
cpa!(scope, max = max_initial);

(max, index)
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/kernel/reduce/naive/argmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial =
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
cpa!(scope, min = min_initial);

(min, index)
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/reduce/shared/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);

let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
cpa!(scope, value_shared_memory[write_position] = max);
(value_shared_memory, index_shared_memory)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/reduce/shared/argmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);

let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
cpa!(scope, value_shared_memory[write_position] = min);
(value_shared_memory, index_shared_memory)
}
Expand Down
2 changes: 0 additions & 2 deletions crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use burn_tensor::Element;
use ndarray::LinalgScalar;
use num_traits::One;
use num_traits::Signed;

#[cfg(not(feature = "std"))]
Expand All @@ -20,7 +19,6 @@ where
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ One
laggui marked this conversation as resolved.
Show resolved Hide resolved
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement
Expand Down
12 changes: 7 additions & 5 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,16 +563,18 @@ where
where
E: Signed,
{
let zero = 0.elem();
let one = 1.elem::<E>();
NdArrayTensor::new(
tensor
.array
.mapv(|x| {
if x > E::zero() {
E::one()
} else if x < E::zero() {
-E::one()
if x > zero {
one
} else if x < zero {
-one
} else {
E::zero()
zero
}
})
.into_shared(),
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-ndarray/src/ops/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ macro_rules! keepdim {
}};
}

use burn_tensor::ElementConversion;
pub(crate) use keepdim;
use ndarray::Axis;

Expand Down Expand Up @@ -63,7 +64,7 @@ pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
.into_shared();

NdArrayTensor { array }
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).cos().elem())
.mapv_into(|a| (a.to_f64()).cos().elem())
.into_shared();

NdArrayTensor::new(array)
Expand All @@ -415,7 +415,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_sin<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).sin().elem())
.mapv_into(|a| (a.to_f64()).sin().elem())
.into_shared();

NdArrayTensor::new(array)
Expand All @@ -424,7 +424,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_tanh<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).tanh().elem())
.mapv_into(|a| (a.to_f64()).tanh().elem())
.into_shared();

NdArrayTensor::new(array)
Expand All @@ -433,7 +433,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_erf<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| erf(a.to_f64().unwrap()).elem())
.mapv_into(|a| erf(a.to_f64()).elem())
.into_shared();

NdArrayTensor::new(array)
Expand Down Expand Up @@ -473,7 +473,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32().unwrap()))
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32()))
}

fn float_permute<const D: usize>(
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use serde::{Serialize, Serializer};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::Element;
use crate::{backend::Backend, check, Bool, Data, DataSerialize, Float, Int, Shape, TensorKind};

/// A tensor with a given backend, shape and data type.
Expand Down Expand Up @@ -1159,7 +1160,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: 'static + Copy;
type Elem: Element;

/// Creates an empty tensor with the given shape.
///
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::{
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
};
use num_traits::Zero;

impl<B, const D: usize, K> Tensor<B, D, K>
where
Expand Down Expand Up @@ -656,7 +655,7 @@ where
///
/// A boolean tensor with the same shape as the input tensor.
pub fn bool(self) -> Tensor<B, D, Bool> {
K::not_equal_elem::<D>(self.primitive, K::Elem::zero())
K::not_equal_elem::<D>(self.primitive, 0.elem())
}

/// Create a random tensor of the given shape on the given device where each element is
Expand Down
12 changes: 0 additions & 12 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,6 @@ impl<E: Element> DataSerialize<E> {
}
}

impl<const D: usize> Data<bool, D> {
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();

Data {
value,
shape: self.shape,
}
}
}

impl<E: Element, const D: usize> Data<E, D> {
/// Populates the data with random values.
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution, rng: &mut R) -> Self {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use core::cmp::Ordering;

use crate::Distribution;
use crate::{cast::ToElement, Distribution};
use half::{bf16, f16};
use num_traits::{identities::Zero, One, ToPrimitive};
use rand::RngCore;
use serde::{Deserialize, Serialize};

/// Element trait for tensor.
pub trait Element:
ToPrimitive
+ Zero
+ One
ToElement
+ ElementRandom
+ ElementConversion
+ ElementPrecision
Expand Down Expand Up @@ -38,7 +35,7 @@ pub trait ElementConversion {
/// # Returns
///
/// The converted element.
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
fn from_elem<E: ToElement>(elem: E) -> Self;

/// Converts and returns the converted element.
fn elem<E: Element>(self) -> E;
Expand Down Expand Up @@ -105,7 +102,7 @@ macro_rules! make_element {
}

impl ElementConversion for $type {
fn from_elem<E: ToPrimitive>(elem: E) -> Self {
fn from_elem<E: ToElement>(elem: E) -> Self {
#[allow(clippy::redundant_closure_call)]
$convert(&elem)
}
Expand Down Expand Up @@ -140,71 +137,71 @@ macro_rules! make_element {

make_element!(
ty f64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(),
convert |elem: &dyn ToElement| elem.to_f64(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f64, b: &f64| a.total_cmp(b),
dtype DType::F64
);

make_element!(
ty f32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(),
convert |elem: &dyn ToElement| elem.to_f32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f32, b: &f32| a.total_cmp(b),
dtype DType::F32
);

make_element!(
ty i64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(),
convert |elem: &dyn ToElement| elem.to_i64(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i64, b: &i64| Ord::cmp(a, b),
dtype DType::I64
);

make_element!(
ty i32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(),
convert |elem: &dyn ToElement| elem.to_i32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i32, b: &i32| Ord::cmp(a, b),
dtype DType::I32
);

make_element!(
ty u32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(),
convert |elem: &dyn ToElement| elem.to_u32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u32, b: &u32| Ord::cmp(a, b),
dtype DType::U32
);

make_element!(
ty i16 Precision::Half,
convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(),
convert |elem: &dyn ToElement| elem.to_i16(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i16, b: &i16| Ord::cmp(a, b),
dtype DType::I16
);

make_element!(
ty i8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(),
convert |elem: &dyn ToElement| elem.to_i8(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i8, b: &i8| Ord::cmp(a, b),
dtype DType::I8
);

make_element!(
ty u8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
convert |elem: &dyn ToElement| elem.to_u8(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u8, b: &u8| Ord::cmp(a, b),
dtype DType::U8
);

make_element!(
ty f16 Precision::Half,
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
f16::from_elem(sample)
Expand All @@ -214,7 +211,7 @@ make_element!(
);
make_element!(
ty bf16 Precision::Half,
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
bf16::from_elem(sample)
Expand All @@ -223,6 +220,17 @@ make_element!(
dtype DType::BF16
);

make_element!(
ty bool Precision::Other,
convert |elem: &dyn ToElement| elem.to_u8() != 0,
random |distribution: Distribution, rng: &mut R| {
let sample: u8 = distribution.sampler(rng).sample();
bool::from_elem(sample)
},
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
dtype DType::Bool
);

#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
Expand Down
Loading
Loading