In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else "cpu"

In [2]:
device

'cuda'

In [3]:
query, key, value = torch.randn(2,3,8, device = device), torch.randn(2,3,8, device = device), torch.randn(2,3,8, device = device)

In [4]:
F.scaled_dot_product_attention(query, key, value)

tensor([[[-1.3102,  0.3511, -0.0225,  0.0732,  1.0476,  0.0821, -0.1841,
          -0.7381],
         [-1.4471,  0.3482,  0.0497,  0.1377,  1.0303,  0.0467, -0.3995,
          -0.7993],
         [-0.9609,  0.0229, -0.9022,  0.2971,  0.9908, -0.3394,  0.3428,
          -0.2445]],

        [[ 0.1529,  0.4307,  0.3648,  0.7064, -0.2333, -0.0723, -0.3468,
           0.0796],
         [ 0.3643,  0.3583, -0.3079,  0.4511, -0.1409, -0.3006, -0.7080,
           0.5138],
         [ 0.2387,  0.3994,  0.2214,  0.8081, -0.2716, -0.2258, -0.4523,
           0.0724]]], device='cuda:0')

In [5]:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt = "f(*args, **kwargs)", globals = {"args" : args, "kwargs" : kwargs, 'f' : f}
    )
    return t0.blocked_autorange().mean * 1e6
    

In [6]:
batch_size = 16
embed_dimension = 32
max_sequence_len = 1024
num_heads = 32
dtype = torch.float32

In [7]:
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

In [8]:
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

The default implementation runs in 274997.736 microseconds


In [9]:

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import sdp_kernel, SDPBackend

# Helpful arguments mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")


with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 1.96 GiB total capacity; 328.13 MiB already allocated; 1.58 GiB free; 342.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [10]:
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads : int, embed_dimension:int, bias:bool=False, is_causal:bool=False, dropout:float=0.0 ):
        
        super().__init__()
        
        assert embed_dimension % num_heads == 0
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias = bias)
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        
        self.is_causal = is_causal
        
    def forward():
        
        query_projected = self.attn(x)
        
        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads *3)
        
        
        query, key, value = query_projected.chunk(3,1)
        query = query.view(batch_size, -1, sel.num_heads, head_dim).transpose(1,2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1,2)
        
        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False
        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p= dropout, is_causal=is_causal)
        y = y.transpose(1,2).view(batch_size, -1, self.num_heads*head_dim)
        
        y =self.resid_dropou(self.c_proj(y))
        
        return y

In [11]:
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)

CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)
