## flash attention and attention mechanism 
### self attention

In [None]:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
import torch

batchsize = 2
seqlen = 128
num_heads = 8
head_dim = 64

q,k,v = [torch.randn(batchsize, seqlen, num_heads, head_dim) for _ in range(3)]

In [2]:
# flash_attn_func needs bf16 and cuda
q = q.to(torch.bfloat16).cuda()
k = k.to(torch.bfloat16).cuda()
v = v.to(torch.bfloat16).cuda()

In [3]:
output_flash_attn = flash_attn_func(q, k, v)
print(output_flash_attn.shape)

torch.Size([2, 128, 8, 64])


In [4]:
activation = torch.einsum('bqhc,bkhc->bhqk',q,k) / (head_dim**0.5)
output = torch.einsum('bhqk,bkhc->bqhc',activation,v)
print(output.shape)

torch.Size([2, 128, 8, 64])


In [8]:
output[0,0,0]

tensor([ 13.6875,   2.9844,   1.0938,   6.7500,  20.8750,   9.7500,   0.8945,
          3.7812,   7.3125,  -3.6250,  -0.6055,  10.7500,  -6.7188, -20.0000,
          8.4375, -19.0000, -12.1875,   3.0469,   6.6562,  20.1250,   3.6719,
         12.1250, -10.5625,   8.5000,  -6.3125,  15.7500, -11.6250,   4.1875,
         -7.2500,   4.1562, -13.3125,   5.5625,   8.3750,  -8.8125,  13.5000,
         -5.0312,   9.8750,  13.4375,   5.1875,  -0.0776,  -3.4688, -10.6250,
         -2.0781, -24.1250,  -2.1875,   8.5000, -17.1250,  12.6875,   9.1250,
          8.6250,   9.3125,   0.7578, -16.0000,   2.4375,  -9.1875,  14.6250,
        -11.9375,  -7.3438,   7.1250,   2.2656,  -1.1016,  -1.9297, -10.1875,
         14.3125], device='cuda:0', dtype=torch.bfloat16)

In [6]:
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=q.device), 1) # 1 means diagonal and above
# node can attend itself and the nodes before it

### cross attention

In [15]:
memory_len = 256
memory_k = torch.randn(batchsize, memory_len, num_heads, head_dim, device=q.device).to(torch.bfloat16)
memory_v = torch.randn(batchsize, memory_len, num_heads, head_dim, device=q.device).to(torch.bfloat16)

In [16]:
output_cross_attention = flash_attn_func(q, memory_k, memory_v)

In [55]:
# original scaled-dot product self-attention 
# no learnable parameters
import copy
import torch
import torch.nn as nn


batchsize = 16
seqlen = 128
dim = 1024
heads = 16
x = torch.randn(batchsize, seqlen, dim) # b k h c
q,k,v = copy.deepcopy(x), copy.deepcopy(x), copy.deepcopy(x)


score = torch.einsum("bqc,bkc->bqk",q,k) * (q.shape[-1] ** -0.5)
mask = torch.zeros_like(score)
score += mask

# print(score.mean(),score.std()) # check the varaince of score if no scaled 
score = nn.functional.softmax(score,dim=-1)

output = torch.einsum("bqk,bkc->bqc",score,v)




In [56]:
# Multi-Head attention
import copy
import torch.nn as nn
q,k,v = copy.deepcopy(x), copy.deepcopy(x), copy.deepcopy(x)

Wq = nn.Linear(dim,dim)
Wk = nn.Linear(dim,dim)
Wv = nn.Linear(dim,dim)
Wo = nn.Linear(dim,dim)

# projection into many lower dimensions
q = Wq(q).reshape(batchsize, seqlen,heads,dim//heads) # b q h c
k = Wk(k).reshape(batchsize, seqlen,heads,dim//heads) # b k h c 
v = Wv(v).reshape(batchsize, seqlen,heads,dim//heads) # b k h c

# heads times scaled-dot product attention

## compute heads attention
score = torch.einsum("bqhc,bkhc->bqkh",q,k) * ((dim//heads)**-0.5)
mask = torch.zeros_like(score)
score += mask
# print(score.mean(),score.std()) # check the varaince of score if no scaled 

score = nn.functional.softmax(score,dim=-1)

## output heads times and concatenation
output = torch.einsum("bqkh,bkhc->bqhc",score,v).reshape(batchsize, seqlen,-1)

# projection final layers

output = Wo(output)

In [15]:
# Multi-Head attention (Cross Attention)
import copy
import torch
import torch.nn as nn

batchsize = 16
seqlen_query = 128
seqlen_memory = 256
dim = 1024
heads = 16
x_query = torch.randn(batchsize, seqlen_query, dim) # b q c
x_memory = torch.randn(batchsize, seqlen_memory, dim) # b m c
q = copy.deepcopy(x_query)
k,v = copy.deepcopy(x_memory),copy.deepcopy(x_memory)

# Define the projection weight matrices
Wq = nn.Linear(dim, dim)
Wk = nn.Linear(dim, dim)
Wv = nn.Linear(dim, dim)
Wo = nn.Linear(dim, dim)

# projection into many lower dimension space
q = Wq(q).reshape(batchsize, seqlen_query,heads,dim//heads) # b q h c
k = Wk(k).reshape(batchsize, seqlen_memory,heads,dim//heads) # b m h c 
v = Wv(v).reshape(batchsize, seqlen_memory,heads,dim//heads) # b m h c

# score
score = torch.einsum("bqhc,bmhc->bqmh",q,k)
score = nn.functional.softmax(score,dim=-1)

## output heads times and concatenation
output = torch.einsum("bqmh,bmhc->bqhc",score,v).reshape(batchsize, seqlen_query,-1)

# projection final layers

output = Wo(output)

In [20]:
# Li Mu layer norm
# layer norm on token
LN_token = nn.LayerNorm(dim)
print(LN_token(output).mean(dim=-1)[0,0],LN_token(output).std(dim=-1)[0,0])
print(LN_token(output).mean(dim=[-1,-2])[0],LN_token(output).std(dim=[-1,-2])[0])

tensor(-6.5193e-09, grad_fn=<SelectBackward0>) tensor(1.0005, grad_fn=<SelectBackward0>)
tensor(4.6566e-10, grad_fn=<SelectBackward0>) tensor(1.0000, grad_fn=<SelectBackward0>)


In [16]:
LN_sample = nn.LayerNorm((seqlen_query, dim))
LN_sample(output).mean(dim=-1)[0,0],LN_sample(output).std(dim=-1)[0,0]


(tensor(-0.0191, grad_fn=<SelectBackward0>),
 tensor(1.0172, grad_fn=<SelectBackward0>))

In [17]:
# mean on last two dimensions
LN_sample(output).mean(dim=[-1,-2])[0],LN_sample(output).std(dim=[-1,-2])[0]

(tensor(-9.3132e-10, grad_fn=<SelectBackward0>),
 tensor(1.0000, grad_fn=<SelectBackward0>))