In [1]:
%load_ext autoreload
%autoreload 2

# Exercise 1

<img src="./images/01.png" width=800>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
import os
import numpy as np
from torchinfo import summary

from torch.utils.data import Dataset, DataLoader
from utils import train_network, set_seed, accuracy_score_wrapper

  from tqdm.autonotebook import tqdm


In [None]:
import wandb
wandb.login()

In [3]:
torch.backends.cudnn.deterministic = True
set_seed(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Dataset and DataLoader

In [4]:
mnist_train = torchvision.datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:56<00:00, 175626.19it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 205896.13it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:06<00:00, 241701.51it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2723060.14it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
class LargestDigitVariable(Dataset):
    def __init__(self, dataset, max_to_sample=6):
        super().__init__()
        self.dataset = dataset
        self.max_to_sample = max_to_sample
    
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        how_many = np.random.randint(1, self.max_to_sample, size=1)[0]
        selected = np.random.randint(1, len(self.dataset), size=how_many)
        padding_needed = self.max_to_sample - how_many
        x_new = torch.stack(
            [self.dataset[i][0] for i in selected]
            + [torch.zeros((1, 28, 28)) for _ in range(padding_needed)])
        y_new = max([self.dataset[i][1] for i in selected])
        return x_new, y_new

In [None]:
B = 128

largest_train = LargestDigitVariable(mnist_train)
largest_test = LargestDigitVariable(mnist_test)
T = largest_train.max_to_sample
train_loader = DataLoader(largest_train, batch_size=B, shuffle=True)
test_loader = DataLoader(largest_test, batch_size=B)

In [None]:
def get_mask_fill(x, time_dimension=1, fill=0):
    """

    :param x: (B, ..., T, ...)
    :type x: tensor
    :param time_dimension: time dimention of x, defaults to 1
    :type time_dimension: int, optional
    :param fill: the constand use dto denote that an item, defaults to 0
    :type fill: int, optional
    """
    dimensions_to_sum_over = list(range(1, len(x.shape)))
    if time_dimension in dimensions_to_sum_over:
        dimensions_to_sum_over.remove(time_dimension)
    with torch.no_grad():
        mask = torch.sum((x != fill), dim=dimensions_to_sum_over) > 0
        return mask

In [None]:
class DotScore(nn.Module):

    def __init__(self, H):
        """
        H: the number of dimensions coming into the dot score. 
        """
        super().__init__()
        self.H = H
    
    def forward(self, states, context):
        """
        states: (B, T, H) shape
        context: (B, H) shape
        output: (B, T, 1), giving a score to each of the T items based on the context 
        
        """
        T = states.size(1)
        scores = torch.bmm(states,context.unsqueeze(2)) / np.sqrt(self.H) #(B, T, H) -> (B, T, 1)
        return scores

In [None]:
class GeneralScore(nn.Module):

    def __init__(self, H):
        """
        H: the number of dimensions coming into the dot score. 
        """
        super().__init__()
        self.w = nn.Bilinear(H, H, 1) #stores $W$
    
    def forward(self, states, context):
        """
        states: (B, T, H) shape
        context: (B, H) shape
        output: (B, T, 1), giving a score to each of the T items based on the context 
        
        """
        T = states.size(1)
        #Repeating the values T times 
        context = torch.stack([context for _ in range(T)], dim=1) #(B, H) -> (B, T, H)
        scores = self.w(states, context) #(B, T, H) -> (B, T, 1)
        return scores        

In [None]:
class AdditiveAttentionScore(nn.Module):

    def __init__(self, H):
        super().__init__()
        self.v = nn.Linear(H, 1) 
        self.w = nn.Linear(2*H, H)#2*H because we are going to concatenate two inputs
    
    def forward(self, states, context):
        """
        states: (B, T, H) shape
        context: (B, H) shape
        output: (B, T, 1), giving a score to each of the T items based on the context 
        
        """
        T = states.size(1)
        #Repeating the values T times 
        context = torch.stack([context for _ in range(T)], dim=1) #(B, H) -> (B, T, H)
        state_context_combined = torch.cat((states, context), dim=2) #(B, T, H) + (B, T, H)  -> (B, T, 2*H)
        scores = self.v(torch.tanh(self.w(state_context_combined))) # (B, T, 2*H) -> (B, T, 1)
        return scores

In [None]:
class ApplyAttention(nn.Module):
    """
    This helper module is used to apply the results of an attention mechanism to a set of inputs. 
    """

    def __init__(self):
        super().__init__()
        
    def forward(self, states, attention_scores, mask=None):
        """
        states: (B, T, H) shape giving the T different possible inputs
        attention_scores: (B, T, 1) score for each item at each context
        mask: None if all items are present. Else a boolean tensor of shape 
            (B, T), with `True` indicating which items are present / valid. 
            
        returns: a tuple with two tensors. The first tensor is the final context
        from applying the attention to the states (B, H) shape. The second tensor
        is the weights for each state with shape (B, T, 1). 
        """
        
        if mask is not None:
            #set everything not present to a large negative value that will cause vanishing gradients 
            attention_scores[~mask] = -1000.0
        #compute the weight for each score
        weights = F.softmax(attention_scores, dim=1) #(B, T, 1) still, but sum(T) = 1
    
        final_context = (states*weights).sum(dim=1) #(B, T, D) * (B, T, 1) -> (B, D)
        return final_context, weights

### Backbone: Fully_Connected

In [None]:
class Flatten2(nn.Module):
    """
    Takes a vector of shape (A, B, C, D, E, ...)
    and flattens everything but the first two dimensions, 
    giving a result of shape (A, B, C*D*E*...)
    """
    def forward(self, input):
        return input.view(input.size(0), input.size(1), -1)

In [None]:
class SmarterAttentionNetFC(nn.Module):

    def __init__(self, input_size, hidden_size, out_size, score_net=None):
        super().__init__()
        self.backbone = nn.Sequential(
            Flatten2(),# Shape is now (B, T, D)
            nn.Linear(input_size,hidden_size), #Shape becomes (B, T, H)
            nn.LeakyReLU(),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
        )#returns (B, T, H)
        
        #Try changing this and see how the results change!
        self.score_net = AdditiveAttentionScore(hidden_size) if (score_net is None) else score_net

        self.apply_attn = ApplyAttention()
        
        self.prediction_net = nn.Sequential( #(B, H), 
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size, out_size ) #(B, H)
        )
        
    
    def forward(self, input):

        mask = get_mask_fill(input)

        h = self.backbone(input) #(B, T, D) -> (B, T, H)

        #h_context = torch.mean(h, dim=1) 
        #computes torch.mean but ignoring the masked out parts
        #first add together all the valid items
        h_context = (mask.unsqueeze(-1)*h).sum(dim=1)#(B, T, H) -> (B, H)
        #then divide by the number of valid items, pluss a small value incase a bag was all empty
        h_context = h_context/(mask.sum(dim=1).unsqueeze(-1)+1e-10)

        scores = self.score_net(h, h_context) # (B, T, H) , (B, H) -> (B, T, 1)

        final_context, _ = self.apply_attn(h, scores, mask=mask)

        return self.prediction_net(final_context)
        

In [None]:
neurons = 256
classes = 10
D = 28 * 28
attn_add_fc = SmarterAttentionNetFC(D, neurons, classes, score_net=AdditiveAttentionScore(neurons))

### Backbone: CNN

In [None]:
C = 1
n_filters = 16

In [None]:
def cnn_layer(in_filters, out_filters, kernel_size=3):
    """
    in_filters: how many channels are in the input to this layer
    out_filters: how many channels should this layer output
    kernel_size: how large should the filters of this layer be
    """
    padding = kernel_size//2
    return nn.Sequential(
        nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding), 
        nn.BatchNorm2d(out_filters),
        nn.LeakyReLU(), # I'm not setting the leak value to anything just to make the code shorter. 
    )

In [None]:
class SmarterAttentionNetCNN(nn.Module):

    def __init__(self, input_size, n_filters ,out_size, score_net=None):
        super().__init__()
        C, W, H = input_size
        self.conv_backbone = nn.Sequential(
            cnn_layer(C, n_filters, 3), #Shape becomes (B*T, n_filters, W, H)
            cnn_layer(n_filters, n_filters, 3),  #Shape becomes (B*T, n_filters, W, H)
            nn.MaxPool2d(2),
            cnn_layer(n_filters, 2*n_filters, 3),  #Shape becomes (B*T, 2*n_filters, W//2, H//2)
            # cnn_layer(2*n_filters, C, 3), #Shape becomes (B*T, C, W//2, H//2)
            )#returns (B*T, C, W, H)
        self.hidden_size = 2 * n_filters * (W//2) * (H//2)
        #Try changing this and see how the results change!
        self.score_net = AdditiveAttentionScore(self.hidden_size) if (score_net is None) else score_net

        self.apply_attn = ApplyAttention()
        
        self.prediction_net = nn.Sequential( #(B, H), 
            nn.BatchNorm1d(self.hidden_size),
            nn.Linear(self.hidden_size,self.hidden_size),
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.hidden_size),
            nn.Linear(self.hidden_size, out_size ) #(B, H)
        )

    def forward(self, input):
        B, T, C, W, H = input.shape

        # Reshape for CNN: (B * T, C, W, H)
        cnn_input = input.view( B*T, C, W, H)

        # Pass through convolutional backbone
        h = self.conv_backbone(cnn_input)
        h = h.view(B, T, -1) #(B, T, D) -> (B, T, H)
        mask = get_mask_fill(input)

        #h_context = torch.mean(h, dim=1) 
        #computes torch.mean but ignoring the masked out parts
        #first add together all the valid items
        h_context = (mask.unsqueeze(-1)*h).sum(dim=1)#(B, T, H) -> (B, H)
        #then divide by the number of valid items, pluss a small value incase a bag was all empty
        h_context = h_context/(mask.sum(dim=1).unsqueeze(-1)+1e-10)

        scores = self.score_net(h, h_context) # (B, T, H) , (B, H) -> (B, T, 1)

        final_context, _ = self.apply_attn(h, scores, mask=mask)

        return self.prediction_net(final_context)
        

In [None]:
attn_add_cnn = SmarterAttentionNetCNN((1, 28, 28), 32, classes,)

## Training

In [None]:
loss_func = nn.CrossEntropyLoss()
score_funcs = {"Accuracy": accuracy_score_wrapper}
epochs = 10
config = {
    'device': device,
    'loss_func': loss_func.__class__.__name__,
    'epochs': epochs,
    'batch_size': B,
    }

In [None]:
models = {
    'fc': attn_add_fc,
    'cnn': attn_add_cnn,
}

In [None]:
for experiment, model in models.items():
    
    optimizer = optim.AdamW(model.parameters())
    config['optimizer'] = optimizer.defaults
    with open('model_summary.txt', 'w') as f:
        f.write(str(summary(model, inpt_size=(B, T, 1, 28, 28))))
    wandb.init(
        project="Exercise10_1",
        name=experiment,
        config=config
    )
    artifact = wandb.Artifact('model_summary', type='model_architecture')
    artifact.add_file('model_summary.txt')
    wandb.log_artifact(artifact)
    results = train_network(
        model=model,
        optimizer=optimizer,
        loss_func=loss_func,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=epochs,
        device=device,
        score_funcs=score_funcs          
    )
wandb.finish()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 100%|██████████| 1/1 [00:33<00:00, 33.98s/it]
