In [1]:
%load_ext autoreload
%autoreload 2

# Exercise 2

<img src="./images/02.png" width=800>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
import os
import numpy as np
import wandb
from torchinfo import summary

from torch.utils.data import Dataset, DataLoader
from utils import train_network, set_seed, accuracy_score_wrapper, weight_reset

  from tqdm.autonotebook import tqdm


In [None]:
import wandb
wandb.login()

In [3]:
torch.backends.cudnn.deterministic = True
set_seed(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Dataset and DataLoader

In [None]:
mnist_train = torchvision.datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)

In [10]:
class LargestDigitVariable(Dataset):
    def __init__(self, dataset, max_to_sample=6):
        super().__init__()
        self.dataset = dataset
        self.max_to_sample = max_to_sample
    
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        how_many = np.random.randint(1, self.max_to_sample, size=1)[0]
        selected = np.random.randint(1, len(self.dataset), size=how_many)
        x_new = torch.stack(
            [self.dataset[i][0] for i in selected]
            + [torch.zeros((1, 28, 28)) for i in range(self.max_to_sample-how_many)])
        y_new = max([self.dataset[i][1] for i in selected])
        return x_new, y_new

In [None]:
B = 128

largest_train = LargestDigitVariable(mnist_train)
largest_test = LargestDigitVariable(mnist_test)
T = largest_train.max_to_sample
train_loader = DataLoader(largest_train, batch_size=B, shuffle=True)
test_loader = DataLoader(largest_test, batch_size=B)

In [6]:
def get_mask_fill(x, time_dimension=1, fill=0):
    """

    :param x: (B, ..., T, ...)
    :type x: tensor
    :param time_dimension: time dimention of x, defaults to 1
    :type time_dimension: int, optional
    :param fill: the constand use dto denote that an item, defaults to 0
    :type fill: int, optional
    """
    dimensions_to_sum_over = list(range(1, len(x.shape)))
    if time_dimension in dimensions_to_sum_over:
        dimensions_to_sum_over.remove(time_dimension)
    with torch.no_grad():
        mask = torch.sum((x != fill), dim=dimensions_to_sum_over) > 0
        return mask

In [None]:
class DotScore(nn.Module):

    def __init__(self, H):
        """
        H: the number of dimensions coming into the dot score. 
        """
        super().__init__()
        self.H = H
    
    def forward(self, states, context):
        """
        states: (B, T, H) shape
        context: (B, H) shape
        output: (B, T, 1), giving a score to each of the T items based on the context 
        
        """
        T = states.size(1)
        #compute $\boldsymbol{h}_t^\top \bar{\boldsymbol{h}}$
        scores = torch.bmm(states,context.unsqueeze(2)) / np.sqrt(self.H) #(B, T, H) -> (B, T, 1)
        return scores

In [None]:
class GeneralScore(nn.Module):
    def __init__(self, H, init_method='random'):
        """
        H: The number of dimensions.
        init_method: 'random' (default) or 'dotscore_like'.
        """
        super().__init__()
        self.w = nn.Bilinear(H, H, 1)

        if init_method == 'dotscore_like':
            with torch.no_grad():
                identity_scaled = torch.eye(H).unsqueeze(0) / np.sqrt(H)
                epsilon = torch.empty_like(identity_scaled).uniform_(-.01, .01)
                self.w.weight.copy_(identity_scaled + epsilon)
        elif init_method != 'random':
            raise ValueError(f"Unknown init_method: {init_method}")
    
    def forward(self, states, context):
        """
        states: (B, T, H) shape
        context: (B, H) shape
        output: (B, T, 1), giving a score to each of the T items based on the context 
        
        """
        T = states.size(1)
        #Repeating the values T times 
        context = torch.stack([context for _ in range(T)], dim=1) #(B, H) -> (B, T, H)
        scores = self.w(states, context) #(B, T, H) -> (B, T, 1)
        return scores        

In [None]:
class ApplyAttention(nn.Module):
    """
    This helper module is used to apply the results of an attention mechanism toa set of inputs. 
    """

    def __init__(self):
        super().__init__()
        
    def forward(self, states, attention_scores, mask=None):
        """
        states: (B, T, H) shape giving the T different possible inputs
        attention_scores: (B, T, 1) score for each item at each context
        mask: None if all items are present. Else a boolean tensor of shape 
            (B, T), with `True` indicating which items are present / valid. 
            
        returns: a tuple with two tensors. The first tensor is the final context
        from applying the attention to the states (B, H) shape. The second tensor
        is the weights for each state with shape (B, T, 1). 
        """
        
        if mask is not None:
            #set everything not present to a large negative value that will cause vanishing gradients 
            attention_scores[~mask] = -1000.0
        #compute the weight for each score
        weights = F.softmax(attention_scores, dim=1) #(B, T, 1) still, but sum(T) = 1
    
        final_context = (states*weights).sum(dim=1) #(B, T, D) * (B, T, 1) -> (B, D)
        return final_context, weights

### Backbone: Fully_Connected

In [None]:
class Flatten2(nn.Module):
    """
    Takes a vector of shape (A, B, C, D, E, ...)
    and flattens everything but the first two dimensions, 
    giving a result of shape (A, B, C*D*E*...)
    """
    def forward(self, input):
        return input.view(input.size(0), input.size(1), -1)

In [None]:
class SmarterAttentionNetFC(nn.Module):

    def __init__(self, input_size, hidden_size, out_size, score_net=None):
        super().__init__()
        self.backbone = nn.Sequential(
            Flatten2(),# Shape is now (B, T, D)
            nn.Linear(input_size,hidden_size), #Shape becomes (B, T, H)
            nn.LeakyReLU(),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
        )#returns (B, T, H)
        
        #Try changing this and see how the results change!
        self.score_net = score_net

        self.apply_attn = ApplyAttention()
        
        self.prediction_net = nn.Sequential( #(B, H), 
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size,hidden_size),
            nn.LeakyReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size, out_size ) #(B, H)
        )
        
    
    def forward(self, input):

        mask = get_mask_fill(input)

        h = self.backbone(input) #(B, T, D) -> (B, T, H)

        #h_context = torch.mean(h, dim=1) 
        #computes torch.mean but ignoring the masked out parts
        #first add together all the valid items
        h_context = (mask.unsqueeze(-1)*h).sum(dim=1)#(B, T, H) -> (B, H)
        #then divide by the number of valid items, pluss a small value incase a bag was all empty
        h_context = h_context/(mask.sum(dim=1).unsqueeze(-1)+1e-10)

        scores = self.score_net(h, h_context) # (B, T, H) , (B, H) -> (B, T, 1)

        final_context, _ = self.apply_attn(h, scores, mask=mask)

        return self.prediction_net(final_context)
        

In [None]:
neurons = 256
classes = 10
D = 28 * 28
attn_dc = SmarterAttentionNetFC(D, neurons, classes, score_net=DotScore(neurons))
attn_gc = SmarterAttentionNetFC(D, neurons, classes, score_net=GeneralScore(neurons, init_method='random'))
attn_gcinit = SmarterAttentionNetFC(D, neurons, classes, score_net=GeneralScore(neurons, init_method='dotscore_like'))

## Training

In [None]:
loss_func = nn.CrossEntropyLoss()
score_funcs = {"Accuracy": accuracy_score_wrapper}
epochs = 10
config = {
    'device': device,
    'loss_func': loss_func.__class__.__name__,
    'epochs': epochs,
    'batch_size': B,
    }

In [None]:
models = {
    'dot_score': attn_dc,
    'general_score': attn_gc,
    'general_score_init': attn_gcinit,
}

In [None]:
for experiment, model in models.items():
    print(experiment)
    optimizer = optim.AdamW(model.parameters())
    config['optimizer'] = optimizer.defaults
    with open('model_summary.txt', 'w') as f:
        f.write(str(summary(model, inpt_size=(B, T, 1, 28, 28))))
    wandb.init(
        project="Exercise10_2",
        name=experiment,
        config=config
    )
    artifact = wandb.Artifact('model_summary', type='model_architecture')
    artifact.add_file('model_summary.txt')
    wandb.log_artifact(artifact)

    results = train_network(
        model=model,
        optimizer=optimizer,
        loss_func=loss_func,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=epochs,
        device=device,
        score_funcs=score_funcs          
    )

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 100%|██████████| 1/1 [00:33<00:00, 33.98s/it]
