<a href="https://colab.research.google.com/github/zxcej/COMP691_LABS/blob/main/2023_Lab7_ex.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 7: Self-Attention


This lab covers the following topics:

- Gain insight into the self-attention operation using the sequential MNIST example from before.
- Gain insight into positional encodings

## 0 Initialization

Run the code cell below to download the MNIST digits dataset:

In [1]:
!wget -O MNIST.tar.gz https://activeeon-public.s3.eu-west-2.amazonaws.com/datasets/MNIST.new.tar.gz
!tar -zxvf MNIST.tar.gz

import torchvision
import torch
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F

from torch.utils.data import Subset

dataset = torchvision.datasets.MNIST('./', download=False, transform=transforms.Compose([transforms.ToTensor()]), train=True)
train_indices = torch.arange(0, 10000)
train_dataset = Subset(dataset, train_indices)

dataset=torchvision.datasets.MNIST('./', download=False, transform=transforms.Compose([transforms.ToTensor()]), train=False)
test_indices = torch.arange(0, 10000)
test_dataset = Subset(dataset, test_indices)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True, num_workers=0)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16,
                                          shuffle=False, num_workers=0)

--2023-03-20 15:38:40--  https://activeeon-public.s3.eu-west-2.amazonaws.com/datasets/MNIST.new.tar.gz
Resolving activeeon-public.s3.eu-west-2.amazonaws.com (activeeon-public.s3.eu-west-2.amazonaws.com)... 52.95.148.138
Connecting to activeeon-public.s3.eu-west-2.amazonaws.com (activeeon-public.s3.eu-west-2.amazonaws.com)|52.95.148.138|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34812527 (33M) [application/x-gzip]
Saving to: 'MNIST.tar.gz'

     0K .......... .......... .......... .......... ..........  0%  569K 60s
    50K .......... .......... .......... .......... ..........  0%  576K 59s
   100K .......... .......... .......... .......... ..........  0%  588K 59s
   150K .......... .......... .......... .......... ..........  0%  587K 58s
   200K .......... .......... .......... .......... ..........  0%  588K 58s
   250K .......... .......... .......... .......... ..........  0% 1.16M 53s
   300K .......... .......... .......... .......... .......... 

## Exercise 1: Self-Attention without Positional Encoding

In this section, will implement a very simple model based on self-attention without positional encoding. The model you will implement will consider the input image as a sequence of 28 rows. You may use PyTorch's [`nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) for this part. Implement a model with the following architecture:

* **Input**: Input image of shape `(batch_size, sequence_length, input_size)`, where sequence_length = image_height and input_size = image_width.

* **Linear 1**: Linear layer which converts input of shape `(sequence_length*batch_size, input_size)` to input of shape `(sequence_length*batch_size, embed_dim)`, where `embed_dim` is the embedding dimension.

* **Attention 1**: `nn.MultiheadAttention` layer with 8 heads which takes an input of shape `(sequence_length, batch_size, embed_dim)` and outputs a tensor of shape `(sequence_length, batch_size, embed_dim)`. 

* **ReLU**: ReLU activation layer.

* **Linear 2**: Linear layer which converts input of shape `(sequence_length*batch_size, embed_dim)` to input of shape `(sequence_length*batch_size, embed_dim)`.

* **ReLU**: ReLU activation layer.

* **Attention 2**: `nn.MultiheadAttention` layer with 8 heads which takes an input of shape `(sequence_length, batch_size, embed_dim)` and outputs a tensor of shape `(sequence_length, batch_size, embed_dim)`.

* **ReLU**: ReLU activation layer.

* **AvgPool**: Average along the sequence dimension from `(batch_size, sequence_length, embed_dim)` to `(batch_size, embed_dim)`

* **Linear 3**: Linear layer which takes an input of shape `(batch_size, embed_dim)` and outputs the class logits of shape `(batch_size, 10)`.


**NOTE**: Be cautious of correctly permuting and reshaping the input between layers. E.g. if `x` is of shape `(batch_size, sequence_length, input_size)`, note that `x.reshape(batch_size*sequence_length, -1) != x.permute(1,0,2).reshape(batch_size*sequence_length, -1)`. In this example, `x.reshape(batch_size*sequence_length, -1)` has `[batch0_seq0, batch0_seq1, ..., batch1_seq0, batch1_seq1, ...]` format, while `x.permute(1,0,2).reshape(batch_size*sequence_length, -1)` has `[batch0_seq0, batch1_seq0, ..., batch0_seq1, batch1_seq1, ...]` format.

In [2]:
# Self-attention without positional encoding
torch.manual_seed(691)

# Define your model here
class myModel(nn.Module):
    def __init__(self, input_size, embed_dim, seq_length,
                 num_classes=10, num_heads=8):
        super(myModel, self).__init__()
        self.layer1 = nn.Linear(input_size, embed_dim)
        self.attention1 = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer2 = nn.Linear(embed_dim, embed_dim)
        self.attention2 = nn.MultiheadAttention(embed_dim, num_heads)
        # From 64*28*64 to 64*64 using average pooling
        self.avgpool = nn.AvgPool1d(kernel_size=seq_length)
        self.fc = nn.Linear(embed_dim, num_classes)
        self.activation = nn.ReLU()
        

    def forward(self,x):
        # TODO: Implement myModel forward pass
        batch_size, sequence_length, input_size = x.shape
        x = x.reshape(batch_size*sequence_length, input_size)
        x = self.layer1(x)
        x = self.activation(x)
        x = x.reshape(batch_size, sequence_length, -1)
        x = x.permute(1, 0, 2)
        x, _ = self.attention1(x, x, x)
        x = x.permute(1, 0, 2)
        x = x.reshape(batch_size*sequence_length, -1)
        x = self.layer2(x)
        x = self.activation(x)
        x = x.reshape(batch_size, sequence_length, -1)
        x = x.permute(1, 0, 2)
        x, _ = self.attention2(x, x, x)
        x = x.permute(1, 0, 2)
        x = x.reshape(batch_size,sequence_length, -1)
        # From 64*28*64 to 64*64 using average pooling
        x = x.permute(0, 2, 1)
        x = self.avgpool(x).squeeze()
        x = self.fc(x)
        return x


Train and evaluate your model by running the cell below. Expect to see  `60-80%` test accuracy.

In [3]:
# Same training code 

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 64
num_layers = 2
num_classes = 10
num_epochs = 8
learning_rate = 0.005

# Initialize model
model = myModel(input_size=input_size, embed_dim=hidden_size, seq_length=sequence_length)
model = model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        if (i+1) % 10 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))


# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

Epoch [1/8], Step [10/157], Loss: 2.2581
Epoch [1/8], Step [20/157], Loss: 2.1685
Epoch [1/8], Step [30/157], Loss: 1.9935
Epoch [1/8], Step [40/157], Loss: 1.8937
Epoch [1/8], Step [50/157], Loss: 1.9072
Epoch [1/8], Step [60/157], Loss: 1.5847
Epoch [1/8], Step [70/157], Loss: 1.7249
Epoch [1/8], Step [80/157], Loss: 1.7521
Epoch [1/8], Step [90/157], Loss: 1.6420
Epoch [1/8], Step [100/157], Loss: 1.6267
Epoch [1/8], Step [110/157], Loss: 1.5332
Epoch [1/8], Step [120/157], Loss: 1.5796
Epoch [1/8], Step [130/157], Loss: 1.4901
Epoch [1/8], Step [140/157], Loss: 1.4746
Epoch [1/8], Step [150/157], Loss: 1.4116
Epoch [2/8], Step [10/157], Loss: 1.2400
Epoch [2/8], Step [20/157], Loss: 1.3711
Epoch [2/8], Step [30/157], Loss: 1.2489
Epoch [2/8], Step [40/157], Loss: 1.1527
Epoch [2/8], Step [50/157], Loss: 1.0967
Epoch [2/8], Step [60/157], Loss: 0.9711
Epoch [2/8], Step [70/157], Loss: 0.9957
Epoch [2/8], Step [80/157], Loss: 0.8141
Epoch [2/8], Step [90/157], Loss: 0.8850
Epoch [2/8

## Exercise 2: Self-Attention with Positional Encoding

Implement a similar model to exercise 1, except this time your embedded input should be added with the positional encoding. For the purpose of this lab, we will use a learned positional encoding, which will be a trainable embedding. Your positional encodings will be added to the initial transformation of the input.

* **Input**: Input image of shape `(batch_size, sequence_length, input_size)`, where sequence_length = image_height and input_size = image_width.

* **Linear 1**: Linear layer which converts input of shape `(batch_size*sequence_length, input_size)` to input of shape `(batch_size*sequence_length, embed_dim)`, where `embed_dim` is the embedding dimension.

* **Add Positional Encoding**: Add a learnable positional encoding of shape `(sequence_length, batch_size, embed_dim)` to input of shape `(sequence_length, batch_size, embed_dim)`, where `pos_embed` is the positional embedding size. The output will be of shape `(sequence_length, batch_size, embed_dim)`.

* **Attention 1**: `nn.MultiheadAttention` layer with 8 heads which takes an input of shape `(sequence_length, batch_size, embed_dim)` and outputs a tensor of shape `(sequence_length, batch_size, embed_dim)`.

* **ReLU**: ReLU activation layer.

* **Linear 2**: Linear layer which converts input of shape `(sequence_length*batch_size, features_dim)` to input of shape `(sequence_length*batch_size, features_dim)`.

* **ReLU**: ReLU activation layer.

* **Attention 2**: `nn.MultiheadAttention` layer with 8 heads which takes an input of shape `(sequence_length, batch_size, features_dim)` and outputs a tensor of shape `(sequence_length, batch_size, features_dim)`.

* **ReLU**: ReLU activation layer.

* **AvgPool**: Average along the sequence dimension from `(batch_size, sequence_length, features_dim)` to `(batch_size, features_dim)`

* **Linear 3**: Linear layer which takes an input of shape `(batch_size, sequence_length*features_dim)` and outputs the class logits of shape `(batch_size, 10)`.


In [13]:
# Self-attention with positional encoding
torch.manual_seed(691)

# Define your model here
class myModel(nn.Module):
    def __init__(self, input_size, embed_dim, seq_length,
                 num_classes=10, num_heads=8):
        super(myModel, self).__init__()

        self.seq_length = seq_length
        self.embed_dim = embed_dim
        
        self.layer1 = nn.Linear(input_size, embed_dim)
        self.positional_encoding = nn.Parameter(torch.rand(self.seq_length, self.embed_dim))
        self.attention1 = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer2 = nn.Linear(embed_dim, embed_dim)
        self.attention2 = nn.MultiheadAttention(embed_dim, num_heads)
        self.avgpool = nn.AvgPool1d(kernel_size=seq_length)
        self.fc = nn.Linear(seq_length*embed_dim, num_classes)
        self.activation = nn.ReLU()
        

    def forward(self,x):
        # TODO: Implement myModel forward pass
        batch_size, sequence_length, input_size = x.shape
        x = x.reshape(batch_size*sequence_length, input_size)
        x = self.layer1(x)
        x = x.reshape(batch_size, sequence_length, -1)
        x = x.permute(1, 0, 2)
        # Add positional encoding, positional encoding is (seq_length, embed_dim), reshape to (seq_length, batch_size, embed_dim)
        x, _ = self.attention1(x, x, x)
        x = self.activation(x)
        x = x + self.positional_encoding.reshape(sequence_length, 1, -1)
        x = x.reshape(sequence_length*batch_size, -1)
        x = self.layer2(x)
        x = self.activation(x)
        x = x.reshape(sequence_length, batch_size, -1)
        x , _ = self.attention2(x, x, x)
        x = self.activation(x)
        #print (x.shape)
        x = x.permute(1,2,0)
        #print (x.shape)
        x = self.avgpool(x).squeeze()
        #x is now (batch_size, embed_dim), reshape to (batch_size, embed_dim*seq_length)
        x = x.unsqueeze(1).repeat(1, sequence_length, 1).reshape(batch_size, -1)
        #print (x.shape)
        x = self.fc(x)
        return x

Use the same training code as the one from part 1 to train your model. You may copy the training loop here. Expect to see close to `~90+%` test accuracy.

In [14]:
# Same training code 

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 64
num_layers = 2
num_classes = 10
num_epochs = 8
learning_rate = 0.005

# Initialize model
model = myModel(input_size=input_size, embed_dim=hidden_size, seq_length=sequence_length)
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        if (i+1) % 10 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))


# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

Epoch [1/8], Step [10/157], Loss: 2.3084
Epoch [1/8], Step [20/157], Loss: 2.2603
Epoch [1/8], Step [30/157], Loss: 2.1109
Epoch [1/8], Step [40/157], Loss: 2.1115
Epoch [1/8], Step [50/157], Loss: 1.9868
Epoch [1/8], Step [60/157], Loss: 1.8934
Epoch [1/8], Step [70/157], Loss: 1.9498
Epoch [1/8], Step [80/157], Loss: 1.8909
Epoch [1/8], Step [90/157], Loss: 1.9409
Epoch [1/8], Step [100/157], Loss: 1.8412
Epoch [1/8], Step [110/157], Loss: 1.4989
Epoch [1/8], Step [120/157], Loss: 1.6666
Epoch [1/8], Step [130/157], Loss: 1.7141
Epoch [1/8], Step [140/157], Loss: 1.4904
Epoch [1/8], Step [150/157], Loss: 1.3420
Epoch [2/8], Step [10/157], Loss: 1.6007
Epoch [2/8], Step [20/157], Loss: 1.3797
Epoch [2/8], Step [30/157], Loss: 1.2969
Epoch [2/8], Step [40/157], Loss: 1.2998
Epoch [2/8], Step [50/157], Loss: 1.0776
Epoch [2/8], Step [60/157], Loss: 1.2083
Epoch [2/8], Step [70/157], Loss: 0.9983
Epoch [2/8], Step [80/157], Loss: 1.3606
Epoch [2/8], Step [90/157], Loss: 1.2627
Epoch [2/8