# This code contains a simple example using the function `PCInfer` from `Torch2PC` for training a convolutional neural network on MNIST.

The first code cell imports the MNIST data and defines some hyperparameters, but contains nothing specific to `Torch2PC`.

If you use this code, please cite this paper:
[https://arxiv.org/abs/2106.13082](https://arxiv.org/abs/2106.13082)



In [None]:
import torch 
import torch.nn as nn
import numpy as np
import torchvision 
import matplotlib.pyplot as plt
from time import time as tm

# Import TorchSeq2PC 
!git clone https://github.com/RobertRosenbaum/Torch2PC.git
from Torch2PC import TorchSeq2PC as T2PC  

# Seed rng
torch.manual_seed(0)

# # This patches an error that sometimes arises in
# # downloading MNIST
# from six.moves import urllib
# opener = urllib.request.build_opener()
# opener.addheaders = [('User-agent', 'Mozilla/5.0')]
# urllib.request.install_opener(opener)

# This seems to be a more reliable and faster
# source for MNIST
!wget -nc www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

# Load training and testing data from MNIST dataset
# These lines return data structures that contain
# the training and testing data 
from torchvision.datasets import MNIST

# Get training data structure
train_dataset = MNIST('./', 
      train=True, 
      transform=torchvision.transforms.ToTensor(),  
      download=True)

# Number of trainin data points
m = len(train_dataset)

# Print the size of the training data set
print('\n\n\n')
print("Number of data points in training set = ",m)
print("Size of training inputs (X)=",train_dataset.data.size())
print("Size of training labels (Y)=",train_dataset.targets.size())

# Define batch size
batch_size = 300      # Batch size to use with training data

# Create data loader. 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True)


# Choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device = ',device)

# Define the nunmber of epochs, learning rate, 
# and how often to print progress
num_epochs=2
LearningRate=.002
PrintEvery=50

# Choose an optimizer
WhichOptimizer=torch.optim.Adam

# Compute size of each batch
steps_per_epoch = len(train_loader) 
total_num_steps  = num_epochs*steps_per_epoch
print("steps per epoch (mini batch size)=",steps_per_epoch)


# The next code cell builds a convolutional neural network model using `Sequential`. 

The `PCInfer` function treats each element of a `Sequential` model as a layer. As such, it is necessary to use nested calls to `Sequential` (as below) if you want to treat a block of functions as a layer. For the model below, each block will be treated as a layer (5 layers in all).

In [None]:

# Define model using Sequential. 
model=nn.Sequential(
    
    nn.Sequential(nn.Conv2d(1,10,3),
    nn.ReLU(),
    nn.MaxPool2d(2)
    ),

    nn.Sequential(
    nn.Conv2d(10,5,3),
    nn.ReLU(),
    nn.Flatten()
    ),

 nn.Sequential(    
    nn.Linear(5*11*11,50),
    nn.ReLU()
    ),

 nn.Sequential(    
    nn.Linear(50,30),
    nn.ReLU()
    ),


nn.Sequential(
   nn.Linear(30,10)
 )

).to(device)

# Define the loss function
LossFun = nn.CrossEntropyLoss()

# Compute one batch of output and loss to make sure
# things are working
with torch.no_grad():
  TrainingIterator=iter(train_loader)
  X,Y=next(TrainingIterator)  
  X=X.to(device)
  Y=Y.to(device)
  Yhat=model(X).to(device)
  print('output shape = ',Yhat.shape)
  print('loss on initial model = ',LossFun(Yhat,Y).item())


NumParams=sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of trainable parameters in model =',NumParams)



# The next code cell defines hyperparameters for `PCInfer`

The hyperparameter `ErrType` controls which algorithm to use for computing the beliefs and prediction errors. It should be equal to `'Strict'`, `'FixedPred'`, or `'Exact'`. `'Strict'` uses a strict interpretation of predictive coding (without the fixed prediction assumption), `'FixedPred'` uses the fixed prediction assumption, and `'Exact'` computes the exact gradients (same as those computed by backpropagation). See "On the relationship between predictive coding and backpropagation" for more information on these algorithms.

`eta` and `n` are the step size and number of steps to use for the iterations that compute the prediction errors and beliefs. These parameters are not used when `ErrType='Exact'`

In [None]:
# Define PC hyperparameters

ErrType="Strict"
eta=.1
n=20

# The next code cell uses `PCInfer` to train the model

The only line that differs from a typical training loop in PyTorch is the line

`
vhat,Loss,dLdy,v,epsilon=T2PC.PCInfer(model,LossFun,X,Y,ErrType,eta,n)
`

which computes the outputs, loss, etc. and it sets the `.grad` attributes of all parameters in `model` to the parameter update values computed by predictive coding. 

For `ErrType='Exact'`, the gradients are set to the gradient of the loss with respect to that parameter, i.e., the same values computed by calling `Loss.backward()` after a single forward pass. For other values of `ErrType`, refer to the paper for an explanation of how the parameter updates are computed.

In [None]:

# Define the optimizer
optimizer = WhichOptimizer(model.parameters(), lr=LearningRate)

# Initialize vector to store losses
LossesToPlot=np.zeros(total_num_steps)


j=0     # Counters
jj=0    
t1=tm() # Get start time
for k in range(num_epochs):

  # Re-initialize the training iterator (shuffles data for one epoch)
  TrainingIterator=iter(train_loader)
  
  for i in range(steps_per_epoch): # For each batch

    # Get one batch of training data, reshape it
    # and send it to the current device        
    X,Y=next(TrainingIterator)  
    X=X.to(device)
    Y=Y.to(device)

    # Perform inference on this batch
    vhat,Loss,dLdy,v,epsilon=T2PC.PCInfer(model,LossFun,X,Y,ErrType,eta,n)

    # Update parameters    
    optimizer.step() 

    # Zero-out gradients     
    model.zero_grad()
    optimizer.zero_grad()

    # Print and store loss
    with torch.no_grad():
      if(i%PrintEvery==0):
        print('epoch =',k,'step =',i,'Loss =',Loss.item())
      LossesToPlot[jj]=Loss.item() 
      jj+=1

# Compute and print time spent training
tTrain=tm()-t1
print('Training time = ',tTrain,'sec')

# Plot the loss curve
plt.figure()
plt.plot(LossesToPlot)
plt.ylim(bottom=0)  
plt.ylabel('training loss')
plt.xlabel('iteration number')

