# Tensor Operations in Burn

This notebook demonstrates basic tensor operations in Burn, a deep learning framework written in Rust.
Firstly, 
```shell
cargo install evcxr_jupyter
```

Then, 
```shell
evcxr_jupyter --install
```

In [7]:

// Dependency declarations for the notebook.
// The syntax is similar to Cargo.toml. Just prefix with :dep

:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}

In [None]:
// Import packages
use burn::prelude::*;
use burn_ndarray::NdArray;

// Type alias for the backend (using CPU/NdArray)
type B = NdArray<f32>;

## 1. Tensor Creation

In [4]:
let device = <B as Backend>::Device::default();

// Create an empty tensor (uninitialized values)
let empty: Tensor<B, 3> = Tensor::empty([2, 3, 4], &device);
println!("Empty tensor shape: {:?}", empty.shape());

// Create a tensor filled with zeros
let zeros: Tensor<B, 2> = Tensor::zeros([3, 3], &device);
println!("Zeros tensor: {}", zeros);

// Create a tensor filled with ones
let ones: Tensor<B, 2> = Tensor::ones([2, 4], &device);
println!("Ones tensor: {}", ones);

// Create a tensor filled with a specific value
let full: Tensor<B, 2> = Tensor::full([2, 3], 7.0, &device);
println!("Full tensor (7.0): {}", full);

Empty tensor shape: Shape { dims: [2, 3, 4] }
Zeros tensor: Tensor {
  data:
[[0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0]],
  shape:  [3, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Ones tensor: Tensor {
  data:
[[1.0, 1.0, 1.0, 1.0],
 [1.0, 1.0, 1.0, 1.0]],
  shape:  [2, 4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Full tensor (7.0): Tensor {
  data:
[[7.0, 7.0, 7.0],
 [7.0, 7.0, 7.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [13]:
// Create a tensor from a slice of values
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let from_slice = Tensor::<B, 1>::from_floats(data, &device).reshape([2, 3]);
println!("From slice:\n{}", from_slice);

// Create a random tensor
use burn::tensor::Distribution;
let random: Tensor<B, 1> = Tensor::random([5], Distribution::Default, &device);
println!("Random tensor: {}", random);

// Create a tensor with normal distribution
let normal: Tensor<B, 1> = Tensor::random([5], Distribution::Normal(0.0, 1.0), &device);
println!("Normal distribution: {}", normal);

// Create a tensor with uniform distribution in range [0, 10)
let uniform: Tensor<B, 1> = Tensor::random([5], Distribution::Uniform(0.0, 10.0), &device);
println!("Uniform [0, 10): {}", uniform);

From slice:
Tensor {
  data:
[[1.0, 2.0, 3.0],
 [4.0, 5.0, 6.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Random tensor: Tensor {
  data:
[0.32371014, 0.41100568, 0.94457513, 0.8408601, 0.42262083],
  shape:  [5],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Normal distribution: Tensor {
  data:
[-0.22402725, 1.8367178, -1.1049407, -0.6302627, 1.1106112],
  shape:  [5],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Uniform [0, 10): Tensor {
  data:
[8.110331, 7.335061, 9.858947, 6.0834813, 3.6619747],
  shape:  [5],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 2. Shape Operations

In [16]:
// Reshape tensor - change the dimensions without changing the data
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Original (2x3):\n{}", tensor);

let reshaped: Tensor<B, 3> = tensor.clone().reshape([1, 2, 3]);
println!("Reshaped (1x2x3): {}", reshaped);

// Flatten - reshape to 1D
let flat: Tensor<B, 1> = tensor.flatten(0, 1);
println!("Flattened: {}", flat);

Original (2x3):
Tensor {
  data:
[[1.0, 2.0, 3.0],
 [4.0, 5.0, 6.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Reshaped (1x2x3): Tensor {
  data:
[[[1.0, 2.0, 3.0],
  [4.0, 5.0, 6.0]]],
  shape:  [1, 2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Flattened: Tensor {
  data:
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
  shape:  [6],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [18]:
// Transpose - swap dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Original:\n{}", tensor);

let transposed = tensor.clone().transpose();
println!("Transposed:\n{}", transposed);

// Also .t() works for 2D tensors
let t = tensor.t();
println!("Using .t():\n{}", t);

Original:
Tensor {
  data:
[[1.0, 2.0],
 [3.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Transposed:
Tensor {
  data:
[[1.0, 3.0],
 [2.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Using .t():
Tensor {
  data:
[[1.0, 3.0],
 [2.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [20]:
// Squeeze - remove dimensions of size 1
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device).reshape([1, 1, 2]);
println!("Before squeeze [1,1,2]: shape = {:?}", tensor.shape());

let squeezed = tensor.squeeze::<1>();
println!("After squeeze: shape = {:?}", squeezed.shape());

// Unsqueeze - add a dimension of size 1 at specified position
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Before unsqueeze [2,2]: shape = {:?}", tensor.shape());

let unsqueezed = tensor.unsqueeze::<3>();
println!("After unsqueeze: shape = {:?}", unsqueezed.shape());

Before squeeze [1,1,2]: shape = Shape { dims: [1, 1, 2] }
After squeeze: shape = Shape { dims: [2] }
Before unsqueeze [2,2]: shape = Shape { dims: [2, 2] }
After unsqueeze: shape = Shape { dims: [1, 2, 2] }


## 3. Indexing and Slicing

In [22]:
// Create a tensor for indexing examples
let tensor = Tensor::<B, 1>::from_floats(
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
&device
).reshape([3, 4]);
println!("Original tensor:\n{}", tensor);

Original tensor:
Tensor {
  data:
[[1.0, 2.0, 3.0, 4.0],
 [5.0, 6.0, 7.0, 8.0],
 [9.0, 10.0, 11.0, 12.0]],
  shape:  [3, 4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [24]:
// Slice tensor - select a portion using ranges
// Get rows 1-2 (index 1 to end), columns 1-3 (index 1 to 3)
let sliced = tensor.clone().slice([1..3, 1..4]);
println!("Sliced [1..3, 1..4]:\n{}", sliced);

// Get single row
let row = tensor.clone().slice([1..2, 0..4]);
println!("Row 1: {}", row);

// Get single column
let col = tensor.slice([0..3, 2..3]);
println!("Column 2: {}", col);

Sliced [1..3, 1..4]:
Tensor {
  data:
[[6.0, 7.0, 8.0],
 [10.0, 11.0, 12.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Row 1: Tensor {
  data:
[[5.0, 6.0, 7.0, 8.0]],
  shape:  [1, 4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Column 2: Tensor {
  data:
[[3.0],
 [7.0],
 [11.0]],
  shape:  [3, 1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 4. Basic Math Operations

In [26]:
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);

println!("a = {}", a);
println!("b = {}", b);

// Addition
let c = a.clone() + b.clone();
println!("a + b = {}", c);

// Subtraction
let c = a.clone() - b.clone();
println!("a - b = {}", c);

// Multiplication (element-wise)
let c = a.clone() * b.clone();
println!("a * b = {}", c);

// Division (element-wise)
let c = a.clone() / b.clone();
println!("a / b = {}", c);

a = Tensor {
  data:
[[1.0, 2.0],
 [3.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
b = Tensor {
  data:
[[5.0, 6.0],
 [7.0, 8.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a + b = Tensor {
  data:
[[6.0, 8.0],
 [10.0, 12.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a - b = Tensor {
  data:
[[-4.0, -4.0],
 [-4.0, -4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a * b = Tensor {
  data:
[[5.0, 12.0],
 [21.0, 32.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a / b = Tensor {
  data:
[[0.2, 0.33333334],
 [0.42857143, 0.5]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [28]:
// Scalar operations
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);

println!("a = {}", a);

// Add scalar
let c = a.clone() + 10.0;
println!("a + 10 = {}", c);

// Multiply scalar
let c = a.clone() * 2.0;
println!("a * 2 = {}", c);

a = Tensor {
  data:
[[1.0, 2.0],
 [3.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a + 10 = Tensor {
  data:
[[11.0, 12.0],
 [13.0, 14.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a * 2 = Tensor {
  data:
[[2.0, 4.0],
 [6.0, 8.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [30]:
// Matrix multiplication
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);

println!("a = {}", a);
println!("b = {}", b);

let result = a.matmul(b);
println!("a @ b (matmul) = {}", result);

// Verify: [1,2 @ 5,7] = [1*5+2*7, 1*6+2*8] = [19, 22]
//         [3,4 @ 5,7] = [3*5+4*7, 3*6+4*8] = [43, 50]

a = Tensor {
  data:
[[1.0, 2.0],
 [3.0, 4.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
b = Tensor {
  data:
[[5.0, 6.0],
 [7.0, 8.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a @ b (matmul) = Tensor {
  data:
[[19.0, 22.0],
 [43.0, 50.0]],
  shape:  [2, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 5. Element-wise Math Functions

In [32]:
let a: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);

println!("a = {}", a);

// Exponential
println!("exp(a) = {}", a.clone().exp());

// Natural logarithm
println!("log(a + 1) = {}", (a.clone() + 1.0).log());

// Power
println!("a.powf(2) = {}", a.clone().powf_scalar(2.0));
println!("a.powf(0.5) = {}", a.clone().powf_scalar(0.5));

a = Tensor {
  data:
[0.0, 1.0, 2.0],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
exp(a) = Tensor {
  data:
[1.0, 2.7182817, 7.389056],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
log(a + 1) = Tensor {
  data:
[0.0, 0.6931472, 1.0986123],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a.powf(2) = Tensor {
  data:
[0.0, 1.0, 4.0],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a.powf(0.5) = Tensor {
  data:
[0.0, 1.0, 1.4142135],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [33]:
// Trigonometric functions
let angles: Tensor<B, 1> = Tensor::from_floats([0.0, std::f32::consts::PI / 4.0, std::f32::consts::PI / 2.0], &device);

println!("angles = {}", angles);
println!("sin(angles) = {}", angles.clone().sin());
println!("cos(angles) = {}", angles.clone().cos());
println!("tan(angles) = {}", angles.clone().tan());

angles = Tensor {
  data:
[0.0, 0.7853982, 1.5707964],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
sin(angles) = Tensor {
  data:
[0.0, 0.70710677, 1.0],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
cos(angles) = Tensor {
  data:
[1.0, 0.70710677, -4.371139e-8],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
tan(angles) = Tensor {
  data:
[0.0, 1.0, -22877332.0],
  shape:  [3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 6. Reduction Operations

In [35]:
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);

// Sum all elements
println!("Sum: {}", tensor.clone().sum());

// Mean of all elements
println!("Mean: {}", tensor.clone().mean());

// Product of all elements
println!("Product: {}", tensor.clone().prod());

// Maximum and minimum
println!("Max: {}", tensor.clone().max());
println!("Min: {}", tensor.clone().min());

Tensor:
Tensor {
  data:
[[1.0, 2.0, 3.0],
 [4.0, 5.0, 6.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Sum: Tensor {
  data:
[21.0],
  shape:  [1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Mean: Tensor {
  data:
[3.5],
  shape:  [1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Product: Tensor {
  data:
[720.0],
  shape:  [1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Max: Tensor {
  data:
[6.0],
  shape:  [1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Min: Tensor {
  data:
[1.0],
  shape:  [1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


In [37]:
// Reduce along specific dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);

// Sum along dimension 0 (columns)
println!("Sum dim 0: {}", tensor.clone().sum_dim(0));

// Sum along dimension 1 (rows)
println!("Sum dim 1: {}", tensor.clone().sum_dim(1));

// Mean along dimension 0
println!("Mean dim 0: {}", tensor.clone().mean_dim(0));

Tensor:
Tensor {
  data:
[[1.0, 2.0, 3.0],
 [4.0, 5.0, 6.0]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Sum dim 0: Tensor {
  data:
[[5.0, 7.0, 9.0]],
  shape:  [1, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Sum dim 1: Tensor {
  data:
[[6.0],
 [15.0]],
  shape:  [2, 1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Mean dim 0: Tensor {
  data:
[[2.5, 3.5, 4.5]],
  shape:  [1, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 7. Comparison and Selection

In [42]:
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);
let b: Tensor<B, 1> = Tensor::from_floats([4.0, 2.0, 6.0, 7.0], &device);

println!("a = {}", a);
println!("b = {}", b);

// Element-wise comparison returns a boolean tensor
let greater = a.clone().greater(b.clone());
println!("a > b: {}", greater);

let less = a.clone().lower(b.clone());
println!("a < b: {}", less);

let equal = a.clone().equal(b.clone());
println!("a == b: {}", equal);

a = Tensor {
  data:
[1.0, 5.0, 3.0, 8.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
b = Tensor {
  data:
[4.0, 2.0, 6.0, 7.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
a > b: Tensor {
  data:
[false, true, false, true],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Bool",
  dtype:  "bool",
}
a < b: Tensor {
  data:
[true, false, true, false],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Bool",
  dtype:  "bool",
}
a == b: Tensor {
  data:
[false, false, false, false],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Bool",
  dtype:  "bool",
}


In [41]:
// Conditional selection
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);

// mask_where: where condition is true, keep original value, else use replacement
let condition = a.clone().greater_elem(4.0);
let result = a.clone().mask_where(condition, Tensor::zeros([4], &device));
println!("Original: {}", a);
println!("Where > 4, replace with 0: {}", result);

// mask_fill: simpler - just replace values matching condition
let result = a.clone().mask_fill(a.clone().greater_elem(4.0), -1.0);
println!("Where > 4, replace with -1: {}", result);

Original: Tensor {
  data:
[1.0, 5.0, 3.0, 8.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Where > 4, replace with 0: Tensor {
  data:
[1.0, 0.0, 3.0, 0.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Where > 4, replace with -1: Tensor {
  data:
[1.0, -1.0, 3.0, -1.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## Summary

In this notebook, we covered:
- **Tensor Creation**: empty, zeros, ones, full, from_floats, random
- **Shape Operations**: reshape, transpose, flatten, squeeze, unsqueeze
- **Indexing and Slicing**: slice operation with ranges
- **Math Operations**: add, sub, mul, div, matmul
- **Element-wise Functions**: exp, log, powf_scalar, sin, cos, tan
- **Reduction Operations**: sum, mean, prod, max, min
- **Comparison**: greater, lower, equal, mask_where, mask_fill

These are the fundamental building blocks for building neural networks in Burn!

Note: Burn's API differs from PyTorch in several ways:
- Uses Rust's type system with explicit dimensions
- Method names like `lower` instead of `lt`, `powf_scalar` instead of `pow`
- Requires explicit device management and backend specification
- Distribution parameters use tuple syntax: `Normal(mean, std)` instead of keyword arguments