In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device='cuda' if torch.cuda.is_available() else 'cpu'

embed_size=256
dropout=0
block_size=1024

In [2]:
class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.query=nn.Linear(embed_size, head_size, bias=False)
    self.key=nn.Linear(embed_size, head_size, bias=False)
    self.value=nn.Linear(embed_size, head_size, bias=False)
    self.dropout=nn.Dropout(dropout)

    #caching:
    self.k_cache=None
    self.v_cache=None
    self.cache_index=0

  def forward(self, x):
    B, T, C=x.shape
    q=self.query(x)
    k=self.key(x)
    v=self.value(x)

    #initializing KV Caching if empty:
    if self.k_cache is None or self.v_cache is None:
      #block size is the context length(for prediction of next token, the max no of tokens, current token can look back up to ) and it will be the max no of tokens we will be caching
      self.k_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
      self.v_cache=torch.zeros(B, block_size, self.head_size, device=x.device)
      self.cache_index=0

    #Updating the cache:
    if self.cache_index + T <=block_size:
      self.k_cache[:, self.cache_index:self.cache_index+T, :]=k
      self.v_cache[:, self.cache_index:self.cache_index+T, :]=v

      #Note: here we pass only one token at a time, T will always be 1
      #the code above are equivalent to performing self.k_cache[:, self.cache_index, :]=k
    else:
      #shifting one token back:
      shift= self.cache_index + T - block_size #shift is always 1
      self.k_cache[:, :-shift, :]=self.k_cache[:, shift:, :].clone()
      self.v_cache[:, :-shift, :]=self.v_cache[:, shift:, :].clone()

      #store new values:
      self.k_cache[:, -T:, :]=k
      self.v_cache[:, -T, :]=v

    #Updating Cache index:
    self.cache_index=min(self.cache_index + T, block_size)

    wei= q @ self.k_cache.transpose(2, 1) / self.head_size **0.5
    #masking is necessary if using multiple tokens decoding at once
    wei=wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
    wei=F.softmax(wei, dim=2)
    wei=self.dropout(wei)
    out=wei @ self.v_cache

    return out

In [5]:
k_cache = torch.zeros(1, 3, 3)
v_cache = torch.zeros(1, 3, 3)

steps = 3
for i in range(steps):
  k_cache[:, i, :] = torch.randint(10, (1, 3))
print("k_cache Before:\n", k_cache)

shift = 1
k_cache[:, :-shift, :] = k_cache[:, shift:, :].clone()
v_cache[:, :-shift, :] = v_cache[:, shift:, :].clone()
print("k_cache After:\n", k_cache)

k_cache Before:
 tensor([[[2., 2., 5.],
         [1., 0., 3.],
         [8., 5., 2.]]])
k_cache After:
 tensor([[[1., 0., 3.],
         [8., 5., 2.],
         [8., 5., 2.]]])
