In [5]:
import torch
import torch.nn as nn
from einops import rearrange,repeat

In [6]:
def scaled_dot_product_attention(q,k,v,is_causal=False):
    """
    q: b x n x t1 x hd
    k,v: b x n x t2 x hd
    qkT: b x n x t1 x t2
    attention: b x n x t1 x d
    """
    b,d = q.size(0),q.size(2)
    scale = 1 / q.size(2) ** 0.5
    t1 = q.size(1)
    t2 = k.size(1)
    
    qkT = q @ k.transpose(-1,-2) * scale
    
    if is_causal:
        mask = torch.tril(torch.ones_like(qkT)).to(device=qkT.device)
        qkT = qkT.masked_fill(mask==0,float('-inf'))
        
    qkT = qkT.softmax(dim=-1)
        
    attention = qkT @ v
    return attention

# Multi Head Attention


In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, is_causal = False, qkv_bias=False):
        super().__init__()
        
        assert dim % n_heads == 0, 'dim should be div by num heads'
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        self.is_causal = is_causal
        
        self.qkv = nn.Linear(self.dim,self.dim*3,bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.dim)
        
    def forward(self, x):
        
        b,s,d = x.size()
        
        qkv = self.qkv(x).chunk(3,dim=-1)
        
        q,k,v = map(lambda t: t.view(b,s,self.n_heads,self.head_dim).permute(0,2,1,3),qkv)
        
        attention = scaled_dot_product_attention(q,k,v,is_causal=self.is_causal)
        attention = attention.permute(0,2,1,3).contiguous().view(b,s,d)
        
        return self.proj(attention)

In [8]:
mha = MultiHeadAttention(512,8,is_causal=True)
mha(torch.rand(1,128,512)).shape

torch.Size([1, 128, 512])

# Cross Attention

In [9]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, dim, n_heads, q_bias=False, kv_bias=False):
        super().__init__()
        
        assert dim % n_heads == 0, 'dim should be div by num heads'
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        
        self.q = nn.Linear(self.dim,self.dim,bias=q_bias)
        self.kv = nn.Linear(self.dim,self.dim*2,bias=kv_bias)
        self.proj = nn.Linear(self.dim, self.dim)
        
    def forward(self, decoder_out, encoder_out):
        
        b,s,d = decoder_out.size()
        
        q = self.q(decoder_out)
        k,v = self.kv(encoder_out).chunk(2,dim=-1)
        
        q = q.view(b,q.size(1),self.n_heads,self.head_dim).permute(0,2,1,3)
        k = k.view(b,k.size(1),self.n_heads,self.head_dim).permute(0,2,1,3)
        v = v.view(b,v.size(1),self.n_heads,self.head_dim).permute(0,2,1,3)
        
        attention = scaled_dot_product_attention(q,k,v,is_causal=False)
        attention = attention.permute(0,2,1,3).contiguous().view(b,s,d)
        
        return self.proj(attention)

In [10]:
dec_out = torch.rand(1,128,512)
enc_out = torch.rand(1,100,512)
mhca = MultiHeadCrossAttention(512,8)
mhca(dec_out,enc_out).shape

torch.Size([1, 128, 512])

# Grouped Query Attention

In [11]:
class GroupedQueryAttention(nn.Module):
    def __init__(self,dim,n_heads,n_groups):
        super().__init__()
        
        assert dim % n_heads == 0, 'dim should be div by n_heads'
        assert n_heads % n_groups == 0, 'n_heads should be div by n_groups'
        
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        self.n_groups = n_groups
        self.n_repeats = self.n_heads // self.n_groups
        
        self.q = nn.Linear(self.dim,self.head_dim*self.n_heads)
        self.kv = nn.Linear(self.dim,self.head_dim*self.n_groups*2)
        self.proj = nn.Linear(self.dim, self.dim)
        
    def forward(self,x):
        
        b,s,d = x.shape
        
        q = self.q(x)
        k,v = self.kv(x).chunk(2,dim=-1)
        
        q = q.view(b,s,self.n_heads,self.head_dim).permute(0,2,1,3)
        k = k.view(b,s,self.n_groups,self.head_dim).permute(0,2,1,3)
        v = v.view(b,s,self.n_groups,self.head_dim).permute(0,2,1,3)
        
        # repeat interleave: [1,2] * 3 => [1,1,1,2,2,2]
        # k,v: b x n_groups x s x h => b x n_heads x s x h
        # b x n_groups x 1 x s x h => b x n_groups x n_repeats x s x h => b x n_heads x s x h
        k = k[:,:,None,:,:].expand(b, self.n_groups, self.n_repeats, s, self.head_dim).reshape(b, self.n_heads, s, self.head_dim)
        v = v[:,:,None,:,:].expand(b, self.n_groups, self.n_repeats, s, self.head_dim).reshape(b, self.n_heads, s, self.head_dim)
        
        attention = scaled_dot_product_attention(q,k,v,is_causal=True)
        attention = attention.permute(0,2,1,3).contiguous().view(b,s,d)
        
        return self.proj(attention)

In [12]:
x = torch.rand(1,12,16)
gqa = GroupedQueryAttention(16,8,4)
gqa(x).shape

torch.Size([1, 12, 16])

### einops

In [21]:
class GQA(nn.Module):
    def __init__(self,dim,n_heads,n_groups):
        super().__init__()
        
        assert dim % n_heads == 0, 'dim should be div by n_heads'
        assert n_heads % n_groups == 0, 'n_heads should be div by n_groups'
        
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        self.n_groups = n_groups
        self.n_repeats = self.n_heads // self.n_groups
        
        self.q = nn.Linear(self.dim,self.head_dim*self.n_heads)
        self.kv = nn.Linear(self.dim,self.head_dim*self.n_groups*2)
        self.proj = nn.Linear(self.dim, self.dim)
        
    def forward(self,x):
        
        b,s,d = x.shape
        
        q = self.q(x)
        k,v = self.kv(x).chunk(2,dim=-1)
        
        q = rearrange(q,'b s (n h) -> b n s h',n=self.n_heads, h=self.head_dim)
        k = rearrange(k,'b s (g h) -> b g s h',g=self.n_groups, h=self.head_dim)
        v = rearrange(v,'b s (g h) -> b g s h',g=self.n_groups, h=self.head_dim)
        
        k = repeat(k,'b g s h -> b (g r) s h',r=self.n_repeats)
        v = repeat(v,'b g s h -> b (g r) s h',r=self.n_repeats)
        
        attention = scaled_dot_product_attention(q,k,v,is_causal=True)
        attention = rearrange(attention,'b n s h -> b s (n h)')
        
        return self.proj(attention)

In [22]:
x = torch.rand(1,12,16)
gqa = GQA(16,8,4)
gqa(x).shape

torch.Size([1, 12, 16])