In [18]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from dataclasses import dataclass
import math

In [19]:
@dataclass
class ModelConfig:
    d_model = 512 # embeded dim
    n_h = 8 # number of head
    seq_len = 512 # sequence length
    batch_size = 8 # batch size
    dff = 2048 # hidden dim
    dp = 0.1 # dropout
    n_layer = 6 # number of layer
    vocab_size = 37000
    k = 2

In [36]:
class CausalAttention(nn.Module):
    def __init__(self, config, decoder=True):
        super().__init__()
        assert config.d_model % config.n_h == 0
        self.config = config
        self.decoder = decoder
        self.c_attn = nn.Linear(config.d_model, config.d_model * 3)
        self.c_proj = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(config.dp)
        
        d_k = config.d_model // config.n_h

        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.seq_len, config.seq_len)).view(1, 1, config.seq_len, config.seq_len)
        )   
        self.look_up_table_k = nn.Parameter(torch.zeros((2 * config.k + 1, d_k)))
        self.look_up_table_v = nn.Parameter(torch.zeros((2 * config.k + 1, d_k)))

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(d_model, dim=2)
        d_k = k.size(-1) // self.config.n_h
        
        q = q.view(batch_size, seq_len, self.config.n_h, d_k).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.config.n_h, d_k).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.config.n_h, d_k).transpose(1, 2)
        
        i = torch.arange(seq_len, device=x.device).unsqueeze(1)
        j = torch.arange(seq_len, device=x.device).unsqueeze(0) 
        
        relative_pos = torch.clamp(j - i, -self.config.k, self.config.k) + self.config.k
        
        aij_k = self.look_up_table_k[relative_pos]
        aij_v = self.look_up_table_v[relative_pos]
        
        raw_attention_score = q @ k.transpose(-2, -1)
        relative_pos_score = torch.einsum('bhid,ijd->bhij', q, aij_k)
        
        attention_score = raw_attention_score + relative_pos_score
        attention_score = attention_score * (1.0 / math.sqrt(d_k))
        
        if self.decoder:
            attention_score = attention_score.masked_fill(
                self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf')
            )

        attention_score = F.softmax(attention_score, dim=-1)
        attention_score = self.dropout(attention_score)
        
        standard_out = attention_score @ v
        relative_v_out = torch.einsum('bhij,ijd->bhid', attention_score, aij_v)
        
        out = standard_out + relative_v_out
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.c_proj(out)
        return out

In [37]:
selfattn = CausalAttention(config=ModelConfig,decoder=False)


x = torch.randn((ModelConfig.batch_size,ModelConfig.seq_len,ModelConfig.d_model))
print(selfattn(x).shape)

torch.Size([8, 512, 512])
