In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
import pandas as pd
import ast
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader


In [None]:
class RoPE(nn.Module):
    def __init__(self, d_model, n_heads, max_seq_len=4096, base=10000, device=None):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.max_seq_len = max_seq_len
        self.device = device or torch.device("cpu")

        inv_freq = 1.0 / (base ** (torch.arange(0, self.d_head, 2).float() / self.d_head)).to(self.device)
        self.register_buffer("inv_freq", inv_freq)

        self.cos, self.sin = self._build_freqs()

    def _build_freqs(self):
        t = torch.arange(self.max_seq_len, device=self.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()[:, None, None, :]
        sin = emb.sin()[:, None, None, :]
        cos = cos.reshape(1, self.max_seq_len, 1, self.d_head)
        sin = sin.reshape(1, self.max_seq_len, 1, self.d_head)
        return cos, sin

    def _rotate_half(self, x):
        x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_rotary_pos_emb(self, q, k):
        B, S, H, D = q.shape
        if q.shape != k.shape:
            raise NotImplementedError("q and k must have the same shape")
        cos, sin = self.cos[:, :S, :, :], self.sin[:, :S, :, :]
        return (q * cos) + (self._rotate_half(q) * sin), (k * cos) + (self._rotate_half(k) * sin)

    def __call__(self, q, k):
        return self._apply_rotary_pos_emb(q, k)

In [None]:
class MoE(nn.Module):
    def __init__(self, in_dim, out_dim, num_experts, dropout=0.0, top_k=1):
        super().__init__()
        if top_k != 1:
            raise NotImplementedError("Top-k is not implemented for MoE")
        self.top_k = top_k
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_experts = num_experts

        self.experts = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_experts)])

    def forward(self, x, expert_probs, inference=False):
        if inference:
            raise NotImplementedError("Inference is not implemented for MoE")

        B, S, D = x.shape

        if expert_probs.shape != (B, self.num_experts):
            raise ValueError(f"Expert probabilities must be of shape (B, num_experts), got {expert_probs.shape}")
        
        expert_idx = torch.argmax(expert_probs, dim=-1)

        x_out = torch.zeros(B, S, self.out_dim, device=x.device)

        for i, expert in enumerate(self.experts):
            x_out[expert_idx == i, : , :] += expert(x[expert_idx == i, : , :])

        return x_out, expert_idx


    def __call__(self, x, expert_probs, inference=False):
        return self.forward(x, expert_probs, inference)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0, is_causal=True, use_moe=False, num_experts=4, router_idx=None, verbose_router=False):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.mlp_ratio = mlp_ratio
        self.num_experts = num_experts
        self.dropout = dropout
        self.router_idx = router_idx
        self.is_causal = is_causal
        self.verbose_router = verbose_router

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.rope = RoPE(d_model=dim, n_heads=n_heads, device=torch.device("cuda"), max_seq_len=4096)

        self.use_moe = use_moe
        if use_moe:
            if router_idx is None:
                raise ValueError("router_idx must be provided when using MoE")
            self.router = nn.Linear(dim, num_experts)
            self.qkv = MoE(dim, 3 * dim, num_experts, dropout=dropout)
            self.o = MoE(dim, dim, num_experts, dropout=dropout)
            self.mlp_in = MoE(dim, int(dim * mlp_ratio), num_experts, dropout=dropout)
            self.mlp_out = MoE(int(dim * mlp_ratio), dim, num_experts, dropout=dropout)
        else:
            self.qkv = nn.Linear(dim, 3 * dim)
            self.o = nn.Linear(dim, dim)
            self.mlp_in = nn.Linear(dim, int(dim * mlp_ratio))
            self.mlp_out = nn.Linear(int(dim * mlp_ratio), dim)

    def mlp(self, x, expert_probs=None):
        if self.use_moe:
            x, expert_idx = self.mlp_in(x, expert_probs)
            x = F.gelu(x)
            x, _ = self.mlp_out(x, expert_probs)
        else:
            x = F.gelu(self.mlp_in(x))
            x = self.mlp_out(x)
        return x

    def attn(self, x):
        B, S, D = x.shape
        head_dim = D // self.n_heads
        num_heads = self.n_heads

        if self.use_moe:
            router_out = self.router(x[:, self.router_idx])
            expert_probs = F.softmax(router_out, dim=-1)
            if self.verbose_router:
                top_experts = torch.argmax(expert_probs, dim=-1)  # [batch, tokens]
                num_experts = expert_probs.size(-1)
                counts = torch.bincount(top_experts.flatten(), minlength=num_experts)
                usage = counts.float() / counts.sum() * 100
                print(f"Expert usage (%): {usage.tolist()}")

            qkv, expert_idx = self.qkv(x, expert_probs)
        else:
            qkv = self.qkv(x)
            expert_idx = None
            expert_probs = None

        q, k, v = qkv.split(D, dim=2)
        q = q.view(B, S, num_heads, head_dim)
        k = k.view(B, S, num_heads, head_dim)
        v = v.view(B, S, num_heads, head_dim)

        q, k = self.rope(q, k)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)

        attn_out = attn_out.permute(0, 2, 1, 3)
        attn_out = attn_out.view(B, S, D)
        if self.use_moe:
            attn_out, expert_idx = self.o(attn_out, expert_probs)
        else:
            attn_out = self.o(attn_out)
            expert_idx = None

        return attn_out, expert_idx, expert_probs

        
    def forward(self, x):
        attn_out, _, expert_probs = self.attn(x)
        x = self.norm1(x + attn_out)
        if self.use_moe:
            mlp_out = self.mlp(x, expert_probs)
        else:
            mlp_out = self.mlp(x)

        x = self.norm2(x + mlp_out)
        return x

    def __call__(self, x):
        return self.forward(x)

In [None]:
df = pd.read_csv("../data/addition.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_x = torch.tensor(df['lhs_seq'].apply(ast.literal_eval).tolist())
_y = torch.tensor(df['rhs_seq'].apply(ast.literal_eval).tolist())

X = torch.cat([_x, _y], dim=1).to(device)

loss_mask = torch.ones_like(X).to(device)
loss_mask[:, _x.shape[1]:] = 0


In [53]:
vocab = "0123456789ri+=_"
vocab_size = len(vocab)
def encode(s):
    return [vocab.index(c) for c in s]
def decode(l):
    return ''.join([vocab[i] for i in l])
    
IGNORE_INDEX = vocab_size-1
D_MODEL = 64
N_HEADS = 4
N_LAYERS = 8


In [None]:
class AdditionDataset(Dataset):
    def __init__(self, X, loss_mask):
        self.X = X
        self.loss_mask = loss_mask

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.loss_mask[idx]

dataset = AdditionDataset(X, loss_mask)
loader = DataLoader(dataset, batch_size=8000, shuffle=True)

In [None]:
# class AdditionModel(nn.Module):
#     def __init__(self, d_model, n_heads, n_layers, vocab_size, num_experts):
#         super().__init__()
#         self.encoder = nn.Embedding(vocab_size, d_model)
#         self.transformer = nn.ModuleList([TransformerBlock(d_model, n_heads, is_causal=True, use_moe=False) for _ in range(n_layers)])
#         self.decoder = nn.Linear(d_model, vocab_size)
#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         self.moe = TransformerBlock(d_model, n_heads, is_causal=True, use_moe=True, num_experts=num_experts, router_idx=7)
#         # self.router = nn.Linear(d_model, num_experts)
#         self.to(self.device)


#     def preprocess(self, x):
#         self.moe(x)

#     def forward(self, x):
#         for layer in self.transformer:
#             x = layer(x)
#         x = self.decoder(x)
#         return x

#     def train(self, optimizer,epochs=1000):
#         for epoch in (range(epochs)):
#             for i, (batch_x, batch_mask) in enumerate(loader):
#                 x = self.encoder(batch_x)
#                 # x[:, :8, :] = self.preprocess(x[:, :8,:])

#                 y_hat = self.forward(x[:, :-1])
#                 loss = F.cross_entropy(
#                     input = y_hat.permute(0, 2, 1),
#                     target= batch_x[:, 1:],
#                     reduction="none"
#                 )
#                 masked_loss = loss * batch_mask[:, 1:]
#                 loss = masked_loss.mean()
#                 loss.backward()
#                 optimizer.step()
#                 optimizer.zero_grad()
#                 print(f"Epoch {i} loss: {loss.item()}")



In [None]:
# model = AdditionModel(D_MODEL, N_HEADS, N_LAYERS, vocab_size, 4)
# optimizer = AdamW(model.parameters(), lr=0.001)
# model.train(optimizer, epochs=1)


In [None]:
class RRM(nn.Module):
    def __init__(self, d_model, n_heads, max_recursions, vocab_size, num_experts, recursive_idx=recursive_idx, router_idx=6):
        super().__init__()
        self.max_recursions = max_recursions
        self.encoder = nn.Embedding(vocab_size, d_model)
        self.transformer = TransformerBlock(d_model, n_heads, is_causal=True, use_moe=True, num_experts=num_experts, router_idx=router_idx, verbose_router=True)
        self.decoder = nn.Linear(d_model, vocab_size)
        self.device = torch.device("cuda")
        self.to(self.device)

    def forward(self, x):
        for _ in range(self.max_recursions):
            x = self.transformer(x)
        return x

    def train(self, optimizer,epochs=1):
        for _ in (range(epochs)):
            for i, (batch_x, batch_mask) in enumerate(loader):
                x = self.encoder(batch_x)

                y_hat = self.forward(x[:, :-1])
                loss = F.cross_entropy(
                    input = y_hat.permute(0, 2, 1),
                    target= batch_x[:, 1:],
                    reduction="none"
                )
                masked_loss = loss * batch_mask[:, 1:]
                loss = masked_loss.mean()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                print(f"Epoch {i} loss: {loss.item()}")


In [None]:
model = RRM(64, N_HEADS, max_recursions=4, vocab_size=vocab_size, num_experts=4)
optimizer = AdamW(model.parameters(), lr=0.001)
model.train(optimizer, epochs=1)