# Theoretical energy of deep linear networks

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/analytical_test.ipynb)

In [20]:
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install plotly==5.11.0
!pip install -U kaleido

In [21]:
import jpc

import jax
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.io as pio

pio.renderers.default = 'iframe'

## Hyperparameters

We define some global parameters, including network architecture, learning rate, batch size etc.

In [22]:
SEED = 0
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 10
N_TRAIN_ITERS = 100

## Dataset

Some utils to fetch MNIST.

In [23]:
#@title data utils


def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]
    

## Plotting

In [24]:
def plot_energies(energies):
    n_train_iters = len(energies["theory"])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    for energy_type, energy in energies.items():
        is_theory = energy_type == "theory"
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=energy,
                name=energy_type,
                mode="lines",
                line=dict(
                    width=3, 
                    dash="dash" if is_theory else "solid"
                ),
                legendrank=1 if is_theory else 2
            )
        )

    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="Energy",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image("dln_energy_example.pdf")
    return fig

## Linear network

In [25]:
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 5)

network = [
    eqx.nn.Linear(784, 300, key=subkeys[0], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),
    eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),
    eqx.nn.Linear(300, 10, key=subkeys[5], use_bias=False),
]

## Train and test

A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already "jitted" for performance.

Below we simply wrap each of these functions in our training and test loops, respectively.

In [26]:
def evaluate(model, test_loader):
    test_acc = 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        test_acc += jpc.test_discriminative_pc(
            model=model,
            y=label_batch,
            x=img_batch
        )

    return test_acc / len(test_loader)


def train(
      model,  
      lr,
      batch_size,
      test_every,
      n_train_iters
):
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    train_loader, test_loader = get_mnist_loaders(batch_size)

    num_energies, theory_energies = [], []
    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        theory_energies.append(
            jpc.linear_equilib_energy(
                network=model, 
                x=img_batch, 
                y=label_batch
            )
        )
        result = jpc.make_pc_step(
            model,
            optim,
            opt_state,
            y=label_batch,
            x=img_batch,
            record_energies=True
        )
        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
        train_loss, t_max = result["loss"], result["t_max"]
        num_energies.append(result["energies"][:, t_max-1].sum())

        if ((iter+1) % test_every) == 0:
            avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

    return {
        "experiment": num_energies,
        "theory": theory_energies
    }


## Run

In [27]:
energies = train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)
plot_energies(energies)

Train iter 10, train loss=0.077298, avg test accuracy=0.642328
Train iter 20, train loss=0.071461, avg test accuracy=0.719551
Train iter 30, train loss=0.053858, avg test accuracy=0.739183
Train iter 40, train loss=0.055664, avg test accuracy=0.753506
Train iter 50, train loss=0.049541, avg test accuracy=0.805188
Train iter 60, train loss=0.048094, avg test accuracy=0.750701
Train iter 70, train loss=0.046416, avg test accuracy=0.817107
Train iter 80, train loss=0.051727, avg test accuracy=0.834135
Train iter 90, train loss=0.053808, avg test accuracy=0.827825
Train iter 100, train loss=0.045950, avg test accuracy=0.833834
