In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

import os
import sys
sys.path.append("../")

import matplotlib.pyplot as plt

from dataclasses import dataclass, field

print(f"Torch version: {torch.__version__}")
os.system("nvidia-smi")

Torch version: 2.0.1+cu117
Fri Jul 28 15:22:19 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   33C    P0    51W / 300W |    511MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   34C    P0    50W / 300W |   5835MiB / 16384MiB |     

0

In [2]:
@dataclass
class CFG:
    embedding_dim: int = field(default=256)
    head_dim: int = field(default=64)

In [3]:
base_cfg = CFG()
embeddim = base_cfg.embedding_dim
headdim =  base_cfg.head_dim

tokens = torch.randn(1, 5, embeddim)
Q_latent = torch.randn(embeddim, headdim) / math.sqrt(embeddim)
K_latent = torch.randn(embeddim, headdim) / math.sqrt(embeddim)
V_latent = torch.randn(embeddim, embeddim) / math.sqrt(embeddim)

In [4]:
Q = torch.einsum("BTE, EH -> BTH", tokens, Q_latent)
K = torch.einsum("BTE, EH -> BTH", tokens, K_latent)
V = torch.einsum("BTE, EF -> BTF", tokens, V_latent)
scores = torch.einsum("BTH, BSH -> BTS", Q, K)
attn = torch.nn.functional.softmax(scores / math.sqrt(headdim), dim=-1)

result = torch.einsum("BST, BTF -> BSF", attn, V)

### Same is done in PyTorch wrapper around CUDA kernel:

In [5]:
import torch.nn.functional as F

attn_torch = F.scaled_dot_product_attention(Q, K, V)
torch.allclose(attn_torch, result, atol=1E-6, rtol=1E-6)

True

In [6]:
attn_torch.shape

torch.Size([1, 5, 256])

### Doing MHA

In [14]:
embedding_dim = 768
num_heads = 12 # 768 // 12 = 64 - each head operating on 64-dim vector space
headdim = embedding_dim // num_heads

tokens = torch.randn(1, 5, embedding_dim)
Q_latent = torch.randn(embedding_dim, headdim * num_heads) / math.sqrt(embedding_dim)
K_latent = torch.randn(embedding_dim, headdim * num_heads) / math.sqrt(embedding_dim)
V_latent = torch.randn(embedding_dim, headdim * num_heads) / math.sqrt(embedding_dim)

In [9]:
!pip install einops -qq

In [26]:
import einops

Q = torch.einsum("BTE, EH -> BTH", tokens, Q_latent)
K = torch.einsum("BTE, EH -> BTH", tokens, K_latent)
V = torch.einsum("BTE, EF -> BTF", tokens, V_latent)

Q_mh = einops.rearrange(Q, "B T (heads headdim) -> B T heads headdim", heads=num_heads, headdim=headdim)
K_mh = einops.rearrange(K, "B T (heads headdim) -> B T heads headdim", heads=num_heads, headdim=headdim)
V_mh = einops.rearrange(V, "B T (heads headdim) -> B T heads headdim", heads=num_heads, headdim=headdim)

scores_mh = torch.einsum("BTHD, BSHD -> BHTS", Q_mh, K_mh)
attmath_mh = F.softmax(scores_mh / math.sqrt(headdim), dim=-1)
result = torch.einsum("BHTS, BTHD -> BHTD", attmath_mh, V_mh) # B heads tokens embed_dim (for each head)
result = einops.rearrange(result, "B H T D -> B T (H D)")

In [27]:
result.shape

torch.Size([1, 5, 768])

### Torch version of MHA

In [40]:
mha = nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)

In [36]:
mha.in_proj_weight

torch.Size([2304, 768])

In [49]:
attn_mask = -1E-4 * torch.triu(torch.ones(tokens.shape[1], tokens.shape[1]), 1)
print(attn_mask)

tensor([[-0.0000e+00, -1.0000e-04, -1.0000e-04, -1.0000e-04, -1.0000e-04],
        [-0.0000e+00, -0.0000e+00, -1.0000e-04, -1.0000e-04, -1.0000e-04],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e-04, -1.0000e-04],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e-04],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]])


In [50]:
mha(tokens, tokens, tokens, attn_mask=attn_mask)

(tensor([[[ 0.1252, -0.2225, -0.0994,  ..., -0.2800, -0.0930, -0.3960],
          [ 0.1050, -0.1977, -0.1301,  ..., -0.3278,  0.0305, -0.3866],
          [ 0.1150, -0.1873, -0.1067,  ..., -0.1942,  0.0123, -0.2534],
          [ 0.2343, -0.2250, -0.1684,  ..., -0.3058, -0.1578, -0.2287],
          [ 0.2323, -0.2044, -0.1477,  ..., -0.3364, -0.1396, -0.2945]]],
        grad_fn=<TransposeBackward0>),
 tensor([[[0.2313, 0.1711, 0.1839, 0.1844, 0.2293],
          [0.1671, 0.1977, 0.1935, 0.1976, 0.2442],
          [0.1922, 0.1974, 0.1754, 0.1853, 0.2497],
          [0.1930, 0.2019, 0.2351, 0.1927, 0.1773],
          [0.1794, 0.2096, 0.2178, 0.1780, 0.2152]]], grad_fn=<MeanBackward1>))