# MHA with KV cache 

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
import math

## Multi head attention without KV Caching and Mask

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads 
        self.d_kv = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):

        batch, seq, dim = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # each head token is attended to
        Q = Q.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        K = K.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        V = V.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_kv)
        attn_weights = F.softmax(scores, dim=-1)
        attn_outputs = torch.matmul(attn_weights , V)

        attn_outputs = attn_outputs.transpose(1, 2).contiguous()
        attn_outputs = attn_outputs.view(batch, seq, self.d_model)

        output = self.W_o(attn_outputs)
        return output



## MHA with KV caching

In [6]:
class MultiHeadAttention_KVCache(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads 
        self.d_kv = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None, use_cache=None):
        """
        x : batch, seq, d_model
        kv_cache : dict with K, V keys 
        """

        batch, seq, dim = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # each head token is attended to
        Q = Q.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        K = K.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        V = V.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 

        if kv_cache is not None:
            print( f"Using KV cache: "
                f"K cache shape {kv_cache['k'].shape}, "
                f"V cache shape {kv_cache['v'].shape}")
            
            # appends newly computed keys and values with previous cached ones
            K = torch.concat([kv_cache["k"], K], dim=2)
            V = torch.concat([kv_cache["v"], V], dim=2)
        
        if use_cache: 
            new_kv_cache = {
                # removes K and V from computation graph
                "k" : K.detach(),
                "v" : V.detach()
            }
        else: 
            new_kv_cache = None
        
        print(f"Q shape: {Q.shape}")
        print(f"K shape (after cache): {K.shape}")
        print(f"V shape (after cache): {V.shape}")

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_kv)
        attn_weights = F.softmax(scores, dim=-1)
        attn_outputs = torch.matmul(attn_weights , V)

        attn_outputs = attn_outputs.transpose(1, 2).contiguous()
        attn_outputs = attn_outputs.view(batch, seq, self.d_model)

        output = self.W_o(attn_outputs)
        return output, new_kv_cache

In [13]:
batch = 2 
num_heads = 8 
seq = 10 
d_model = 64 

mha_kv = MultiHeadAttention_KVCache(d_model, num_heads)

kv_cache = None 

x = torch.randn(batch, seq, d_model)

# training
# for i in range(5):
#     print(f"\n--------- step {i} -------------")
#     output , kv_cache = mha_kv(x, kv_cache=None, use_cache=False)

# inference
for i in range(5):
    print(f"\n--------- step {i} -------------")
    x_token = x[:, i:i+1, :]   # shape: (batch, 1, d_model)
    output , kv_cache = mha_kv(x_token, kv_cache=kv_cache, use_cache=True)



--------- step 0 -------------
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 1, 8])
V shape (after cache): torch.Size([2, 8, 1, 8])

--------- step 1 -------------
Using KV cache: K cache shape torch.Size([2, 8, 1, 8]), V cache shape torch.Size([2, 8, 1, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 2, 8])
V shape (after cache): torch.Size([2, 8, 2, 8])

--------- step 2 -------------
Using KV cache: K cache shape torch.Size([2, 8, 2, 8]), V cache shape torch.Size([2, 8, 2, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 3, 8])
V shape (after cache): torch.Size([2, 8, 3, 8])

--------- step 3 -------------
Using KV cache: K cache shape torch.Size([2, 8, 3, 8]), V cache shape torch.Size([2, 8, 3, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 4, 8])
V shape (after cache): torch.Size([2, 8, 4, 8])

--------- step 4 -------------
Using KV cache: K cache shape torch

In [None]:
def calculate_kv_cache_size(kv_cache):
    k = kv_cache["k"]
    v = kv_cache["v"]

    precision = k.element_size()
    print(f"precision : {precision}")
    
    total_bytes = (k.numel() + v.numel()) * precision
    print(f"k numel : {k.numel()}\nv numel : {v.numel()}")
    return total_bytes

size_bytes = calculate_kv_cache_size(kv_cache)
print(f"raw byte size : {size_bytes}")
size_mb = size_bytes / (1024**2)
print(f"size in mb : {size_mb}")

precision : 4
k numel : 640
v numel : 640
raw byte size : 5120
size in mb : 0.0048828125


## with causal masking

In [7]:
class MultiHeadAttention_KVCache_Masked(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads 
        self.d_kv = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None, use_cache=None):
        """
        x : batch, seq, d_model
        kv_cache : dict with K, V keys 
        use_cache : return kv cache 
        """

        batch, seq, dim = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # each head token is attended to
        Q = Q.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        K = K.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 
        V = V.view(batch, seq, self.num_heads, self.d_kv).transpose(1,2) 

        past_len = 0
        if kv_cache is not None:
            print( f"Using KV cache: "
                f"K cache shape {kv_cache['k'].shape}, "
                f"V cache shape {kv_cache['v'].shape}")
            past_len = kv_cache["k"].shape[2]
            # appends newly computed keys and values with previous cached ones
            K = torch.concat([kv_cache["k"], K], dim=2)
            V = torch.concat([kv_cache["v"], V], dim=2)
        
        if use_cache: 
            new_kv_cache = {
                # removes K and V from computation graph
                "k" : K.detach(),
                "v" : V.detach()
            }
        else: 
            new_kv_cache = None
        
        # causal masking 
        # q len = seq, K length = past + seq 
        # Keys (K) include both previously cached tokens (past_len) and the newly processed tokens (seq).
        total_len = past_len + seq 

        # causal_mask = torch.triu(
        #         torch.ones(seq, total_len, device=x.device, dtype=torch.bool),
        #         diagonal=past_len + 1
        #     )
        # # -inf where True
        # attn_mask = torch.zeros(seq, total_len, device=x.device)
        # attn_mask.masked_fill_(causal_mask, float('-inf'))
        # # for brodcasting -> batch , heads, seq, total_len
        # attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)

        print(f"Q shape: {Q.shape}")
        print(f"K shape (after cache): {K.shape}")
        print(f"V shape (after cache): {V.shape}")

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_kv)
        # all masking here
        if seq > 1:
            scores = scores.masked_fill(torch.triu(torch.ones(seq, total_len, dtype=torch.bool), diagonal=past_len+1), float("-inf"))
        attn_weights = F.softmax(scores, dim=-1)
        attn_outputs = torch.matmul(attn_weights , V)

        attn_outputs = attn_outputs.transpose(1, 2).contiguous()
        attn_outputs = attn_outputs.view(batch, seq, self.d_model)

        output = self.W_o(attn_outputs)
        return output, new_kv_cache

In [8]:
batch = 2 
num_heads = 8 
seq = 10 
d_model = 64 

mha_kv = MultiHeadAttention_KVCache_Masked(d_model, num_heads)

kv_cache = None 

x = torch.randn(batch, seq, d_model)

# training
# for i in range(5):
#     print(f"\n--------- step {i} -------------")
#     output , kv_cache = mha_kv(x, kv_cache=None, use_cache=False)

# inference
for i in range(5):
    print(f"\n--------- step {i} -------------")
    x_token = x[:, i:i+1, :]   # shape: (batch, 1, d_model)
    output , kv_cache = mha_kv(x_token, kv_cache=kv_cache, use_cache=True)



--------- step 0 -------------
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 1, 8])
V shape (after cache): torch.Size([2, 8, 1, 8])

--------- step 1 -------------
Using KV cache: K cache shape torch.Size([2, 8, 1, 8]), V cache shape torch.Size([2, 8, 1, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 2, 8])
V shape (after cache): torch.Size([2, 8, 2, 8])

--------- step 2 -------------
Using KV cache: K cache shape torch.Size([2, 8, 2, 8]), V cache shape torch.Size([2, 8, 2, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 3, 8])
V shape (after cache): torch.Size([2, 8, 3, 8])

--------- step 3 -------------
Using KV cache: K cache shape torch.Size([2, 8, 3, 8]), V cache shape torch.Size([2, 8, 3, 8])
Q shape: torch.Size([2, 8, 1, 8])
K shape (after cache): torch.Size([2, 8, 4, 8])
V shape (after cache): torch.Size([2, 8, 4, 8])

--------- step 4 -------------
Using KV cache: K cache shape torch

In [5]:
def calculate_kv_cache_size(kv_cache):
    k = kv_cache["k"]
    v = kv_cache["v"]

    precision = k.element_size()
    print(f"precision : {precision}")
    
    total_bytes = (k.numel() + v.numel()) * precision
    print(f"k numel : {k.numel()}\nv numel : {v.numel()}")
    return total_bytes

size_bytes = calculate_kv_cache_size(kv_cache)
print(f"raw byte size : {size_bytes}")
size_mb = size_bytes / (1024**2)
print(f"size in mb : {size_mb}")

precision : 4
k numel : 640
v numel : 640
raw byte size : 5120
size in mb : 0.0048828125


### what masking does

In [12]:
randomM = torch.randn(5,5)
mask = torch.triu(torch.ones(5, 5, dtype=torch.bool), diagonal=1)
print(mask)
masked_rM = randomM.masked_fill(mask, float("-inf"))
print(masked_rM)

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])
tensor([[ 0.3363,    -inf,    -inf,    -inf,    -inf],
        [-0.0253,  0.9355,    -inf,    -inf,    -inf],
        [ 0.8504, -1.1353,  0.6112,    -inf,    -inf],
        [ 1.3716, -1.2834, -0.2104,  0.7902,    -inf],
        [ 1.2539,  0.3034, -0.9586,  1.9031,  0.3068]])
