# Coding Test 

You will be assessed overall on:

    1) How far you get in the alloted time.
    2) Code optimisations.
    3) Code reusability.
    4) Code readability.



# Transformers

## 1. Coding Test
Below you will find some code which is a classical example of the use of a transformer model.
Your task is to complete the code such that the training runs and the validation loss decreases.

Based on [Attention is all you need, Vaswani et al. 2017](https://arxiv.org/abs/1706.03762), write the code
1. for the multi-head self-attention mechanism
2. for the forward pass of the layer
3. for generating causal masks.

Finally, run the training and validation with the provided functions and dataset.

Note: `pip install portalocker`

In [39]:
import os
import time
from copy import deepcopy

import math
from typing import Tuple, Optional
import torch
from torch import Tensor, nn
from torch.utils.data import dataset
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [50]:
def attention(query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, dropout = None) -> Tensor:
    """ Computes the scaled dot-product attention over the provided query, key, and value tensors.

    :param query: Tensor of shape (batch_size, n_heads, seq_len, dim), queries for attention.
    :param key: Tensor of shape (batch_size, n_heads, seq_len, dim), keys for attention.
    :param value: Tensor of shape (batch_size, n_heads, seq_len, dim), values for attention.
    :param mask: Optional tensor for masking irrelevant parts of the input sequence.
    :return: Tensor of shape (batch_size, n_heads, seq_len, dim), the weighted sum of the value using the attention weights.     
    """
    # weights via scaled dot product attention
    d_k = query.size(-1)
    weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        
        weights = weights.masked_fill(mask == 0, float('-inf')) # shape: (batch_size, n_heads, seq_len, seq_len)
    normalized_weights = weights.softmax(dim=-1)

    # apply dropout
    if dropout is not None:
        normalized_weights = dropout(normalized_weights)

    return torch.matmul(normalized_weights, value)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, dropout: Optional[float] = None):
        """ Initializes the MultiHeadSelfAttention layer. 

        :param dim: Total dimensionality of the input and output of the self-attention layer.
        :param n_heads: Number of attention heads. `dim` should be divisible by `n_heads` to equally
                        split the dimensionality across all heads.
        :param dropout: Optional dropout rate for attention weights; defaults to None if not provided.
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        assert n_heads * self.head_dim == dim, f"embedding dim={dim} not divisible by n_heads={n_heads}."
        self.dropout = nn.Dropout(p=dropout)
        self.linear_query = nn.Linear(dim, dim)
        self.linear_key = nn.Linear(dim, dim)
        self.linear_value = nn.Linear(dim, dim)
        self.linear_cat_attn = nn.Linear(dim, dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """ The forward method for MultiHeadSelfAttention.
        
        :param query: Tensor of shape (batch_size, seq_len, dim), queries for attention.
        :param key: Tensor of shape (batch_size, seq_len, dim), keys for attention.
        :param value: Tensor of shape (batch_size, seq_len, dim), values for attention.
        :param mask: Optional tensor for masking irrelevant parts of the input sequence.
        :return: Tensor of shape (batch_size, seq_len, dim), the output of the attention layer.
        """
        # linear projection
        query = self.linear_query(query)
        key = self.linear_key(key)
        value = self.linear_value(value)

        # split the heads
        query = query.view(query.shape[0], -1, self.n_heads, self.head_dim).transpose(1, 2)
        key = key.view(key.shape[0], -1, self.n_heads, self.head_dim).transpose(1, 2)
        value = value.view(value.shape[0], -1, self.n_heads, self.head_dim).transpose(1, 2)
        
        # attention output: (batch_size, n_heads, seq_len, dim)
        attn_output =  attention(query, key, value, mask, self.dropout)

        # concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(query.size(0), -1, self.dim)

        return self.linear_cat_attn(attn_output)


In [51]:
class EncoderBlock(nn.Module):
    def __init__(self, attn: MultiHeadSelfAttention, d_model: int, dim_feedforward: int, layer_norm_eps=1e-5, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.attn = attn
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """ The encoder block, generally, passes the input through a multi-head self-attention layer followed by feedforward networks to transform the input.
    

        :param x: Tensor of shape (batch_size, seq_len, d_model), the input to the encoder block.
        :param mask: Optional tensor of shape (batch_size, seq_len, seq_len), a mask tensor for the
                     self-attention layer to ignore certain positions within the input sequence.
        :return: Tensor of shape (batch_size, seq_len, d_model), the output of the encoder block.
        """
        # a multi-head self-attention layer -> residual connection -> normalization
        x = self.norm1(x + self._self_attn(x, mask))

        # feedforward network -> residual connection -> normalization
        x = self.norm2(x + self._feed_forward(x))
        return x

    def _self_attn(self, x: Tensor, mask: Optional[Tensor]) -> Tensor:
        x = self.attn(x, x, x, mask)
        return self.dropout1(x)

    def _feed_forward(self, x: Tensor) -> Tensor:
        x = self.dropout(self.activation(self.linear1(x)))
        return self.dropout2(self.linear2(x))

In [52]:
class TransformerEncoder(nn.Module):
    def __init__(self, encoder: EncoderBlock, n_blocks: int):
        super(TransformerEncoder, self).__init__()
        self.encoder_blocks = nn.ModuleList([deepcopy(encoder) for _ in range(n_blocks)])

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        for encoder in self.encoder_blocks:
            x = encoder(x, mask)
        return x

In [53]:
class PositionalEncodingTorch(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
class TransformerModelManualAttn(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.pos_encoder = PositionalEncodingTorch(d_model, dropout)
        encoder_layers = EncoderBlock(attn=MultiHeadSelfAttention(dim=d_model, n_heads=nhead, dropout=dropout),
                                      d_model=d_model, dim_feedforward=d_hid, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Move the input data through the transformer model. The process involves
        embedding the input, applying positional encoding, passing the result through
        a transformer encoder, and finally projecting the output to the vocabulary space.

        :param src: Tensor of shape (batch_size, seq_len), containing the indices of input tokens.
        :param src_mask: Optional tensor used for masking in self-attention layers.
        :return: Tensor of shape (batch_size, seq_len, ntoken), containing the logits of predicted token probabilities.
        """
        src = self.embedding(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

## Data Processing

In [42]:
# Load data, load tokenizer, build vocabulary
train_iter, val_iter, test_iter = PennTreebank()
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

# get train, validation, and test data
train_data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in train_iter] # shape (num_examples, num_tokens)
val_data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in val_iter]
test_data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in test_iter] 

# concat into one long sequence for training, validation, and testing
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_data))) # shape (num_tokens_in_text1+num_tokens_in_text2+num_tokens_in_text3+...,)
val_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, val_data)))
test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_data)))

# print the lengths of the training, validation, and test data
print(f"Number of tokens in training data: {train_data.size(0)}")
print(f"Number of tokens in validation data: {val_data.size(0)}")
print(f"Number of tokens in test data: {test_data.size(0)}")

# Number of tokens in training data: 924412
# Number of tokens in validation data: 73339
# Number of tokens in test data: 82114

Number of tokens in training data: 924412
Number of tokens in validation data: 73339
Number of tokens in test data: 82114


In [43]:

# batchify the data
batch_size = 20
eval_batch_size = 10

# sequence lengths
train_seq_len = train_data.shape[0] // batch_size
val_seq_len = val_data.shape[0] // eval_batch_size
test_seq_len = test_data.shape[0] // eval_batch_size

# trim the data to fit the batch size
train_data = train_data[:train_seq_len * batch_size]
val_data = val_data[:val_seq_len * eval_batch_size]
test_data = test_data[:test_seq_len * eval_batch_size]

# reshape the data where each column is a sequence
train_data = train_data.view(batch_size, train_seq_len).t().contiguous()
val_data = val_data.view(eval_batch_size, val_seq_len).t().contiguous()
test_data = test_data.view(eval_batch_size, test_seq_len).t().contiguous()

print("Training sequence length after trimming and batching:", train_data.size(0))
print("Validation sequence length after trimming and batching:", val_data.size(0))
print("Test sequence length after trimming and batching:", test_data.size(0))
# Training sequence length after trimming and batching: 46220
# Validation sequence length after trimming and batching: 7333
# Test sequence length after trimming and batching: 8211


# move to the device
device = "mps" # torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_data.to(device)
val_data.to(device)
test_data.to(device)

# ======================  My local Python environment does not have the required packages to run this code. Due to time constraint, 
# ======================= I cannot re-set a new one to run this code. I am sorry for the inconvenience for reviewing my code. =======================


Training sequence length after trimming and batching: 46220
Validation sequence length after trimming and batching: 7333
Test sequence length after trimming and batching: 8211


tensor([[  99,  244,  291,  ...,    2,  933,    5],
        [  15,    4,   71,  ...,    2,  428,   61],
        [  25,   44,  124,  ...,    1, 3625,   48],
        ...,
        [ 188, 5267,    0,  ...,    0,    0,   95],
        [   5,    4,    3,  ...,    7, 1230,   65],
        [ 300,  434,    1,  ...,   34,   12,  380]], device='mps:0')

## Modeling

In [58]:

# Model Parameters
ntokens = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.1
use_causal_mask = True

# Create Model
model = TransformerModelManualAttn(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)



## Training --- One Iteration

#### Preparing Model Inputs

In [45]:
window_size = 35
start_seq_pos = 0
batch_win1 = train_data[start_seq_pos:start_seq_pos + window_size] 
print(batch_win1.shape) # shape: (window_size, batch_size)
# torch.Size([35, 20])

torch.Size([35, 20])


In [46]:
target_win1 = train_data[start_seq_pos + 1:start_seq_pos + window_size + 1] # shape (window_size, batch_size)
print(target_win1.shape)  # shape: (window_size, batch_size)
# torch.Size([35, 20])

torch.Size([35, 20])


In [47]:
# mask
mask = torch.tril(torch.ones((1, window_size, window_size)))

print(mask.shape) # shape: (1, window_size, window_size)

torch.Size([1, 35, 35])


In [56]:
batch_win1.shape

torch.Size([35, 20])

In [57]:
batch_win1 = batch_win1.to(device)
mask = mask.to(device) 

output = model(src=batch_win1.T, src_mask=mask)

Weights:  torch.Size([20, 2, 35, 35])
Key:  torch.Size([20, 2, 35, 100])
Weights:  torch.Size([20, 2, 35, 35])
Key:  torch.Size([20, 2, 35, 100])


In [61]:
output

tensor([[[ 6.4839e-01, -1.9069e-01, -6.7224e-01,  ...,  1.4212e+00,
           6.8141e-02, -3.2933e-01],
         [-8.3399e-01, -5.4409e-01, -1.0539e+00,  ...,  1.3813e+00,
          -2.3350e-01, -2.6184e-01],
         [ 1.2623e-01, -1.8705e-01, -3.3221e-01,  ...,  1.3431e+00,
          -7.6831e-01, -3.7534e-01],
         ...,
         [-4.2288e-01, -3.4111e-01, -8.1829e-01,  ...,  1.6811e+00,
          -2.0424e-01, -6.1760e-01],
         [-1.2379e-01, -4.9680e-01, -6.8506e-01,  ...,  1.7153e+00,
          -3.4596e-01, -9.7069e-01],
         [ 1.6114e-01, -4.3917e-01, -5.6082e-01,  ...,  2.0197e+00,
          -6.7841e-01, -4.9689e-01]],

        [[-1.1922e+00, -1.1030e-01, -1.6944e-01,  ...,  3.0151e-01,
           2.7070e-01, -6.6524e-01],
         [-1.1228e+00, -4.6843e-01, -4.3408e-02,  ...,  8.9467e-01,
          -9.1702e-01,  6.2178e-03],
         [-1.5348e+00, -3.3282e-01, -1.4233e-01,  ...,  6.1153e-01,
          -6.1223e-01, -4.2305e-01],
         ...,
         [-5.6631e-01,  4

#### Objective Function

In [63]:
output.view(-1, ntokens).shape

torch.Size([700, 9922])

In [66]:
criterion = nn.CrossEntropyLoss()
loss = criterion(output.view(-1, ntokens), target_win1.to(device).reshape(-1))


In [67]:
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr, )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

## Full Training Steps

#### Train & Eval Function

In [None]:
def get_batch(source: Tensor, i: int, window_size: int) -> Tuple[Tensor, Tensor]:
    seq_len = min(window_size, len(source) - 1 - i) # ensure that each token in the batch has a next token
    data = source[i:i + seq_len]
    target = source[i + 1:i + seq_len + 1].reshape(-1)
    return data, target

def train(
        model,
        train_data: Tensor,
        window_size: int,
        criterion,
        ntokens: int,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        epoch: int = 0,
        device: torch.device = None,
        use_causal_mask: bool = True
) -> None:
    model.train()
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = torch.tril(torch.ones((1, window_size, window_size))).to(device) if use_causal_mask else None
    num_batches = len(train_data) // window_size
    for batch, i in enumerate(range(0, train_data.shape[0] - 1, window_size)):  # [0, 35, 70, 105, ...]
        data, targets = get_batch(train_data, i, window_size=window_size)
        if data.shape[0] < window_size and src_mask is not None:
            src_mask = torch.tril(torch.ones((1, data.shape[0], data.shape[0]))).to(device)
        output = model(src=data, src_mask=src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f}'
                  f' | ppl {ppl:8.2f}'
                  )
            total_loss = 0
            start_time = time.time()


def evaluate(
        model,
        eval_data: Tensor,
        window_size: int,
        ntokens: int,
        criterion,
        device: torch.device = None,
        use_causal_mask: bool = True,
) -> float:
    model.eval()
    total_loss = 0.
    src_mask = torch.tril(torch.ones((1, window_size, window_size))).to(device) if use_causal_mask else None
    with torch.no_grad():
        for i in range(0, eval_data.shape[0] - 1, window_size):
            data, targets = get_batch(eval_data, i, window_size=window_size)
            if data.shape[0] < window_size and src_mask is not None:
                src_mask = torch.tril(torch.ones((1, data.shape[0], data.shape[0]))).to(device)
            output = model(data, src_mask=src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += data.shape[0] * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

bptt = 35

#### Run model

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr, )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

best_val_loss = float('inf')
epochs = 10

# Train
home = os.path.join(os.path.expanduser("~"), "transformer_test")
save_dir = os.path.join(os.path.join(home, 'pytorch_example')) # feel free to change path
os.makedirs(save_dir, exist_ok=True)
best_model_params_path = os.path.join(save_dir, "best_model_params.pt")

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_data, bptt, criterion, ntokens, optimizer, scheduler, epoch, device, use_causal_mask)
    val_loss = evaluate(model, val_data, bptt, ntokens, criterion, device, use_causal_mask)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_params_path)

    scheduler.step()
model.load_state_dict(torch.load(best_model_params_path))  # load best model states

# Test
test_loss = evaluate(model, test_data, bptt, ntokens, criterion, device)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

## 2. General knowledge and LLMs

**Describe the attention complexity in memory and computation costs. Do you know methods to try to reduce this cost?**

Computational Cost is in the complexity of $O(N^2 \dot D)$ if each query/key is of dimension $D$ and there are $N$ tokens
* The dot product operation for a single pair of query and key is $O(D)$.
* This dot product is computed for each pair of tokens in a sequence of length $N$.

The memory cost is in the complexity of $O(N^2 \dot D)$ because the attention matrix (N by N) needs to be stored during the forward pass to calculate gradients during backpropagation

Methods to reduce the cost:
* Sparse Attention 
* Low-Rank Approximations (LoRA)

**Why use a causal mask in attention? Are there any other interesting masks or patterns that can be used for training or for generation?**
* At the training timen, a causal mask for auto-regressive language modeling can prevent the model from "seeing the future tokens.
* Padding Mask is commonly used to ignore padding tokens in input sequences, ensuring that these don't affect the computation of attention scores. This is commonly used for many NLP tasks, e.g., auto-encoding language modeling and text classification via batching. If batched generation is required, it can be used for generative tasks.

