### Multi Head Attention

In [4]:
import torch

In [2]:
def my_scaled_dot_product_attention(query, key=None, value=None):
    key = key if key is not None else query
    value = value if value is not None else query
    # query and key must have same embedding dimension
    assert query.size(-1) == key.size(-1)

    dk = key.size(-1) # embed dimension of key
    # query, key, value = (bs, seq_len, embed_dim)
    
    # compute dot-product to obtain pairwise "similarity" and scale it
    qk = query @ key.transpose(-1, -2) / dk**0.5
    
    # apply softmax
    # attn_weights = (bs, seq_len, seq_len)
    attn_weights = torch.softmax(qk, dim=-1)

    # compute weighted sum of value vectors
    # attn = (bs, seq_len, embed_dim)
    attn = attn_weights @ value
    return attn, attn_weights

In [5]:
X = torch.normal(mean=0, std=1, size=(2, 3, 6))
torch_attended = torch.nn.functional.scaled_dot_product_attention(X, X, X)
attended, attn_weights = my_scaled_dot_product_attention(X, X, X)
assert torch.allclose(torch_attended, attended) == True

#### Batch Multiplication

In [10]:
batch_size = 3
A = torch.randn(batch_size, 10, 256)

output = []
for batch_idx in range(batch_size):
    pairwise_dot_product = A[batch_idx] @ A[batch_idx].transpose(-1, -2)
    output.append(pairwise_dot_product)

# Output has shape (batch_size, 10, 10)
output[0].size()

torch.Size([10, 10])

### Naive Implementation:

In [11]:
class AttentionBlock(torch.nn.Module):
    def __init__(self, input_dim: int, output_dim: int, bias=False):
        super().__init__()
        # Linear layers to project Query, Key and Value 
        self.W_q = torch.nn.Linear(input_dim, output_dim, bias=bias)
        self.W_k = torch.nn.Linear(input_dim, output_dim, bias=bias)
        self.W_v = torch.nn.Linear(input_dim, output_dim, bias=bias)

    def forward(self, query, key, value):
        # project Q, K, V
        q_logits = self.W_q(query)
        k_logits = self.W_k(key)
        v_logits = self.W_v(value)

        # apply scaled dot product attention on projected values
        attn, weights = my_scaled_dot_product_attention(q_logits, k_logits, v_logits)
        return attn, weights

class MyMultiheadAttention(torch.nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, projection_bias=False):
        super().__init__()
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        head_embed_dim = self.embed_dim // n_heads
        # for each head, create an attention block
        self.head_blocks = torch.nn.ModuleList([AttentionBlock(input_dim=embed_dim, output_dim=head_embed_dim, bias=projection_bias) for i in range(self.n_heads)])
        # final projection of MHA
        self.projection = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)


    def forward(self, query, key, value):
        # these lists are to store output of each head
        attns_list = []
        attn_weights_list = []

        # for every head pass the original query, key, value
        for head in self.head_blocks:
            attn, attn_weights = head(query, key, value)
            attns_list.append(attn)
            attn_weights_list.append(attn_weights)

        # concatenate attention outputs and take average of attention weights
        attns, attn_weights = torch.cat(attns_list, dim=2), torch.stack(attn_weights_list).mean(dim=0)
        # shape: (bs, seq_len, embed_dim), attn_weights: (bs, seq_len, seq_len)
        return self.projection(attns), attn_weights

### Text Classification

In [13]:
import datasets
from transformers import AutoTokenizer

original_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")


news_ds = datasets.load_dataset("SetFit/bbc-news", split="train")
# train a new tokenizer with limited vocab size for demo
tokenizer = original_tokenizer.train_new_from_iterator(news_ds['text'], vocab_size=1000)

README.md:   0%|          | 0.00/880 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


train.jsonl:   0%|          | 0.00/2.87M [00:00<?, ?B/s]

test.jsonl:   0%|          | 0.00/2.28M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1225 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [14]:
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True)

ds = news_ds.map(tokenize, batched=True).select_columns(['label', 'input_ids', 'text']).train_test_split()

class_id_to_class = {
    0: "tech",
    1: "business",
    2: "sports",
    3: "entertainment",
    4: "politics",
}
num_classes = len(class_id_to_class)

Map:   0%|          | 0/1225 [00:00<?, ? examples/s]

In [15]:
class TextClassifier(torch.nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, num_classes: int, mha: torch.nn.Module):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=0)
        self.mha = mha
        self.fc1 = torch.nn.Linear(in_features=embed_dim, out_features=128)
        self.relu = torch.nn.ReLU()
        self.final = torch.nn.Linear(in_features=128, out_features=num_classes)

    def forward(self, input_ids: torch.Tensor, **kwargs):
        # inputs: (bs, seq_len)
        # embeddings: (bs, seq_len, embed_dim)
        embeddings = self.get_embeddings(input_ids)
        attn, attn_weights = self.get_attention(embeddings, embeddings, embeddings)
        
        # take the first token's embeddings i.e. embeddings of CLS token
        # cls_token_embeddings: (bs, embed_dim)
        cls_token_embeddings = attn[:, 0, :] 
        return self.final(self.relu(self.fc1(cls_token_embeddings)))
    
    def get_embeddings(self, input_ids):
        return self.embedding(input_ids)
    
    def get_attention(self, query, key, value):
        attn, attn_weights = self.mha(query, key, value)
        return attn, attn_weights

n_heads = 8
embed_dim = 64
vocab_size = tokenizer.vocab_size
torch_mha = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=n_heads, batch_first=True)
my_mha = MyMultiheadAttention(embed_dim=embed_dim, n_heads=n_heads, projection_bias=True)
torch_classifier = TextClassifier(vocab_size=tokenizer.vocab_size, embed_dim=embed_dim, num_classes=num_classes, mha=torch_mha)
my_classifier = TextClassifier(vocab_size=tokenizer.vocab_size, embed_dim=embed_dim, num_classes=num_classes, mha=my_mha)

In [16]:
from torch.utils.data import DataLoader
import time

def collate_fn(batch):
    labels = []
    input_ids = []
    for row in batch:
        labels.append(row['label'])
        input_ids.append(torch.LongTensor(row['input_ids']))

    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    labels = torch.LongTensor(labels)
    input_ids = torch.Tensor(input_ids)
    return {"labels": labels, "input_ids": input_ids}

train_dl = test_dl = DataLoader(ds['train'], shuffle=True, batch_size=32, collate_fn=collate_fn)
test_dl = DataLoader(ds['test'], shuffle=False, batch_size=32, collate_fn=collate_fn)

def train(model: torch.nn.Module, train_dl, val_dl, epochs=10) -> list[tuple[float, float]]:
    optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    losses = []
    train_start = time.time()
    for epoch in range(epochs):
        epoch_start = time.time()
        train_loss = 0.0
        model.train()
        for batch in train_dl:
            optim.zero_grad()
            logits = model(**batch)
            loss = loss_fn(logits, batch['labels'])
            loss.backward()
            optim.step()
            train_loss += loss.item() * batch['labels'].size(0)

        train_loss /= len(train_dl.dataset)

        model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        with torch.no_grad():
            for batch in val_dl:
                logits = model(**batch)
                loss = loss_fn(logits, batch['labels'])
                val_loss += loss.item() * batch['labels'].size(0)
                val_accuracy += (logits.argmax(dim=1) == batch['labels']).sum()

        val_loss /= len(val_dl.dataset)
        val_accuracy /= len(val_dl.dataset)
        log_steps = max(1, int(0.2 * epochs))

        losses.append((train_loss, val_loss))
        if epoch % log_steps == 0 or epoch == epochs - 1:
            epoch_duartion = time.time() - epoch_start
            print(f'Epoch {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}. Epoch Duration: {epoch_duartion:.1f} seconds')

    train_duration = time.time() - train_start
    print(f"Training finished. Took {train_duration:.1f} seconds")

    return losses

In [17]:
def get_model_param_count(model):
    return sum(t.numel() for t in model.parameters())

print(f"My classifier params: {get_model_param_count(my_classifier):,}")
print(f"Torch classifier params: {get_model_param_count(torch_classifier):,}")


My classifier params: 89,605
Torch classifier params: 89,605


In [18]:

torch_losses = train(torch_classifier, train_dl, test_dl, epochs=10)
my_losses = train(my_classifier, train_dl, test_dl, epochs=10)

Epoch 1/10, Training Loss: 1.5996, Validation Loss: 1.5754, Validation Accuracy: 0.3290. Epoch Duration: 8.5 seconds
Epoch 3/10, Training Loss: 0.9998, Validation Loss: 0.7223, Validation Accuracy: 0.7622. Epoch Duration: 8.2 seconds
Epoch 5/10, Training Loss: 0.4822, Validation Loss: 0.4587, Validation Accuracy: 0.8469. Epoch Duration: 8.2 seconds
Epoch 7/10, Training Loss: 0.2149, Validation Loss: 0.3461, Validation Accuracy: 0.8795. Epoch Duration: 8.2 seconds
Epoch 9/10, Training Loss: 0.0957, Validation Loss: 0.3107, Validation Accuracy: 0.9088. Epoch Duration: 8.3 seconds
Epoch 10/10, Training Loss: 0.0662, Validation Loss: 0.2853, Validation Accuracy: 0.9153. Epoch Duration: 8.3 seconds
Training finished. Took 82.4 seconds
Epoch 1/10, Training Loss: 1.6022, Validation Loss: 1.5797, Validation Accuracy: 0.2410. Epoch Duration: 11.5 seconds
Epoch 3/10, Training Loss: 0.9476, Validation Loss: 0.7273, Validation Accuracy: 0.6971. Epoch Duration: 11.6 seconds
Epoch 5/10, Training Los

In [22]:
import toolz
import pandas as pd

def predict(texts, model, bs=32):
    output_dfs = []
    for batch in toolz.partition_all(bs, texts):
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            class_probs = torch.softmax(model(**inputs), dim=1).numpy()
            pred_classes = class_probs.argmax(axis=1)
            col_names = [f"class_{i}_prob" for i in range(class_probs.shape[-1])]
            df = pd.DataFrame(class_probs, columns=col_names)
            df['pred_class'] = pred_classes
            df['pred_class_name'] = df['pred_class'].map(class_id_to_class)
            output_dfs.append(df)

    return pd.concat(output_dfs)

my_preds_df = predict(ds['test']['text'], my_classifier)
my_preds_df['model'] = 'My Model'
my_preds_df['actual_class'] = ds['test']['label']
torch_preds_df = predict(ds['test']['text'], torch_classifier)
torch_preds_df['model'] = 'Torch Model'
torch_preds_df['actual_class'] = ds['test']['label']

from sklearn.metrics import classification_report

print("My Classifier")
print(classification_report(my_preds_df['actual_class'], my_preds_df['pred_class']))

print("Torch Classifier")
print(classification_report(torch_preds_df['actual_class'], torch_preds_df['pred_class']))

My Classifier
              precision    recall  f1-score   support

           0       0.87      0.73      0.79        55
           1       0.88      0.87      0.88        70
           2       0.93      0.99      0.96        71
           3       0.87      0.94      0.90        49
           4       0.84      0.87      0.86        62

    accuracy                           0.88       307
   macro avg       0.88      0.88      0.88       307
weighted avg       0.88      0.88      0.88       307

Torch Classifier
              precision    recall  f1-score   support

           0       0.87      0.95      0.90        55
           1       0.93      0.94      0.94        70
           2       0.95      0.97      0.96        71
           3       0.91      0.84      0.87        49
           4       0.91      0.85      0.88        62

    accuracy                           0.92       307
   macro avg       0.91      0.91      0.91       307
weighted avg       0.92      0.92      0.91   

In [24]:
class MyEfficientMultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, projection_bias=False):
        super().__init__()
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_embed_dim = self.embed_dim // n_heads
        self.W_q = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.W_k = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.W_v = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.projection = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)

    def forward(self, query, key, value):
        # shape of query = (bs, seq_len, embed_dim)
        batch_size = query.size(0)

        # linear projection of query, key and value
        q = self.W_q(query)
        k = self.W_k(key)
        v = self.W_v(value)

        # reshape the projected query, key, value
        # to (bs, n_heads, seq_len, head_embed_dim)
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)

        # do scaled dot product attention
        # attn.shape = (bs, n_heads, seq_len, head_embed_dim)
        # attn_weights.shape (bs, n_heads, seq_len, seq_len)
        attn, attn_weights = my_scaled_dot_product_attention(q, k, v)
        # swap the n_heads and seq_len so that we have
        # (bs, seq_len, n_heads, head_embed_dim)
        # call .contiguous() so that view function will work later
        attn = attn.transpose(1, 2).contiguous()
        # "combine" (n_heads, head_embed_dim) matrix as a single "embed_dim" vector
        attn = attn.view(batch_size, -1, self.embed_dim)

        output = self.projection(attn)
        return output, attn_weights.mean(dim=1)

    def split_heads(self, x):
        # x.shape = (bs, seq_len, embed_dim)
        batch_size = x.size(0)
        # first split the embed_dim into (n_heads, head_embed_dim)
        temp =  x.view(batch_size, -1, self.n_heads, self.head_embed_dim)
        # now we swap seq_len and n_heads dimension
         # output shape = (bs, n_heads, seq_len, head_embed_dim)
        return temp.transpose(1, 2)
class MyEfficientMultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, projection_bias=False):
        super().__init__()
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_embed_dim = self.embed_dim // n_heads
        self.W_q = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.W_k = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.W_v = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)
        self.projection = torch.nn.Linear(embed_dim, embed_dim, bias=projection_bias)

    def forward(self, query, key, value):
        # shape of query = (bs, seq_len, embed_dim)
        batch_size = query.size(0)

        # linear projection of query, key and value
        q = self.W_q(query)
        k = self.W_k(key)
        v = self.W_v(value)

        # reshape the projected query, key, value
        # to (bs, n_heads, seq_len, head_embed_dim)
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)

        # do scaled dot product attention
        # attn.shape = (bs, n_heads, seq_len, head_embed_dim)
        # attn_weights.shape (bs, n_heads, seq_len, seq_len)
        attn, attn_weights = my_scaled_dot_product_attention(q, k, v)
        # swap the n_heads and seq_len so that we have
        # (bs, seq_len, n_heads, head_embed_dim)
        # call .contiguous() so that view function will work later
        attn = attn.transpose(1, 2).contiguous()
        # "combine" (n_heads, head_embed_dim) matrix as a single "embed_dim" vector
        attn = attn.view(batch_size, -1, self.embed_dim)

        output = self.projection(attn)
        return output, attn_weights.mean(dim=1)

    def split_heads(self, x):
        # x.shape = (bs, seq_len, embed_dim)
        batch_size = x.size(0)
        # first split the embed_dim into (n_heads, head_embed_dim)
        temp =  x.view(batch_size, -1, self.n_heads, self.head_embed_dim)
        # now we swap seq_len and n_heads dimension
         # output shape = (bs, n_heads, seq_len, head_embed_dim)
        return temp.transpose(1, 2)


In [26]:
my_efficient_mha = MyEfficientMultiHeadAttention(embed_dim=embed_dim, n_heads=n_heads, projection_bias=True)
my_efficient_classifier = TextClassifier(vocab_size=tokenizer.vocab_size, embed_dim=embed_dim, num_classes=num_classes, mha=my_efficient_mha)
my_efficient_losses = train(my_efficient_classifier, train_dl, test_dl, epochs=10)

Epoch 1/10, Training Loss: 1.6054, Validation Loss: 1.5861, Validation Accuracy: 0.4072. Epoch Duration: 9.8 seconds
Epoch 3/10, Training Loss: 1.1183, Validation Loss: 0.7892, Validation Accuracy: 0.7068. Epoch Duration: 9.6 seconds
Epoch 5/10, Training Loss: 0.4770, Validation Loss: 0.4058, Validation Accuracy: 0.8371. Epoch Duration: 9.6 seconds
Epoch 7/10, Training Loss: 0.2780, Validation Loss: 0.4079, Validation Accuracy: 0.8339. Epoch Duration: 9.6 seconds
Epoch 9/10, Training Loss: 0.2086, Validation Loss: 0.3294, Validation Accuracy: 0.8730. Epoch Duration: 9.6 seconds
Epoch 10/10, Training Loss: 0.1688, Validation Loss: 0.3112, Validation Accuracy: 0.8958. Epoch Duration: 9.6 seconds
Training finished. Took 96.5 seconds
