In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from IPython import display

In [None]:
# # Mounting the Google Drive to save the the data and results
# from google.colab import drive
# drive.mount('/content/drive')

# Accessing the GPU for training

In [None]:
# If you want to use GPU, you can use the following code to check if GPU is available
torch.cuda.get_device_name(0)

if torch.cuda.is_available():
  device = torch.device("cuda:0")
  print("GPU")
else:
  device = torch.device("cpu")
  print("CPU")

# Loading in the data and creating dataloaders

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Normalising the images so they can be compatible with torchvision

batch_size = 32

# training set
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
# Shuffle true as we want the training data to be randomised
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)

# validation set
valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
# No shuffling for the valset
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                        shuffle=False, num_workers=2)

# defining the classes of each image

classes = ('plane', 'car', 'bird', 'cat',
              'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Creating the model

In [None]:
class MLP_Conv_Block(nn.Module):
    def __init__(self, in_channels, out_channels, K):
        super(MLP_Conv_Block, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        # MLP layer that runs a spatial pool, flattens it, passes it through a linear function and ReLU.
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)), # AdaptiveAvgPool is another way to implement spatial pooling
            nn.Flatten(), # Flatten will turn pool of X to (4,3) so it can be passed to linear function
            nn.Linear(in_channels, K),
            nn.ReLU() # My chose of non-linear activation
        )
        # Conv Layer made up of K number of independent convolutions 
        self.convs = nn.ModuleList([nn.Conv2d(in_channels, out_channels, kernel_size=5) for i in range(K)])

    def forward(self, x):
      # Specifying the batch size = x.size(0) so I can use different batch sizes without errors
      batch_size = x.size(0)
      # passing the input X through the MLP layer
      mlp = self.fc(x)
      # passing X through all the convolutions
      conv_out = [conv(x) for conv in self.convs]
      # mlp[:,1] allows me to index the scalers of the (4, K) tensor made from MLP layer
      output = sum([mlp[:, 1].reshape(batch_size, 1, 1, 1) * conv_out[i] for i in range(len(self.convs))])
      output = nn.Flatten()(output)
      return output


class Classifier(nn.Module):
    def __init__(self, in_features, out_classes):
        super(Classifier, self).__init__()
        self.classify = nn.Sequential(
            # final block output needs to pass through pooling before passing to MLP classifier
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            # The first linear layer takes pool of X
            nn.Linear(in_features, 512), # 512 is the number of neurons in the first layer
            nn.ReLU(), # ReLU removes negative values
            nn.Linear(512, 256), # 256 is the number of neurons in the second layer
            nn.ReLU(),
            nn.Linear(256, 128), # 128 is the number of neurons in the third layer
            nn.ReLU(),
            nn.Linear(128, out_classes) # 10 is the number of classes
        )

    def forward(self, x):
        classifier = self.classify(x)
        return classifier

class Model(nn.Module):
    def __init__(self, in_channels, out_channels, K):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            MLP_Conv_Block(in_channels, out_channels, K),
            nn.Flatten()
        )
        # output multiplied by 28x28 because of the kernel size of 5 after conv layers
        # formula - [W - (k+1)] - [32 - (5+1) = 28]
        self.classifier = Classifier(out_channels * 28 * 28, 10)

    def forward(self, x):
        x = self.model(x) # passing the input through the MLP_Conv_Block
        x = x.unsqueeze(2).unsqueeze(3) # adding two dimensions to the output of the MLP_Conv_Block
        x = self.classifier(x) # passing the output of the MLP_Conv_Block through the classifier
        return x

# initialising the model to device if possible
# net = Model(in_channels=3, out_channels=32, K=10).to(device) 

# if not
net = Model(in_channels=3, out_channels=32, K=10)

# Loss function and optimiser

In [None]:
# defining the loss function and the optimizer
criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.SGD(net.parameters(), lr=0.001) 

# Training loop

In [None]:
import matplotlib.pyplot as plt

# Initialize lists to store loss and accuracy values
train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0 # initialising the running loss
    total_train = 0 # initialising the total number of training images
    correct_train = 0 # initialising the number of correctly classified training images

    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs) # passing the inputs through the model
        loss = criterion(outputs, labels) # calculating the loss
        loss.backward() # backpropagating the loss
        optimizer.step() # updating the weights

        # print statistics
        running_loss += loss.item() # adding the loss to the running loss
        _, predicted = torch.max(outputs.data, 1) # getting the predicted class
        total_train += labels.size(0) # adding the number of images in the batch to the total number of training images
        correct_train += (predicted == labels).sum().item() # adding the number of correctly classified images in the batch to the total number of correctly classified training images
        if i % 2000 == 1999:   # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0 # resetting the running loss

    # Calculate training accuracy and loss for the epoch
    train_acc = 100 * correct_train / total_train  
    train_loss = running_loss / len(trainloader)

    # Append values to the lists
    train_accs.append(train_acc) # appending the training accuracy to the list of training accuracies for graphing
    train_losses.append(train_loss) # appending the training loss to the list of training losses for graphing

    # Evaluate validation accuracy and loss for the epoch
    with torch.no_grad(): 
        running_val_loss = 0.0
        total_val = 0
        correct_val = 0

        for data in valloader: 
            images, labels = data 
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images) # passing the images through the model
            val_loss = criterion(outputs, labels) # calculating the loss
            running_val_loss += val_loss.item() # adding the loss to the running loss
            _, predicted = torch.max(outputs.data, 1) # getting the predicted class
            total_val += labels.size(0) # adding the number of images in the batch to the total number of validation images
            correct_val += (predicted == labels).sum().item() # adding the number of correctly classified images in the batch to the total number of correctly classified validation images

        val_acc = 100 * correct_val / total_val # calculating the validation accuracy
        val_loss = running_val_loss / len(valloader) # calculating the validation loss

        # Append values to the lists
        val_accs.append(val_acc) # appending the validation accuracy to the list of validation accuracies for graphing
        val_losses.append(val_loss) # appending the validation loss to the list of validation losses for graphing

    # print statistics at the end of each epoch
    print(f'Epoch {epoch + 1} train loss: {train_loss:.3f}, train acc: {train_acc:.2f}%, val loss: {val_loss:.3f}, val acc: {val_acc:.2f}%')

print('Finished Training')

# Plot loss and accuracy curves
epochs = range(1, len(train_losses) + 1) # creating a list of epochs

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, 'b', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('results/Train and Val loss.png', dpi=300)
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_accs, 'b', label='Training accuracy')
plt.plot(epochs, val_accs, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('results/Train and Val accuracy.png', dpi=300)
plt.show()


# Final Model accuracy on CIFAR-10 Validation Set

In [None]:
correct = 0
total = 0
net.eval()  # set the model to evaluation mode

with torch.no_grad():
    for images, labels in valloader:
        images = images.to(device)
        labels = labels.to(device)

        # pass the input data through the network
        outputs = net(images) 
        
        # determine the predicted class by finding the index with highest value
        _, predicted = torch.max(outputs, 1) 
        
        # count the number of correct predictions
        correct += (predicted == labels).sum().item()
        
        # count the total number of images
        total += labels.size(0)

# compute the accuracy of the model
accuracy = 100 * correct / total
print(f'Accuracy of the network on the {total} test images: {accuracy:.2f} %')
