<a href="https://colab.research.google.com/github/svetaU/Attention/blob/main/SA_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import math
try:
  import einops
except ModuleNotFoundError: 
  !pip install --quiet einops
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
#try:
#    import pytorch_lightning as pl
#except ModuleNotFoundError: 
#    !pip install --quiet pytorch-lightning>=1.5
#    import pytorch_lightning as pl

# Test data setup

In [2]:
# 2 batches with 1 sequence of 3 tokens, each token dim = 3
x = torch.tensor([[[1.,1.,0.], [2.,2.,1.], [3.,3.,2.]], [[4.,4.,3.], [5.,5.,4.], [6.,6.,5.]]])
xm = torch.tensor([[[1.,1.,0.,1.,1.,0.], [2.,2.,1.,2.,2.,1.], [3.,3.,2.,3.,3.,2.]], 
                  [[4.,4.,3.,4.,4.,3.], [5.,5.,4.,5.,5.,4.], [6.,6.,5.,6.,6.,5.]]])
w = torch.tensor([[1., -1., -1.],
        [1., 1., 1.],
        [ -1.,  -1., 1.],
        [1.,  1., 1.],
        [ -1.,  1.,  -1.],
        [1.,  1., 1.],
        [1., -1.,  1.],
        [ 1., 1., 1.],
        [-1.,  1., 1.]])

In [3]:
dim = x.shape[2]
to_qvk = nn.Linear(dim, dim * 3, bias=False)
with torch.no_grad():
    to_qvk.weight.copy_(w)
    qkv = to_qvk(x)
    q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3))

# Function1 for single head 

In [11]:
def single_head_attention_1(q, k, v, mask=None):
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size()[-1])
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

# Function2 for single head 

In [12]:
def single_head_attention_2(q, k, v, mask=None):
  scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) / math.sqrt(q.size()[-1])
  if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[1:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
  attention = torch.softmax(scaled_dot_prod, dim=-1)
  values = torch.einsum('b i j , b j d -> b i d', attention, v)
  return values, attention

# Function for multihead

In [3]:
def multi_head_attention(x,heads,mask=None):
  dim = x.size()[-1]
  dim_head = dim // heads
  to_qvk = nn.Linear(dim, dim_head * heads * 3, bias=False)
  qkv = to_qvk(x)
  q, k, v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d ', k=3, h=heads))
  scaled_dot_prod = torch.einsum('b h i d , b h j d -> b h i j', q, k) / math.sqrt(q.size()[-1])
  if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[2:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
  attention = torch.softmax(scaled_dot_prod, dim=-1)
  values = torch.einsum('b h i j , b h j d -> b h i d', attention, v)
  values = rearrange(values, 'b h t d -> b t (h d)')
  return values,attention

# Test functions

In [13]:
values_1, attention_1 = single_head_attention_1(q, k, v)
values_2, attention_2 = single_head_attention_2(q, k, v)

In [4]:
heads = 2
values_m, attention_m = multi_head_attention(xm,heads)

In [6]:
values_m[0,0,:]

tensor([ 1.6464, -0.8931, -0.4826, -1.7495, -0.6201,  0.3286],
       grad_fn=<SliceBackward0>)