In [1]:
import torch
from torch import nn

## Multi Head Attention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_dim: int, n_heads: int):
        super(MultiHeadAttention, self).__init__()
        self.n_dim = n_dim
        self.n_heads = n_heads
        self.head_dim = self.n_dim // self.n_heads
        self.q_linear = nn.Linear(n_dim, n_dim)
        self.k_linear = nn.Linear(n_dim, n_dim)
        self.v_linear = nn.Linear(n_dim, n_dim)
        self.out_linear = nn.Linear(n_dim, n_dim)
        
    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
        bs = x.size()[0]
        q, k, v = self.q_linear(x), self.k_linear(x), self.v_linear(x)
        q, k, v = self.split_head(q), self.split_head(k), self.split_head(v)
        scores = torch.matmul(q, k.transpose(-1, -2)) / self.head_dim ** 0.5 # (b, n, s, s)
        if attention_mask is not None:
            scores.masked_fill_(attention_mask == 0, float('-inf'))
        scores = torch.softmax(scores, dim=-1)
        out = torch.matmul(scores, v)
        out = out.transpose(1, 2).contiguous().view(bs, -1, self.n_dim)
        out = self.out_linear(out)
        return out
    
    def split_head(self, x: torch.Tensor):
        bs = x.size()[0]
        return x.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)

In [3]:
mha = MultiHeadAttention(32, 4)
input = torch.rand((32, 16, 32))
out = mha(input)
out

tensor([[[-0.2974, -0.1306,  0.0310,  ...,  0.0902,  0.1056, -0.0551],
         [-0.2978, -0.1315,  0.0307,  ...,  0.0893,  0.1045, -0.0551],
         [-0.2980, -0.1326,  0.0310,  ...,  0.0905,  0.1049, -0.0531],
         ...,
         [-0.2976, -0.1325,  0.0303,  ...,  0.0896,  0.1050, -0.0536],
         [-0.2970, -0.1319,  0.0312,  ...,  0.0891,  0.1049, -0.0551],
         [-0.2962, -0.1325,  0.0296,  ...,  0.0890,  0.1056, -0.0545]],

        [[-0.2957, -0.1326,  0.0064,  ...,  0.0912,  0.1377, -0.0308],
         [-0.2947, -0.1326,  0.0071,  ...,  0.0880,  0.1367, -0.0319],
         [-0.2959, -0.1317,  0.0072,  ...,  0.0897,  0.1383, -0.0318],
         ...,
         [-0.2951, -0.1347,  0.0057,  ...,  0.0880,  0.1391, -0.0294],
         [-0.2944, -0.1360,  0.0057,  ...,  0.0877,  0.1387, -0.0287],
         [-0.2929, -0.1330,  0.0068,  ...,  0.0883,  0.1396, -0.0303]],

        [[-0.3553, -0.1200,  0.0087,  ...,  0.0742,  0.1008, -0.0897],
         [-0.3562, -0.1199,  0.0106,  ...,  0

## Multi Query Attention

In [4]:
class MultiQueryAttention(nn.Module):
    def __init__(self, n_dim: int, n_heads: int):
        super(MultiQueryAttention, self).__init__()
        self.n_dim = n_dim
        self.n_heads = n_heads
        self.head_dim = self.n_dim // self.n_heads
        self.q_linear = nn.Linear(self.n_dim, self.n_dim)
        self.k_linear = nn.Linear(self.n_dim, self.head_dim) # different with MHA
        self.v_linear = nn.Linear(self.n_dim, self.head_dim) # different with MHA
        self.out_linear = nn.Linear(self.n_dim, self.n_dim)
    
    def split_head(self, x: torch.Tensor, n: int = None):
        bs = x.size()[0]
        if n is not None:
            return x.view(bs, -1, n, self.head_dim).transpose(1, 2)
        else:
            return x.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
    
    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
        bs = x.size()[0]
        q, k, v = self.q_linear(x), self.k_linear(x), self.v_linear(x)
        q, k, v = self.split_head(q), self.split_head(k, 1), self.split_head(v, 1)
        k = k.expand(-1, self.n_heads, -1, -1)
        v = v.expand(-1, self.n_heads, -1, -1)
        scores = torch.matmul(q, k.transpose(-1, -2)) / self.head_dim ** 0.5
        if attention_mask is not None:
            scores.masked_fill_(attention_mask == 0, float('-inf'))
        scores = torch.softmax(scores, dim=-1)
        out = torch.matmul(scores, v)
        out = out.transpose(1, 2).contiguous().view(bs, -1, self.n_dim)
        out = self.out_linear(out)
        return out

In [5]:
mqa = MultiQueryAttention(32, 4)
input = torch.rand((32, 128, 32))
out = mqa(input)
out

tensor([[[ 0.1611, -0.2697,  0.1478,  ..., -0.1122,  0.0269, -0.0183],
         [ 0.1612, -0.2703,  0.1490,  ..., -0.1133,  0.0281, -0.0196],
         [ 0.1592, -0.2689,  0.1486,  ..., -0.1113,  0.0278, -0.0192],
         ...,
         [ 0.1600, -0.2695,  0.1490,  ..., -0.1119,  0.0273, -0.0190],
         [ 0.1612, -0.2701,  0.1485,  ..., -0.1125,  0.0271, -0.0187],
         [ 0.1608, -0.2704,  0.1494,  ..., -0.1120,  0.0258, -0.0192]],

        [[ 0.1663, -0.2458,  0.1352,  ..., -0.1233,  0.0326, -0.0444],
         [ 0.1665, -0.2458,  0.1347,  ..., -0.1230,  0.0330, -0.0437],
         [ 0.1661, -0.2460,  0.1356,  ..., -0.1226,  0.0324, -0.0449],
         ...,
         [ 0.1664, -0.2463,  0.1354,  ..., -0.1241,  0.0328, -0.0446],
         [ 0.1666, -0.2458,  0.1357,  ..., -0.1233,  0.0326, -0.0452],
         [ 0.1660, -0.2450,  0.1344,  ..., -0.1229,  0.0324, -0.0456]],

        [[ 0.1519, -0.2341,  0.1297,  ..., -0.1129,  0.0252, -0.0425],
         [ 0.1524, -0.2349,  0.1302,  ..., -0

## Group Query Attention

In [6]:
class GroupQueryAttention(nn.Module):
    def __init__(self, n_dim: int, n_heads: int, n_groups: int):
        super(GroupQueryAttention, self).__init__()
        self.n_dim = n_dim
        self.n_heads = n_heads
        self.head_dim = self.n_dim // self.n_heads
        self.n_groups = n_groups
        self.n_rep = self.n_heads // self.n_groups
        self.q_linear = nn.Linear(self.n_dim, self.n_dim)
        self.k_linear = nn.Linear(self.n_dim, self.n_groups * self.head_dim)
        self.v_linear = nn.Linear(self.n_dim, self.n_groups * self.head_dim)
        self.out_linear = nn.Linear(self.n_dim, self.n_dim)
    
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        bs, seq_len, kv_heads, head_dim = x.size()
        if n_rep == 1:
            return x
        else:
            return (
                x[:, :, :, None, :]
                .expand(bs, seq_len, kv_heads, n_rep, head_dim)
                .contiguous()
                .view(bs, seq_len, kv_heads * n_rep, head_dim)
            )
    
    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
        bs, seq_len, n_dim = x.size()
        xq, xk, xv = self.q_linear(x), self.k_linear(x), self.v_linear(x)
        xq = xq.view(bs, -1, self.n_heads, self.head_dim)
        xk, xv = xk.view(bs, -1, self.n_groups, self.head_dim), xv.view(bs, -1, self.n_groups, self.head_dim)
        
        xq = xq.transpose(1, 2)
        xk, xv = self.repeat_kv(xk, self.n_rep), self.repeat_kv(xv, self.n_rep)
        xk, xv = xk.transpose(1, 2), xv.transpose(1, 2)
        scores = torch.matmul(xq, xk.transpose(-1, -2)) / self.head_dim ** 0.5
        if attention_mask is not None:
            scores.masked_fill_(attention_mask == 0, float('-inf'))
        scores = torch.softmax(scores, dim=-1)
        out = torch.matmul(scores, xv)
        out = out.transpose(1, 2).contiguous().view(bs, seq_len, n_dim)
        out = self.out_linear(out)
        return out
        

In [7]:
gqa = GroupQueryAttention(32, 8, 4)
input = torch.rand((32, 128, 32))
out = gqa(input)
out

tensor([[[-0.3260,  0.0513,  0.1795,  ...,  0.0380, -0.2347, -0.2409],
         [-0.3239,  0.0509,  0.1806,  ...,  0.0374, -0.2339, -0.2407],
         [-0.3254,  0.0503,  0.1798,  ...,  0.0387, -0.2335, -0.2408],
         ...,
         [-0.3251,  0.0519,  0.1801,  ...,  0.0378, -0.2345, -0.2399],
         [-0.3257,  0.0500,  0.1795,  ...,  0.0373, -0.2338, -0.2403],
         [-0.3252,  0.0513,  0.1803,  ...,  0.0380, -0.2345, -0.2408]],

        [[-0.3143,  0.0495,  0.2043,  ...,  0.0436, -0.2519, -0.2410],
         [-0.3128,  0.0491,  0.2047,  ...,  0.0431, -0.2526, -0.2423],
         [-0.3124,  0.0496,  0.2054,  ...,  0.0433, -0.2519, -0.2411],
         ...,
         [-0.3127,  0.0495,  0.2062,  ...,  0.0436, -0.2523, -0.2411],
         [-0.3135,  0.0493,  0.2042,  ...,  0.0435, -0.2527, -0.2418],
         [-0.3135,  0.0498,  0.2055,  ...,  0.0431, -0.2530, -0.2413]],

        [[-0.3214,  0.0605,  0.2003,  ...,  0.0362, -0.2619, -0.2429],
         [-0.3227,  0.0615,  0.2004,  ...,  0

**Why MQA/GQA accelerate inference?**

Reduce KV Cache memory usage and the need to load more keys and values from kv cache.