In [1]:
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken

import matplotlib.pyplot as plt

import math

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class RecAdam(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        rectification=True,
        pretrain_step=0,
        total_step=1000,
        k=0.5,
        init_beta=10.0,
        final_beta=0.1,
    ):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        super(RecAdam, self).__init__(params, defaults)

        self.rectification = rectification
        self.pretrain_step = pretrain_step
        self.total_step = total_step
        self.k = k
        self.init_beta = init_beta
        self.final_beta = final_beta

        self.beta = init_beta
        self.current_step = 0

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        if self.rectification:
            self.beta = self.init_beta - (self.init_beta - self.final_beta) * min(
                1, self.current_step / self.total_step
            )
            self.current_step += 1

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad.data

                if grad.is_sparse:
                    raise RuntimeError("RecAdam does not support sparse gradients")

                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                    if hasattr(p, "pre_trained_params"):
                        state["theta"] = p.pre_trained_params
                    else:
                        state["theta"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = exp_avg_sq.sqrt().add_(group["eps"])

                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]
                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1

                if self.rectification and hasattr(p, "pre_trained_params"):
                    p.data.addcdiv_(exp_avg, denom, value=-step_size)
                    p.data.add_(
                        (p.data - state["theta"]), alpha=-self.beta * group["lr"]
                    )
                else:
                    p.data.addcdiv_(exp_avg, denom, value=-step_size)

                if group["weight_decay"] != 0:
                    p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])

        return loss

In [4]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_heads: int = 12
    n_embd: int = 768

In [36]:
class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=64, alpha=128):
        super().__init__()
        self.rank = rank
        self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
        self.scaling = alpha / rank

        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return (x @ (self.lora_A @ self.lora_B)) * self.scaling

In [6]:
class KAdapter(nn.Module):
    def __init__(self, config: GPTConfig, k=2):
        super().__init__()
        self.k = k
        self.adapters = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(config.n_embd, config.n_embd // 2),
                    nn.ReLU(),
                    nn.Linear(config.n_embd // 2, config.n_embd),
                )
                for _ in range(k)
            ]
        )

        for adapter in self.adapters:
            nn.init.normal_(adapter[0].weight, std=0.02)
            nn.init.zeros_(adapter[0].bias)
            nn.init.normal_(adapter[2].weight, std=0.02)
            nn.init.zeros_(adapter[2].bias)

    def forward(self, x):
        output = 0

        for adapter in self.adapters:
            output += adapter(x)

        return output / self.k

In [32]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig, lora: bool = False):
        super().__init__()

        assert config.n_embd % config.n_heads == 0

        self.n_heads = config.n_heads
        self.n_embd = config.n_embd

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)

        """
        equivalent to:
        
        for i in range(config.n_heads): 
            self.key = nn.Linear(config.n_embd, config.head_size)
            self.query = nn.Linear(config.n_embd, config.head_size)
            self.value = nn.Linear(config.n_embd, config.head_size)
        """

        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.c_proj.SCALE_INIT = 1

        self.lora = lora

        if lora:
            self.lora_attn = LoRALayer(config.n_embd, 3 * config.n_embd)
            self.lora_proj = LoRALayer(config.n_embd, config.n_embd)

    def forward(self, x):
        B, T, C = (
            x.size()
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        qkv = self.c_attn(x)

        if self.lora:
            qkv = qkv + self.lora_attn(x)

        q, k, v = qkv.split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(
            1, 2
        )  # B, n_heads, T, head_size
        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(
            1, 2
        )  # B, n_heads, T, head_size
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(
            1, 2
        )  # B, n_heads, T, head_size

        # attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # B, n_heads, T, T
        # attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        # attn = F.softmax(attn, dim=-1)
        # y = attn @ v # B, n_heads, T, head_size

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)  # B, T, C

        out = self.c_proj(y)

        if self.lora:
            out = out + self.lora_proj(y)

        return out


class MLP(nn.Module):
    def __init__(self, config: GPTConfig, lora: bool = False):
        super().__init__()

        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

        self.c_proj.SCALE_INIT = 1

        self.lora = lora

        if lora:
            self.lora_fc = LoRALayer(config.n_embd, 4 * config.n_embd)
            self.lora_proj = LoRALayer(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        fc_out = self.c_fc(x)

        if self.lora:
            fc_out = fc_out + self.lora_fc(x)

        fc_out = self.gelu(fc_out)

        proj_out = self.c_proj(fc_out)

        if self.lora:
            proj_out = proj_out + self.lora_proj(fc_out)

        return proj_out


class Block(nn.Module):
    def __init__(
        self,
        config: GPTConfig,
        lora: bool = False,
        k_adapters: int = 0,
    ):
        super().__init__()

        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config, lora)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config, lora)

        self.k_adapters = k_adapters

        if k_adapters > 0:
            self.adapters = KAdapter(config, k_adapters)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))

        x = x + self.mlp(self.ln_2(x))

        if self.k_adapters > 0:
            x = x + self.adapters(x)

        return x


In [33]:
class GPT(nn.Module):
    def __init__(
        self,
        config: GPTConfig,
        lora: bool = False,
        k_adapters: int = 0,
        use_modular: bool = False,
    ):
        super().__init__()

        self.config = config

        if use_modular:
            self.modular = nn.ModuleDict(
                dict(
                    wte=nn.Embedding(config.vocab_size, config.n_embd),
                    wpe=nn.Embedding(config.block_size, config.n_embd),
                    h=nn.ModuleList(
                        [Block(config, lora, k_adapters) for _ in range(config.n_layer)]
                    ),
                    ln_f=nn.LayerNorm(config.n_embd),
                )
            )

            self.modular.apply(self._init_weights)

            self.projection = nn.Linear(config.n_embd, config.n_embd, bias=False)
            self.scale_factor = 0.05

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),
                wpe=nn.Embedding(config.block_size, config.n_embd),
                h=nn.ModuleList(
                    [Block(config, lora, k_adapters) for _ in range(config.n_layer)]
                ),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.transformer.wte.weight = self.lm_head.weight

        self.use_modular = use_modular

        self.weight_masks = {}

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02

            if hasattr(module, "SCALE_INIT"):
                std *= (2 * self.config.n_layer) ** -0.5

            torch.nn.init.normal_(module.weight, mean=0.0, std=std)

            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def set_weight_masks(self, masks):
        for name, mask in masks.items():
            if name in dict(self.named_parameters()):
                param = dict(self.named_parameters())[name]

                mask = mask.to(device)

                def grad_hook(grad, mask=mask):
                    return grad * mask

                param.register_hook(grad_hook)

                self.weight_masks[name] = mask

    def forward(self, idx, targets=None):
        B, T = idx.size()

        assert T <= self.config.block_size

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)

        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)

        if self.use_modular:
            tok_emb_mod = self.modular.wte(idx)
            pos_emb_mod = self.modular.wpe(pos)
            x_mod = tok_emb_mod + pos_emb_mod

            for block in self.modular.h:
                x_mod = block(x_mod)

            x_mod = self.modular.ln_f(x_mod)

            x = x + (self.scale_factor * self.projection(x_mod))

        logits = self.lm_head(x)  # B, T, vocab_size

        loss = None

        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @classmethod
    def from_pretrained(
        cls,
        override_args=None,
        lora: bool = False,
        k_adapters: int = 0,
        use_modular: bool = False,
    ):
        override_args = override_args or {}

        assert all(k == "dropout" for k in override_args)

        from transformers import GPT2LMHeadModel

        config = GPTConfig()
        model = GPT(config, lora=lora, k_adapters=k_adapters, use_modular=use_modular)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith(".attn.bias")]

        model_hf = GPT2LMHeadModel.from_pretrained("gpt2")

        sd_hf = model_hf.state_dict()

        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.masked_bias")]
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]

        transposed = [
            "attn.c_attn.weight",
            "attn.c_proj.weight",
            "mlp.c_fc.weight",
            "mlp.c_proj.weight",
        ]

        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        if use_modular:
            with torch.no_grad():
                for k, v in sd.items():
                    if k.startswith("transformer."):
                        modular_key = k.replace("transformer.", "modular.")
                        if modular_key in sd:
                            sd[modular_key].copy_(v)

        return model

    def configure_optimizers(self, weight_decay, learning_rate):
        base_params = []
        adapt_params = []

        for name, param in self.named_parameters():
            if any(x in name for x in ["lora_A", "lora_B", "adapters", "modular"]):
                adapt_params.append(param)
            else:
                base_params.append(param)

        use_fused = "cuda" in device

        recadam_optimizer = RecAdam(
            [{"params": base_params}],
            lr=learning_rate * 0.1,
            betas=(0.9, 0.95),
            eps=1e-8,
            weight_decay=weight_decay,
            rectification=True,
            pretrain_step=0,
            total_step=200,
            k=0.5,
            init_beta=10.0,
            final_beta=0.1,
        )

        adamw_optimizer = torch.optim.AdamW(
            [{"params": adapt_params}],
            lr=learning_rate,
            betas=(0.9, 0.95),
            eps=1e-8,
            weight_decay=weight_decay,
            fused=use_fused,
        )

        return [recadam_optimizer, adamw_optimizer]

    def hyper_scale(self):
        original_wte = self.transformer.wte.weight.data
        original_wpe = self.transformer.wpe.weight.data

        self.transformer.wte = nn.Embedding(
            self.config.vocab_size, 2 * self.config.n_embd
        )
        self.transformer.wpe = nn.Embedding(
            self.config.block_size, 2 * self.config.n_embd
        )

        self.transformer.wte.weight.data = torch.cat([original_wte] * 2, dim=1)
        self.transformer.wpe.weight.data = torch.cat([original_wpe] * 2, dim=1)

        for block in self.transformer.h:
            # Layer Norm 1

            original_ln_1_weight = block.ln_1.weight.data
            original_ln_1_bias = block.ln_1.bias.data

            block.ln_1 = nn.LayerNorm(2 * self.config.n_embd)

            block.ln_1.weight.data = torch.cat([original_ln_1_weight] * 2, dim=0)
            block.ln_1.bias.data = torch.cat([original_ln_1_bias] * 2, dim=0)

            block.attn.n_heads = 2 * block.attn.n_heads
            block.attn.n_embd = 2 * block.attn.n_embd

            # Attention Head

            original_c_attn_weight = block.attn.c_attn.weight.data
            original_c_attn_bias = block.attn.c_attn.bias.data

            q_w, k_w, v_w = original_c_attn_weight.chunk(3, dim=0)
            q_b, k_b, v_b = original_c_attn_bias.chunk(3)

            block.attn.c_attn = nn.Linear(
                2 * self.config.n_embd, 6 * self.config.n_embd
            )

            new_weight = (
                torch.cat(
                    [q_w, q_w, k_w, k_w, v_w, v_w],
                    dim=0,
                )
            ) / 2
            new_weight = torch.cat([new_weight] * 2, dim=1)

            new_bias = torch.cat([q_b, q_b, k_b, k_b, v_b, v_b])

            block.attn.c_attn.weight.data = new_weight
            block.attn.c_attn.bias.data = new_bias

            original_c_proj_weight = block.attn.c_proj.weight.data
            original_c_proj_bias = block.attn.c_proj.bias.data

            block.attn.c_proj = nn.Linear(
                2 * self.config.n_embd, 2 * self.config.n_embd
            )

            block.attn.c_proj.weight.data = (
                torch.cat(
                    [torch.cat([original_c_proj_weight] * 2, dim=1)] * 2,
                    dim=0,
                )
                / 2
            )

            block.attn.c_proj.bias.data = torch.cat([original_c_proj_bias] * 2, dim=0)

            # MLP

            original_c_fc_weight = block.mlp.c_fc.weight.data
            original_c_fc_bias = block.mlp.c_fc.bias.data

            block.mlp.c_fc = nn.Linear(2 * self.config.n_embd, 8 * self.config.n_embd)

            block.mlp.c_fc.weight.data = (
                torch.cat(
                    [torch.cat([original_c_fc_weight] * 2, dim=1)] * 2,
                    dim=0,
                )
                / 2
            )

            block.mlp.c_fc.bias.data = torch.cat([original_c_fc_bias] * 2, dim=0)

            original_c_proj_weight = block.mlp.c_proj.weight.data
            original_c_proj_bias = block.mlp.c_proj.bias.data

            block.mlp.c_proj = nn.Linear(8 * self.config.n_embd, 2 * self.config.n_embd)

            block.mlp.c_proj.weight.data = (
                torch.cat(
                    [torch.cat([original_c_proj_weight] * 2, dim=1)] * 2,
                    dim=0,
                )
                / 2
            )

            block.mlp.c_proj.bias.data = torch.cat([original_c_proj_bias] * 2, dim=0)

            # Layer Norm 2

            original_ln_2_weight = block.ln_2.weight.data
            original_ln_2_bias = block.ln_2.bias.data

            block.ln_2 = nn.LayerNorm(2 * self.config.n_embd)

            block.ln_2.weight.data = torch.cat([original_ln_2_weight] * 2, dim=0)
            block.ln_2.bias.data = torch.cat([original_ln_2_bias] * 2, dim=0)

        original_ln_f_weight = self.transformer.ln_f.weight.data
        original_ln_f_bias = self.transformer.ln_f.bias.data

        self.transformer.ln_f = nn.LayerNorm(2 * self.config.n_embd)

        self.transformer.ln_f.weight.data = torch.cat([original_ln_f_weight] * 2, dim=0)
        self.transformer.ln_f.bias.data = torch.cat([original_ln_f_bias] * 2, dim=0)

        original_lm_head = self.lm_head.weight.data

        self.lm_head = nn.Linear(
            2 * self.config.n_embd, self.config.vocab_size, bias=False
        )

        self.lm_head.weight.data = torch.cat([original_lm_head] * 2, dim=1) / 2

In [37]:
model = GPT.from_pretrained(lora=True, k_adapters=0, use_modular=False)

model = model.to(device)

model = torch.compile(model, backend="aot_eager")

In [11]:
# model.hyper_scale()

In [38]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


total_params, trainable_params = count_parameters(model)

print(f"Total Parameters: {total_params:,}")

Total Parameters: 133,876,992


In [162]:
total_batch_size = 16384
B = 4
T = model.config.block_size

assert total_batch_size % (B * T) == 0

grad_accumulation_steps = total_batch_size // (B * T)

print(
    f"Batch size: {total_batch_size}, Gradient accumulation steps: {grad_accumulation_steps}"
)

Batch size: 16384, Gradient accumulation steps: 4


In [163]:
class DataLoader:
    def __init__(self, B, T, fileName):
        self.B = B
        self.T = T

        with open(fileName, "r") as f:
            text = f.read()

        encoder = tiktoken.get_encoding("gpt2")
        tokens = encoder.encode(text)

        self.tokens = torch.tensor(tokens, dtype=torch.long)

        self.curr = self.B * self.T

    def next(self):
        B, T = self.B, self.T

        buffer = self.tokens[self.curr : self.curr + B * T + 1]

        x = buffer[:-1].view(B, T)
        y = buffer[1:].view(B, T)

        self.curr += B * T  

        if self.curr + (B * T + 1) > len(self.tokens):
            self.curr = B * T

        return x, y
    
    def reset(self):
        self.curr = 0


data_loader = DataLoader(B=B, T=T, fileName="data/current_events_dataset.txt")

In [164]:
optimizers = model.configure_optimizers(weight_decay=0.1, learning_rate=1e-4)

In [165]:
def eval2(fileName):
    model.eval()

    total_loss = 0.0
    total_batches = 0
    
    eval_data = torch.load(fileName, weights_only=True, map_location=device)

    x_eval = eval_data['context']
    y_eval = eval_data['target']

    with torch.no_grad():
        total_batches = x_eval.shape[0] // B
        
        for i in range(0, total_batches):
            x, y = x_eval[i : i + B], y_eval[i : i + B]    

            x, y = x.to(device), y.to(device)

            logits, _ = model(x)
            
            logits = logits[:, -1, :]
            
            loss = F.cross_entropy(logits, y)

            total_loss += loss.item()

    average_loss = total_loss / total_batches
    
    model.train()

    return average_loss

In [None]:
for name, param in model.named_parameters():
    if not any(x in name for x in ["lora_A", "lora_B", "adapters", "modular"]):
        param.requires_grad = False

In [166]:
i = 0

In [139]:
while i <= 10:
    for opt in optimizers:
        opt.zero_grad()

    loss_acum = 0.0

    for micro_step in range(grad_accumulation_steps):
        x, y = data_loader.next()
        x, y = x.to(device), y.to(device)

        logits, loss = model(x, y)

        loss = loss / grad_accumulation_steps

        loss_acum += loss.item()

        loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    for opt in optimizers:
        opt.step()
    
    general_loss = eval2("data/general_eval.pt")
    current_events_loss = eval2("data/current_events_eval.pt")
    
    print(f"Epoch {i}, General Loss: {general_loss}, Current Events Loss: {current_events_loss}")
        
    i += 1

In [174]:
general_loss = eval2("data/general_eval.pt")
current_events_loss = eval2("data/current_events_eval.pt")

print("General Loss", general_loss)
print("Current Events Loss", current_events_loss)
print("Ratio", general_loss / current_events_loss)

General Loss 4.279356412887573
Current Events Loss 6.460156536102295
Ratio 0.6624230216361758
