In [158]:
import torch
import torch.nn as nn
import math

In [159]:
# I'm deliberately not using the inbuilt torch modules like attention, etc.

In [160]:
class SwiGLU(nn.Module):
    def __init__(self, dim_in, intermediate_size):
        super().__init__()
        self.dim_in = dim_in
        self.fc = nn.Linear(dim_in, intermediate_size * 2, bias=False)
        self.swish = nn.SiLU()

    def forward(self, x):
        out = self.fc(x) # [..., 2 * intermediate_size]
        gate, val = out.chunk(2, dim=-1) # each [..., intermediate_size]
        return self.swish(gate) * val

In [161]:
class RoPE:
    def __init__(self):
        self.cos_cache = None
        self.sin_cache = None
        self.theta = None
        self.cached_seq = 0
        self.cached_d = 0

    def get_rot_cached(self, d, seq_length):
        if self.cos_cache is None:
            self.theta = 1e6 ** (-2 * torch.arange(d//2, dtype=torch.float32) / d)
            ms = torch.arange(seq_length)
            angles = torch.einsum('i,j->ij', ms, self.theta)
            self.cos_cache = torch.cos(angles)
            self.sin_cache = torch.sin(angles)
            self.cached_d = d
            self.cached_seq = seq_length
            needed_dhalf = d // 2
            return self.cos_cache[:seq_length, :needed_dhalf], self.sin_cache[:seq_length, :needed_dhalf]

        needed_dhalf = d // 2

        # Recompute if dimension changes
        if d != self.cached_d:
            self.theta = 1e6 ** (-2 * torch.arange(needed_dhalf, dtype=torch.float32) / d)
            ms = torch.arange(self.cached_seq)
            angles = torch.einsum('i,j->ij', ms, self.theta)
            self.cos_cache = torch.cos(angles)
            self.sin_cache = torch.sin(angles)
            self.cached_d = d

        # Extend sequence length if needed
        if seq_length > self.cached_seq:
            new_ms = torch.arange(self.cached_seq, seq_length, dtype=torch.float32)
            new_angles = torch.einsum('i,j->ij', new_ms, self.theta)
            new_cos = torch.cos(new_angles)
            new_sin = torch.sin(new_angles)
            self.cos_cache = torch.cat([self.cos_cache, new_cos], dim=0)
            self.sin_cache = torch.cat([self.sin_cache, new_sin], dim=0)
            self.cached_seq = seq_length

        # Return the sliced views as needed
        return self.cos_cache[:seq_length, :needed_dhalf], self.sin_cache[:seq_length, :needed_dhalf]

    def apply(self, t):
        # (batch_size, heads, seq_length, d)
        seq_length, d = t.shape[-2:]

        r_cos, r_sin = self.get_rot_cached(d, seq_length)

        t_even = t[..., 0::2]
        t_odd = t[..., 1::2]
        t_conj = torch.empty_like(t)
        t_conj[..., 0::2] = -t_odd
        t_conj[..., 1::2] = t_even

        return t * r_cos.repeat_interleave(2, dim=-1) + t_conj * r_sin.repeat_interleave(2, dim=-1)

In [162]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, num_q_heads, group_size, dim_model, dim_k, dropout=0.1, enable_rope = True):
        super().__init__()
        assert dim_model % num_q_heads == 0, "dim_model must be divisible by num_q_heads"
        assert num_q_heads % group_size == 0, "num_q_heads must be divisible by group_size"

        self.enable_rope = enable_rope
        self.group_size = group_size
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_q_heads // group_size
        self.dim_model = dim_model
        self.dim_k = dim_k # we are assuming dim_k = dim_v

        self.rope = None
        if self.enable_rope:
            self.rope = RoPE()

        self.q_proj = nn.Linear(dim_model, self.dim_k * self.num_q_heads, bias=False)
        self.k_proj = nn.Linear(dim_model, self.dim_k * self.num_kv_heads, bias=False)
        self.v_proj = nn.Linear(dim_model, self.dim_k * self.num_kv_heads, bias=False)
        self.fc = nn.Linear(self.dim_k * num_q_heads, dim_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):

        batch_size, seq_length = x.shape[:2]

        q = self.q_proj(x).view(batch_size, seq_length, self.num_q_heads, self.dim_k).transpose(1,2) # (batch_size, num_q_heads, seq_length, dim_k)
        k = self.k_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.dim_k).transpose(1,2) # (batch_size, num_kv_heads, seq_length, dim_k)
        v = self.v_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.dim_k).transpose(1,2) # (batch_size, num_kv_heads, seq_length, dim_k)

        if self.enable_rope:
            q = self.rope.apply(q)
            k = self.rope.apply(k)

        q = q.view(batch_size, self.num_kv_heads, self.group_size, seq_length, self.dim_k)

        k = k.transpose(-1,-2) # (batch_size, num_kv_heads, dim_k, seq_length)
        k = k.unsqueeze(2) # (batch_size, num_kv_heads, 1, dim_k, seq_length)

        scores = torch.matmul(q,k) / math.sqrt(self.dim_k) # (batch_size, num_kv_heads, group_size, seq_length, seq_length)

        if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = nn.functional.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        v = v.unsqueeze(2) # (batch_size, num_kv_heads, 1, seq_length, dim_k)
        out = torch.matmul(attn_weights, v) # (batch_size, num_kv_heads, group_size, seq_length, dim_k)

        out = out.contiguous().view(batch_size, self.num_q_heads, seq_length, self.dim_k)
        out = self.fc(out.transpose(1,2).contiguous().view(batch_size, seq_length, self.num_q_heads * self.dim_k))

        return out

In [163]:
class DecoderLayer(nn.Module):
    def __init__(self, dim_model, dim_k, num_q_heads, group_size, Intermediate_size, eps=1e-6, dropout=0.1):
        super().__init__()
        self.gq_attn = GroupedQueryAttention(
            num_q_heads,
            group_size,
            dim_model,
            dim_k,
            dropout=dropout,
            enable_rope = True)
        self.rms_norm_1 = nn.RMSNorm(normalized_shape=dim_model, eps=eps)
        self.attention_dropout = nn.Dropout(p=dropout)

        #FFN
        self.swiglu = SwiGLU(dim_in=dim_model, intermediate_size=Intermediate_size)
        self.down_proj = nn.Linear(Intermediate_size, dim_model, bias=False)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.rms_norm_2 = nn.RMSNorm(normalized_shape=dim_model, eps=eps)

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.RMSNorm):
                nn.init.ones_(module.weight)

    def forward(self, x):
        batch_size, seq_length = x.shape[:2]

        mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device)).bool()
        mask = mask[None, None, None, :, :]

        norm1 = self.rms_norm_1(x)
        context = self.gq_attn(norm1, mask=mask)
        context = self.attention_dropout(context)
        x = context + x

        norm2 = self.rms_norm_2(x)
        act = self.swiglu(norm2)
        ffn_out = self.down_proj(act)
        ffn_dropout = self.ffn_dropout(ffn_out)
        x = ffn_dropout + x

        return x

In [164]:
class Proto2(nn.Module):
    def __init__(self, vocab_size, dim_model, dim_k, num_q_heads, group_size, num_decoder_layers, Intermediate_size, eps=1e-6, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, dim_model)
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(
                dim_model,
                dim_k,
                num_q_heads,
                group_size,
                Intermediate_size,
                eps,
                dropout=dropout
                ) for _ in range(num_decoder_layers)]
            )
        self.rms_norm = nn.RMSNorm(dim_model, eps)  # final LN
        self.output_head = nn.Linear(dim_model, vocab_size, bias=False)
        self.output_head.weight = self.token_embedding.weight
        self.dropout = nn.Dropout(dropout)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.RMSNorm):
                nn.init.ones_(module.weight)

    def forward(self, input_ids, targets=None):
        x = self.token_embedding(input_ids)
        x = self.dropout(x)
        for layer in self.decoder_layers:
            x = layer(x)
        x = self.rms_norm(x)
        logits = self.output_head(x)
        if targets is None:
            return logits
        # Shift for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_targets = targets[..., 1:].contiguous()
        loss = nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_targets.view(-1))
        return logits, loss

In [165]:
vocab_size = 151936 # random for demo

model = Proto2(
    vocab_size=vocab_size,
    dim_model=1024,
    dim_k=1024//8, # possibly derived from dim_model // (num_q_heads // group_size)
    num_q_heads=16,
    group_size=2,
    num_decoder_layers=28,
    Intermediate_size=3072,
    eps=1e-6,
    dropout=0.1 # should drop to 0.0 in finetuning
)

In [166]:
total = 0
for name, param in model.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Parameter shape: {param.shape}\n")
    total += param.numel()
    print(f"Total params: {total}\n")



Parameter name: token_embedding.weight
Parameter shape: torch.Size([151936, 1024])

Total params: 155582464

Parameter name: decoder_layers.0.gq_attn.q_proj.weight
Parameter shape: torch.Size([2048, 1024])

Total params: 157679616

Parameter name: decoder_layers.0.gq_attn.k_proj.weight
Parameter shape: torch.Size([1024, 1024])

Total params: 158728192

Parameter name: decoder_layers.0.gq_attn.v_proj.weight
Parameter shape: torch.Size([1024, 1024])

Total params: 159776768

Parameter name: decoder_layers.0.gq_attn.fc.weight
Parameter shape: torch.Size([1024, 2048])

Total params: 161873920

Parameter name: decoder_layers.0.rms_norm_1.weight
Parameter shape: torch.Size([1024])

Total params: 161874944

Parameter name: decoder_layers.0.swiglu.fc.weight
Parameter shape: torch.Size([6144, 1024])

Total params: 168166400

Parameter name: decoder_layers.0.down_proj.weight
Parameter shape: torch.Size([1024, 3072])

Total params: 171312128

Parameter name: decoder_layers.0.rms_norm_2.weight
Par