# Minimal PCN

This notebook provides a minimal `PyTorch` implementation of a PCN trained to generate MNIST digits in a supervised manner.

## Setup

In [1]:
#@title Imports


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import random

In [2]:
#@title Utils


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def vectorise(batch):
    batch_size = batch.size(0)
    return batch.reshape(batch_size, -1).squeeze()


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]


def accuracy(pred_labels, true_labels):
    batch_size = pred_labels.size(0)
    correct = 0
    for b in range(batch_size):
        if torch.argmax(pred_labels[b, :]) == torch.argmax(true_labels[b, :]):
            correct += 1
    return correct / batch_size


In [3]:
#@title Datasets


class MNIST(datasets.MNIST):
    def __init__(self, train, path="./data", normalise=True):
        if normalise:
            transform = transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(path, download=True, transform=transform, train=train)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = vectorise(img)
        label = one_hot(label)
        return img, label
        


In [4]:
#@title PCN


class PCN(object):

    def __init__(
        self, 
        network, 
        dt=0.01, 
        device="cpu"
        ):
      
        self.network = network.to(device)
        self.n_layers = len(self.network)
        self.n_nodes = self.n_layers + 1
        self.dt = dt
        self.n_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
        self.device = device

    def reset(self):
        self.zero_grad()
        self.preds = [None] * self.n_nodes
        self.errs = [None] * self.n_nodes
        self.xs = [None] * self.n_nodes

    def reset_xs(self, prior, init_std):
        self.set_prior(prior)
        self.propagate_xs()
        for l in range(self.n_layers):
            self.xs[l] = torch.empty(self.xs[l].shape).normal_(mean=0, std=init_std).to(self.device)

    def set_obs(self, obs):
        self.xs[-1] = obs.clone()

    def set_prior(self, prior):
        self.xs[0] = prior.clone()

    def forward(self, x):
        return self.network(x)

    def propagate_xs(self):
        for l in range(1, self.n_layers):
            self.xs[l] = self.network[l - 1](self.xs[l - 1])

    def infer_train(
        self, 
        obs, 
        prior,
        n_iters, 
        init_std=0.05
        ):

        self.reset()
        self.set_prior(prior)
        self.propagate_xs()
        self.set_obs(obs)

        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(self.network[l], self.xs[l], self.errs[l + 1])
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            if (t+1) != n_iters:
                self.clear_grads()

        self.set_weight_grads()

    def infer_test(
        self, 
        obs, 
        prior, 
        n_iters,
        init_std=0.05,
        ):
        
        self.reset()
        self.reset_xs(prior, init_std)
        self.set_obs(obs)

        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(self.network[l], self.xs[l], self.errs[l + 1])
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            _, epsdfdx = torch.autograd.functional.vjp(self.network[0], self.xs[0], self.errs[1])
            with torch.no_grad():
                self.xs[0] = self.xs[0] + self.dt * epsdfdx

            if (t+1) != n_iters:
                self.clear_grads()

        return self.xs[0]

    def set_weight_grads(self):
        for l in range(self.n_layers):
            for w in self.network[l].parameters():
                dw = torch.autograd.grad(
                    self.preds[l + 1], 
                    w,
                    - self.errs[l + 1],
                    allow_unused=True,
                    retain_graph=True
                )[0]
                w.grad = dw.clone()

    def zero_grad(self):
        self.network.zero_grad()

    def save_weights(self, path):
        torch.save(self.network.state_dict(), path)

    def load_weights(self, path):
        self.network.load_state_dict(torch.load(path))

    def clear_grads(self):
        with torch.no_grad():
            for l in range(1, self.n_nodes):
                self.preds[l] = self.preds[l].clone()
                self.errs[l] = self.errs[l].clone()
                self.xs[l] = self.xs[l].clone()

    def __str__(self):
        return f"PCN(\n{self.network}\n"


In [5]:
#@title Archs
       

network = nn.Sequential(
    nn.Sequential(
        nn.Linear(10, 250),
        nn.Tanh()
    ),
    nn.Sequential(
        nn.Linear(250, 250),
        nn.Tanh()
    ),
    nn.Linear(250, 28*28)
)


In [6]:
#@title Training script


def train(seed):
    set_seed(seed)
    device = get_device()

    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_data, batch_size=64, shuffle=True, drop_last=True)

    model = PCN(
        network=network,
        device=device
    )
    optimizer = optim.Adam(model.network.parameters(), lr=1e-4)

    train_losses, test_losses = [], []
    test_accs = []
    n_epochs = 2
    log_every = 100

    for epoch in range(1, n_epochs+1):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loss = 0
        for batch_id, (img_batch, label_batch) in enumerate(train_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            model.infer_train(
                obs=img_batch, 
                prior=label_batch, 
                n_iters=10
            )
            optimizer.step()
            train_loss += (model.errs[-1]**2).mean().item()

            if batch_id % log_every == 0:
                print(f"Train loss: {(model.errs[-1]**2).mean().item():.5f} [{batch_id * len(img_batch)}/{len(train_loader.dataset)}]")

        test_loss, test_acc = (0, 0)
        for batch_id, (img_batch, label_batch) in enumerate(test_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            label_preds = model.infer_test(
                obs=img_batch, 
                prior=label_batch, 
                n_iters=100
            )
            test_loss += (model.errs[-1]**2).mean().item()
            test_acc += accuracy(label_preds, label_batch)

        train_losses.append(train_loss / len(train_loader))
        test_losses.append(test_loss / len(test_loader))
        test_accs.append(test_acc / len(test_loader))
        print(f"\nAvg test accuracy: {test_accs[epoch-1]:.4f}\n")

    np.save("train_losses.npy", train_losses)
    np.save("test_losses.npy", test_losses)
    np.save("test_accs.npy", test_accs)
    model.save_weights("weights.pth")


## Train

In [7]:
train(seed=0)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 1
-------------------------------
Train loss: 0.93041 [0/60000]
Train loss: 0.61235 [6400/60000]
Train loss: 0.58521 [12800/60000]
Train loss: 0.51824 [19200/60000]
Train loss: 0.47155 [25600/60000]
Train loss: 0.43134 [32000/60000]
Train loss: 0.41525 [38400/60000]
Train loss: 0.38573 [44800/60000]
Train loss: 0.36918 [51200/60000]
Train loss: 0.35273 [57600/60000]

Avg test accuracy: 0.8250

Epoch 2
-------------------------------
Train loss: 0.35036 [0/60000]
Train loss: 0.32395 [6400/60000]
Train loss: 0.29842 [12800/60000]
Train loss: 0.30604 [19200/60000]
Train loss: 0.29617 [25600/60000]
Train loss: 0.27753 [32000/60000]
Train loss: 0.26285 [38400/60000]
Train loss: 0.25684 [44800/60000]
Train loss: 0.24507 [51200/60000]
Train loss: 0.25178 [57600/60000]

Avg test accuracy: 0.8411

Epoch 3
-------------------------------
Train loss: 0.24458 [0/60000]
Train loss: 0.24750 [6400/60000]
Train loss: 0.23