In [None]:
! git clone https://github.com/Bogacz-Group/PredictiveCoding.git
! pip install -r PredictiveCoding/requirements.txt

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import random
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy

# load predictive coding library
import PredictiveCoding.predictive_coding as pc

In [None]:
n_train = 10000
n_val = 500
n_test = 5000
batch_size = 500

# get mnist data
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: torch.flatten(x))])
dataset_train = datasets.MNIST('./data', download=True, train=True, transform=transform)
dataset_eval = datasets.MNIST('./data', download=True, train=False, transform=transform)

# Randomly sample the train dataset
train_dataset = torch.utils.data.Subset(dataset_train, random.sample(range(len(dataset_train)), n_train))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Randomly sample the val dataset
val_dataset, test_dataset, not_used = torch.utils.data.random_split(dataset_eval, [n_val, n_test, dataset_eval.__len__()-n_val-n_test])

In [8]:
input_size = 10        # for the 10 classes
hidden_size = 256
hidden2_size = 256
output_size = 28*28    # for the 28 by 28 mnist images

In [9]:
activation_fn = nn.ReLU

pc_model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    pc.PCLayer(), # contains neural activity of layer 2
    activation_fn(),
    nn.Linear(hidden_size, hidden2_size),
    pc.PCLayer(), # contains neural activity of layer 3
    activation_fn(),
    nn.Linear(hidden2_size, output_size)
)

In [10]:
pc_model.train()

Sequential(
  (0): Linear(in_features=10, out_features=256, bias=True)
  (1): PCLayer()
  (2): ReLU()
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): PCLayer()
  (5): ReLU()
  (6): Linear(in_features=256, out_features=784, bias=True)
)

In [13]:
# number of neural activity updates
T = 20
# optimizer for activity updates
optimizer_x_fn = optim.Adam
optimizer_x_kwargs = {'lr': 0.1}

# optimizer for weight updates
optimizer_p_fn = optim.Adam
optimizer_p_kwargs = {"lr": 0.001, "weight_decay":0.001}

trainer = pc.PCTrainer(pc_model, T=T, optimizer_x_fn=optimizer_x_fn, optimizer_x_kwargs=optimizer_x_kwargs, optimizer_p_fn=optimizer_p_fn, optimizer_p_kwargs=optimizer_p_kwargs)

In [14]:
epochs = 10

def loss_fn(output, _target):
    return 0.5*(output - _target).pow(2).sum()

for epoch in range(epochs):
    for data, label in train_loader:
        labels_one_hot = F.one_hot(label).float()
        trainer.train_on_batch(inputs=labels_one_hot, loss_fn=loss_fn, loss_fn_kwargs={'_target':data}, is_log_progress=False, is_return_results_every_t=False, is_checking_after_callback_after_t=False)