<a href="https://colab.research.google.com/github/wldud01/Pytorch_tutorial/blob/main/Understanding_self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
#@title Embedding sentence

sentence = "안녕하세요. 반갑습니다. 인생은 짧고 처음 디저트를 먹습니다."

dc = {s:i for i,s in enumerate(sorted(sentence.split()))}

print(dc)

{'디저트를': 0, '먹습니다.': 1, '반갑습니다.': 2, '안녕하세요.': 3, '인생은': 4, '짧고': 5, '처음': 6}


In [6]:
# next we assign an integer index to each word

import torch

sentence_int = torch.tensor(
    [dc[s] for s in sentence.split()]
)
print(sentence_int)

tensor([3, 2, 4, 5, 6, 0, 1])


In [7]:
vocab_size = 50_000

torch.manual_seed(123)
# embedding layer
embed = torch.nn.Embedding(vocab_size,3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[-1.1925,  0.6984, -1.4097],
        [-0.2196, -0.3792,  0.7671],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.1690,  0.9178,  1.5810],
        [ 0.3374, -0.1778, -0.3035],
        [-0.5880,  0.3486,  0.6603]])
torch.Size([7, 3])


In [14]:
torch.manual_seed(123)

# embedding token
d= embedded_sentence.shape[1]
print(d)

# q-dimension and k-dimension is same
# v-dimension is arbitary
d_q,d_k,d_v = 2,2,4

W_query = torch.nn.Parameter(torch.rand(d,d_q))# 3,2
W_key = torch.nn.Parameter(torch.rand(d,d_k)) # 3,2
W_value = torch.nn.Parameter(torch.rand(d,d_v)) # 3,4

print(W_query)

3
Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)


In [13]:
x_2 = embedded_sentence[1]
print(x_2)

# x_2의 query key value
query_2 = x_2@ W_query
key_2 = x_2@W_key
value_2 = x_2@W_value

print(query_2.shape)
print(query_2)
print(key_2.shape)
print(value_2.shape)

tensor([-0.2196, -0.3792,  0.7671])
torch.Size([2])
tensor([-0.1037,  0.2902], grad_fn=<SqueezeBackward4>)
torch.Size([2])
torch.Size([4])


In [17]:
# whole key and value using embedded sentence
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print(keys[4])
print("key,shape:", keys.shape)
print("values.shape", values.shape)

tensor([0.6443, 1.7357], grad_fn=<SelectBackward0>)
key,shape: torch.Size([7, 2])
values.shape torch.Size([7, 4])


In [18]:
# x_2 의 query를 전체 embedded sentence의 key matrices와 dot product
print("query_2", query_2)
omega_24 = query_2.dot(keys[4])
print(omega_24)

query_2 tensor([-0.1037,  0.2902], grad_fn=<SqueezeBackward4>)
tensor(0.4368, grad_fn=<DotBackward0>)


In [15]:
omega_2 = query_2 @ keys.T # 1x2  2x7 =  1x7
print(omega_2)

tensor([-0.1197,  0.0518,  0.4487, -0.1807,  0.4368, -0.0794,  0.1677],
       grad_fn=<SqueezeBackward4>)


In [19]:
import torch.nn.functional as F

# d_k로 scaling  -> softmax - > proabability distribution
attention_weights_2 = F.softmax(omega_2/d_k**0.5,dim=0)
print(attention_weights_2)

tensor([0.1202, 0.1357, 0.1797, 0.1151, 0.1782, 0.1237, 0.1473],
       grad_fn=<SoftmaxBackward0>)


In [20]:
# softmax를 거친 attention weights를 value값과 cardinary product
context_vector_2 = attention_weights_2 @ values

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([4])
tensor([0.2634, 0.5714, 0.3120, 0.5370], grad_fn=<SqueezeBackward4>)


In [21]:
#@title Self attention
import torch.nn as nn
class SelfAttention(nn.Module):

  # initialize
  def __init__(self,d_in, d_out_kq, d_put_v):
    super().__init__()
    self.d_out_kq = d_out_kq
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_value = nn.Parameter(torch.rand(d_in, d_put_v))

  # forward pass
  def forward(self,x):
    # key query value represent
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value

    # make attention score and attention weights probability
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(
        attn_scores/self.d_out_kq**0.5, dim = -1
    )
    context_vec = attn_weights @ values
    return context_vec

In [22]:
torch.manual_seed(123)

# reduce d_out_v from 4 to 1 because we have 4 head

d_in, d_out_kq,d_out_v = 3,2,4

# Self Attention module
sa = SelfAttention(d_in,d_out_kq,d_out_v)
print(sa(embedded_sentence))

tensor([[-0.4831, -0.2340, -0.3703, -0.5371],
        [ 0.2634,  0.5714,  0.3120,  0.5370],
        [ 0.9139,  1.4875,  1.0045,  1.6377],
        [-0.2496, -0.0196, -0.1716, -0.2275],
        [ 0.9082,  1.4865,  1.0015,  1.6333],
        [ 0.0316,  0.2859,  0.0857,  0.1780],
        [ 0.3908,  0.7338,  0.4397,  0.7393]], grad_fn=<MmBackward0>)


In [23]:
class MultiHeadAttentionWrapper(nn.Module):

  def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
    super().__init__()
    # iteration
    self.heads = nn.ModuleList(
        [SelfAttention(d_in,d_out_kq,d_out_v)
        for _ in range(num_heads)]
    )

  # forward pass
  def forward(self,x):
    return torch.cat([head(x) for head in self.heads], dim = -1)

In [24]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 1

sa  = SelfAttention(d_in,d_out_kq,d_out_v)
print(sa(embedded_sentence))

tensor([[-0.1627],
        [ 0.1739],
        [ 0.5123],
        [-0.0651],
        [ 0.5108],
        [ 0.0621],
        [ 0.2366]], grad_fn=<MmBackward0>)


In [26]:
torch.manual_seed(123)

block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
    d_in,d_out_kq, d_out_v, num_heads = 4
)
context_vecs = mha(embedded_sentence)

print(context_vecs)
print("context vector shape", context_vecs.shape)

tensor([[-0.1627, -0.1817, -0.1925, -0.2708],
        [ 0.1739,  0.4257,  0.4324,  0.1601],
        [ 0.5123,  1.4266,  1.4526,  0.9668],
        [-0.0651, -0.1244,  0.1373, -0.1819],
        [ 0.5108,  1.4200,  1.3082,  0.9486],
        [ 0.0621,  0.0892,  0.3831, -0.0157],
        [ 0.2366,  0.7066,  0.5258,  0.3166]], grad_fn=<CatBackward0>)
context vector shape torch.Size([7, 4])


In [2]:
#@title Cross Attention
import torch.nn as nn

class CrossAttention(nn.Module):

  def __init__(self, d_in, d_out_qk, d_out_v):
    super().__init__()
    self.d_out_kq = d_out_qk
    self.W_query = nn.Parameter(torch.rand(d_in,d_out_kq))
    self.W_key = nn.Parameter(torch.rand(d_in,d_out_kq))
    self.W_value = nn.Parameter(torch.rand(d_in,d_out_v))

  def forward(self, x_1, x_2):
    # decoder input query
    queries_1 = x_1 @ self.W_query

    # encoder context vector
    keys_2 = x_2 @ self.W_key
    values_2 = x_2 @ self.W_value

    attn_scores = queries_1 @ keys_2.T
    attn_weights = torch.softmax(
        attn_scores/self.d_out_kq**0.5, dim = -1
    )
    context_vec = attn_weights @ values_2
    return context_vec

In [8]:
import torch
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3,2,4

crossattn = CrossAttention(d_in,d_out_kq,d_out_v)

first_input = embedded_sentence
second_input = torch.rand(8,d_in)

print('first input shape', first_input.shape)
print("second input shape", second_input.shape)

first input shape torch.Size([7, 3])
second input shape torch.Size([8, 3])


In [10]:
#x_1 ,x_2
context_vectors = crossattn(first_input, second_input)

print(context_vectors)
print("output shape", context_vectors.shape)

tensor([[0.3860, 0.8021, 0.5985, 0.9250],
        [0.4357, 0.8886, 0.6678, 1.0311],
        [0.4874, 0.9718, 0.7359, 1.1353],
        [0.4054, 0.8359, 0.6258, 0.9667],
        [0.4863, 0.9709, 0.7347, 1.1336],
        [0.4231, 0.8665, 0.6503, 1.0042],
        [0.4429, 0.9006, 0.6775, 1.0460]], grad_fn=<MmBackward0>)
output shape torch.Size([7, 4])


In [19]:
#@title Causal attention

torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3,2,4

W_query = nn.Parameter(torch.rand(d_in,d_out_kq))
W_key = nn.Parameter(torch.rand(d_in,d_out_kq))
W_value = nn.Parameter(torch.rand(d_in,d_out_v))

x = embedded_sentence

keys = x @ W_key
queries = x @ W_query
values = x @ W_value

attn_scores = queries @ keys.T
attn_weights = torch.softmax(
    attn_scores/d_out_kq**0.5,dim = 1
)

print(attn_scores)
print(attn_scores.shape)
print("")
print("Attention weights" ,"\n",attn_weights)

tensor([[ 0.9265, -0.3509, -2.5037,  1.0740, -2.5363,  0.4344, -0.9315],
        [-0.1197,  0.0518,  0.4487, -0.1807,  0.4368, -0.0794,  0.1677],
        [-1.3374,  0.4991,  3.4707, -1.5023,  3.5360, -0.6004,  1.2903],
        [ 0.4730, -0.1851, -1.3934,  0.5869, -1.3953,  0.2432, -0.5191],
        [-1.2598,  0.4810,  3.4806, -1.4859,  3.5151, -0.6049,  1.2954],
        [ 0.1076, -0.0437, -0.3491,  0.1443, -0.3454,  0.0613, -0.1303],
        [-0.2787,  0.1112,  0.8626, -0.3597,  0.8584, -0.1510,  0.3216]],
       grad_fn=<MmBackward0>)
torch.Size([7, 7])

Attention weights 
 tensor([[0.2729, 0.1106, 0.0241, 0.3028, 0.0236, 0.1927, 0.0733],
        [0.1202, 0.1357, 0.1797, 0.1151, 0.1782, 0.1237, 0.1473],
        [0.0133, 0.0489, 0.3995, 0.0119, 0.4184, 0.0225, 0.0855],
        [0.2178, 0.1368, 0.0582, 0.2360, 0.0581, 0.1851, 0.1080],
        [0.0141, 0.0484, 0.4035, 0.0120, 0.4135, 0.0224, 0.0861],
        [0.1616, 0.1452, 0.1170, 0.1659, 0.1173, 0.1564, 0.1366],
        [0.0965, 0.127

In [20]:
block_size = attn_scores.shape[0]
# diagonal matrix
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])


In [21]:
# create diagonal matrix and multiple with attn_weigths
# zero is for masking future work
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.2729, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1202, 0.1357, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0133, 0.0489, 0.3995, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2178, 0.1368, 0.0582, 0.2360, 0.0000, 0.0000, 0.0000],
        [0.0141, 0.0484, 0.4035, 0.0120, 0.4135, 0.0000, 0.0000],
        [0.1616, 0.1452, 0.1170, 0.1659, 0.1173, 0.1564, 0.0000],
        [0.0965, 0.1272, 0.2163, 0.0911, 0.2157, 0.1056, 0.1476]],
       grad_fn=<MulBackward0>)


In [23]:
# sum each col
row_sums = masked_simple.sum(dim=1, keepdim = True)
# scaling vector
masked_simple_norm = masked_simple/row_sums

print(row_sums)
print(masked_simple_norm)

tensor([[0.2729],
        [0.2559],
        [0.4617],
        [0.6488],
        [0.8915],
        [0.8634],
        [1.0000]], grad_fn=<SumBackward1>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4697, 0.5303, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0289, 0.1058, 0.8653, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3357, 0.2108, 0.0897, 0.3638, 0.0000, 0.0000, 0.0000],
        [0.0158, 0.0543, 0.4526, 0.0135, 0.4638, 0.0000, 0.0000],
        [0.1872, 0.1682, 0.1355, 0.1921, 0.1359, 0.1812, 0.0000],
        [0.0965, 0.1272, 0.2163, 0.0911, 0.2157, 0.1056, 0.1476]],
       grad_fn=<DivBackward0>)


In [24]:
# more efficient method

# Upper triangle
mask = torch.triu(torch.ones(block_size,block_size),diagonal=1)
print(mask)

# positive mask values with -torch.inf
masked = attn_scores.masked_fill(mask.bool(),-torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0.]])
tensor([[ 0.9265,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1197,  0.0518,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.3374,  0.4991,  3.4707,    -inf,    -inf,    -inf,    -inf],
        [ 0.4730, -0.1851, -1.3934,  0.5869,    -inf,    -inf,    -inf],
        [-1.2598,  0.4810,  3.4806, -1.4859,  3.5151,    -inf,    -inf],
        [ 0.1076, -0.0437, -0.3491,  0.1443, -0.3454,  0.0613,    -inf],
        [-0.2787,  0.1112,  0.8626, -0.3597,  0.8584, -0.1510,  0.3216]],
       grad_fn=<MaskedFillBackward0>)


In [25]:
# to apply the softmax function
# e^(-inf) approaches 0
attn_weights = torch.softmax(masked/d_out_kq**0.5,dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4697, 0.5303, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0289, 0.1058, 0.8653, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3357, 0.2108, 0.0897, 0.3638, 0.0000, 0.0000, 0.0000],
        [0.0158, 0.0543, 0.4526, 0.0135, 0.4638, 0.0000, 0.0000],
        [0.1872, 0.1682, 0.1355, 0.1921, 0.1359, 0.1812, 0.0000],
        [0.0965, 0.1272, 0.2163, 0.0911, 0.2157, 0.1056, 0.1476]],
       grad_fn=<SoftmaxBackward0>)
