In [2]:
import torch
import torch.nn as nn

class HeadInferTransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, device):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.device = device

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        self.W_q.to(device)
        self.W_k.to(device)
        self.W_v.to(device)
        self.W_o.to(device)

        self.num_heads_on_gpu = num_heads

    def forward(self, x, kv_cache):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        all_head_outputs = []
        for h in range(self.num_heads):
            if h < self.num_heads_on_gpu:

                K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)[:, h:h+1, :, :]
                V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)[:, h:h+1, :, :]
                K_cache = kv_cache['K'][h].to(self.device)
                V_cache = kv_cache['V'][h].to(self.device)
                K = torch.cat([K_cache, K], dim=2)
                V = torch.cat([V_cache, V], dim=2)
                kv_cache['K'][h] = K.cpu()
                kv_cache['V'][h] = V.cpu()
                attention_scores = torch.matmul(Q[:, h:h+1, :, :], K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
                attention_probs = self.softmax(attention_scores)
                head_output = torch.matmul(attention_probs, V)

            else:
                K_cache = kv_cache['K'][h].to(self.device)
                V_cache = kv_cache['V'][h].to(self.device)
                K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)[:, h:h+1, :, :].cpu()
                V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)[:, h:h+1, :, :].cpu()
                K = torch.cat([K_cache, K], dim=2)
                V = torch.cat([V_cache, V], dim=2)
                kv_cache['K'][h] = K.cpu()
                kv_cache['V'][h] = V.cpu()
                K = K.to(self.device)
                V = V.to(self.device)
                attention_scores = torch.matmul(Q[:, h:h+1, :, :], K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
                attention_probs = self.softmax(attention_scores)
                head_output = torch.matmul(attention_probs, V)
            all_head_outputs.append(head_output)
        all_head_outputs = torch.cat(all_head_outputs, dim=1)
        output = self.W_o(all_head_outputs.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model))
        return output

def initialize_kv_cache(batch_size, num_heads, seq_len, head_dim, device='cpu'):
  kv_cache = {'K': [], 'V': []}
  for _ in range(num_heads):
    kv_cache['K'].append(torch.zeros(batch_size, 1, 0, head_dim, device=device))
    kv_cache['V'].append(torch.zeros(batch_size, 1, 0, head_dim, device=device))
  return kv_cache

if __name__ == '__main__':
    batch_size = 2
    seq_len = 32
    d_model = 256
    num_heads = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    layer = HeadInferTransformerLayer(d_model, num_heads, device)
    layer.to(device)
    x = torch.randn(batch_size, seq_len, d_model).to(device)

    kv_cache = initialize_kv_cache(batch_size, num_heads, seq_len, d_model // num_heads)
    output = layer(x, kv_cache)

    print("Output shape:", output.shape)

Output shape: torch.Size([2, 32, 256])
