# Autodifferentiation and Gradient Descent in Burn

This notebook demonstrates how to use automatic differentiation in Burn to compute gradients and implement gradient descent.

In [13]:
// Dependency declarations
:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}
:dep burn-autodiff = {path = "../../crates/burn-autodiff"}


In [14]:
// Import packages
use burn::prelude::*;
use burn_autodiff::Autodiff;
use burn_ndarray::NdArray;

// Type alias: Autodiff<NdArray> enables automatic differentiation
type B = Autodiff<NdArray<f32>>;


## 1. Understanding require_grad()

In Burn, tensors can be marked for gradient tracking using `.require_grad()`. This tells the framework to track operations on this tensor so gradients can be computed later.

In [15]:
let device = <B as Backend>::Device::default();

// Create a regular tensor - no gradient tracking
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
println!("Regular tensor x: {}", x);

// Create a tensor that requires gradient computation
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
println!("Tensor y with require_grad: {}", y);

// Now let's do some operations on y
let z = y.clone() * 2.0;
let result = z.sum();
println!("result = sum(y * 2) = {}", result);


Regular tensor x: Tensor {
  data:
[1.0, 2.0, 3.0, 4.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
Tensor y with require_grad: Tensor {
  data:
[1.0, 2.0, 3.0, 4.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
result = sum(y * 2) = Tensor {
  data:
[20.0],
  shape:  [1],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}


## 2. Computing Gradients with backward()

The `.backward()` method computes the gradients of all tensors that have `require_grad()` set. It returns a gradients object that holds the computed gradients.

In [16]:
// Example: y = [1, 2, 3, 4]
// z = y * 2 = [2, 4, 6, 8]
// result = sum(z) = 20
//
// d(result)/d(y) = d(result)/dz * dz/dy = 1 * 2 = [2, 2, 2, 2]

let device = <B as Backend>::Device::default();
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let z = y.clone() * 2.0;
let result = z.sum();

// Compute gradients
let grads = result.backward();

// Get gradient for y
let y_grad = y.grad(&grads).unwrap();
println!("y = {}", y);
println!("d(result)/dy = {}", y_grad);


y = Tensor {
  data:
[1.0, 2.0, 3.0, 4.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
dy/dx = Tensor {
  data:
[2.0, 2.0, 2.0, 2.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}


## 3. More Complex Example: Quadratic Function
Let's compute the gradient of a more complex function: f(x) = x²

The derivative is: f'(x) = 2x

In [18]:
// f(x) = x^2
// f'(x) = 2x

let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let y = x.clone().powf_scalar(2.0);
let result = y.clone().sum();

let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();

println!("x = {}", x);
println!("x^2 = {}", y);
println!("d(x^2)/dx = {}", x_grad);

// Verify: d(x^2)/dx should be [2, 4, 6, 8]
println!("Expected: [2, 4, 6, 8]");


x = Tensor {
  data:
[1.0, 2.0, 3.0, 4.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
x^2 = Tensor {
  data:
[1.0, 4.0, 9.0, 16.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
d(x^2)/dx = Tensor {
  data:
[2.0, 4.0, 6.0, 8.0],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Expected: [2, 4, 6, 8]


## 4. Chain Rule Example

Let's verify the chain rule: f(g(x))' = f'(g(x)) * g'(x)

Example: y = sin(x²), we want dy/dx

Let u = x², y = sin(u)
dy/du = cos(u), du/dx = 2x
dy/dx = cos(x²) * 2x

In [26]:
// y = sin(x^2)
// dy/dx = cos(x^2) * 2x

let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device).require_grad();

// Forward pass
let x_squared = x.clone().powf_scalar(2.0);
let y = x_squared.sin();
let result = y.clone().sum();

// Backward pass
let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();

println!("x = {}", x);
println!("y = sin(x^2) = {}", y);
println!("dy/dx = {}", x_grad);

// Verify manually: cos(x^2) * 2x
let expected_grad = x.clone().powf_scalar(2.0).cos() * (x.clone() * 2.0);
println!("Expected (cos(x^2) * 2x): {}", expected_grad);


x = Tensor {
  data:
[0.0, 1.0, 2.0, 3.0],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
y = sin(x^2) = Tensor {
  data:
[0.0, 0.84147096, -0.7568025, 0.4121185],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}
dy/dx = Tensor {
  data:
[0.0, 1.0806046, -2.6145744, -5.4667816],
  shape:  [4],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Expected (cos(x^2) * 2x): Tensor {
  data:
[0.0, 1.0806046, -2.6145744, -5.4667816],
  shape:  [4],
  device:  Cpu,
  backend:  "autodiff<ndarray>",
  kind:  "Float",
  dtype:  "f32",
}


## 5. Gradient Descent from Scratch

Now let's implement the classic gradient descent algorithm to find the minimum of a function.

We'll minimize: f(x) = (x - 3)²

The minimum is at x = 3, where f(x) = 0

In [24]:
// Target: minimize f(x) = (x - 3)^2
// This has minimum at x = 3

fn loss<B: Backend>(x: &Tensor<B, 1>) -> Tensor<B, 1> {
    // f(x) = (x - 3)^2
    (x.clone() - 3.0).powf_scalar(2.0)
}

let device = <B as Backend>::Device::default();
// Start from x = 0
let mut x_val: f32 = 0.0;

let learning_rate: f32 = 0.1;

println!("Starting gradient descent to minimize (x - 3)^2");
println!("Expected minimum: x = 3");
println!("---");

for i in 0..20 {
    // Create a new tensor with current x value and require gradients
    let x = Tensor::<B, 1>::from_floats([x_val], &device).require_grad();
    
    // Forward pass
    let loss_value = loss(&x);
    
    // Get loss as f32 for printing
    let loss_scalar: f32 = loss_value.clone().into_scalar().elem::<f32>();
    
    println!("Iteration {}: x = {:.4}, loss = {:.4}", i, x_val, loss_scalar);

    // Backward pass
    let grads = loss_value.backward();
    let grad = x.grad(&grads).unwrap();
    
    // Update: x = x - learning_rate * gradient
    let grad_val: f32 = grad.into_scalar().elem::<f32>();
    x_val = x_val - grad_val * learning_rate;
}

println!("---");
println!("Final x = {:.4}", x_val);


Starting gradient descent to minimize (x - 3)^2
Expected minimum: x = 3
---
Iteration 0: x = 0.0000, loss = 9.0000
Iteration 1: x = 0.6000, loss = 5.7600
Iteration 2: x = 1.0800, loss = 3.6864
Iteration 3: x = 1.4640, loss = 2.3593
Iteration 4: x = 1.7712, loss = 1.5099
Iteration 5: x = 2.0170, loss = 0.9664
Iteration 6: x = 2.2136, loss = 0.6185
Iteration 7: x = 2.3709, loss = 0.3958
Iteration 8: x = 2.4967, loss = 0.2533
Iteration 9: x = 2.5973, loss = 0.1621
Iteration 10: x = 2.6779, loss = 0.1038
Iteration 11: x = 2.7423, loss = 0.0664
Iteration 12: x = 2.7938, loss = 0.0425
Iteration 13: x = 2.8351, loss = 0.0272
Iteration 14: x = 2.8681, loss = 0.0174
Iteration 15: x = 2.8944, loss = 0.0111
Iteration 16: x = 2.9156, loss = 0.0071
Iteration 17: x = 2.9324, loss = 0.0046
Iteration 18: x = 2.9460, loss = 0.0029
Iteration 19: x = 2.9568, loss = 0.0019
---
Final x = 2.9654


## 6. Linear Regression with Gradient Descent

Let's use gradient descent to fit a simple linear regression model: y = wx + b

We'll generate synthetic data where the true relationship is y = 2x + 1

In [29]:
use burn::tensor::{Distribution, TensorData};

let device = <B as Backend>::Device::default();
// Generate synthetic data: y = 2x + 1 + noise
let num_samples = 100;
let x_data: Vec<f32> = (0..num_samples).map(|i| i as f32 / 10.0).collect();
// Generate noise using Burn's random tensor
let noise = Tensor::<B, 1>::random([num_samples], Distribution::Uniform(-0.25, 0.25), &device);
let y_data: Vec<f32> = x_data.iter().enumerate().map(|(i, &x)| {
    2.0 * x + 1.0 + noise.clone().slice([i..i+1]).into_scalar().elem::<f32>()
}).collect();

let x: Tensor<B, 2> = TensorData::new(x_data.clone(), [num_samples, 1]).into();
let y: Tensor<B, 2> = TensorData::new(y_data.clone(), [num_samples, 1]).into();

println!("Generated {} data points", num_samples);
println!("True relationship: y = 2x + 1");
println!("First 5 x values: {:?}", &x_data[0..5]);
println!("First 5 y values: {:?}", &y_data[0..5]);


Generated 100 data points
True relationship: y = 2x + 1
First 5 x values: [0.0, 0.1, 0.2, 0.3, 0.4]
First 5 y values: [0.87993187, 0.98804677, 1.5366085, 1.7324162, 1.653858]


In [31]:
// Initialize weights randomly
let device = <B as Backend>::Device::default();
let mut w_val: f32 = 0.5; // Start with reasonable initial values
let mut b_val: f32 = 0.5;

let learning_rate: f32 = 0.01;
let num_epochs = 100;

println!("Training linear regression with gradient descent...");
println!("Initial w = {:.4}, b = {:.4}", w_val, b_val);

for epoch in 0..num_epochs {
    // Create tensors with current parameter values
    let w = Tensor::<B, 2>::from_floats([[w_val]], &device).require_grad();
    let b = Tensor::<B, 2>::from_floats([[b_val]], &device).require_grad();
    
    // Forward pass: y_pred = w * x + b
    let y_pred = x.clone().matmul(w.clone()) + b.clone();
    
    // Compute loss: MSE = (1/n) * sum((y_pred - y)^2)
    let loss = (y_pred.clone() - y.clone()).powf_scalar(2.0).mean();
    
    // Backward pass
    let grads = loss.backward();
    let w_grad = w.grad(&grads).unwrap();
    let b_grad = b.grad(&grads).unwrap();
    
    // Update weights
    let w_grad_val: f32 = w_grad.into_scalar().elem::<f32>();
    let b_grad_val: f32 = b_grad.into_scalar().elem::<f32>();
    w_val = w_val - w_grad_val * learning_rate;
    b_val = b_val - b_grad_val * learning_rate;
    
    if epoch % 20 == 0 {
        let loss_val: f32 = loss.clone().into_scalar().elem::<f32>();
        println!("Epoch {:3}: loss = {:.4}, w = {:.4}, b = {:.4}", epoch, loss_val, w_val, b_val);
    }
}

println!("---");
println!("Final: w = {:.4}, b = {:.4}", w_val, b_val);
println!("True: w = 2.0, b = 1.0");


Training linear regression with gradient descent...
Initial w = 0.5000, b = 0.5000
Epoch   0: loss = 81.7705, w = 1.5358, b = 0.6586
Epoch  20: loss = 0.0365, w = 2.0384, b = 0.7594
Epoch  40: loss = 0.0341, w = 2.0351, b = 0.7810
Epoch  60: loss = 0.0321, w = 2.0322, b = 0.8006
Epoch  80: loss = 0.0305, w = 2.0295, b = 0.8184
---
Final: w = 2.0272, b = 0.8336
True: w = 2.0, b = 1.0


## Summary

In this notebook, we covered:

- **require_grad()**: Mark tensors for gradient tracking
- **backward()**: Compute gradients automatically using reverse-mode autodiff
- **grad()**: Retrieve computed gradients
- **Gradient Descent**: Implemented from scratch to minimize a quadratic function
- **Linear Regression**: Used gradient descent to fit a linear model to data

These concepts are the foundation of neural network training in Burn!