In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ast
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import math
import itertools

In [22]:
import torch
from torch.utils.data import TensorDataset, random_split
import itertools

class AddTokenizer:
    def __init__(self, min_int=-500, max_int=500):
        self.min_int = min_int
        self.max_int = max_int
        self.token_to_id = {}
        self.id_to_token = []
        self.add_special_tokens()
        self.build_vocab()
        
    def add_special_tokens(self):
        specials = ["+", "-", "*", "/", "="]
        for tok in specials:
            self.token_to_id[tok] = len(self.id_to_token)
            self.id_to_token.append(tok)

    def build_vocab(self):
        for i in range(self.min_int, self.max_int + 1):
            tok = str(i)
            self.token_to_id[tok] = len(self.id_to_token)
            self.id_to_token.append(tok)

    def encode_int(self, i):
        return [self.token_to_id[str(i)]]

    def encode(self, s):
        return [self.token_to_id[s]]

    def decode(self, ids):
        return "".join(self.id_to_token[i] for i in ids)


def build_addition_dataset(min_int=-500, max_int=500):
    tokenizer = AddTokenizer(min_int, max_int)

    X = []
    mask = []

    for a, b in itertools.product(range(min_int, max_int + 1),
                                  range(min_int, max_int + 1)):

        c = a + b
        if not (min_int <= c <= max_int):
            continue

        seq = (
            tokenizer.encode_int(a)
            + tokenizer.encode("+")
            + tokenizer.encode_int(b)
            + tokenizer.encode("=")
            + tokenizer.encode_int(c)
        )

        # convert to tensor
        X.append(seq)
        mask.append([1] * len(seq))

    # convert lists to tensors of shape [N, L]
    X = torch.tensor(X, dtype=torch.long)
    mask = torch.tensor(mask, dtype=torch.long)

    return tokenizer, X, mask


def train_val_split(X, mask, val_frac=0.1):
    N = len(X)
    val_size = int(N * val_frac)
    train_size = N - val_size
    train_data, val_data = random_split(
        TensorDataset(X, mask),
        [train_size, val_size]
    )
    return train_data, val_data


In [None]:
import torch
from torch.utils.data import TensorDataset, random_split

def make_addition_dataset(max_int=500):
    # all ordered pairs (a,b)
    a_vals = torch.arange(0, max_int+1)
    b_vals = torch.arange(0, max_int+1)
    A, B = torch.meshgrid(a_vals, b_vals, indexing='ij')

    A = A.reshape(-1)      # (N,)
    B = B.reshape(-1)      # (N,)
    C = A + B              # (N,)

    # filter valid sums
    valid = (C <= max_int)
    A = A[valid]
    B = B[valid]
    C = C[valid]

    X = torch.stack([A, B, C], dim=1)  # (N, 3)

    return X


def train_val_split(X, val_frac=0.1):
    N = len(X)
    val_size = int(N * val_frac)
    train_size = N - val_size
    ds = TensorDataset(X)
    return random_split(ds, [train_size, val_size])


In [50]:
X = make_addition_dataset(max_int=500)
train_data, val_data = train_val_split(X)

print(X.shape)      # → [~251k examples, 3]
print(train_data[0]) # → tensor([a, b, a+b])


torch.Size([125751, 3])
(tensor([ 23, 119, 142]),)


In [51]:
class LowRankExpert(nn.Module):
    def __init__(self, in_dim, out_dim, rank):
        super().__init__()
        self.A = nn.Linear(rank, out_dim, bias=False)
        self.B = nn.Linear(in_dim, rank, bias=False)

    def forward(self, x):
        return self.A(self.B(x))

In [52]:
class MoEWithLoadBalancing(nn.Module):
    def __init__(self, in_dim, out_dim, num_experts, dropout=0.0, top_k=1):
        super().__init__()
        if top_k != 1:
            raise NotImplementedError("Top-k is not implemented for MoE")
        self.top_k = top_k
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_experts = num_experts

        # self.experts = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_experts)])
        self.experts = nn.ModuleList([LowRankExpert(in_dim, out_dim, 4) for _ in range(num_experts)])

    def forward(self, x, expert_probs, return_load_balance_loss=False):
        B, S, D = x.shape

        if expert_probs.shape != (B, self.num_experts):
            raise ValueError(f"Expert probabilities must be of shape (B, num_experts), got {expert_probs.shape}")
        
        expert_idx = torch.argmax(expert_probs, dim=-1)

        x_out = torch.zeros(B, S, self.out_dim, device=x.device, dtype=x.dtype)

        for i, expert in enumerate(self.experts):
            mask = expert_idx == i
            if mask.any():
                x_out[mask, :, :] = expert(x[mask, :, :])

        if return_load_balance_loss:
            lb_loss = compute_load_balancing_loss(expert_probs, expert_idx, self.num_experts)
            return x_out, expert_idx, lb_loss
        
        return x_out, expert_idx


In [53]:
class TransformerBlockWithLoadBalancing(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0, is_causal=True, 
                 use_moe=False, num_experts=4, router_idx=None, verbose_router=False):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.mlp_ratio = mlp_ratio
        self.num_experts = num_experts
        self.dropout = dropout
        self.router_idx = router_idx
        self.is_causal = is_causal
        self.verbose_router = verbose_router

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.use_moe = use_moe
        if use_moe:
            if router_idx is None:
                raise ValueError("router_idx must be provided when using MoE")
            self.router = nn.Linear(dim, num_experts)
            self.qkv = MoEWithLoadBalancing(dim, 3 * dim, num_experts, dropout=dropout)
            self.o = MoEWithLoadBalancing(dim, dim, num_experts, dropout=dropout)
            self.mlp_in = MoEWithLoadBalancing(dim, int(dim * mlp_ratio), num_experts, dropout=dropout)
            self.mlp_out = MoEWithLoadBalancing(int(dim * mlp_ratio), dim, num_experts, dropout=dropout)
        else:
            self.qkv = nn.Linear(dim, 3 * dim)
            self.o = nn.Linear(dim, dim)
            self.mlp_in = nn.Linear(dim, int(dim * mlp_ratio))
            self.mlp_out = nn.Linear(int(dim * mlp_ratio), dim)

    def forward(self, x, return_load_balance_loss=False):
        B, S, D = x.shape
        total_lb_loss = 0.0
        
        # Attention block
        if self.use_moe:
            router_out = self.router(x[:, self.router_idx])
            expert_probs = F.softmax(router_out, dim=-1)
            
            if self.verbose_router:
                top_experts = torch.argmax(expert_probs, dim=-1)
                counts = torch.bincount(top_experts.flatten(), minlength=self.num_experts)
                usage = counts.float() / counts.sum() * 100
                print(f"Expert usage (%): {[f'{u:.1f}' for u in usage.tolist()]}")
            
            # QKV projection
            if return_load_balance_loss:
                qkv, _, lb_loss_qkv = self.qkv(x, expert_probs, return_load_balance_loss=True)
                total_lb_loss += lb_loss_qkv
            else:
                qkv, _ = self.qkv(x, expert_probs)
        else:
            qkv = self.qkv(x)

        # Split into Q, K, V and compute attention
        q, k, v = qkv.split(D, dim=2)
        q = q.view(B, S, self.n_heads, D // self.n_heads).permute(0, 2, 1, 3)
        k = k.view(B, S, self.n_heads, D // self.n_heads).permute(0, 2, 1, 3)
        v = v.view(B, S, self.n_heads, D // self.n_heads).permute(0, 2, 1, 3)
        
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)
        attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(B, S, D)
        
        # Output projection
        if self.use_moe:
            if return_load_balance_loss:
                attn_out, _, lb_loss_o = self.o(attn_out, expert_probs, return_load_balance_loss=True)
                total_lb_loss += lb_loss_o
            else:
                attn_out, _ = self.o(attn_out, expert_probs)
        else:
            attn_out = self.o(attn_out)
        
        x = self.norm1(x + attn_out)
        
        # MLP block
        if self.use_moe:
            if return_load_balance_loss:
                mlp_out, _, lb_loss_mlp_in = self.mlp_in(x, expert_probs, return_load_balance_loss=True)
                total_lb_loss += lb_loss_mlp_in
            else:
                mlp_out, _ = self.mlp_in(x, expert_probs)
            
            mlp_out = F.gelu(mlp_out)
            
            if return_load_balance_loss:
                mlp_out, _, lb_loss_mlp_out = self.mlp_out(mlp_out, expert_probs, return_load_balance_loss=True)
                total_lb_loss += lb_loss_mlp_out
            else:
                mlp_out, _ = self.mlp_out(mlp_out, expert_probs)
        else:
            mlp_out = F.gelu(self.mlp_in(x))
            mlp_out = self.mlp_out(mlp_out)
        
        x = self.norm2(x + mlp_out)
        
        if return_load_balance_loss:
            return x, total_lb_loss
        return x


In [55]:
# Model hyperparameters
D_MODEL = 32
N_HEADS = 2
DEPTH = 4
LOAD_BALANCE_COEFF = 0.01  # Alpha parameter for load balancing

In [56]:
class ToolModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, depth):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.transformers = nn.ModuleList([
            TransformerBlockWithLoadBalancing(d_model, n_heads, mlp_ratio=4, is_causal=False, use_moe=False)
            for _ in range(depth)
        ])


        self.project = nn.Linear(d_model, vocab_size)


    def forward(self, x):
        x = self.embed(x)  # Embed the input first: [B, 3] -> [B, 3, D]
        for tf in self.transformers:
            x = tf(x)

        x = self.project(x)
        return x

        # # embedding → scalar value
        # self.to_num = nn.Linear(d_model, 1)

        # # numeric result → embedding update
        # self.from_num = nn.Linear(1, d_model)

        # # gate
        # self.gate = nn.Linear(2 * d_model, 1)

    # def forward(self, x):
    #     # x shape: [batch, 3]  (n1, n2, target_sum)

    #     emb = self.embed(x)           # [B, 3, D]
    #     h1, h2 = emb[:, 0], emb[:, 1]

    #     num1 = self.to_num(h1)        # [B, 1]
    #     num2 = self.to_num(h2)        # [B, 1]

    #     num_sum = num1 + num2         # pure operation

    #     # update latent
    #     gate = torch.sigmoid(self.gate(torch.cat([h1, h2], dim=-1)))  # [B, 1]
    #     update = self.from_num(num_sum)

    #     # apply update to h1 only (simple test)
    #     h1_new = h1 + gate * update

    #     return num_sum, num1, num2


In [57]:
# Create dataset and dataloader
class AdditionDataset(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]

dataset = AdditionDataset(X)
loader = DataLoader(dataset, batch_size=256, shuffle=True)


In [58]:
_X = X.unsqueeze(0)

In [61]:
model = ToolModel(vocab_size=501, d_model=D_MODEL, n_heads=N_HEADS, depth=DEPTH)
model.to(torch.device("cuda"))
optimizer = AdamW(model.parameters(), lr=1e-3)
model.train()

for epoch in range(2000):
    for step, batch in enumerate(loader, start=1):
        batch = batch.to(next(model.parameters()).device)  # ensure same device

        y_hat = model(batch)
        y_hat = y_hat.transpose(2,1)

        loss = F.cross_entropy(y_hat, batch)



        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        print(loss.item())

    if epoch % 1 == 0:
        print(f"[step {epoch}] "
            f"loss={loss.item():.4f}  ")



6.367236614227295
6.358667373657227
6.307708263397217
6.281274795532227
6.192593097686768
6.186535358428955
6.059232234954834
6.0523762702941895
5.9908270835876465
5.987705230712891
5.975044250488281
5.910215854644775
5.892751693725586
5.812833786010742
5.80508279800415
5.766063690185547
5.6840386390686035
5.698946475982666
5.684512615203857
5.637073993682861
5.6229023933410645
5.541421890258789
5.505592346191406
5.516163349151611
5.46511697769165
5.442641735076904
5.413387298583984
5.356739521026611
5.357078552246094
5.336609363555908
5.25968599319458
5.232937335968018
5.163944721221924
5.152922630310059
5.168474197387695
5.141955852508545
5.095445156097412
5.066081523895264
5.0337958335876465
5.0008440017700195
4.962721347808838
4.9335103034973145
4.895631313323975
4.915548801422119
4.872579097747803
4.81920051574707
4.798313617706299
4.78100061416626
4.7504777908325195
4.695327281951904
4.672508239746094
4.654851913452148
4.592723369598389
4.575587749481201
4.575484752655029
4.55752

KeyboardInterrupt: 

In [41]:
batch.shape

torch.Size([256, 3])

In [42]:
y_hat.shape

torch.Size([256, 3, 101])