Skip to content

Commit

Permalink
Pretty Print Tensors (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
agelas committed Apr 7, 2023
1 parent ca8ee07 commit d8f64ce
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 3 deletions.
2 changes: 1 addition & 1 deletion burn-ndarray/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod tests {
type TestBackend = crate::NdArrayBackend<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;

use alloc::format;
burn_tensor::testgen_all!();

#[cfg(feature = "std")]
Expand Down
90 changes: 88 additions & 2 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Range;
use core::{fmt::Debug, ops::Range};

use crate::{backend::Backend, Bool, Data, Float, Int, Shape, TensorKind};

Expand Down Expand Up @@ -266,13 +269,93 @@ where
}
}

impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
K: BasicOps<B>,
<K as BasicOps<B>>::Elem: Debug,
{
/// Recursively formats the tensor data for display and appends it to the provided accumulator string.
///
/// This function is designed to work with tensors of any dimensionality.
/// It traverses the tensor dimensions recursively, converting the elements
/// to strings and appending them to the accumulator string with the
/// appropriate formatting.
///
/// # Arguments
///
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
/// * `depth` - The current depth of the tensor dimensions being processed.
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
fn display_recursive(&self, acc: &mut String, depth: usize, multi_index: &mut [usize]) {
if depth == 0 {
acc.push('[');
}

if depth == self.dims().len() - 1 {
// if we are at the innermost dimension, just push its elements into the accumulator
for i in 0..self.dims()[depth] {
if i > 0 {
acc.push_str(", ");
}
multi_index[depth] = i;
let range: [core::ops::Range<usize>; D] =
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
let elem = &self.clone().index(range).to_data().value[0];
acc.push_str(&format!("{:?}", elem));
}
} else {
// otherwise, iterate through the current dimension and recursively display the inner tensors
for i in 0..self.dims()[depth] {
if i > 0 {
acc.push_str(", ");
}
acc.push('[');
multi_index[depth] = i;
self.display_recursive(acc, depth + 1, multi_index);
acc.push(']');
}
}

if depth == 0 {
acc.push(']');
}
}
}

/// Pretty print tensors
impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
where
B: Backend,
B::IntElem: core::fmt::Display,
K: BasicOps<B>,
<K as BasicOps<B>>::Elem: Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "Tensor {{")?;
write!(f, " data: ")?;

let mut acc = String::new();
let mut multi_index = vec![0; D];
self.display_recursive(&mut acc, 0, &mut multi_index);
write!(f, "{}", acc)?;
writeln!(f, ",")?;
writeln!(f, " shape: {:?},", self.dims())?;
writeln!(f, " device: {:?},", self.device())?;
writeln!(f, " backend: {:?},", B::name())?;
writeln!(f, " kind: {:?},", K::name())?;
writeln!(f, " dtype: {:?},", K::elem_type_name())?;
write!(f, "}}")
}
}

/// Trait that list all operations that can be applied on all tensors.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
type Elem;
type Elem: 'static;

fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D>;
Expand Down Expand Up @@ -310,6 +393,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
fn elem_type_name() -> &'static str {
core::any::type_name::<Self::Elem>()
}
}

impl<B: Backend> BasicOps<B> for Float {
Expand Down
10 changes: 10 additions & 0 deletions burn-tensor/src/tensor/api/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,26 @@ pub struct Bool;

pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
type Primitive<const D: usize>: Clone + core::fmt::Debug;
fn name() -> &'static str;
}

impl<B: Backend> TensorKind<B> for Float {
type Primitive<const D: usize> = B::TensorPrimitive<D>;
fn name() -> &'static str {
"Float"
}
}

impl<B: Backend> TensorKind<B> for Int {
type Primitive<const D: usize> = B::IntTensorPrimitive<D>;
fn name() -> &'static str {
"Int"
}
}

impl<B: Backend> TensorKind<B> for Bool {
type Primitive<const D: usize> = B::BoolTensorPrimitive<D>;
fn name() -> &'static str {
"Bool"
}
}
81 changes: 81 additions & 0 deletions burn-tensor/src/tests/stats/basic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[burn_tensor_testgen::testgen(stats)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Tensor};

#[test]
Expand All @@ -13,4 +14,84 @@ mod tests {
let data_expected = Data::from([[2.4892], [15.3333]]);
data_expected.assert_approx_eq(&data_actual, 3);
}

#[test]
fn test_display_2d_int_tensor() {
let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
let tensor_int: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Int> =
Tensor::from_data(int_data);

let output = format!("{}", tensor_int);
let expected = format!(
"Tensor {{\n data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
TestBackend::name()
);
assert_eq!(output, expected);
}

#[test]
fn test_display_2d_float_tensor() {
let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]);
let tensor_float: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Float> =
Tensor::from_data(float_data);

let output = format!("{}", tensor_float);
let expected = format!(
"Tensor {{\n data: [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Float\",\n dtype: \"f32\",\n}}",
TestBackend::name()
);
assert_eq!(output, expected);
}

#[test]
fn test_display_2d_bool_tensor() {
let bool_data = Data::from([
[true, false, true],
[false, true, false],
[false, true, true],
]);
let tensor_bool: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Bool> =
Tensor::from_data(bool_data);

let output = format!("{}", tensor_bool);
let expected = format!(
"Tensor {{\n data: [[true, false, true], [false, true, false], [false, true, true]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Bool\",\n dtype: \"bool\",\n}}",
TestBackend::name()
);
assert_eq!(output, expected);
}

#[test]
fn test_display_3d_tensor() {
let data = Data::from([
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],
]);
let tensor: burn_tensor::Tensor<TestBackend, 3, burn_tensor::Int> = Tensor::from_data(data);

let output = format!("{}", tensor);
let expected = format!(
"Tensor {{\n data: [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], \
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]],\n shape: [2, 3, 4],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
TestBackend::name()
);
assert_eq!(output, expected);
}

#[test]
fn test_display_4d_tensor() {
let data = Data::from([
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]],
]);

let tensor: burn_tensor::Tensor<TestBackend, 4, burn_tensor::Int> = Tensor::from_data(data);

let output = format!("{}", tensor);
let expected = format!(
"Tensor {{\n data: [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]],\n shape: [2, 2, 2, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
TestBackend::name()
);
assert_eq!(output, expected);
}
}

0 comments on commit d8f64ce

Please sign in to comment.