<a href="https://colab.research.google.com/github/svetaU/Attention/blob/main/SA_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import numpy as np
import math
try:
  import einops
except ModuleNotFoundError: 
  !pip install --quiet einops
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
#try:
#    import pytorch_lightning as pl
#except ModuleNotFoundError: 
#    !pip install --quiet pytorch-lightning>=1.5
#    import pytorch_lightning as pl

# Test data setup

In [2]:
# 2 batches with 1 sequence of 3 tokens, each token dim = 3
x = torch.tensor([[[1.,1.,0.], [2.,2.,1.], [3.,3.,2.]], [[4.,4.,3.], [5.,5.,4.], [6.,6.,5.]]])
w = torch.tensor([[1., -1., -1.],
        [1., 1., 1.],
        [ -1.,  -1., 1.],
        [1.,  1., 1.],
        [ -1.,  1.,  -1.],
        [1.,  1., 1.],
        [1., -1.,  1.],
        [ 1., 1., 1.],
        [-1.,  1., 1.]])

In [3]:
dim = x.shape[2]
to_qvk = nn.Linear(dim, dim * 3, bias=False)
with torch.no_grad():
    to_qvk.weight.copy_(w)
    qkv = to_qvk(x)
    q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3))

# Function1 for single head 

In [5]:
def single_head_attention_1(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

# Function2 for single head 

In [6]:
def single_head_attention_2(q, k, v, mask=None):
  scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) / math.sqrt(q.size()[-1])
  if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[1:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
  attention = torch.softmax(scaled_dot_prod, dim=-1)
  values = torch.einsum('b i j , b j d -> b i d', attention, v)
  return values, attention

# Test functions

In [7]:
values_1, attention_1 = single_head_attention_1(q, k, v)
values_2, attention_2 = single_head_attention_2(q, k, v)

In [8]:
values_1

tensor([[[-2.3632e+00,  3.0897e+00,  3.6324e-01],
         [-2.0585e+00,  2.1756e+00,  5.8529e-02],
         [-2.0100e+00,  2.0299e+00,  9.9600e-03]],

        [[-5.0017e+00,  1.1005e+01,  3.0017e+00],
         [-5.0003e+00,  1.1001e+01,  3.0003e+00],
         [-5.0001e+00,  1.1000e+01,  3.0001e+00]]])

In [9]:
values_2

tensor([[[-2.3632e+00,  3.0897e+00,  3.6324e-01],
         [-2.0585e+00,  2.1756e+00,  5.8529e-02],
         [-2.0100e+00,  2.0299e+00,  9.9600e-03]],

        [[-5.0017e+00,  1.1005e+01,  3.0017e+00],
         [-5.0003e+00,  1.1001e+01,  3.0003e+00],
         [-5.0001e+00,  1.1000e+01,  3.0001e+00]]])