# From Llama 3 to Llama 3.1

Description - to fill in later

### Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

### Grouped-query attention

- same as Llama 3

In [3]:
class GroupedQueryAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, 
                    context_length, n_heads,
                    n_kv_groups, dtype=None):
        super().__init__()

        assert n_heads % n_kv_groups == 0, "Number of heads must be divisible by number of key-value groups"
        
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.group_size = n_heads // n_kv_groups
        self.d_k = d_k

        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_kv_groups * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_kv_groups * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model, bias=False, dtype=dtype)     
        
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   
            
        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 
        
    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (nkv k) -> b nkv t k', nkv=self.n_kv_groups)
        v = rearrange(self.wv(x), 'b t (nkv v) -> b nkv t v', nkv=self.n_kv_groups)

        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        k = repeat(k, 'b nkv t k -> b (nkv gsz) t k', gsz=self.group_size)
        v = repeat(v, 'b nkv t v -> b (nkv gsz) t v', gsz=self.group_size)
        
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out)