In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else "cpu"

In [3]:
device

'cpu'

In [4]:
query, key, value = torch.randn(2,3,8, device = device), torch.randn(2,3,8, device = device), torch.randn(2,3,8, device = device)

In [5]:
F.scaled_dot_product_attention(query, key, value)

tensor([[[ 0.0738,  0.7135, -0.4148,  0.3798,  0.2800, -0.5711, -1.0276,
           0.2604],
         [-0.0544,  0.4186, -0.0078,  0.4152,  0.2415, -0.3661, -0.7775,
           0.2893],
         [-0.5978, -0.6032,  1.9235,  0.5572,  0.0575,  0.4908,  0.2839,
           0.2406]],

        [[ 0.0606, -0.0222,  0.3741,  0.6267,  0.9806,  0.6901, -0.0340,
           0.5006],
         [ 0.3133, -0.2853,  0.1878,  0.5794,  1.2674, -0.1754, -0.2621,
           0.7463],
         [-0.4307,  0.3035,  0.4884,  0.4645,  0.6270,  1.3981,  0.4026,
          -0.0508]]])

In [6]:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt = "f(*args, **kwargs)", globals = {"args" : args, "kwargs" : kwargs, 'f' : f}
    )
    return t0.blocked_autorange().mean * 1e6
    

In [10]:
batch_size = 16
embed_dimension = 32
max_sequence_len = 1024
num_heads = 32
dtype = torch.float32

In [11]:
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

In [12]:
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

The default implementation runs in 3781691.513 microseconds


In [13]:

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import sdp_kernel, SDPBackend

# Helpful arguments mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")


with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

The math implementation runs in 3695212.725 microseconds
The flash attention implementation runs in 3720808.256 microseconds
The memory efficient implementation runs in 3701618.127 microseconds


In [20]:
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads : int, embed_dimension:int, bias:bool=False, is_causal:bool=False, dropout:float=0.0 ):
        
        super().__init__()
        
        assert embed_dimension % num_heads == 0
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias = bias)
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        
        self.is_causal = is_causal
        
    def forward():
        
        query_projected = self.attn(x)
        
        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads *3)
        
        
        query, key, value = query_projected.chunk(3,1)
        query = query.view(batch_size, -1, sel.num_heads, head_dim).transpose(1,2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1,2)
        
        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False
        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p= dropout, is_causal=is_causal)
        y = y.transpose(1,2).view(batch_size, -1, self.num_heads*head_dim)
        
        y =self.resid_dropou(self.c_proj(y))
        
        return y

In [21]:
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx