In [23]:
import torch 

In [48]:
device ='cuda'

def get_cos_sin(seq_length, head_dim, theta=3.14):
    # take sequence length and head dimension as input.
    # return (seq_length, head_dim)
    assert head_dim%2==0
    i = torch.arange(head_dim//2).unsqueeze(0).to(device)
    position = torch.arange(seq_length).unsqueeze(1).to(device)
    angle = 1.0/(theta**(-2*i/head_dim))
    return torch.cos(position*angle).repeat(1,2), torch.sin(position*angle).repeat(1,2)

def attention(q,k,v,is_causal=True,mask=None):
    # take q,k,v of shape (batch_size, num_head, seq_length, head_dim) 
    assert len(q.shape)==4 
    assert is_causal==True or mask!=None
    batch_size, num_head, seq_length, head_dim = q.shape
    if is_causal:
        mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool().to(device)
    scaled_dot_product = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(q.size(-1)))
    scaled_dot_product.masked_fill_(mask, float('-inf'))
    attention_weights = torch.nn.functional.softmax(scaled_dot_product, dim=-1)
    return torch.matmul(attention_weights, v)

In [49]:
# normal attention
torch.manual_seed(42)
batch_size=1
seq_length=12
hidden_dim=128
num_head=4
# num_kv_head=4
assert hidden_dim%num_head==0
head_dim=hidden_dim//num_head

cos, sin = get_cos_sin(seq_length, head_dim)
x = torch.randn(batch_size,seq_length, hidden_dim).to(device)

In [54]:
def rotate_half(x):
    # [x1,x2,x3,x4] -> [x3,x4,x1,x2]
    head_size = x.shape[-1]
    assert head_size%2==0
    x1 = x[..., : head_size // 2]  
    x2 = x[..., head_size // 2 :]  
    return torch.cat([-x2, x1], dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    # take x of shape (batch_size, num_heads, seq_length, head_dim)
    # cos, sin of shape (seq_length, head_dim)
    batch_size, num_head, seq_length, head_dim = x.size()
    assert cos.size(0)==seq_length
    assert cos.size(1)==head_dim
    x = x * cos + rotate_half(x) * sin
    return x

In [55]:
def CausalAttention(x,cos,sin):
    # take x of shape (batch_size, seq_length, hidden_dim)
    # return (batch_size, seq_length, hidden_dim)
    batch_size, seq_length, hidden_dim = x.shape
    query_prj = torch.nn.Linear(hidden_dim, hidden_dim).to(device)
    key_prj = torch.nn.Linear(hidden_dim, hidden_dim).to(device)
    value_prj = torch.nn.Linear(hidden_dim, hidden_dim).to(device)
    
    head_dim = hidden_dim//num_head
    query = query_prj(x).view(batch_size,seq_length, num_head, head_dim)
    value = value_prj(x).view(batch_size,seq_length, num_head, head_dim)
    key = key_prj(x).view(batch_size,seq_length, num_head, head_dim)
    query = query.transpose(1, 2) # batch_size, num_head, seq_length, head_dim
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)
    query = apply_rotary_pos_emb(query, cos, sin)
    key = apply_rotary_pos_emb(key, cos, sin)
    out = attention(query, key, value).transpose(1, 2).contiguous().view(batch_size,seq_length, hidden_dim)
    ref_out = torch.nn.functional.scaled_dot_product_attention(query, key, value,is_causal=True).transpose(1, 2).contiguous().view(batch_size,seq_length, hidden_dim)
    torch.testing.assert_close(ref_out, out, rtol=1e-6, atol=1e-6)
    # print(query.size(), key.size(), value.size())
    # print(out.size())
    return out

CausalAttention(x,cos,sin)

tensor([[[-0.3027,  0.2380,  0.6034,  ...,  0.5112,  0.7894, -0.6095],
         [-0.0117, -0.0066,  0.8311,  ...,  0.5305,  0.6811, -0.3407],
         [-0.2062,  0.0031,  0.3447,  ...,  0.2466,  0.3350, -0.1707],
         ...,
         [ 0.2361, -0.0146, -0.0264,  ..., -0.0757,  0.0509, -0.3383],
         [ 0.1223,  0.0140,  0.0098,  ..., -0.1104,  0.0818, -0.3325],
         [ 0.2322, -0.0486, -0.0757,  ..., -0.1011, -0.0754, -0.3290]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [56]:
from dataclasses import dataclass
import torch.nn as nn

@dataclass
class Config:
    batch_size: int = 1
    seq_length: int = 12
    hidden_dim: int = 128
    num_heads: int = 4
    num_layers: int = 16

config = Config()

class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim=128):
        self.layer_norm1 = torch.nn.LayerNorm(hidden_dim)
        self.layer_norm2 = torch.nn.LayerNorm(hidden_dim)
        self.attention = CausalAttention
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, 4*hidden_dim),
            nn.GELU(),
            nn.Linear(4*hidden_dim, hidden_dim)
        )
    def forward(self, x, cos, sin):
        out = self.attention(self.layer_norm1(x),cos,sin)
        out = self.mlp(self.layer_norm2(out))
        return out
    
class GPT(nn.Module):
    def __init__(self, config):
        self.num_heads = config.num_heads
        self.hidden_dim = config.hidden_dim
        self.vocab_size = config.vocab_size
        self.seq_length = config.seq_length
        self.num_layers = config.num_layers
        self.layers = nn.ModuleList([DecoderLayer(self.hidden_dim) for _ in range(self.num_layers)])
        self.cos, self.sin = get_cos_sin(self.seq_length, self.hidden_dim//self.num_heads)
    def forward(self, x):
        for layer in self.layers:
            x = layer(x, self.cos, self.sin)
        return x