### AI5000 Deep Learning
#### Assignment 3

In [7]:
import torch
import torch.nn as nn
from tqdm import tqdm

1. **Self-Attention for Object Recognition with CNNs**: 

Implement a sample CNN with one or more self-attention layer(s) for performing object recognition over CIFAR-10 dataset. You have to implement the self-attention layer yourself and use it in the forward function defined by you. All
other layers (fully connected, nonlinearity, conv layer, etc.) can be bulit-in implementations. The network can be a simpler one (e.g., it may have 1x Conv, 4x [Conv followed by SA], 1x Conv, and 1x GAP). Please refer to the reading material provided here or any other similar one. [10 Marks]

In [8]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        
        out_channels = max(in_channels//8, 1)
        self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.final_conv = nn.Conv2d(out_channels, in_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        batch_size, C, W, H = input.size()

        # batch x C' x W x H -> batch x WH x C'
        query_matrix = self.query_conv(input).view(batch_size, -1, W*H).permute(0, 2, 1)  

        # batch x C' x W x H -> batch x C' x WH 
        key_matrix = self.key_conv(input).view(batch_size, -1, W*H)

        # batch x C' x W x H -> batch x WH x C'
        value_matrix = self.value_conv(input).view(batch_size, -1, W*H).permute(0, 2, 1) 

        # Computing product of key and query matrices and taking softmax
        # batch x WH x C', batch x C' x WH -> WH x WH
        attention_weights = self.softmax(torch.matmul(query_matrix, key_matrix))

        # Computing product of attention weights and value matrix
        # WH x WH, batch x WH x C'-> batch x WH x C' -> batch x C' x WH -> batch x C' x W x H
        output = torch.matmul(attention_weights, value_matrix).permute(0, 2, 1).view(batch_size, -1, W, H)

        # Applying 1x1 conv to get the final output
        output = self.final_conv(output)

        return input + output

SA_block = SelfAttention(1)
x = torch.rand(1, 1, 3, 3)
SA_block.forward(x)

tensor([[[[-0.0877, -0.0682,  0.2165],
          [ 0.0914,  0.1210,  0.8490],
          [ 0.8329,  0.1387,  0.4903]]]], grad_fn=<AddBackward0>)

In [61]:
# create a CNN with one or more self attention layers to perform object recognition on CIFAR-10 dataset using pytorch and cuda

# network structure: 
# 1x Conv
# 4x (Conv + Self Attention)
# 1x Conv
# 1x GAP

class CNNwithSelfAttention(nn.Module):
    def __init__(self, output_channels = 10):
        super(CNNwithSelfAttention, self).__init__()

        self.conv0 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

        self.conv1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.sa1 = SelfAttention(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.sa2 = SelfAttention(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.sa3 = SelfAttention(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.sa4 = SelfAttention(256)
        
        self.conv5 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
        #put a linear layer to get 10 channels in output, output_dim = [batch_size, 10]
        self.Linear = nn.Linear(128, output_channels)

    def forward(self, x):
        x = self.conv0(x)

        x = self.conv1(x)
        x = self.sa1(x)
        x = self.conv2(x)
        x = self.sa2(x)
        x = self.conv3(x)
        x = self.sa3(x)
        x = self.conv4(x)
        x = self.sa4(x)

        x = self.conv5(x)

        x = self.global_avg_pool(x)
        
        x = x.view(x.size(0), -1)
        x = self.Linear(x)

        return x

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

CNNwithSA = CNNwithSelfAttention()
x = torch.rand(100, 3, 32, 32)
output = CNNwithSA.forward(x)
print(output.size())

print("Params:", count_parameters(CNNwithSA))


torch.Size([100, 10])
Params: 733118


In [None]:
def adjust_lr(optimizer, epoch, lr):
    if epoch < 10:
        lr = lr
    elif epoch < 20:
        lr = lr/10
    else:
        lr = lr/100
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [123]:
# Creating a CNN with self attention layers to perform object recognition on CIFAR-10 dataset using pytorch and cuda

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# Hyper-parameters
num_epochs = 50
batch_size = 128
learning_rate = 1e-3
frame_size = (32, 32)

transforms_to_train = transforms.Compose([
                                            # transforms.ColorJitter(brightness=.33, saturation=.33),
                                            # transforms.RandomHorizontalFlip(p=0.5),
                                            # transforms.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.10)),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                        ])

transforms_to_test = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                        ])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root = './data',
                                             train = True, 
                                             transform = transforms_to_train,
                                             download = True)

test_dataset = torchvision.datasets.CIFAR10(root = './data',
                                            train = False, 
                                            transform = transforms_to_test)
val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Data loader for training, testing, validation

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                             batch_size = batch_size,
                                             shuffle = True)
val_loader = torch.utils.data.DataLoader(dataset = val_dataset,
                                            batch_size = batch_size,
                                            shuffle = False)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                            batch_size = batch_size,
                                            shuffle = False)

print(len(train_loader), len(val_loader), len(test_loader))

Device: cpu
Files already downloaded and verified
313 79 79


In [11]:
# defining model:
model = CNNwithSelfAttention().to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

# Train the model
total_steps = len(train_loader)
loss_list = []

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model.forward(images)

        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), f"model_{epoch+1}.pt")
    loss_list.append(loss.item())

KeyboardInterrupt: 

In [None]:
# test model accuracy on test data
model = CNNwithSelfAttention().to(device)
model.eval()

# load weights
model.load_state_dict(torch.load("model_21.pt", map_location=torch.device('cpu')))
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model.forward(images)
        
        _, predicted_value = torch.max(outputs, 1)

        correct += (predicted_value == labels).sum().item()
        total += labels.size(0)

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])


2. **Object Recognition with Vision Transformer**: 

Implement and train an Encoder only Transformer (ViT-like) for the above object recognition task. In other words, implement multi-headed self-attention for the image classification (i.e., appending a < class > token to the image patches that are accepted as input tokens). Compare the performance of the two implementations (try to keep the number of parameters to be comparable and use the same amount of training and testing
data). [10 Marks]

In [110]:
class AttentionBlock(nn.Module):
    def __init__(self, inputdim, qdim, vdim):
        super(AttentionBlock, self).__init__()
        self.qdim = qdim
        self.query = nn.Linear(inputdim, qdim)
        self.key = nn.Linear(inputdim, qdim)
        self.value = nn.Linear(inputdim, vdim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # Compute attention scores
        scores = torch.matmul(q, k.permute(0, 2, 1))
        scores = scores / torch.sqrt(torch.tensor(self.qdim))
        attention_weights = self.softmax(scores)
        
        # Apply attention weights to value
        output = torch.matmul(attention_weights, v)
        
        # return [B, N, V]
        return output

In [112]:
class MultiHeadAttention(nn.Module):
    def __init__(self, inputdim, qdim, vdim, num_heads):
        super(MultiHeadAttention, self).__init__()

        self.num_heads = num_heads
        self.heads = nn.ModuleList([AttentionBlock(inputdim, qdim, vdim) for _ in range(num_heads)])
        self.fc = nn.Linear(num_heads * vdim, inputdim)

    def forward(self, x):
        # input shape: [B, N, D]

        # Apply attention heads, shape = [B, N, V]
        head_outputs = [head(x) for head in self.heads]

        # Concatenate outputs, shape = [B, N, V * num_heads]
        multihead_output = torch.cat(head_outputs, dim=-1)

        # Apply linear layer, shape = [B, N, D]
        output = self.fc(multihead_output)

        return output

# temp = MultiHeadAttention(128, 32, 64, 4)
# x = torch.rand(100, 64, 128)
# output = temp.forward(x)
# print(output.shape)

torch.Size([100, 64, 128])


In [12]:
# Implementing a ViT like transformer model with self attention layers to perform object recognition on CIFAR-10 dataset using pytorch and cuda

# creating MultiHeadAttention class using the Self Attention class defined above

# class MultiHeadAttention(nn.Module):
#     def __init__(self, in_channels, num_heads):
#         super(MultiHeadAttention, self).__init__()

#         self.num_heads = num_heads
#         self.attention_heads = nn.ModuleList([SelfAttention(in_channels) for _ in range(num_heads)])
#         self.conv1x1 = nn.Conv2d(in_channels*num_heads, in_channels, kernel_size=1)

#     def forward(self, input):

#         attention_outputs = [head(input) for head in self.attention_heads]
#         concatenated_attention_output = torch.cat(attention_outputs, dim = 1)
#         output = self.conv1x1(concatenated_attention_output)

#         return output

# MHA = MultiHeadAttention(3, 8)
# x = torch.rand(1, 3, 32, 32)
# output = MHA.forward(x)


In [116]:
# Create the Transformer Encoder block in the ViT model

class Transformer_Encoder(nn.Module):
    def __init__(self, input_dim, qdim, vdim, num_heads, mlp_dim):
        # input_dim = D
        # 
        super(Transformer_Encoder, self).__init__()

        self.norm1 = nn.LayerNorm(input_dim)
        self.multi_head_attention = MultiHeadAttention(input_dim, qdim, vdim, num_heads)

        self.norm2 = nn.LayerNorm(input_dim)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, input_dim)
        )

    def forward(self, input):
        # input shape = [B, N, D]
        x = self.norm1(input)
        x = self.multi_head_attention(x)
        x_afterMHA = x + input

        # B, N, D
        x = self.norm2(x_afterMHA)

        x = x.view(x.size(0) * x.size(1), -1)
        x = self.mlp(x)
        x = x.view(input.shape[0], input.shape[1], -1)
        x = x + x_afterMHA

        return x

# transformer = Transformer_Encoder(128, 32, 64, 4, 512)
# x = torch.rand(100, 64, 128)
# output = transformer.forward(x)
# print(output.shape)

torch.Size([100, 64, 128])


In [121]:
class ViTEncoder(nn.Module):
    def __init__(self, in_channels, input_shape = [32, 32], q_dim = 64, v_dim = 64, num_heads = 8, mlp_dim = 256, num_layers_encoder = 4, proj_dim = 128, patch_size = 4, output_classes = 10):
        super(ViTEncoder, self).__init__()
        self.patch_size = patch_size
        self.unfold = nn.Unfold(kernel_size = patch_size, stride = patch_size)

        self.class_token = nn.Parameter(torch.randn(1, 1, proj_dim))
        self.positional_encoding = nn.Parameter(torch.randn((1, (input_shape[0] * input_shape[1])//(patch_size * patch_size) + 1, proj_dim)))

        self.linear_proj = nn.Linear(in_channels * patch_size * patch_size, proj_dim)
        self.transformer_blocks = nn.ModuleList([Transformer_Encoder(proj_dim, q_dim, v_dim, num_heads, mlp_dim) for _ in range(num_layers_encoder)])
        self.mlp_final = nn.Sequential(
            nn.Linear(proj_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, output_classes)
        )

    def extract_patches(self, image):
        # image shape: [B, C, H, W]
        # return shape: [B, num_patches, C * patch_size * patch_size]
        patched_image = self.unfold(image).permute(0, 2, 1)

        return patched_image

    def forward(self, input):
        # [B, num_patches, C * patch_size * patch_size]
        x = self.extract_patches(input) 
        batch_size, num_patches, _ = x.size()
        
        # [B * num_patches, C * patch_size * patch_size]
        x = x.reshape (x.size(0) * x.size(1), -1) 
        
        # [B * num_patches, proj_dim]
        x = self.linear_proj(x) 

        # [B, num_patches, proj_dim]
        x = x.view(batch_size, num_patches, -1)

        # add class token
        class_token = self.class_token.expand(batch_size, -1, -1)
        x = torch.cat((class_token, x), dim = 1)
        
        # add positional encoding
        x = x + self.positional_encoding
        
        # [B, num_patches, proj_dim]
        for transformer in self.transformer_blocks:
            x = transformer(x)

        # take only 1st output (from class token)
        x = x[:, 0, :]

        # [B, output_classes] 
        x = self.linear_final(x)

        return x

encoder = ViTEncoder(3)
x = torch.rand(100, 3, 32, 32)
print(encoder.forward(x).shape)
    

# compute params in the encoder:
print("Params:", count_parameters(encoder))

torch.Size([100, 10])
Params: 1336970


In [63]:
# defining model:
model = ViTEncoder(3).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

# Train the model
total_steps = len(train_loader)
loss_list = []

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model.forward(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_steps, loss.item()))
    loss_list.append(loss.item())

KeyboardInterrupt: 

In [68]:
vit_model = ViTEncoder(3).to(device)
vit_model.load_state_dict(torch.load("vit_32.pt", map_location=torch.device('cpu')))

print(len(test_loader))
# test model accuracy on test data
vit_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = vit_model.forward(images)
        
        _, predicted_value = torch.max(outputs, 1)

        correct += (predicted_value == labels).sum().item()
        total += labels.size(0)
        print(correct/total * 100)

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

1
35.05
Accuracy of the model on the test images: 35.05 %


In [31]:
import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        device = img.device

        x = self.to_patch_embedding(img)
        
        x += self.pos_embedding.to(device, dtype=x.dtype)

        print(x.shape)
        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

vit = SimpleViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 10,
    dim = 256,
    depth = 4,
    heads = 8,
    mlp_dim = 512,
    channels = 3
)
x = torch.randn(10, 3, 32, 32)
output = vit(x)

torch.Size([10, 16, 256])
