# Miniminal Autodiff Engine

This notebook contains a minimal implementation of an automatic differentiation engine that supports reverse-mode autodiff on N-dimensional tensors:
- Tensor: core data structure and computation graph
- Optimizers: update model parameters using gradients (SGD, SGDMomentum, Adam)
- MLP module: neural network library with linear layers and activations
- Loss and training functions: measure prediction error and run training loop

Note: this project is self-contained and won't be shared with CNN / Transformer / Diffusion, those projects will use Pytorch autograd.

## Setup

In [285]:
import numpy as np
import matplotlib.pyplot as plt
np.set_printoptions(precision=3)
np.random.seed(0)

def _unbroadcast(grad, target_shape):
    """
    Reduce a gradient to match the given target shape.

    Params:
        grad: np.array, the gradient with the broadcasted shape
        target_shape: tuple of int, the shape of the tensor before broadcasting

    Returns:
        result: np.array, the gradient with reduced dimensions and summed so it matches the target shape
    """
    # grad.ndim is number of axes, len(target_shape) is desired number of axes.
    # For [[a, b], [c, d]], the result of sum(axis=0) is [a+c, b+d]
    while grad.ndim > len(target_shape):
        grad = grad.sum(axis=0)
    
    # For every axis where the original tensor had size 1 and was broadcasted, we sum
    # the gradient along that axis to collapse it back.
    for axis, dim in reversed(list(enumerate(target_shape))):
        if dim == 1:
            grad = grad.sum(axis=axis, keepdims=True)
    return grad

## Tensor

The `Tensor` is a core data structure used to construct computation graphs or neural networks that process input data, generate predictions, and evaluate a loss function. It also enables us to perform backpropagation to compute gradients at each node so the network's weights can be updated during training.

In [286]:
class Tensor:
    def __init__(self, data, requires_grad=False):
        # Numpy scalar, vector, matrix
        self.data = np.array(data, dtype=np.float64)

        # Determines if we should store gradient
        self.requires_grad = requires_grad

        # Gradient of the loss wrt this tensor (same shape as .data)
        self.grad = None

        # Closure that computes local gradients and accumulates them into parent tensors
        self._backward = lambda: None

        # Parent tensors
        self._prev = ()

        # Op name for logging
        self._op = ""

    def __repr__(self):
        # String representation for logging
        return f"Tensor(data={self.data}, grad={self.grad}, op={self._op})"

    def backward(self):
        """Performs backpropagation on the computation graph and updates internal variables."""
        if self.data.shape != ():
            raise ValueError("backward() assumes scalar output on the loss layer")

        # Set loss gradient to ones
        self.grad = np.array(1.0, dtype=np.float64)

        # Build DFS topological sort of computation graph
        topo, visited = [], set()
        def build(v):
            if v not in visited:
                visited.add(v)
                for p in v._prev:
                    build(p)
                topo.append(v)
        build(self)

        # Reverse topo sort order (starting from the outermost / loss layer)
        for v in reversed(topo):
            if v.grad is not None:
                v._backward()

    def __add__(self, other):
        """Elementwise addition with broadcasting, returns a new Tensor and sets up gradient flow to both parents"""
        if not isinstance(other, Tensor):
            other = Tensor(other)

        # Construct new tensor resulting from the operation
        output_data = self.data + other.data
        output = Tensor(data=output_data, requires_grad=(self.requires_grad or other.requires_grad))
        output._prev = (self, other)
        output._op = "add"

        # For the new tensor, set backprop method to update the gradients of its parents.
        # For addition, the local derivative with respect to each parent is 1, so the upstream
        # gradient is passed straight through.
        # .grad has the gradient with respect to the outermost layer.
        def _backward():
            if self.requires_grad:
                grad = _unbroadcast(output.grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad
            if other.requires_grad:
                grad = _unbroadcast(output.grad, other.data.shape)
                other.grad = grad if other.grad is None else other.grad + grad

        output._backward = _backward
        return output

    def __mul__(self, other):
        """Elementwise multiplication with broadcasting, returns a new Tensor and sets up gradient flow to both parents"""
        if not isinstance(other, Tensor):
            other = Tensor(other)

        output_data = self.data * other.data
        output = Tensor(data=output_data, requires_grad=(self.requires_grad or other.requires_grad))
        output._prev = (self, other)
        output._op = "mul"

        def _backward():
            if self.requires_grad:
                grad = output.grad * other.data
                grad = _unbroadcast(grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad
            if other.requires_grad:
                grad = output.grad * self.data
                grad = _unbroadcast(grad, other.data.shape)
                other.grad = grad if other.grad is None else other.grad + grad

        output._backward = _backward
        return output

    def __rmul__(self, other):
        """Reverse multiplication operator, uses predefined ops"""
        if not isinstance(other, Tensor):
            other = Tensor(other)
        return other * self

    def __matmul__(self, other):
        """Matrix multiplication (@ operator), returns a new Tensor and sets up gradient flow to both parents"""
        if not isinstance(other, Tensor):
            other = Tensor(other)

        output_data = self.data @ other.data
        output = Tensor(data=output_data, requires_grad=(self.requires_grad or other.requires_grad))
        output._prev = (self, other)
        output._op = "matmul"

        def _backward():
            if self.requires_grad:
                if other.data.ndim == 1:
                    grad = np.outer(output.grad, other.data)
                else:
                    grad = output.grad @ other.data.T
                grad = _unbroadcast(grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad
            if other.requires_grad:
                if self.data.ndim == 1:
                    grad = np.outer(self.data, output.grad)
                else:
                    grad = self.data.T @ output.grad
                grad = _unbroadcast(grad, other.data.shape)
                other.grad = grad if other.grad is None else other.grad + grad

        output._backward = _backward
        return output

    def __rmatmul__(self, other):
        """Reverse matrix multiplication, uses predefined ops"""
        if not isinstance(other, Tensor):
            other = Tensor(other)
        return other.__matmul__(self)

    def __neg__(self):
        """Negative operator, uses predefined ops"""
        return self * -1

    def __sub__(self, other):
        """Negative operator, uses predefined ops"""
        if not isinstance(other, Tensor):
            other = Tensor(other)
        return self + (-other)

    def __rsub__(self, other):
        """Reverse subtraction operator, uses predefined ops"""
        if not isinstance(other, Tensor):
            other = Tensor(other)
        return other + (-self)

    def __pow__(self, power):
        """Power (** operator), returns a new Tensor and sets up gradient flow to both parents"""
        assert isinstance(power, (int, float)), "only scalar powers supported"
        output_data = self.data ** power
        output = Tensor(output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "pow"

        def _backward():
            if self.requires_grad:
                grad = output.grad * (power * (self.data ** (power - 1)))
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def __truediv__(self, other):
        """True division, uses predefined ops"""
        if not isinstance(other, Tensor):
            other = Tensor(other)
        return self * (other ** -1)

    def exp(self):
        """e^x, returns a new Tensor and sets up gradient flow to both parents"""
        output_data = np.exp(self.data)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "exp"

        def _backward():
            if self.requires_grad:
                grad = output.grad * output.data
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def log(self):
        """Natural logarithm, returns a new Tensor and sets up gradient flow to the parent"""
        output_data = np.log(self.data)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "log"

        def _backward():
            if self.requires_grad:
                grad = output.grad * (1.0 / self.data)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def relu(self):
        """ReLU activation, returns a new Tensor and sets up gradient flow to the parent"""
        output_data = np.maximum(0, self.data)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "relu"

        def _backward():
            if self.requires_grad:
                grad = output.grad * (self.data > 0).astype(self.data.dtype)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def sigmoid(self):
        """Sigmoid activation, returns a new Tensor and sets up gradient flow to the parent"""
        output_data = 1 / (1 + np.exp(-self.data))
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "sigmoid"

        def _backward():
            if self.requires_grad:
                grad = output.grad * (output.data * (1 - output.data))
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def tanh(self):
        """Tanh activation, returns a new Tensor and sets up gradient flow to the parent"""
        output_data = np.tanh(self.data)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "tanh"

        def _backward():
            if self.requires_grad:
                grad = output.grad * (1 - output.data ** 2)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def sum(self, axis=None, keepdims=False):
        """Sum of tensor elements over given axes, returns a new Tensor"""
        output_data = self.data.sum(axis=axis, keepdims=keepdims)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "sum"

        def _backward():
            if self.requires_grad:
                grad = output.grad
                if not keepdims and axis is not None:
                    grad = np.expand_dims(grad, axis=axis)
                grad = np.broadcast_to(grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def mean(self, axis=None, keepdims=False):
        """Mean of tensor elements over given axes, returns a new Tensor"""
        output_data = self.data.mean(axis=axis, keepdims=keepdims)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "mean"

        def _backward():
            if self.requires_grad:
                grad = output.grad
                if axis is None:
                    denom = self.data.size
                else:
                    axes = (axis,) if isinstance(axis, int) else axis
                    denom = 1
                    for a in axes:
                        denom *= self.data.shape[a]
                    if not keepdims:
                        grad = np.expand_dims(grad, axis=axis)

                grad = np.ones_like(self.data) * (grad / denom)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def max(self, axis=None, keepdims=False):
        """Max of tensor elements over given axes, returns a new Tensor"""
        output_data = self.data.max(axis=axis, keepdims=keepdims)
        output = Tensor(output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "max"

        def _backward():
            if self.requires_grad:
                # Mask out the maxima
                mask = (self.data == output.data)
                grad = mask * output.grad
                if not keepdims and axis is not None:
                    grad = np.expand_dims(grad, axis=axis)
                grad = _unbroadcast(grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def reshape(self, *shape):
        """Reshape tensor to given shape, returns a new Tensor"""
        output_data = self.data.reshape(*shape)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "reshape"

        def _backward():
            if self.requires_grad:
                grad = output.grad.reshape(self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def transpose(self, *axes):
        """Transpose tensor according to given axes, returns a new Tensor"""
        output_data = self.data.transpose(*axes)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "transpose"

        def _backward():
            if self.requires_grad:
                if axes:
                    inv_axes = np.argsort(axes)
                    grad = output.grad.transpose(*inv_axes)
                else:
                    grad = output.grad.transpose()
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def broadcast_to(self, shape):
        """Broadcast tensor to given shape, returns a new Tensor"""
        output_data = np.broadcast_to(self.data, shape)
        output = Tensor(data=output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "broadcast_to"

        def _backward():
            if self.requires_grad:
                grad = _unbroadcast(output.grad, self.data.shape)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

    def __getitem__(self, idx):
        """Slice operation ([...] to retrieve), returns a new Tensor"""
        output_data = self.data[idx]
        output = Tensor(output_data, requires_grad=self.requires_grad)
        output._prev = (self,)
        output._op = "slice"

        def _backward():
            if self.requires_grad:
                grad = np.zeros_like(self.data)
                np.add.at(grad, idx, output.grad)
                self.grad = grad if self.grad is None else self.grad + grad

        output._backward = _backward
        return output

### Testing

In [287]:
x = Tensor([1, 2], requires_grad=True)
W = Tensor([[1, 2], [1, 1]], requires_grad=True)
b = Tensor([1, 1], requires_grad=True)
y = W @ x + b
label = Tensor([5, 5])
loss = ((y - label) ** 2).sum()
loss.backward()

print("x:", x)
print("W:", W)
print("b:", b)
print("y:", y)
print("label:", label)
print("loss:", loss)

x: Tensor(data=[1. 2.], grad=[0. 2.], op=)
W: Tensor(data=[[1. 2.]
 [1. 1.]], grad=[[ 2.  4.]
 [-2. -4.]], op=)
b: Tensor(data=[1. 1.], grad=[ 2. -2.], op=)
y: Tensor(data=[6. 4.], grad=[ 2. -2.], op=add)
label: Tensor(data=[5. 5.], grad=None, op=)
loss: Tensor(data=2.0, grad=1.0, op=sum)


In [288]:
a = Tensor([3.0], requires_grad=True)
b = Tensor([4.0], requires_grad=True)
c = (a + b).sum()
c.backward()
print("add:", a.grad, b.grad, "(expected 1, 1)")

a = Tensor([3.0], requires_grad=True)
b = Tensor([4.0], requires_grad=True)
c = (a * b).sum()
c.backward()
print("mul:", a.grad, b.grad, "(expected 4, 3)")

x = Tensor([2.0], requires_grad=True)
y = (x ** 3).sum()
y.backward()
print("pow:", x.grad, "(expected 12)")

a = Tensor([6.0], requires_grad=True)
b = Tensor([2.0], requires_grad=True)
c = (a / b).sum()
c.backward()
print("div:", a.grad, b.grad, "(expected 1/2, -6/4 = -1.5)")

x = Tensor([2.0], requires_grad=True)
y = x.exp().sum()
y.backward()
print("exp:", x.grad, "(expected 7.389)")

x = Tensor([2.0], requires_grad=True)
y = x.log().sum()
y.backward()
print("log:", x.grad, "(expected 0.5)")

x = Tensor([-1.0, 2.0], requires_grad=True)
y = x.relu().sum()
y.backward()
print("relu:", x.grad, "(expected [0, 1])")

x = Tensor([0.0], requires_grad=True)
y = x.sigmoid().sum()
y.backward()
print("sigmoid:", x.grad, "(expected 0.25)")

x = Tensor([0.0], requires_grad=True)
y = x.tanh().sum()
y.backward()
print("tanh:", x.grad, "(expected 1.0)")

W = Tensor([[1.0, 2.0]], requires_grad=True)
x = Tensor([3.0, 4.0], requires_grad=True)
y = (W @ x).sum()
y.backward()
print("matmul:", W.grad, x.grad, "(expected [[3, 4]], [1, 2])")

x = Tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.sum()
y.backward()
print("sum:", x.grad, "(expected [1, 1, 1])")

x = Tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.mean()
y.backward()
print("mean:", x.grad, "(expected [1/3, 1/3, 1/3])")

x = Tensor([[1.0, 2.0]], requires_grad=True)
y = x.broadcast_to((3,2))
z = y.sum()
z.backward()
print("broadcast:", x.grad, "(expected [[3, 3]])")

x = Tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
y = x.reshape(4,)
z = y.transpose().sum()
z.backward()
print("reshape+transpose:", x.grad, "(expected all ones)")

add: [1.] [1.] (expected 1, 1)
mul: [4.] [3.] (expected 4, 3)
pow: [12.] (expected 12)
div: [0.5] [-1.5] (expected 1/2, -6/4 = -1.5)
exp: [7.389] (expected 7.389)
log: [0.5] (expected 0.5)
relu: [0. 1.] (expected [0, 1])
sigmoid: [0.25] (expected 0.25)
tanh: [1.] (expected 1.0)
matmul: [[3. 4.]] [1. 2.] (expected [[3, 4]], [1, 2])
sum: [1. 1. 1.] (expected [1, 1, 1])
mean: [0.333 0.333 0.333] (expected [1/3, 1/3, 1/3])
broadcast: [[3. 3.]] (expected [[3, 3]])
reshape+transpose: [[1. 1.]
 [1. 1.]] (expected all ones)


## Optimizers

As we saw in numerical optimization, after performing backpropagation and computing the gradient at each `Tensor`, there are different strategies on how to update the parameters based on the gradient and learning parameters.

In [289]:
class SGD:
    """Optimizer for performing stochastic gradient descent"""

    def __init__(self, params, lr=1e-3):
        self.params = [param for param in params if param.requires_grad]
        self.lr = lr

    def step(self):
        """Update .data based on the .grad and learning parameters"""
        for param in self.params:
            if param is not None:
                param.data = param.data - self.lr * param.grad

    def zero_grad(self):
        """Resets the parameters given to the optimizer"""
        for param in self.params:
            param.grad = None


class SGDMomentum:
    """Optimizer for performing stochastic gradient descent with momentum"""

    def __init__(self, params, lr=1e-2, momentum=0.9):
        self.params = [param for param in params if param.requires_grad]
        self.lr = lr
        self.momentum = momentum
        self.velocities = [np.zeros_like(param.data) for param in self.params]

    def step(self):
        """Update .data based on the .grad and learning parameters"""
        for idx, param in enumerate(self.params):
            if param.grad is not None:
                # Update velocities for next step
                self.velocities[idx] = self.momentum * self.velocities[idx] - self.lr * param.grad
                param.data = param.data + self.velocities[idx]

    def zero_grad(self):
        """Resets the parameters given to the optimizer"""
        for param in self.params:
            param.grad = None


class Adam:
    """Optimizer for performing Adam with first and second moments"""

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.params = [param for param in params if param.requires_grad]
        self.lr = lr
        self.betas = betas
        self.eps = eps
        self.time = 0
        self.momentum = [np.zeros_like(param.data) for param in self.params]
        self.velocities = [np.zeros_like(param.data) for param in self.params]

    def step(self):
        """Update .data based on the .grad and learning parameters"""
        self.time += 1
        beta1, beta2 = self.betas
        for idx, param in enumerate(self.params):
            if param.grad is None:
                continue
            grad = param.grad

            # Update first and second moment estimates
            self.momentum[idx] = beta1 * self.momentum[idx] + (1 - beta1) * grad
            self.velocities[idx] = beta2 * self.velocities[idx] + (1 - beta2) * (grad * grad)

            # Correct bias in estimates
            m_hat = self.momentum[idx] / (1 - beta1 ** self.time)
            v_hat = self.velocities[idx] / (1 - beta2 ** self.time)

            param.data -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)

    def zero_grad(self):
        """Resets the parameters given to the optimizer"""
        for param in self.params:
            param.grad = None

### Testing

In [290]:
def test_optimizer(optimizer_cls):
    print("Testing", optimizer_cls.__name__)
    w = Tensor([0.0], requires_grad=True)
    optimizer = optimizer_cls([w], lr=0.1)

    for step in range(20):
        # Forward: simple quadratic loss
        loss = ((w - Tensor([5.0])) ** 2).sum()
        loss.backward()

        # Optimizer step
        optimizer.step()
        optimizer.zero_grad()

        print(f"step {step:2d}: w = {w.data}, loss = {loss.data}")


test_optimizer(SGD)
print("\n\n==\n\n")
test_optimizer(SGDMomentum)
print("\n\n==\n\n")
test_optimizer(Adam)


Testing SGD
step  0: w = [1.], loss = 25.0
step  1: w = [1.8], loss = 16.0
step  2: w = [2.44], loss = 10.240000000000002
step  3: w = [2.952], loss = 6.553599999999998
step  4: w = [3.362], loss = 4.194303999999998
step  5: w = [3.689], loss = 2.6843545599999996
step  6: w = [3.951], loss = 1.7179869183999996
step  7: w = [4.161], loss = 1.0995116277759995
step  8: w = [4.329], loss = 0.7036874417766399
step  9: w = [4.463], loss = 0.4503599627370493
step 10: w = [4.571], loss = 0.2882303761517112
step 11: w = [4.656], loss = 0.184467440737095
step 12: w = [4.725], loss = 0.11805916207174093
step 13: w = [4.78], loss = 0.07555786372591429
step 14: w = [4.824], loss = 0.04835703278458515
step 15: w = [4.859], loss = 0.030948500982134555
step 16: w = [4.887], loss = 0.019807040628566166
step 17: w = [4.91], loss = 0.012676506002282305
step 18: w = [4.928], loss = 0.008112963841460612
step 19: w = [4.942], loss = 0.005192296858534741


==


Testing SGDMomentum
step  0: w = [1.], loss = 2

## MLP Library

Now that we have the `Tensor` class and optimizers, let's create the modules that will allow us to build a simple neural network.

In [291]:
class Module:
    """Base class to construct neural networks"""

    def parameters(self):
        return []

    def zero_grad(self):
        """Reset the gradients in the entire network"""
        for param in self.parameters():
            param.grad = None


class Linear(Module):
    """Construct a linear or fully-connected layer"""

    def __init__(self, in_dim, out_dim):
        self.W = Tensor(np.random.randn(out_dim, in_dim) * 0.01, requires_grad=True)
        self.b = Tensor(np.zeros(out_dim), requires_grad=True)

    def __call__(self, x):
        if not isinstance(x, Tensor):
            x = Tensor(x)

        # If single sample, add a batch dimension
        added_batch = (x.data.ndim == 1)
        if added_batch:
            x = x.reshape(1, -1)

        # (N, in_dim) @ (in_dim, out_dim) + (outdim,) -> (N, out_dim)
        out = x @ self.W.transpose() + self.b

        # If original input was 1D, squeeze back to (out_dim,)
        if added_batch:
            # Prefer reshape over indexing to avoid scatter-style backward
            out = out.reshape(-1)

        return out

    def parameters(self):
        return [self.W, self.b]


class ReLU(Module):
    def __call__(self, x):
        return x.relu()


class Tanh(Module):
    def __call__(self, x):
        return x.tanh()


class Sigmoid(Module):
    def __call__(self, x):
        return x.sigmoid()


class MLP(Module):
    """Construct a neural network of fully-connected layers with activation functions or
       non-linearities (ReLU, Tanh, Sigmoid)"""

    def __init__(self, in_dim, hidden, out_dim, activation=ReLU):
        """
        Constructor for neural network

        Params:
            in_dim: int, number of dimensions for the input layer
            hidden: list of int, list of dimensions for the hidden layers
            out_dim: int, number of dimensions for the output layer
            activation: callable, activation function to add to the linear layers
        """
        dims = [in_dim] + hidden + [out_dim]

        self.layers = []
        for idx in range(len(dims) - 1):
            # Add linear layer
            self.layers.append(Linear(dims[idx], dims[idx + 1]))
            # Add activation after every linear except the last one
            if idx < len(dims) - 2:
                self.layers.append(activation())

    def __call__(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out

    def parameters(self):
        return [param for layer in self.layers for param in layer.parameters()]

## Loss and Training Functions

Let's add loss functions (mean-squared error for regression, cross-entropy for classification) and a function to perform the training loop.

In [292]:
def mse_loss(pred, target):
    """Loss function that computes the mean-squared error"""
    return ((pred - target) ** 2).mean()


def cross_entropy_loss(logits, target_idx):
    """
    Computes cross-entropy loss between raw logits and class index targets

    Params:
        logits: Tensor, the shape is (num_classes,) or (batch, num_classes)
        target_idx: int, the correct class index/indices

    Returns:
        loss: Tensor, scalar cross-entropy loss
    """
    # Shift logits for numerical stability
    logits_stable = logits - logits.max(axis=-1, keepdims=True)

    # Convert logits to log probabilities with softmax
    log_probs = logits_stable - logits_stable.exp().sum(axis=-1, keepdims=True).log()

    # Single sample
    if np.ndim(target_idx) == 0:
        return -log_probs[target_idx]
    # Mini-batch
    else:
        idx = np.arange(len(target_idx))
        return (-log_probs[idx, target_idx]).mean()


def train(model, data, labels, optimizer, loss_fn, epochs=100, batch_size=32):
    """
    Train a model using the given optimizer and loss function

    Params:
        model: Module, the model to train (must implement __call__ and parameters())
        data: iterable, input samples (list or array of Tensor or np.ndarray)
        labels: iterable, target outputs (list or array of Tensor or np.ndarray)
        optimizer: object, optimizer with step() and zero_grad() methods
        loss_fn: function, computes a scalar loss given (pred, target)
        epochs: int, number of passes over the dataset
        batch_size: int, number of samples per batch
    """
    num_samples = len(data)
    for epoch in range(epochs):
        epoch_loss = 0.0

        # Iterate in mini-batches
        for start in range(0, num_samples, batch_size):
            end = start + batch_size
            x_batch = data[start:end].reshape(-1, 784)
            y_batch = labels[start:end]

            # Forward pass (model should handle batch input)
            preds = model(x_batch)

            # Compute loss (reduces to scalar inside loss_fn)
            loss = loss_fn(preds, y_batch)

            # Backward and optimizer step
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += float(loss.data) * len(y_batch)

        avg_loss = epoch_loss / num_samples
        print(f"epoch {epoch}: loss {avg_loss:.4f}")

## Training on MNIST

Let's see if we can construct a neural network and train it on MNIST to validate the implementation.

In [293]:
from datasets import load_dataset

dataset = load_dataset("mnist")

dataset_train = dataset["train"]
dataset_test = dataset["test"]

X_train = np.array([np.array(img).reshape(-1) / 255.0 for img in dataset_train["image"]])
y_train = np.array(dataset_train["label"])

X_test = np.array([np.array(img).reshape(-1) / 255.0 for img in dataset_test["image"]])
y_test = np.array(dataset_test["label"])

In [294]:
neural_network = MLP(784, [256, 128, 10], 10)
optimizer = Adam(params=neural_network.parameters())
train(neural_network, X_train, y_train, optimizer, cross_entropy_loss)

epoch 0: loss 0.7274
epoch 1: loss 0.2214
epoch 2: loss 0.1289
epoch 3: loss 0.0890
epoch 4: loss 0.0639
epoch 5: loss 0.0497
epoch 6: loss 0.0386
epoch 7: loss 0.0312
epoch 8: loss 0.0268
epoch 9: loss 0.0233
epoch 10: loss 0.0215
epoch 11: loss 0.0195
epoch 12: loss 0.0154
epoch 13: loss 0.0151
epoch 14: loss 0.0165
epoch 15: loss 0.0144
epoch 16: loss 0.0117
epoch 17: loss 0.0108
epoch 18: loss 0.0157
epoch 19: loss 0.0099
epoch 20: loss 0.0109
epoch 21: loss 0.0118
epoch 22: loss 0.0079
epoch 23: loss 0.0107
epoch 24: loss 0.0085
epoch 25: loss 0.0094
epoch 26: loss 0.0086
epoch 27: loss 0.0086
epoch 28: loss 0.0077
epoch 29: loss 0.0057
epoch 30: loss 0.0113
epoch 31: loss 0.0079
epoch 32: loss 0.0050
epoch 33: loss 0.0092
epoch 34: loss 0.0057
epoch 35: loss 0.0051
epoch 36: loss 0.0071
epoch 37: loss 0.0076
epoch 38: loss 0.0065
epoch 39: loss 0.0070
epoch 40: loss 0.0040
epoch 41: loss 0.0058
epoch 42: loss 0.0062
epoch 43: loss 0.0058
epoch 44: loss 0.0049
epoch 45: loss 0.007

In [296]:
preds = neural_network(X_test)
loss = cross_entropy_loss(preds, y_test)
print("test loss:", float(loss.data))

pred_labels = preds.data.argmax(axis=1)
accuracy = (pred_labels == y_test).mean()
print("test accuracy:", accuracy)

test loss: 0.2806237125266356
test accuracy: 0.9786
