In [71]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

# torch.manual_seed(0)

from flash_attn import flash_attn_func

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
device = 'cuda:0'
dtype=torch.float16
embed_dim = 4
num_heads = 2
seq_len = 4

In [2]:

# unpacked version of flash attn (q, k, v)
q = torch.rand(1, 512, 30, 30, dtype=dtype).to(device)
dp = 0.1

s = time.time()
out = flash_attn_func(q, q, q, dropout_p = dp)
e = time.time()
e - s

0.0006763935089111328

In [3]:
k, v = q, q

s = time.time()
# similarity
sim = q @ k

# attention
attn = sim.softmax(dim=-1)

# aggregate values
out = attn @ v

e = time.time()
e - s

0.8314464092254639

In [67]:
def multiply_by_ychunks(x, y, chunks=1):
    if chunks <= 1:
        return x @ y
    else:
        return torch.cat([x @ _y for _y in y.chunk(chunks, dim=-1)], dim=-1)


def multiply_by_xchunks(x, y, chunks=1):
    if chunks <= 1:
        return x @ y
    else:
        return torch.cat([_x @ y for _x in x.chunk(chunks, dim=-2)], dim=-2)
    
class MultiheadAttention(nn.Module):
    def __init__(self,
                 d_model,
                 num_head=8,
                 dropout=0.,
                 use_linear=True,
                 d_att=None,
                 use_dis=False,
                 qk_chunks=1,
                 max_mem_len_ratio=-1,
                 top_k=-1):
        super().__init__()
        self.d_model = d_model
        self.num_head = num_head
        self.use_dis = use_dis
        self.qk_chunks = qk_chunks
        self.max_mem_len_ratio = float(max_mem_len_ratio)
        self.top_k = top_k

        self.hidden_dim = d_model // num_head
        self.d_att = self.hidden_dim if d_att is None else d_att
        self.T = self.d_att**0.5
        self.use_linear = use_linear

        if use_linear:
            self.linear_Q = nn.Linear(d_model, d_model)
            self.linear_K = nn.Linear(d_model, d_model)
            self.linear_V = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.drop_prob = dropout
        self.projection = nn.Linear(d_model, d_model)
        self._init_weight()

    def forward(self, Q, K, V):
        """
        :param Q: A 3d tensor with shape of [T_q, bs, C_q]
        :param K: A 3d tensor with shape of [T_k, bs, C_k]
        :param V: A 3d tensor with shape of [T_v, bs, C_v]
        """
        # print("MultiheadAttention:", Q.shape, K.shape, V.shape)
        num_head = self.num_head
        hidden_dim = self.hidden_dim

        bs = Q.size()[1]

        # Linear projections
        if self.use_linear:
            Q = self.linear_Q(Q)
            K = self.linear_K(K)
            V = self.linear_V(V)

        # Scale
        Q = Q / self.T

        if not self.training and self.max_mem_len_ratio > 0:
            mem_len_ratio = float(K.size(0)) / Q.size(0)
            if mem_len_ratio > self.max_mem_len_ratio:
                scaling_ratio = math.log(mem_len_ratio) / math.log(
                    self.max_mem_len_ratio)
                Q = Q * scaling_ratio

        # Multi-head
        Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3)
        K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0)
        V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3)
        # print(Q.shape, K.shape, V.shape)

        # Multiplication
        QK = multiply_by_ychunks(Q, K, self.qk_chunks)
        if self.use_dis:
            QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True)

        # Activation
        if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]:
            top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1)
            top_attn = torch.softmax(top_QK, dim=-1)
            attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn)
        else:
            attn = torch.softmax(QK, dim=-1)

        # Dropouts
        attn = self.dropout(attn)

        # Weighted sum
        outputs = multiply_by_xchunks(attn, V,
                                      self.qk_chunks).permute(2, 0, 1, 3)

        # Restore shape
        outputs = outputs.reshape(-1, bs, self.d_model)
        outputs = self.projection(outputs)
        # print(outputs.shape)
        return outputs, attn
    
    def _init_weight(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

## Verify the same output for Mulithead and flash

Note: we don't expect any speedup in this low dim

In [65]:
attn2 = MultiheadAttention(embed_dim, num_head = num_heads)

In [37]:
x = torch.rand(seq_len, 1, embed_dim)

In [38]:
s = time.time()
out, out_drop = attn2(x, x, x)
e = time.time()
e - s

torch.Size([1, 2, 4, 2]) torch.Size([1, 2, 2, 4]) torch.Size([1, 2, 4, 2])


0.0006425380706787109

In [39]:
out

tensor([[[-0.1673,  0.1168, -0.6121, -0.8281]],

        [[-0.1641,  0.1154, -0.6128, -0.8274]],

        [[-0.1716,  0.1183, -0.6124, -0.8301]],

        [[-0.1704,  0.1181, -0.6121, -0.8295]]], grad_fn=<ViewBackward0>)

In [23]:
# to use flash attention in the same manner as Multihead
# 1 reshape current tensor to match: bs, seq len, nheads, headdim
d_att = embed_dim // num_heads
bs = x.size()[1]

torch.Size([1, 36, 8, 16])

In [40]:
s = time.time()
d_att = embed_dim // num_heads
bs = x.size()[1]
y = x.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
out3 = flash_attn_func(y, y, y, dropout_p = dp)
out3 = out3.reshape(-1, bs, embed_dim)
e = time.time()
e - s

0.0008258819580078125

In [29]:
out3.shape

torch.Size([36, 1, 128])

In [41]:
out

tensor([[[-0.1673,  0.1168, -0.6121, -0.8281]],

        [[-0.1641,  0.1154, -0.6128, -0.8274]],

        [[-0.1716,  0.1183, -0.6124, -0.8301]],

        [[-0.1704,  0.1181, -0.6121, -0.8295]]], grad_fn=<ViewBackward0>)

## Repeat experiment but in higher dim

From earlier experiments wittnessed seq_len = 289 and embed_dim = 256

Also have seen seq len 900, embed_dim 256

In [44]:
embed_dim = 256
num_heads = 8
seq_len = 900
bs = 4

In [68]:
attn4 = MultiheadAttention(embed_dim, num_head = num_heads).to(device)

In [51]:
lq = torch.rand(seq_len, bs, embed_dim).to(device) # lq = larger q

In [56]:
s = time.time()
out4, out_drop = attn4(lq, lq, lq)
e = time.time()
e - s

torch.Size([4, 8, 900, 32]) torch.Size([4, 8, 32, 900]) torch.Size([4, 8, 900, 32])


0.0014352798461914062

In [73]:
s = time.time()
d_att = embed_dim // num_heads
bs = lq.size()[1]
y = lq.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
out5 = flash_attn_func(y, y, y, dropout_p = dp)
out5 = out5.reshape(-1, bs, embed_dim)
e = time.time()
e - s

0.0004987716674804688

## Repeat calculation 10000 times and compare results


In [60]:
num_iterations = 10000

In [74]:
total_time = 0

for _ in range(num_iterations):
    start_time = time.time()
    
    # Perform the function call
    out6, out_drop = attn4(lq, lq, lq)
    
    end_time = time.time()
    
    total_time += (end_time - start_time)

avg_runtime = total_time / num_iterations
avg_runtime

0.0010495921850204467

In [76]:
total_time = 0

for _ in range(num_iterations):
    start_time = time.time()
    
    d_att = embed_dim // num_heads
    bs = lq.size()[1]
    y = lq.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
    out7 = flash_attn_func(y, y, y, dropout_p = dp)
    out7 = out7.reshape(-1, bs, embed_dim)
    
    end_time = time.time()
    
    total_time += (end_time - start_time)

avg_runtime = total_time / num_iterations
avg_runtime

7.724754810333252e-05

 ## Increasing realism of experiments

making it more realistic I'd have 3 seperate qkv for each (not sure how cuda caches computations, but might be artifiically speeding up)

Note: turned off manual seed

Notice that reshapping all 3 takes significantly more time, just barely faster than multiheadd attention now

Mixed emotions about this, because most likely the tensor creation is the slow part, I'd like to run isolated experiments with full models

In [78]:
s = time.time()
rq = torch.rand(seq_len, bs, embed_dim).to(device)
rk = torch.rand(seq_len, bs, embed_dim).to(device)
rv = torch.rand(seq_len, bs, embed_dim).to(device)
out8, out_drop = attn4(rq, rk, rv)
e = time.time()
e - s

0.013247489929199219

In [79]:
s = time.time()
rq = torch.rand(seq_len, bs, embed_dim).to(device)
rk = torch.rand(seq_len, bs, embed_dim).to(device)
rv = torch.rand(seq_len, bs, embed_dim).to(device)
d_att = embed_dim // num_heads
bs = rq.size()[1]
q = rq.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
k = rk.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
v = rv.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
out9 = flash_attn_func(q, k, v, dropout_p = dp)
out9 = out9.reshape(-1, bs, embed_dim)
e = time.time()
e - s

0.01303863525390625

In [81]:
total_time = 0

for _ in range(num_iterations):
    start_time = time.time()
    
    ## start func call
    rq = torch.rand(seq_len, bs, embed_dim).to(device)
    rk = torch.rand(seq_len, bs, embed_dim).to(device)
    rv = torch.rand(seq_len, bs, embed_dim).to(device)
    out10, out_drop = attn4(rq, rk, rv)
    # end func call
    
    end_time = time.time()
    
    total_time += (end_time - start_time)

avg_runtime = total_time / num_iterations
avg_runtime

0.012385438585281372

In [80]:
total_time = 0

for _ in range(num_iterations):
    start_time = time.time()
    ## start func call
    
    rq = torch.rand(seq_len, bs, embed_dim).to(device)
    rk = torch.rand(seq_len, bs, embed_dim).to(device)
    rv = torch.rand(seq_len, bs, embed_dim).to(device)
    d_att = embed_dim // num_heads
    bs = rq.size()[1]
    q = rq.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
    k = rk.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
    v = rv.view(-1, bs, num_heads, d_att).permute(1, 0, 2, 3).to(dtype).to(device)
    out11 = flash_attn_func(q, k, v, dropout_p = dp)
    out11 = out11.reshape(-1, bs, embed_dim)
    
    ## end 
    end_time = time.time()
    
    total_time += (end_time - start_time)

avg_runtime = total_time / num_iterations
avg_runtime

0.012145824599266053