In [1]:
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
torch.manual_seed(42)  
sample_rng = torch.Generator()
sample_rng.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [3]:
batch_size = 32
sequence_length = 1024
hidden_dim = 512
num_heads = 8
head_dim = hidden_dim // num_heads

n_iter = 5

## CPU implementation

In [4]:
x = torch.randn(batch_size, sequence_length, hidden_dim, generator=sample_rng)

dt = 0
for i in range(n_iter):
    t0 = time.time()
    wq = torch.nn.Linear(hidden_dim, hidden_dim)
    wk = torch.nn.Linear(hidden_dim, hidden_dim)
    wv = torch.nn.Linear(hidden_dim, hidden_dim)
    wo = torch.nn.Linear(hidden_dim, hidden_dim)

    with torch.no_grad(): 

        q = wq(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        k = wk(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        v = wv(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        attn = F.softmax(attn, dim=-1)
        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(batch_size, sequence_length, hidden_dim)
        manual_output = wo(y)
    t1 = time.time()
    dt += (t1 - t0)*1000
print(f"CPU from scratch implementation: {dt/n_iter:.2f}ms")


CPU from scratch implementation: 752.25ms


In [5]:
x = torch.randn(batch_size, sequence_length, hidden_dim, generator=sample_rng)

#warmup
multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=0.0)
attn_output, attn_output_weights = multihead_attn(x, x, x)

dt = 0
for i in range(n_iter):
    
    t0 = time.time()
    multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=0.0)
    attn_output, attn_output_weights = multihead_attn(x, x, x)
    t1 = time.time()
    dt += (t1 - t0)*1000   
print(f"CPU pytorch implementation: {dt/n_iter:.2f}ms")

CPU pytorch implementation: 701.31ms


## GPU implementation

In [6]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
x = torch.randn(batch_size, sequence_length, hidden_dim, device=device)

dt = 0
for i in range(n_iter):
    torch.cuda.synchronize()
    t0 = time.time()
    wq = torch.nn.Linear(hidden_dim, hidden_dim, device=device)
    wk = torch.nn.Linear(hidden_dim, hidden_dim, device=device)
    wv = torch.nn.Linear(hidden_dim, hidden_dim, device=device)
    wo = torch.nn.Linear(hidden_dim, hidden_dim, device=device)

    with torch.no_grad(): 

        q = wq(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        k = wk(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        v = wv(x).view(batch_size, sequence_length, num_heads, head_dim).transpose(1,2)
        
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        attn = F.softmax(attn, dim=-1)
        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(batch_size, sequence_length, hidden_dim)
        manual_output = wo(y)
    torch.cuda.synchronize()
    t1 = time.time()
    dt += (t1 - t0)*1000
print(f"GPU from scratch implementation: {dt/n_iter:.2f}ms")


GPU from scratch implementation: 60.79ms


In [9]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
x = torch.randn(batch_size, sequence_length, hidden_dim, device=device)

#warmup
multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=0.0, device=device)
attn_output, attn_output_weights = multihead_attn(x, x, x)

dt = 0
for i in range(n_iter):
    torch.cuda.synchronize()
    t0 = time.time()
    multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=0.0, device=device)
    attn_output, attn_output_weights = multihead_attn(x, x, x)
    torch.cuda.synchronize()
    t1 = time.time()
    dt += (t1 - t0)*1000   
print(f"GPU pytorch implementation: {dt/n_iter:.2f}ms")


GPU pytorch implementation: 15.98ms
