# Theoretical energy of deep linear networks

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/analytical_test.ipynb)

In [1]:
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install plotly==5.11.0
!pip install -U kaleido

In [2]:
import jpc

import jax
from jax import vmap
import jax.numpy as jnp
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import plotly.graph_objs as go
import plotly.io as pio

pio.renderers.default = 'iframe'

import warnings
warnings.simplefilter('ignore')  # ignore warnings

## Hyperparameters

We define some global parameters, including network architecture, learning rate, batch size etc.

In [3]:
SEED = 0
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
MAX_T1 = 300
TEST_EVERY = 10
N_TRAIN_ITERS = 100

## Dataset

Some utils to fetch MNIST.

In [4]:
def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]
    

## Plotting

In [5]:
def plot_total_energies(energies):
    n_train_iters = len(energies["theory"])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    for energy_type, energy in energies.items():
        is_theory = energy_type == "theory"
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=energy,
                name=energy_type,
                mode="lines",
                line=dict(
                    width=3, 
                    dash="dash" if is_theory else "solid",
                    color="rgb(27, 158, 119)" if is_theory else "#00CC96"
                ),
                legendrank=1 if is_theory else 2
            )
        )

    fig.update_layout(
        height=300,
        width=450,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="Energy",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image("dln_total_energy.pdf")
    return fig


## Linear network

In [6]:
key = jax.random.PRNGKey(0)
width, n_hidden = 300, 10
network = jpc.make_mlp(
    key, 
    [784] + [width]*n_hidden + [10], 
    act_fn="linear", 
    use_bias=False
)

## Train and test

A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already "jitted" for performance.

Below we simply wrap each of these functions in our training and test loops, respectively.

In [7]:
def evaluate(model, test_loader):
    avg_test_loss, avg_test_acc = 0, 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        test_loss, test_acc = jpc.test_discriminative_pc(
            model=model,
            output=label_batch,
            input=img_batch
        )
        avg_test_loss += test_loss
        avg_test_acc += test_acc

    return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)


def train( 
      model,  
      lr,
      batch_size,
      max_t1,
      test_every,
      n_train_iters
):
    optim = optax.adam(lr)
    opt_state = optim.init(
        (eqx.filter(model, eqx.is_array), None)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)

    num_total_energies, theory_total_energies = [], []
    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        theory_total_energies.append(
            jpc.linear_equilib_energy(
                network=model, 
                x=img_batch, 
                y=label_batch
            )
        )
        result = jpc.make_pc_step(
            model,
            optim,
            opt_state,
            output=label_batch,
            input=img_batch,
            max_t1=max_t1,
            record_energies=True
        )
        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
        train_loss, t_max = result["loss"], result["t_max"]
        num_total_energies.append(result["energies"][:, t_max-1].sum())

        if ((iter+1) % test_every) == 0:
            avg_test_loss, avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

    return {
        "experiment": jnp.array(num_total_energies),
        "theory": jnp.array(theory_total_energies)
    }


## Run

In [8]:
energies = train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    max_t1=MAX_T1,
    n_train_iters=N_TRAIN_ITERS
)
plot_total_energies(energies)

Train iter 10, train loss=0.084618, avg test accuracy=45.022034
Train iter 20, train loss=0.062942, avg test accuracy=59.715546
Train iter 30, train loss=0.058635, avg test accuracy=72.215546
Train iter 40, train loss=0.058973, avg test accuracy=70.983574
Train iter 50, train loss=0.053510, avg test accuracy=77.163460
Train iter 60, train loss=0.052320, avg test accuracy=76.101761
Train iter 70, train loss=0.063362, avg test accuracy=76.352165
Train iter 80, train loss=0.057726, avg test accuracy=76.382210
Train iter 90, train loss=0.054073, avg test accuracy=75.801285
Train iter 100, train loss=0.053406, avg test accuracy=78.225159
