# Recurrent Neural Networks and Transformer

## Text Classification with RNN

In [None]:
import math
import portalocker
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchtext.datasets import AG_NEWS
train_iter = (AG_NEWS(split='train'))

# Let's check what the data looks like
print(len(train_iter))
print(next(train_iter))

### <font size='4'>Implement a RNNCell</font>

In [None]:
# Documentation of nn.Module https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module
class RNNCell(torch.nn.Module):
    """
    RNNCell is a single cell that takes x_t and h_{t_1} as input and outputs h_t.
    """
    def __init__(self, input_dim: int, hidden_dim: int):
        """
        Constructor of RNNCell.
        
        Inputs: 
        - input_dim: Dimension of the input x_t
        - hidden_dim: Dimension of the hidden state h_{t-1} and h_t
        """
        
        # We always need to do this step to properly implement the constructor
        super(RNNCell, self).__init__()
        
        self.linear_x, self.linear_h, self.non_linear = None, None, None  
        
        
        # Define the linear transformation layers for x_t and h_{t-1} and the non-linear layer using tanh here.                            #
        
        self.linear_x = torch.nn.Linear(input_dim, hidden_dim)
        self.linear_h = torch.nn.Linear(hidden_dim, hidden_dim)
        self.non_linear = torch.nn.Tanh()
       
    def forward(self, x_cur: torch.Tensor, h_prev: torch.Tensor):
        """
        Compute h_t given x_t and h_{t-1}.
        
        Inputs:
        - x_cur: x_t, a tensor with the same of BxC, where B is the batch size and 
          C is the channel dimension.
        - h_prev: h_{t-1}, a tensor with the same of BxH, where H is the channel
          dimension.
        """
        h_cur = None
        # Define the linear transformation layers for x_t and h_{t-1} and the non-linear layer.                                                   #
        linear_x_output = self.linear_x(x_cur)
        linear_h_output = self.linear_h(h_prev)

        # Applying non-linear activation function
        h_cur = self.non_linear(linear_x_output + linear_h_output)
        return h_cur

In [None]:
# Let's run a sanity check of your model
x = torch.randn((2, 8))
h = torch.randn((2, 16))
model = RNNCell(8, 16)
y = model(x, h)
assert len(y.shape) == 2 and y.shape[0] == 2 and y.shape[1] == 16
print(y.shape)

### <font size='4'>Implement a single-layer (single-stack) RNN</font>

In [None]:
class RNN(torch.nn.Module):
    """
    RNN is a single-layer (stack) RNN by connecting multiple RNNCell together in a single
    direction, where the input sequence is processed from left to right.
    """
    def __init__(self, input_dim: int, hidden_dim: int):
        """
        Constructor of the RNN module.
        
        Inputs: 
        - input_dim: Dimension of the input x_t
        - hidden_dim: Dimension of the hidden state h_{t-1} and h_t
        """
        super(RNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        # Define the RNNCell.                                               
        self.rnn_cell = RNNCell(input_dim, hidden_dim)
    
    def forward(self, x: torch.Tensor):
        """
        Compute the hidden representations for every token in the input sequence.
        
        Input:
        - x: A tensor with the shape of BxLxC, where B is the batch size, L is the squence 
          length, and C is the channel dimmension
          
        Return:
        - h: A tensor with the shape of BxLxH, where H is the hidden dimension of RNNCell
        """
        b = x.shape[0]
        seq_len = x.shape[1]
        
        # initialize the hidden dimension
        init_h = x.new_zeros((b, self.hidden_dim))
        
        h = []
        # Compute the hidden representation for every token in the input from left to right.
       
        h_t = init_h
        for t in range(seq_len):
            h_t = self.rnn_cell(x[:, t, :], h_t)
            h.append(h_t.unsqueeze(1))

        h = torch.cat(h, dim=1)
        return h
        

In [None]:
# Let's run a sanity check of your model
x = torch.randn((2, 10, 8))
model = RNN(8, 16)
y = model(x)
assert len(y.shape) == 3
for dim, dim_gt in zip(y.shape, [2, 10, 16]):
    assert dim == dim_gt
print(y.shape)

### <font size='4'>Implement a RNN-based text classifier</font>

In [None]:
class RNNClassifier(nn.Module):
    """
    A RNN-based classifier for text classification. It first converts tokens into word embeddings.
    And then feeds the embeddings into a RNN, where the hidden representations of all tokens are
    then averaged to get a single embedding of the sentence. It will be used as input to a linear
    classifier.
    """
    def __init__(self, 
            vocab_size: int, embed_dim: int, rnn_hidden_dim: int, num_class: int, pad_token: int
        ):
        """
        Constructor.
        
        Inputs:
        - vocab_size: Vocabulary size, indicating how many tokens we have in total.
        - embed_dim: The dimension of word embeddings
        - rnn_hidden_dim: The hidden dimension of the RNN.
        - num_class: Number of classes.
        - pad_token: The index of the padding token.
        """
        super(RNNClassifier, self).__init__()
        
        # word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token)
        
        self.rnn, self.fc = None, None
        
        # Define the RNN and the classification layer.                
       
        # RNN layer
        self.rnn = nn.RNN(input_size=embed_dim, hidden_size=rnn_hidden_dim, batch_first=True)
        
        # Classification layer
        self.fc = nn.Linear(rnn_hidden_dim, num_class)

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        """
        Get classification scores (logits) of the input.
        
        Input:
        - text: Tensor with the shape of BxLxC.
        
        Return:
        - logits: Tensor with the shape of BxK, where K is the number of classes
        """
        
        # get word embeddings
        embedded = self.embedding(text)
        
        logits = None
        # Compute logits of the input.                                    
       
        # RNN layer
        rnn_output, _ = self.rnn(embedded)

        # Average the hidden representations of all tokens
        avg_output = rnn_output.mean(dim=1)
        
        # Classification layer
        logits = self.fc(avg_output
        
        return logits

In [None]:
# Let's run a sanity check of your model
#torch.cuda.is_available()
vocab_size = 10
embed_dim = 16
rnn_hidden_dim = 32
num_class = 3

x = torch.arange(vocab_size).view(1, -1)
x = torch.cat((x, x), dim=0)
print('x.shape: {}'.format(x.shape))
model = RNNClassifier(vocab_size, embed_dim, rnn_hidden_dim, num_class, 0)
y = model(x)
assert len(y.shape) == 2 and y.shape[0] == 2 and y.shape[1] == num_class
print(y.shape)

model = model.to('cuda:0')
x = x.to('cuda:0')
y = model(x)
print(y.shape, y)

### Set up data related stuff

In [None]:
# check here for details https://github.com/pytorch/text/blob/main/torchtext/data/utils.py#L52-#L166
from torchtext.data.utils import get_tokenizer
# check here for details https://github.com/pytorch/text/blob/main/torchtext/vocab/vocab_factory.py#L65-L113
from torchtext.vocab import build_vocab_from_iterator

# A tokenizer splits a input setence into a set of tokens, including those puncuation
# For example
# >>> tokens = tokenizer("You can now install TorchText using pip!")
# >>> tokens
# >>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']
tokenizer = get_tokenizer('basic_english')

train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# Creates a vocab object which maps tokens to indices
# Check here for details https://github.com/pytorch/text/blob/main/torchtext/vocab/vocab.py
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])

# The specified token will be returned when a out-of-vocabulary token is queried.
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

# The padding token we need to use
# The returned indices are always in an array
PAD_TOKEN = vocab(tokenizer('<pad>'))
assert len(PAD_TOKEN) == 1
PAD_TOKEN = PAD_TOKEN[0]

### <font size='4'>Collate Batched Data with Data Loaders</font>

In [None]:
### Documentation of DataLoader https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
from torch.utils.data import DataLoader  

# Merges a list of samples to form a mini-batch of Tensor(s)
def collate_batch(batch):
    """
    Input: 
    - batch: A list of data in a mini batch, where the length denotes the batch size. 
      The actual context depends on a particular dataset. In our case, each position 
      contains a label and a Tensor (tokens in a sentence).
      
    Returns:
    - batched_label: A Tensor with the shape of (B,)
    - batched_text: A Tensor with the shape of (B, L, C), where L is the sequence length
      and C is the channeld dimension
    """
    label_list, text_list, text_len_list = [], [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        text_len_list.append(processed_text.size(0))
    batched_label, batched_text = None, None
    # Pad the text tensor in the mini batch so that they have the same length
    # Specifically, you need to calculate the maximum length in the batch and then add the token PAD_TOKEN to the end of those shorter sentences.                                                      #
    # find the max length of text in the mini-batch
    max_len = max(text_len_list)

    # create a tensor of zeros with the shape (batch_size, max_len, channels)
    batched_text = torch.zeros((len(batch), max_len, channels))

    # pad each sequence in the mini-batch with zeros to make them the same length
    for i, text in enumerate(text_list):
        padded_text = torch.cat((text, torch.zeros((max_len - text.size(0), channels), dtype=torch.int64)), dim=0)
        batched_text[i] = padded_text

    # create a tensor of labels for the mini-batch
    batched_label = torch.tensor(label_list, dtype=torch.int64)
    
    return batched_label.long(), batched_text.long()

# Now, let's check what the batched data looks like
train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
for idx, (label, data) in enumerate(dataloader):
    if idx > 0:
        break
    print('label.shape: {}'.format(label.shape))
    print('label: {}'.format(label))
    print('data.shape: {}'.format(data.shape))

### <font size='4'>Functions of training for a single epoch and evaluation</font>

In [None]:
import time

def train(model, dataloader, loss_func, device, grad_norm_clip):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text) in enumerate(dataloader):
        label = label.to(device)
        text = text.to(device)
        optimizer.zero_grad()
        
        logits = None
        # compute the logits of the input, get the loss, and do the gradient backpropagation.
    
        logits = model(text)

        # compute the loss
        loss = loss_func(logits, label)

        # do the gradient backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
       
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        total_acc += (logits.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(model, dataloader, loss_func, device):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            label = label.to(device)
            text = text.to(device)
            # Compute the logits of the input, get the loss.                    
            logits = model(text)
            loss = loss_func(logits, label)
            total_loss += loss.item()

            pred = torch.argmax(logits, dim=1)
            total_acc += torch.sum(pred == label).item()
            total_count += label.shape[0]
            total_acc += (logits.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

### <font size='4' color='red'>Task 1.6: Define the model and loss function to train the model (3 points)</font>

In [None]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

assert torch.cuda.is_available()
# device = 'cuda'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyper parameters
epochs = 3 # epoch
lr = 0.0005 # learning rate
batch_size = 64 # batch size for training
word_embed_dim = 64
rnn_hidden_dim = 96

train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)

model, loss_func = None, None
# Define the classifier
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, word_embed_dim, rnn_hidden_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, word_embed_dim)
        self.rnn = nn.GRU(word_embed_dim, rnn_hidden_dim, batch_first=True)
        self.fc = nn.Linear(rnn_hidden_dim, num_class)
        
    def forward(self, text):
        embedded = self.embedding(text)
        output, hidden = self.rnn(embedded)
        last_output = output[:, -1, :]
        logits = self.fc(last_output)
        return logits
        
model = TextClassificationModel(vocab_size, word_embed_dim, rnn_hidden_dim, num_class)
loss_func = nn.CrossEntropyLoss()

# copy the model to the specified device (GPU)
model = model.to(device)
        
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-8)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(
    train_dataset, 
    [num_train, len(train_dataset) - num_train]
)

train_dataloader = DataLoader(
    split_train_, batch_size=batch_size,
    shuffle=True, collate_fn=collate_batch
)

valid_dataloader = DataLoader(
    split_valid_, batch_size=batch_size, 
    shuffle=False, collate_fn=collate_batch
)

test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size,
    shuffle=False, collate_fn=collate_batch
)

# You should be able get a validation accuracy around 87%
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_dataloader, loss_func, device, 1)
    accu_val = evaluate(model, valid_dataloader, loss_func, device)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

## Text Classification with Transformer Encoder

### <font size='4'>Implement the multi-head attention module</font>

In [None]:
class MultiHeadAttention(nn.Module):
    """
    A module that computes multi-head attention given query, key, and value tensors.
    """
    def __init__(self, input_dim: int, num_heads: int):
        """
        Constructor.
        
        Inputs:
        - input_dim: Dimension of the input query, key, and value. Here we assume they all have
          the same dimensions. But they could have different dimensions in other problems.
        - num_heads: Number of attention heads
        """
        super(MultiHeadAttention, self).__init__()
        
        assert input_dim % num_heads == 0
        
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.dim_per_head = input_dim // num_heads
        
        # Define the linear transformation layers for key, value, and query.
        # Also define the output layer.
        self.key_layer = nn.Linear(input_dim, input_dim)
        self.value_layer = nn.Linear(input_dim, input_dim)
        self.query_layer = nn.Linear(input_dim, input_dim)
        
        # Define the output layer.
        self.output_layer = nn.Linear(input_dim, input_dim)
        
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor=None):
        """
        Compute the attended feature representations.
        
        Inputs:
        - query: Tensor of the shape BxLxC, where B is the batch size, L is the sequence length,
          and C is the channel dimension
        - key: Tensor of the shape BxLxC
        - value: Tensor of the shape BxLxC
        - mask: Tensor indicating where the attention should *not* be performed
        """
        b = query.shape[0]        
        
        dot_prod_scores = None
        # Compute the scores based on dot product between transformed query, key, and value. 
        # Reshape query, key, value so that multiple heads can be processed in parallel
        Q = self.query_layer(query).view(b, -1, self.num_heads, self.dim_per_head)
        K = self.key_layer(key).view(b, -1, self.num_heads, self.dim_per_head)
        V = self.value_layer(value).view(b, -1, self.num_heads, self.dim_per_head)

        # Transpose dimensions of Q, K, V to prepare for batch matrix multiplication
        Q = Q.transpose(1,2) # Bxnum_headsxLxdim_per_head
        K = K.transpose(1,2) # Bxnum_headsxLxdim_per_head
        V = V.transpose(1,2) # Bxnum_headsxLxdim_per_head

        # Perform batch matrix multiplication to calculate dot product between Q, K
        dot_prod_scores = torch.matmul(Q, K.transpose(-2,-1)) # Bxnum_headsxLxL

        # Scale the dot product scores by the square root of the dimension per head, as per the original paper
        dot_prod_scores = dot_prod_scores / math.sqrt(self.dim_per_head
        
        if mask is not None:
            # We simply set the similarity scores to be near zero for the positions
            # where the attention should not be done. Think of why we do this.
            dot_prod_scores = dot_prod_scores.masked_fill(mask == 0, -1e9)
        
        out = None
        # Compute the attention scores, which are then used to modulate the value tensor. 
        # Finally concate the attended tensors from multiple heads and feed it into the output layer. 
        # Again, think of how to use reshaping tensor to do the concatenation.    #

        #Compute the attention scores using softmax
        attention_scores = F.softmax(dot_prod_scores, dim=-1) # shape: B x H x L x L
        #Compute the attended tensor for each head
        attended_tensors = []
        for i in range(self.num_heads):
            # shape: B x L x L
            curr_att_scores = attention_scores[:, i, :, :]
            # shape: B x L x d_head
            curr_att_tensor = torch.matmul(curr_att_scores, value)
            attended_tensors.append(curr_att_tensor)

        # Step 4: Concatenate the attended tensors across all heads
        concatenated_tensor = torch.cat(attended_tensors, dim=-1) # shape: B x L x C
        
        # Step 5: Apply output linear layer to get the final attended tensor
        out = self.output_layer(concatenated_tensor)
        
        return out

In [None]:
x = torch.randn((2, 10, 8))
mask = torch.randn((2, 10)) > 0.5
mask = mask.unsqueeze(1).unsqueeze(-1)
num_heads = 4
model = MultiHeadAttention(8, num_heads)
y = model(x, x, x, mask)
assert len(y.shape) == len(x.shape)
for dim_x, dim_y in zip(x.shape, y.shape):
    assert dim_x == dim_y
print(y.shape)

### <font size='4'>Implement a Feedforward Network</font>

In [None]:
class FeedForwardNetwork(nn.Module):
    """
    A simple feedforward network. Essentially, it is a two-layer fully-connected
    neural network.
    """
    def __init__(self, input_dim, ff_dim, dropout):
        """
        Inputs:
        - input_dim: Input dimension
        - ff_dim: Hidden dimension
        """
        super(FeedForwardNetwork, self).__init__()
        
        # Define the two linear layers and a non-linear one.
    
        self.fc1 = nn.Linear(input_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
    def forward(self, x: torch.Tensor):
        """
        Input:
        - x: Tensor of the shape BxLxC, where B is the batch size, L is the sequence length,
         and C is the channel dimension
          
        Return:
        - y: Tensor of the shape BxLxC
        """
        
        y = None
        # Process the input.                                                #
        y = self.dropout(x)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.dropout(y)
        y = self.fc2(y)
        
        return y
        

In [None]:
x = torch.randn((2, 10, 8))
ff_dim = 4
model = FeedForwardNetwork(8, ff_dim, 0.1)
y = model(x)
assert len(x.shape) == len(y.shape)
for dim_x, dim_y in zip(x.shape, y.shape):
    assert dim_x == dim_y
print(y.shape)

### <font size='4'>Implement a Single Transformer Encoder Cell</font>

In [None]:
class TransformerEncoderCell(nn.Module):
    """
    A single cell (unit) for the Transformer encoder.
    """
    def __init__(self, input_dim: int, num_heads: int, ff_dim: int, dropout: float):
        """
        Inputs:
        - input_dim: Input dimension for each token in a sequence
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerEncoderCell, self).__init__()
        
        # A single Transformer encoder cell consists of 
        # 1. A multi-head attention module
        # 2. Followed by dropout
        # 3. Followed by layer norm (check nn.LayerNorm)
        # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm
        #                                                                         #
        # At the same time, it also has
        # 1. A feedforward network
        # 2. Followed by dropout
        # 3. Followed by layer norm
 
        # 1. Multi-head attention module
        self.self_attention = MultiHeadAttention(input_dim=input_dim, num_heads=num_heads)
        
        # Followed by dropout
        self.dropout1 = nn.Dropout(p=dropout)
        
        # Followed by layer norm
        self.norm1 = nn.LayerNorm(normalized_shape=input_dim)
        
        # 2. Feedforward network
        self.feedforward = FeedForwardNetwork(input_dim=input_dim, ff_dim=ff_dim, dropout=dropout)
        
        # Followed by dropout
        self.dropout2 = nn.Dropout(p=dropout)
        
        # Followed by layer norm
        self.norm2 = nn.LayerNorm(normalized_shape=input_dim)
 
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor=None):
        """
        Inputs:
        - x: Tensor of the shape BxLxC, where B is the batch size, L is the sequence length,
          and C is the channel dimension
        - mask: Tensor for multi-head attention
        """
        
        y = None
        # Get the output of the multi-head attention part (with dropout and layer norm), which is used as input to the feedforward network 
        # again, followed by dropout and layer norm).                                               
        
        # Multi-head attention part
        attn_output = self.self_attention(x, x, x, mask=mask)
        attn_output = self.dropout1(attn_output)
        # Residual connection
        y = x + attn_output
        y = self.norm1(y)
        
        # Feedforward network part
        ff_output = self.feedforward(y)
        ff_output = self.dropout2(ff_output)
        # Residual connection
        y = x + ff_output
        y = self.norm2(y)
        
        return y

In [None]:
x = torch.randn((2, 10, 8))
mask = torch.randn((2, 10)) > 0.5
mask = mask.unsqueeze(1).unsqueeze(-1)
num_heads = 4
model = TransformerEncoderCell(8, num_heads, 32, 0.1)
y = model(x, mask)
assert len(x.shape) == len(y.shape)
for dim_x, dim_y in zip(x.shape, y.shape):
    assert dim_x == dim_y
print(y.shape)

### <font size='4'>Implement Transformer Encoder</font>

In [None]:
class TransformerEncoder(nn.Module):
    """
    A full encoder consisting of a set of TransformerEncoderCell.
    """
    def __init__(self, input_dim: int, num_heads: int, ff_dim: int, num_cells: int, dropout: float=0.1):
        """
        Inputs:
        - input_dim: Input dimension for each token in a sequence
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - num_cells: Number of TransformerEncoderCells
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerEncoder, self).__init__()
        
        self.norm = None
        # Construct a nn.ModuleList to store a stack of TranformerEncoderCells. Check the documentation here of how to use it
        # https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html#torch.nn.ModuleList
        
        # At the same time, define a layer normalization layer to process the output of the entire encoder.                                           

        self.encoder_cells = nn.ModuleList([
            TransformerEncoderCell(input_dim, num_heads, ff_dim, dropout) for _ in range(num_cells)
        ])
        
        self.norm = nn.LayerNorm(input_dim)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor=None):
        """
        Inputs:
        - x: Tensor of the shape BxLxC, where B is the batch size, L is the sequence length,
          and C is the channel dimension
        - mask: Tensor for multi-head attention
        
        Return:
        - y: Tensor of the shape of BxLxC, which is the normalized output of the encoder
        """
        
        y = None
        # Feed x into the stack of TransformerEncoderCells and then normalize the output with layer norm.                                   #

        # Apply each encoder cell to the input sequentially
        for cell in self.encoder_cells:
          x = cell(x, mask)
        # Apply layer normalization
        y = self.norm(x)
        
        return y
        

In [None]:
x = torch.randn((2, 10, 8))
mask = torch.randn((2, 10)) > 0.5
mask = mask.unsqueeze(1).unsqueeze(-1)
num_heads = 4
model = TransformerEncoder(8, num_heads, 32, 2, 0.1)
y = model(x)
assert len(x.shape) == len(y.shape)
for dim_x, dim_y in zip(x.shape, y.shape):
    assert dim_x == dim_y
print(y.shape)

### <font size='4'>Implement Positional Encoding</font>

In [None]:
class PositionalEncoding(nn.Module):
    """
    A module that adds positional encoding to each of the token's features.
    So that the Transformer is position aware.
    """
    def __init__(self, input_dim: int, max_len: int=10000):
        """
        Inputs:
        - input_dim: Input dimension about the features for each token
        - max_len: The maximum sequence length
        """
        super(PositionalEncoding, self).__init__()
        
        self.input_dim = input_dim
        self.max_len = max_len
        
    def forward(self, x):
        """
        Compute the positional encoding and add it to x.
        
        Input:
        - x: Tensor of the shape BxLxC, where B is the batch size, L is the sequence length,
          and C is the channel dimension
          
        Return:
        - x: Tensor of the shape BxLxC, with the positional encoding added to the input
        """
        seq_len = x.shape[1]
        input_dim = x.shape[2]
        
        pe = None
        # Compute the positional encoding                                   
        # Check Section 3.5 for the definition (https://arxiv.org/pdf/1706.03762.pdf)
                                        
        # PE_{(pos,2i)} = sin(pos / 10000^{2i/\dmodel})                           
        # PE_{(pos,2i+1)} = cos(pos / 10000^{2i/\dmodel})                         
                                                                       
        # You should replace 10000 with max_len here.
    

        # Create the positional encoding matrix
        pe = torch.zeros(self.max_len, input_dim)
        position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, input_dim, 2).float() * (-math.log(self.max_len) / input_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # Add batch dimension
        self.register_buffer('pe', pe)
        x = x + pe.to(x.device)
        return x


In [None]:
# Sanity check
x = torch.randn(1, 100, 20)
pe = PositionalEncoding(20)
y = pe(x)
assert len(x.shape) == len(y.shape)
for dim_x, dim_y in zip(x.shape, y.shape):
    assert dim_x == dim_y

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward((torch.zeros(1, 100, 20)))
plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])

### <font size='4'>Implement a Transformer-based Text Classifier</font>

In [None]:
class TransformerClassifier(nn.Module):
    """
    A Transformer-based text classifier.
    """
    def __init__(self, 
            vocab_size: int, embed_dim: int, num_heads: int, trx_ff_dim: int, 
            num_trx_cells: int, num_class: int, dropout: float=0.1, pad_token: int=0
        ):
        """
        Inputs:
        - vocab_size: Vocabulary size, indicating how many tokens we have in total.
        - embed_dim: The dimension of word embeddings
        - num_heads: Number of attention heads in a multi-head attention module
        - trx_ff_dim: The hidden dimension for a feedforward network
        - num_trx_cells: Number of TransformerEncoderCells
        - dropout: Dropout ratio
        - pad_token: The index of the padding token.
        """
        super(TransformerClassifier, self).__init__()
        
        self.embed_dim = embed_dim
        
        # word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token)
        
        # Define a module for positional encoding, Transformer encoder, and a output layer                                                          #
        # positional encoding layer
        self.pos_enc = PositionalEncoding(embed_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=trx_ff_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_trx_cells)
        
        # output layer
        self.output_layer = nn.Linear(embed_dim, num_class)

    def forward(self, text, mask=None):
        """
        Inputs:
        - text: Tensor with the shape of BxLxC.
        - mask: Tensor for multi-head attention
        
        Return:
        - logits: Tensor with the shape of BxK, where K is the number of classes
        """
        
        # word embeddings, note we multiple the embeddings by a factor
        embedded = self.embedding(text) * math.sqrt(self.embed_dim)
        
        logits = None
        # Apply positional embedding to the input, which is then fed into the encoder.
        # Average pooling is applied then to all the features of all tokens.
        # Finally, the logits are computed based on the pooled features.  

        
        # positional encoding
        embedded = self.pos_enc(embedded)
        
        # Transformer encoder
        transformer_output = self.transformer_encoder(embedded, src_key_padding_mask=mask)
        
        # average pooling
        pooled = torch.mean(transformer_output, dim=1)
        
        # output layer
        logits = self.output_layer(pooled)
        
        return logits

In [None]:
vocab_size = 10
embed_dim = 16
num_heads = 4
trx_ff_dim = 16
num_trx_cells = 2
num_class = 3

x = torch.arange(vocab_size).view(1, -1)
x = torch.cat((x, x), dim=0)
mask = (x != 0).unsqueeze(-2).unsqueeze(1)
model = TransformerClassifier(vocab_size, embed_dim, num_heads, trx_ff_dim, num_trx_cells, num_class)
print('x: {}, mask: {}'.format(x.shape, mask.shape))
y = model(x, mask)
assert len(y.shape) == 2 and y.shape[0] == x.shape[0] and y.shape[1] == num_class
print(y.shape)

### <font size='4'>Define the Model and Loss Function</font>

In [None]:
assert torch.cuda.is_available()
# device = 'cuda'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
epochs = 5 # epoch
lr = 0.0005  # learning rate
batch_size = 64 # batch size for training
  
train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64

num_heads = 4
num_trx_cells = 2

gradient_norm_clip = 1

# Define a Transformer-based text classifier and a loss function.         

# Define the model and loss function
model = TransformerClassifier(vocab_size, emsize, num_heads, 4*emsize, num_trx_cells, num_class).to(device)
criterion = nn.CrossEntropyLoss()

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-8)
total_accu = None

# You should be able to get a validation accuracy around 89%
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_dataloader, loss_func, device, gradient_norm_clip)
    accu_val = evaluate(model, valid_dataloader, loss_func, device)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

## Image Classification with Transformer

### <font size='4'>Implement VisionTransformer for Image Classification</font>

In [None]:
class VisionTransformerClassifier(nn.Module):
    """
    In the model, we partition an image into non-overlapping patches. Each patch is treated as a token.
    We can get a sequence of such tokens by flattening the patches. Each token's embeddings is the
    flattened RGB pixel values. If the patch size is 4, then the embeddings' dimension is 4*4*3.
    You can check this paper https://arxiv.org/pdf/2010.11929.pdf for reference.
    """
    def __init__(self, 
            patch_size: int, num_heads: int, trx_ff_dim: int, 
            num_trx_cells: int, num_class: int, dropout: float=0.1
        ):
        """
        Inputs:
        - patch_size: Size of the non-overlapping patches
        - num_heads: Number of attention heads
        - trx_ff_dim: Hidden dimension of the feedforward network in a Transformer encoder
        - num_trx_cells: Number of TransformerEncoderCells 
        - num_class: Number of image classes
        - dropout: Dropout ratio
        """
        super(VisionTransformerClassifier, self).__init__()
        
        self.patch_size = patch_size
        
        # Define a TransformerEncoder that takens non-overlapping patches of an image as input and another output layer for classification.       #

        # Intuitively, we need 2D positional encodings for each patch according to its x and y coordinates.
        # But this reference paper https://arxiv.org/pdf/2010.11929.pdf hows there is no significance difference on accuracies.

        # Create the embedding layer for the patches
        self.embedding_dim = patch_size * patch_size * 3
        self.num_patches = (224 // patch_size) ** 2  # assuming 224x224 input size
        
        # Patch embedding layer
        self.patch_embedding = nn.Conv2d(in_channels=3, 
                                         out_channels=self.embedding_dim, 
                                         kernel_size=patch_size, stride=patch_size)
        
        # Positional encoding layer
        self.positional_encoding = PositionalEncoding(self.embedding_dim, dropout)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=trx_ff_dim, 
                                                   dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_trx_cells)
        
        # Output layer
        self.output_layer = nn.Linear(self.embedding_dim, num_class)
    #def init_weights(self):
          #initrange = 0.5
          #self.embedding.weight.data.uniform_(-initrange, initrange)
          #self.fc.weight.data.uniform_(-initrange, initrange)
          #self.fc.bias.data.zero_()

    def forward(self, image: torch.Tensor):
        """
        Input:
        - image: Tensor of the shape BxCxHxW, where H and W are the height and width, respectively.
        
        Return:
        - logtis: Classification logits
        """
        
        b, c, h, w = image.shape
        
        # Partition an image into non-overlapping patches. 
        # Think of how to reshape the tensor to convert it to be the BxLxC format, which we have  
        # extensively used for NLP tasks. You will find tensor.permute helpful.    
        # Check documentation here https://pytorch.org/docs/stable/generated/torch.permute.html#torch.permute
        
        # Partition image into non-overlapping patches
        patches = self.patch_embedding(image)  # shape: B x C x H' x W', where H' and W' depend on patch_size
        patches = patches.flatten(2).transpose(1, 2)  # shape: B x L x C, where L = H' x W'
        
        # Apply positional encoding
        patches = self.positional_encoding(patches)
        
        # Apply Transformer encoder
        patches = self.transformer_encoder(patches)
        
        # Apply output layer
        logits = self.output_layer(patches.mean(dim=1))

        return logits

In [None]:
# Sanity check
image = torch.randn((2, 3, 32, 32))
patch_size = 4
num_heads = 4
num_trx_cells = 2
trx_ff_dim = 16
dropout = 0.1
num_class = 5

vit = VisionTransformerClassifier(patch_size, num_heads, trx_ff_dim, num_trx_cells, num_class, dropout)
logits = vit(image)
assert len(logits.shape) == 2 and logits.shape[0] == image.shape[0] and logits.shape[1] == num_class
print(logits.shape)

### Prepare Data Loaders

In [None]:
# let's download the data
%cd ../datasets

# 1 -- Linux 
# 2 -- MacOS
# 3 -- Command Prompt on Windows
# 4 -- manually downloading the data
choice = 1


if choice == 1:
    # should work well on Linux and in Powershell on Windows
    !wget http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
elif choice == 2 or choice ==3:
    # if wget is not available for you, try curl
    # should work well on MacOS
    !curl http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz --output cifar-10-python.tar.gz
else:
    print('Please manually download the data from http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz and put it under the datasets folder.')
!tar -xzvf cifar-10-python.tar.gz

if choice==3:
    !del cifar-10-python.tar.gz
else:
    !rm cifar-10-python.tar.gz

In [None]:
from six.moves import cPickle as pickle
import numpy as np
import os
from imageio import imread
import platform

def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = load_pickle(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  return Xtr, Ytr, Xte, Yte


def get_CIFAR10_data(cifar10_dir, num_training=49000, num_validation=1000, num_test=1000,
                     subtract_mean=True):
    """
    Load the CIFAR-10 dataset from disk and perform preprocessing to prepare
    it for classifiers. These are the same steps as we used for the SVM, but
    condensed to a single function.
    """
    # Load the raw CIFAR-10 data
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = list(range(num_training, num_training + num_validation))
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = list(range(num_training))
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = list(range(num_test))
    X_test = X_test[mask]
    y_test = y_test[mask]

    # Normalize the data: subtract the mean image
    if subtract_mean:
      mean_image = np.mean(X_train, axis=0)
      X_train -= mean_image
      X_val -= mean_image
      X_test -= mean_image

    # Transpose so that channels come first
    X_train = X_train.transpose(0, 3, 1, 2).copy()
    X_val = X_val.transpose(0, 3, 1, 2).copy()
    X_test = X_test.transpose(0, 3, 1, 2).copy()

    # Package data into a dictionary
    return {
      'X_train': X_train, 'y_train': y_train,
      'X_val': X_val, 'y_val': y_val,
      'X_test': X_test, 'y_test': y_test,
    }

# Split the data into train, val, and test sets. 
# Check the get_CIFAR10_data function for more details
cifar10_dir = 'cifar-10-batches-py'
data = get_CIFAR10_data(cifar10_dir)
for k, v in list(data.items()):
    print(('%s: ' % k, v.shape))

In [None]:
from torch.utils.data.dataset import TensorDataset

def make_dataloader(x, y, batch_size, is_train):
    dataset = TensorDataset(
        torch.from_numpy(y).long(),
        torch.from_numpy(x).float() 
    )
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_train,
        num_workers=2,
        drop_last=is_train
    )
    return dataloader
    
train_loader = make_dataloader(data['X_train'], data['y_train'], 8, True)
for idx, (lab, im) in enumerate(train_loader):
    if idx > 1:
        break
    print(im.shape, lab)

### <font size='4'>Define the Model and Loss Function</font>

In [None]:
patch_size = 4
embed_dim = 128
num_heads = 4
trx_ff_dim = 128
num_trx_cells = 2
num_class = 10

# Define the model and loss function

# Define the model
model = VisionTransformerClassifier(patch_size=patch_size,
                                    num_heads=num_heads,
                                    trx_ff_dim=trx_ff_dim,
                                    num_trx_cells=num_trx_cells,
                                    num_class=num_class)

# Define the loss function
criterion = nn.CrossEntropyLoss()

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
model = model.to(device)

batch_size = 16
train_loader = make_dataloader(data['X_train'], data['y_train'], batch_size, True)
val_loader = make_dataloader(data['X_test'], data['y_test'], batch_size, False)

# Hyperparameters
epochs = 5 # epoch
lr = 0.001
gradient_norm_clips = 0.1

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-8)
total_accu = None

# You should be able to get an accuracy around 36%
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_loader, loss_func, device, gradient_norm_clip)
    accu_val = evaluate(model, val_loader, loss_func, device)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

## Machine Translation with Transformer

### <font size='4'>Implement Transformer Decoder Cell</font>

In [None]:
class TransformerDecoderCell(nn.Module):
    """
    A single cell (unit) of the Transformer decoder.
    """
    def __init__(self, input_dim: int, num_heads: int, ff_dim: int, dropout: float=0.1):
        """
        Inputs:
        - input_dim: Input dimension for each token in a sequence
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerDecoderCell, self).__init__()
        
        # Similar to the TransformerEncoderCell, define two MultiHeadAttention modules.
        # One for processing the tokens on the decoder side.
        # The other for getting the attention across the encoder and the decoder. 
        # Also define a feedforward network. 
        # Don't forget the Dropout and Layer Norm layers.                                        
    
        # Multi-head attention layer for the decoder
        self.self_attn = nn.MultiheadAttention(embed_dim=input_dim, 
                                                num_heads=num_heads, 
                                                dropout=dropout)
        self.self_attn_ln = nn.LayerNorm(input_dim)

        # Multi-head attention layer for the encoder-decoder attention
        self.enc_dec_attn = nn.MultiheadAttention(embed_dim=input_dim, 
                                                  num_heads=num_heads, 
                                                  dropout=dropout)
        self.enc_dec_attn_ln = nn.LayerNorm(input_dim)

        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, input_dim),
            nn.Dropout(dropout)
        )
        self.ffn_ln = nn.LayerNorm(input_dim)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask=None, tgt_mask=None):            
        """
        Inputs: 
        - x: Tensor of BxLdxC, word embeddings on the decoder side
        - encoder_output: Tensor of BxLexC, word embeddings on the encoder side
        - src_mask: Tensor, masks of the tokens on the encoder side
        - tgt_mask: Tensor, masks of the tokens on the decoder side
        
        Return:
        - y: Tensor of BxLdxC. Attended features for all tokens on the decoder side.
        """
        
        y = None
        # Compute the self-attended features for the tokens on the decoder side.
        # Then compute the corss-attended features for the tokens on the decoder side to the encoded features, which are finally feed into the
        # feedforward network                                                     
 
        
        # Self-attention
        self_att, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
        x = self.self_attn_ln(x + self_att)
        
        # Encoder-decoder attention
        enc_dec_att, _ = self.enc_dec_attn(x, encoder_output, encoder_output, attn_mask=src_mask)
        x = self.enc_dec_attn_ln(x + enc_dec_att)
        
        # Feedforward network
        ffn_output = self.ffn(x)
        x = self.ffn_ln(x + ffn_output)
        y = x
        
        return y

In [None]:
dec_feats = torch.randn((3, 10, 16))
dec_mask = torch.randn((3, 1, 10, 10)) > 0.5

enc_feats = torch.randn((3, 12, 16))
enc_mask = torch.randn((3, 1, 1, 12)) > 0.5

model = TransformerDecoderCell(16, 2, 32, 0.1)
z = model(dec_feats, enc_feats, enc_mask, dec_mask)
assert len(z.shape) == len(dec_feats.shape)
for dim_z, dim_x in zip(z.shape, dec_feats.shape):
    assert dim_z == dim_x
print(z.shape)

### <font size='4'>Implement Transformer Decoder</font>

In [None]:
class TransformerDecoder(nn.Module):
    """
    A TransformerDecoder is a stack of multiple TransformerDecoderCells and a Layer Norm.
    """
    def __init__(self, input_dim: int, num_heads: int, ff_dim: int, num_cells: int, dropout=0.1):
        """
        Inputs:
        - input_dim: Input dimension for each token in a sequence
        - num_heads: Number of attention heads in a multi-head attention module
        - ff_dim: The hidden dimension for a feedforward network
        - num_cells: How many TransformerDecoderCells in stack
        - dropout: Dropout ratio for the output of the multi-head attention and feedforward
          modules.
        """
        super(TransformerDecoder, self).__init__()
        
        # Construct a nn.ModuleList to store a stack of TranformerDecoderCells. Check the documentation here of how to use it.
        # https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html#torch.nn.ModuleList
        
        # At the same time, define a layer normalization layer to process the output of the entire encoder.                                           #
 
        self.decoder_cells = nn.ModuleList([TransformerDecoderCell(input_dim, num_heads, ff_dim, dropout) for i in range(num_cells)])
        self.norm = nn.LayerNorm(input_dim)
    
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask=None, tgt_mask=None):            
        """
        Inputs: 
        - x: Tensor of BxLdxC, word embeddings on the decoder side
        - encoder_output: Tensor of BxLexC, word embeddings on the encoder side
        - src_mask: Tensor, masks of the tokens on the encoder side
        - tgt_mask: Tensor, masks of the tokens on the decoder side
        
        Return:
        - y: Tensor of BxLdxC. Attended features for all tokens on the decoder side.
        """
        
        y = None
        # Feed x into the stack of TransformerDecoderCells and then normalize the output with layer norm.                                   #
        y = x
        for cell in self.cells:
          y = cell(y, encoder_output, src_mask, tgt_mask)
        y = self.layer_norm(y)
        
        return y

In [None]:
dec_feats = torch.randn((3, 10, 16))
dec_mask = torch.randn((3, 1, 10, 10)) > 0.5

enc_feats = torch.randn((3, 12, 16))
enc_mask = torch.randn((3, 1, 1, 12)) > 0.5

model = TransformerDecoder(16, 2, 32, 2, 0.1)
z = model(dec_feats, enc_feats, enc_mask, dec_mask)
assert len(z.shape) == len(dec_feats.shape)
for dim_z, dim_x in zip(z.shape, dec_feats.shape):
    assert dim_z == dim_x
print(z.shape)

### <font size='4'>Implement a Transformer-based Sequence-to-sequence model</font>

In [None]:
class Seq2SeqTransformer(nn.Module):
    """
    Transformer-based sequence-to-sequence model.
    """
    def __init__(self, 
            num_encoder_layers: int, num_decoder_layers: int, embed_dim: int,
            num_heads: int, src_vocab_size: int, tgt_vocab_size: int,
            trx_ff_dim: int = 512, dropout: float = 0.1, pad_token: int=0
        ):
        """
        Inputs:
        - num_encoder_layers: How many TransformerEncoderCell in stack
        - num_decoder_layers: How many TransformerDecoderCell in stack
        - embed_dim: Word embeddings dimension
        - num_heads: Number of attention heads
        - src_vocab_size: Number of tokens in the source language vocabulary
        - tgt_vocab_size: Number of tokens in the target language vocabulary
        - trx_ff_dim: Hidden dimension in the feedforward network
        - dropout: Dropout ratio
        """
        super(Seq2SeqTransformer, self).__init__()
        
        self.embed_dim = embed_dim
        
        # Word embeddings for both the source and target languages
        self.src_token_embed = nn.Embedding(src_vocab_size, embed_dim, padding_idx=pad_token)
        self.tgt_token_embed = nn.Embedding(tgt_vocab_size, embed_dim, padding_idx=pad_token)
        
        # Define the positional encoding, encoder, decoder, and the output layer. Think of how many classes are in the output layer.               #
        # Positional encoding for both the source and target languages
        self.src_positional_encoding = PositionalEncoding(embed_dim, dropout)
        self.tgt_positional_encoding = PositionalEncoding(embed_dim, dropout)
        
        # Transformer encoder
        self.encoder = TransformerEncoder(
            input_dim=embed_dim,
            num_heads=num_heads,
            ff_dim=trx_ff_dim,
            num_cells=num_encoder_layers,
            dropout=dropout
        )
        
        # Transformer decoder
        self.decoder = TransformerDecoder(
            input_dim=embed_dim,
            num_heads=num_heads,
            ff_dim=trx_ff_dim,
            num_cells=num_decoder_layers,
            dropout=dropout
        )
        
        # Output layer
        self.output_layer = nn.Linear(embed_dim, tgt_vocab_size)
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
        """
        Inputs:
        - src: Tensor of BxLe, word indexes in the source language
        - tgt: Tensor of BxLd, word indexes in the target language
        - src_mask: Tensor, masks of the tokens on the encoder side
        - tgt_mask: Tensor, masks of the tokens on the decoder side
        
        Return:
        - y: Tensor of BxLdxK. K is the number of classes in the output.
        """
        
        # Get word embeddings. Not they are scaled.
        src_embed = self.src_token_embed(src) * math.sqrt(self.embed_dim)
        tgt_embed = self.tgt_token_embed(tgt) * math.sqrt(self.embed_dim)
        
        logits = None
        # Add positional encodings to the word embeddings. 
        # Feed them then  to the encoder and decoder, respectively. Get the logits finally.
        
        # Get word embeddings. Not they are scaled.
        src_embed = self.src_token_embed(src) * math.sqrt(self.embed_dim)
        tgt_embed = self.tgt_token_embed(tgt) * math.sqrt(self.embed_dim)
        
        # Add positional encodings to the word embeddings
        src_embed = self.src_positional_encoding(src_embed)
        tgt_embed = self.tgt_positional_encoding(tgt_embed)
        
        # Pass the source and target embeddings through the encoder and decoder, respectively.
        encoder_output = self.encoder(src_embed, src_mask)
        decoder_output = self.decoder(tgt_embed, encoder_output, tgt_mask, src_mask)
        
        # Get the logits
        logits = self.output_layer(decoder_output)
        
        return logits
        

The number of output layers in the code is 1, and it is defined in the output_layer attribute of the Seq2SeqTransformer class using nn.Linear with tgt_vocab_size as the number of output units.

In [None]:
src_vocab_size = 10
src = torch.arange(src_vocab_size).view(1, -1)
src = torch.cat((src, src), dim=0)
src_mask = torch.randn((2, 1, 1, src_vocab_size)) > 0.5

tgt_vocab_size = 12
tgt = torch.arange(tgt_vocab_size).view(1, -1)
tgt = torch.cat((tgt, tgt), dim=0)
tgt_mask = torch.randn((2, 1, tgt_vocab_size, tgt_vocab_size)) > 0.5

model = Seq2SeqTransformer(2, 2, 16, 2, src_vocab_size, tgt_vocab_size, 32, 0.1, 0)
z = model(src, tgt, src_mask, tgt_mask)
print(z.shape)

### Create Attention Masks

In [None]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


def create_mask(src, tgt, pad_token=0):
    src_mask = (src != pad_token).unsqueeze(-2).unsqueeze(1)
    
    tgt_seq_len = tgt.shape[0]
    tgt_mask = (tgt != pad_token).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.shape[1]).type_as(tgt_mask.data)

    return src_mask, tgt_mask.unsqueeze(1)

In [None]:
# Let's visualize what the target mask looks like
import seaborn
seaborn.set_context(context="talk")
import matplotlib.pyplot as plt

plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0].numpy())

x = torch.arange(src_vocab_size).view(1, -1)
x = torch.cat((x, x), dim=0)
src_mask, tgt_mask = create_mask(x, x)
print(src_mask.shape, tgt_mask.shape)

### Prepare Data Loaders

In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from typing import Iterable, List


SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}


# # Create source and target language tokenizer. Make sure to install the dependencies.
# !pip install -U spacy
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
 
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator 
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object 
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found. 
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

In [None]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(
        token_transform[ln], #Tokenization
        vocab_transform[ln], #Numericalization
        tensor_transform # Add BOS/EOS and create tensor
    )


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch.transpose(0, 1), tgt_batch.transpose(0, 1)

In [None]:
BATCH_SIZE = 8

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

for idx, (src, tgt) in enumerate(train_dataloader):
    if idx > 2:
        break
    print('src: {}, tgt: {}'.format(src.shape, tgt.shape))

### <font size='4'>Define the Model and Loss Function</font>

In [None]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMBED_SIZE = 512
NUM_ATTN_HEADS = 8
FF_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

# Define the model and loss function.                               
# Note that this time we will generate tokens, where some of them in the training time are from paddings.
#  We don't want to penalize the model if the output at such positions are wrong.     
# You can use the `ignore_index` in a loss function to suppress loss computation if the ground-truth label is equal to the given value.
# Check here for more details https://pytorch.org/docs/stable/nn.html#loss-functions.

model = Seq2SeqTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    embed_dim=EMBED_SIZE,
    num_heads=NUM_ATTN_HEADS,
    src_vocab_size=SRC_VOCAB_SIZE,
    tgt_vocab_size=TGT_VOCAB_SIZE,
    trx_ff_dim=FF_DIM,
    dropout=0.1,
    pad_token=PAD_IDX
)

# Define the loss function
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
transformer = transformer.to(device)

optimizer = torch.optim.Adam(
    transformer.parameters(), 
    lr=0.0001, 
    betas=(0.9, 0.98), 
    eps=1e-9
)

### Model Training and Validation

In [None]:
#@title
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    
    for src, tgt in train_dataloader:
        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask = create_mask(src, tgt_input)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)

        logits = model(src, tgt_input, src_mask, tgt_mask)

        optimizer.zero_grad()

        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask)
        
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

from timeit import default_timer as timer
NUM_EPOCHS = 10

# You should be able to get train loss around 1.5 and val loss around 2.2
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))