<a href="https://colab.research.google.com/github/teelch0/Data-Mining/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [None]:
import torch

In [None]:
inputs = torch.nn.Embedding( 4, 8 )

In [None]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-1.1118, -2.4002,  0.7627,  1.0455, -3.1239,  0.6044, -0.7964, -0.0549],
        [ 0.5018,  1.6092,  1.3768,  1.4527, -0.0953,  0.3263,  0.5369, -0.8991],
        [-0.2023,  1.3612, -0.0531, -1.6946,  0.9287, -0.4567, -2.6515,  0.0233],
        [-0.3179,  0.6661, -0.9793, -0.6810, -0.1212,  1.0293,  0.6244, -0.1560]],
       requires_grad=True)

In [None]:
inputs = inputs.data
inputs

tensor([[-1.1118, -2.4002,  0.7627,  1.0455, -3.1239,  0.6044, -0.7964, -0.0549],
        [ 0.5018,  1.6092,  1.3768,  1.4527, -0.0953,  0.3263,  0.5369, -0.8991],
        [-0.2023,  1.3612, -0.0531, -1.6946,  0.9287, -0.4567, -2.6515,  0.0233],
        [-0.3179,  0.6661, -0.9793, -0.6810, -0.1212,  1.0293,  0.6244, -0.1560]])

In [None]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [None]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-1.3401, -0.9156, -1.3707, -3.3514, -2.1593, -3.4785])

In [None]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-1.8908, -1.7768, -4.0872, -3.7123, -2.3631, -3.1359],
        [ 2.6818,  1.2621,  3.0026,  2.7216,  3.0776,  2.9235],
        [-0.2769, -0.0932, -1.1301,  0.9661, -1.2257, -2.0170],
        [-0.3507,  0.5743,  0.0617,  0.1077, -0.7202,  0.3298]])
Values: tensor([[-1.3290, -2.8374, -1.8810, -1.9154, -0.9527, -3.7274],
        [ 3.2560,  2.2823,  1.6171,  3.2857,  1.1398,  1.3761],
        [-2.7452, -2.8251,  0.1613, -0.7266, -0.4205,  0.1436],
        [-0.5715, -0.0403, -1.0059,  0.4403,  0.0192, -0.0224]])


In [None]:
attention_scores = query @ keys.T
attention_scores

tensor([ 38.2147, -34.8005,   8.4305,  -0.0933])

In [None]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([9.9999e-01, 1.1335e-13, 5.2390e-06, 1.6143e-07])

In [None]:
attention_weights.sum()

tensor(1.)

In [None]:
context_vector = attention_weights @ values
context_vector

tensor([-1.3290, -2.8374, -1.8809, -1.9154, -0.9527, -3.7274])

In [None]:
import torch.nn as nn


In [None]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [None]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [None]:
simple.W_v

Parameter containing:
tensor([[0.5806, 0.3486, 0.7062, 0.2142, 0.3308, 0.3980],
        [0.3569, 0.6781, 0.7178, 0.3885, 0.8321, 0.2606],
        [0.1146, 0.6278, 0.8210, 0.7774, 0.3684, 0.2445],
        [0.8441, 0.0332, 0.2504, 0.2416, 0.6393, 0.2418],
        [0.4270, 0.3990, 0.1195, 0.6528, 0.9573, 0.8969],
        [0.6678, 0.7807, 0.8325, 0.2539, 0.0949, 0.4308],
        [0.3899, 0.5472, 0.2075, 0.4764, 0.0356, 0.9468],
        [0.7345, 0.1856, 0.0201, 0.6234, 0.8641, 0.4271]])

In [None]:
context_vectors = simple( inputs )
context_vectors

tensor([[-1.8135, -2.7202, -1.6556, -2.6228, -4.4199, -3.9459],
        [ 1.9758,  2.5223,  3.3570,  1.8698,  2.1228,  1.4861],
        [-1.8134, -2.7201, -1.6555, -2.6228, -4.4198, -3.9458],
        [ 0.8819,  1.4235,  1.9262,  0.7932,  0.9448,  0.4997]])

In [None]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [None]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [None]:
context_vectors = simple( inputs )
context_vectors

tensor([[-1.5887e-01, -1.8098e-01,  1.0701e-01, -2.4173e-01,  1.6169e-01,
          3.2807e-01],
        [ 1.7067e-02, -2.1231e-01,  9.7150e-02, -1.3696e-01,  2.2546e-01,
          3.6515e-01],
        [ 1.8039e-01, -1.5505e-01,  6.4141e-02, -4.3418e-04,  3.4362e-01,
          3.7114e-01],
        [ 3.3379e-01, -8.7434e-01,  1.5158e-01,  2.9614e-01, -4.6630e-01,
          3.2409e-02]], grad_fn=<MmBackward0>)