In [None]:
!pip install torch pandas numpy matplotlib scikit-learn seaborn

## `TAGAT (modular transformer)`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

#config class
class Config:
    vocab_size = 30522
    d_model = 128
    num_heads = 4
    num_experts = 4
    top_k = 2
    num_layers = 1
    max_seq_len = 64
    batch_size = 32
    num_classes = 3  # 3 classes for MultiNLI
    lr = 1e-4
    lambda_entropy = 1 # 1/1 for specialized, 2/3 for balanced
    lambda_balance = 1
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = Config()


#positional encoding
class Positional_encoding(nn.Module):
  def __init__(self, d_model, max_len = 512):
    super().__init__()
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term) #even indices
    pe[:, 1::2] = torch.cos(position *div_term) #odd indices

    self.pe = pe.unsqueeze(0) #add extra dimension (easier to broadcast when adding pos eoncoding to input tensor x)

    def forward(self, x):
      return x + self.pe[:, :x.size(1)].to(x.device) #add input tensor x + pos enc

#expert submodules
class Expert(nn.module):
  super().__init__()

  #multi head self attention
  self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

  #feed forward network
  self.ff = nn.Sequential(
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(d_model * 4, d_model),
      nn.Dropout(0.2)
  )

  #layer normalization + residuals
  self.norm1 = nn.LayerNorm(d_model)
  self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x):
    attn_out, _ = self.attn(x, x, x) #apply MHSA --> x is passed three times for all QKV values --> outputs attention weights
    x = self.norm1(x + attn_out) # apply first normalization --> result is assigned back to input x
    ff_out = self.ff(x) # input x is run through FFN
    return self.norm2(x + ff_out) #apply second normalization after FFN --> this is the final output of each expert

#film modulation
class MetaFilm(nn.Module):
  def __init__(self, d_model):
    super().__init__()
    self.linear = nn.Linear(d_model, d_model * 2) #gamma and beta are each size d_model #apply linear transformations to input x to get outputs of gamma and beta parameters

    def forward(self, x):
      gamma_beta = self.linear(x)
      return gamma_beta.chunk(2, dim = -1) #split them in half equally --> giving two tensors of the same size

#gated modular layer
class GatedModularLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.d_model = config.d_model
    self.num_experts = config.num_experts
    self.top_k = config.top_k

    #create expert instances --> each is initialized with d_model and numheads in the config
    self.experts = nn.ModuleList([Expert(config.d_model, config.num_heads) for _ in range(config.num_experts)])
    self.meta_nets = nn.ModuleList([MetaFilm(config.d_model) for _ in range(config.num_experts)]) #create one FILM network for each expert
    self.gate = nn.Linear(config.d_model, config.num_experts) #takes input with d_model features and outputs a tensor with num_experts features --> determines which experts are most relevant for processing a given input token (provides score output for each expert)
    self.norm = nn.LayerNorm(config.d_model) #layer norm to normalize the combined output of the selected experts and residual connections


    def forward(self, x, return_gate_probs=False):
        B, T, D = x.shape
        gate_logits = self.gate(x)  # (B, T, N) --> pass through gating network and oututs gate_logits (shape of (B, T, num_experts))
        topk_vals, topk_idx = torch.topk(gate_logits, self.top_k, dim=-1) #select top-k experts along the last dimension in gate_logits, store the value of the top_k largest scores + shape (vals), idx stores the indices of the experts selected
        #this tensor stores the indices of the top_k experts with the largest scores --> its shape is also (B, T, top_k)

        topk_probs = F.softmax(topk_vals, dim=-1)  # (B, T, top_k) --> the raw logits pass through softmax for probabilities --> they are the weights of each expert to the final output for that token


        output = torch.zeros_like(x) #output tensor with same shape as the input x filled with zeroes. this will hold all the weights
        balance = torch.zeros(self.num_experts, device=x.device) #new tensor to track the load balance across all experts --> used to calculate a load balancing loss to help gate to distribute tokens more evenly among experts
        entropy = (-topk_probs * torch.log(topk_probs + 1e-8)).sum() #calculates the entropy of the top_k probs--> used as auxillary loss to encourage the gating probabilbites to be less peaked, leading to better exploration of experts during training (1e-8 added to prevent log of zero)


        for i in range(self.num_experts):
            mask = (topk_idx == i).any(-1)  # create boolean mask of shape (B, T), a value is true at (b, t) if expert i is among the top_k selected experts for each token at batch index b and sequence index i
            if not mask.any(): #if not, the look skips to the next expert
                continue


            x_mask = x.clone() #copy of input x is created, and for the tokens where expert i was not selected, the input features are set to zero (to make sure the expert i only processes relevant tokens)
            x_mask[~mask] = 0

            #masked input passed through current expert network + corresponding metafilm to get gamma and beta (this one uses original token for modulation on original token features)
            expert_out = self.experts[i](x_mask)
            gamma, beta = self.meta_nets[i](x)
            modulated = gamma * expert_out + beta #output of expert is modulated using the generated gamma and beta parameters from above


            for b in range(B): #combine expert outputs
                for t in range(T):
                    if i in topk_idx[b, t]:
                        idx = (topk_idx[b, t] == i).nonzero(as_tuple=False)[0].item() #get actual scores and their indices
                        prob = topk_probs[b, t, idx]
                        output[b, t] += prob * modulated[b, t] #gating probabilities for all selected experts are combined together (weighted by the probabilities)
                        balance[i] += prob.item() #prob of expert i for a token is added to balance


        output = self.norm(x + output) #input x is aded back and result is passed through norm
        balance = balance / (B * T) #total accumulated prob for each expert in balance is divded by total tokens to get average probaiblity assigned by each expert across batch and sequence (load balancing loss)


        if return_gate_probs:
            full_probs = torch.zeros(B, T, self.num_experts, device=x.device) #returns final output of layer, the normalized entropy, and the nromalized balance (avg)
            for k in range(self.top_k):
                full_probs.scatter_(-1, topk_idx[:, :, k].unsqueeze(-1),
                                    topk_probs[:, :, k].unsqueeze(-1), reduce='add') #if true, it gets and returns the tensor that ocntains softmax probaiblites for all experts each token (with zero for non-selected experts)
            return output, entropy / (B * T), balance, full_probs


        return output, entropy / (B * T), balance #returns expert output, entropy, balance


#full transformer
class GatedModularTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') #initializes bert base uncased tonkenizer --> converts raw text to numerical input IDs
        self.embed = nn.Embedding(config.vocab_size, config.d_model) #start the token embedding layer --> create lookp table wehre each unique token ID in the vocab is mapped to dense vector
        self.pos_enc = Positional_encoding(config.d_model, config.max_seq_len) #pos encoding
        self.layers = nn.ModuleList([GatedModularLayer(config) for _ in range(config.num_layers)]) #MoE layers
        self.fc_out = nn.Linear(config.d_model, config.num_classes) #map output features from the last gated modular layer to the number of output classes


    def forward(self, input_ids, return_gate_probs=False):
        x = self.embed(input_ids)
        x = self.pos_enc(x)
        total_entropy = 0 #create loss variables
        balances = []
        gate_probs = None


        for i, layer in enumerate(self.layers):
            if return_gate_probs and i == 0:
                x, entropy, balance, gate_probs = layer(x, return_gate_probs=True) #inspect the gating of the first layer
            else:
                x, entropy, balance = layer(x)
            total_entropy += entropy
            balances.append(balance)


        out = self.fc_out(x[:, 0]) #final output layer
        if return_gate_probs:
            return out, total_entropy, balances, gate_probs
        return out, total_entropy, balances

def get_dataloaders(config):
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    dataset = load_dataset("glue", "mnli")

    #subsets --> pick the ranges here
    train_dataset = dataset['train'].select(range(1000))
    test_dataset = dataset['validation'].select(range(300))

    def preprocess(example):
        tokens = tokenizer(example['sentence'], truncation=True, padding='max_length', max_length=config.max_seq_len)
        return {'input_ids': tokens['input_ids'], 'label': example['label']}

    #set everything into dataloader for bert uncased
    train_dataset = train_dataset.map(preprocess)
    test_dataset = test_dataset.map(preprocess)

    train_dataset.set_format(type='torch', columns=['input_ids', 'label'])
    test_dataset.set_format(type='torch', columns=['input_ids', 'label'])

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
    return train_loader, test_loader


def preprocess(example):
        tokens = tokenizer(example['text'], truncation=True, padding='max_length', max_length=config.max_seq_len)
        return {'input_ids': tokens['input_ids'], 'label': example['label']}

    train_dataset = train_dataset.map(preprocess)
    test_dataset = test_dataset.map(preprocess)

    train_dataset.set_format(type='torch', columns=['input_ids', 'label'])
    test_dataset.set_format(type='torch', columns=['input_ids', 'label'])

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
    return train_loader, test_loader

#train
def train_model(model, train_loader, config, epochs=8):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss, total_acc = 0, 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(config.device)
            labels = batch['label'].to(config.device)


            logits, entropy, balance = model(input_ids)
            loss_cls = ce_loss(logits, labels)
            avg_balance = torch.stack(balance).mean(dim=0)
            loss_balance = ((avg_balance - 1 / config.num_experts) ** 2).sum()
            loss = loss_cls + config.lambda_entropy * entropy + config.lambda_balance * loss_balance


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()


            total_loss += loss.item()
            total_acc += (logits.argmax(dim=1) == labels).float().mean().item()


        print(f"Epoch {epoch + 1}: Loss={total_loss:.3f}, Acc={total_acc / len(train_loader):.3f}")


#evaluation + visualization
def plot_heatmap(gate_probs, input_ids, tokenizer, max_tokens=20):
    gate_probs = gate_probs[0].cpu().detach().numpy()
    tokens = input_ids[0][:max_tokens].cpu().tolist()
    token_strs = tokenizer.convert_ids_to_tokens(tokens)
    data = gate_probs[:max_tokens].T


    plt.figure(figsize=(max_tokens * 0.4, 4))
    sns.heatmap(data, annot=True, fmt=".2f", cmap="viridis",
                xticklabels=token_strs, yticklabels=[f'Expert {i}' for i in range(data.shape[0])])
    plt.xlabel("tokens")
    plt.ylabel("experts")
    plt.title("Gate Probabilities Heatmap (Top-k Experts)")
    plt.show()


def eval_visual(model, test_loader, config):
    model.eval()
    batch = next(iter(test_loader))
    input_ids = batch['input_ids'].to(config.device)
    labels = batch['label'].to(config.device)


    with torch.no_grad():
        logits, entropy, balances, gate_probs = model(input_ids, return_gate_probs=True)
        acc = (logits.argmax(dim=1) == labels).float().mean().item()
        print(f"eval accuracy: {acc:.3f}")


    plot_heatmap(gate_probs, input_ids, model.tokenizer)


#main
if __name__ == "__main__":
    model = GatedModularTransformer(config).to(config.device)
    train_loader, test_loader = get_dataloaders(config)

    print("starting training")
    train_model(model, train_loader, config, epochs=3)

    print("creating visuals")
    eval_visual(model, test_loader, config)

    print("done")

# Vanilla Transformer


In [None]:
import math
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
import time


class Config:
    vocab_size = 30522
    d_model = 128
    num_heads = 4
    num_layers = 2
    max_seq_len = 128
    batch_size = 32
    num_classes = 3
    dropout = 0.4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'


config = Config()


class pos_enc(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()


        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))


        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)


        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class plain_trans(nn.Module):
    def __init__(self, config):
        super().__init__()
        #input embedding + pos enc + ecoder layers
        self.embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_enc = pos_enc(config.d_model, config.max_seq_len)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.num_heads,
            dim_feedforward=4 * config.d_model,
            dropout=config.dropout,
            batch_first=True,


        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.d_model, config.num_classes)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        x = self.pos_enc(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)  # avg pooling over sequence
        x = self.dropout(x)
        logits = self.classifier(x)
        return logits


def compute_metrics(logits, labels):
    preds = logits.argmax(dim=-1)
    correct = (preds == labels).sum().item()
    total = labels.size(0)
    return correct / total


def train_epoch(model, dataloader, optimizer, criterion, config):
    model.train()
    total_loss = 0
    total_acc = 0
    n = 0
    start_time = time.perf_counter()

    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(config.device)
        labels = batch['label'].to(config.device)

        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * input_ids.size(0)
        total_acc += compute_metrics(logits, labels) * input_ids.size(0)
        n += input_ids.size(0)

    epoch_time = time.perf_counter() - start_time
    return total_loss / n, total_acc / n, epoch_time


@torch.no_grad()
def eval_epoch(model, dataloader, criterion, config):
    model.eval()
    total_loss = 0
    total_acc = 0
    n = 0
    start_time = time.perf_counter()
    inference_times = []

    for batch in tqdm(dataloader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(config.device)
        labels = batch['label'].to(config.device)

        start_infer = time.perf_counter()
        logits = model(input_ids)
        torch.cuda.synchronize() if config.device == 'cuda' else None  # sync for accurate timing
        end_infer = time.perf_counter()
        inference_times.append(end_infer - start_infer)

        loss = criterion(logits, labels)

        total_loss += loss.item() * input_ids.size(0)
        total_acc += compute_metrics(logits, labels) * input_ids.size(0)
        n += input_ids.size(0)

    epoch_time = time.perf_counter() - start_time
    avg_inference_latency = sum(inference_times) / len(inference_times) if inference_times else 0
    return total_loss / n, total_acc / n, epoch_time, avg_inference_latency


def collate_fn(batch):
    input_ids = torch.stack([torch.tensor(x['input_ids']) for x in batch])
    labels = torch.tensor([x['label'] for x in batch])
    return {'input_ids': input_ids, 'label': labels}


def main():
    dataset = load_dataset("glue", "mnli")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def tokenize_fn(batch):
        # tokenize premise and hypothesis pairs
        return tokenizer(batch['premise'], batch['hypothesis'],
                         padding='max_length', max_length=config.max_seq_len,
                         truncation=True)


    dataset = dataset.map(tokenize_fn, batched=True)
    dataset.set_format(type='torch', columns=['input_ids', 'label'])

    train_loader = DataLoader(dataset['train'], batch_size=config.batch_size,
                              shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset['validation_matched'], batch_size=config.batch_size,
                            shuffle=False, collate_fn=collate_fn)


    model = plain_trans(config).to(config.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) #lr step
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    criterion = nn.CrossEntropyLoss()


    epochs = 20 #set epochs here
    for epoch in range(epochs):
        train_loss, train_acc, train_time = train_epoch(model, train_loader, optimizer, criterion, config)
        val_loss, val_acc, val_time, val_infer_latency = eval_epoch(model, val_loader, criterion, config)
        scheduler.step()

        # gpu memory usage --> allocated and max allocated
        if config.device == 'cuda':
            gpu_mem_allocated = torch.cuda.memory_allocated() / (1024 ** 2)  # MB
            gpu_mem_reserved = torch.cuda.memory_reserved() / (1024 ** 2)  # MB
            gpu_mem_max_allocated = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB
            torch.cuda.reset_peak_memory_stats()  # reset for next epoch
        else:
            gpu_mem_allocated = gpu_mem_reserved = gpu_mem_max_allocated = 0


        print(f"epoch {epoch+1} | "
              f"train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Time: {train_time:.2f}s | "
              f"val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Time: {val_time:.2f}s | "
              f"val infer latency: {val_infer_latency*1000:.2f} ms | "
              f"gpu Mem Allocated: {gpu_mem_allocated:.2f} MB | GPU Mem Reserved: {gpu_mem_reserved:.2f} MB | "
              f"gpu Max Mem Allocated: {gpu_mem_max_allocated:.2f} MB")


if __name__ == "__main__":
    main()
