<a href="https://colab.research.google.com/github/teelch0/Data-Mining/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [32]:
import torch

In [33]:
inputs = torch.nn.Embedding( 4, 8 )

In [34]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-1.2806,  0.4944,  0.4300,  0.0465,  1.4924,  0.2436, -1.0622, -0.4265],
        [-2.4479, -1.0193, -0.2891,  0.4367,  1.2408, -0.8976, -1.7771, -1.0501],
        [ 1.3180, -0.5762,  0.0250,  0.8471,  1.0583,  0.0349,  0.8652,  1.2088],
        [ 0.2441, -1.5191,  0.0631,  0.8979, -0.9527,  2.3419,  1.3389, -0.2293]],
       requires_grad=True)

In [35]:
inputs = inputs.data
inputs

tensor([[-1.2806,  0.4944,  0.4300,  0.0465,  1.4924,  0.2436, -1.0622, -0.4265],
        [-2.4479, -1.0193, -0.2891,  0.4367,  1.2408, -0.8976, -1.7771, -1.0501],
        [ 1.3180, -0.5762,  0.0250,  0.8471,  1.0583,  0.0349,  0.8652,  1.2088],
        [ 0.2441, -1.5191,  0.0631,  0.8979, -0.9527,  2.3419,  1.3389, -0.2293]])

In [36]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [37]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([1.5024, 2.4207, 2.9022, 2.1201, 2.3435, 2.8275])

In [38]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-0.5073,  1.2651, -0.1021,  0.8676, -0.5303,  0.0434],
        [-3.9572, -0.3074, -3.3211, -2.9162, -2.4289, -2.8603],
        [ 3.3206,  1.5578,  2.4214,  2.4110,  2.3086,  3.4707],
        [ 1.0083, -0.9660,  0.0181,  0.8910,  0.4738,  2.5151]])
Values: tensor([[-3.5163e-01,  1.0783e-01, -3.1948e-01,  1.3662e+00, -2.8707e-03,
          5.8443e-03],
        [-2.1639e+00, -2.3352e+00, -3.9932e+00, -1.9737e+00, -3.2772e+00,
         -3.0560e+00],
        [ 2.9322e+00,  7.6027e-01,  2.1532e+00,  1.8534e+00,  1.5613e+00,
          3.6199e+00],
        [ 6.9703e-01,  5.8646e-01, -7.3323e-01,  1.0819e-01,  1.4628e+00,
          5.2351e-02]])


In [39]:
attention_scores = query @ keys.T
attention_scores

tensor([  2.7231, -36.2900,  36.1226,   9.3399])

In [40]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([1.1974e-06, 1.4496e-13, 9.9998e-01, 1.7841e-05])

In [41]:
attention_weights.sum()

tensor(1.0000)

In [42]:
context_vector = attention_weights @ values
context_vector

tensor([2.9321, 0.7603, 2.1531, 1.8534, 1.5613, 3.6199])

In [43]:
import torch.nn as nn


In [44]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [45]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [46]:
simple.W_v

Parameter containing:
tensor([[0.7293, 0.8502, 0.6565, 0.4199, 0.1735, 0.9826],
        [0.3794, 0.2877, 0.7248, 0.0818, 0.8074, 0.6125],
        [0.8329, 0.3381, 0.4947, 0.1881, 0.6524, 0.6146],
        [0.1393, 0.7182, 0.7430, 0.3475, 0.9769, 0.4539],
        [0.0665, 0.1510, 0.0243, 0.3057, 0.1000, 0.6977],
        [0.5152, 0.1690, 0.9540, 0.0295, 0.7460, 0.0525],
        [0.6649, 0.5394, 0.1135, 0.1347, 0.3084, 0.3127],
        [0.2835, 0.2572, 0.6093, 0.0678, 0.2464, 0.7825]])

In [47]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 1.7185,  1.8073,  1.8980,  0.7321,  1.4871,  1.2243],
        [-4.2112, -3.3516, -3.8323, -0.9717, -2.3623, -3.5680],
        [ 1.8876,  2.5140,  1.9831,  1.3276,  1.3042,  3.2954],
        [ 1.8796,  2.4783,  1.9795,  1.2974,  1.3141,  3.1905]])

In [48]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return weights

In [49]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [50]:
context_vectors = simple( inputs )
context_vectors

tensor([[0.2009, 0.1474, 0.3745, 0.2772],
        [0.1809, 0.1300, 0.3843, 0.3048],
        [0.2137, 0.2182, 0.1984, 0.3697],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<SoftmaxBackward0>)

In [51]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [58]:
# this is a hack to get some example weights to work with!
weights = simple( inputs )
weights

tensor([[0.2009, 0.1474, 0.3745, 0.2772],
        [0.1809, 0.1300, 0.3843, 0.3048],
        [0.2137, 0.2182, 0.1984, 0.3697],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<SoftmaxBackward0>)

In [53]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.2787, 1.4662, 1.1613, 1.2218], grad_fn=<SumBackward1>)

In [54]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [59]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2009, 0.0000, 0.0000, 0.0000],
        [0.1809, 0.1300, 0.0000, 0.0000],
        [0.2137, 0.2182, 0.1984, 0.0000],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<MulBackward0>)

In [60]:
masked_weights.sum( dim=-1 )

tensor([0.2009, 0.3109, 0.6303, 1.0000], grad_fn=<SumBackward1>)

In [61]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2009],
        [0.3109],
        [0.6303],
        [1.0000]], grad_fn=<SumBackward1>)

In [62]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [63]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5820, 0.4180, 0.0000, 0.0000],
        [0.3391, 0.3462, 0.3147, 0.0000],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<DivBackward0>)

In [64]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [65]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [66]:
weights

tensor([[0.2009, 0.1474, 0.3745, 0.2772],
        [0.1809, 0.1300, 0.3843, 0.3048],
        [0.2137, 0.2182, 0.1984, 0.3697],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<SoftmaxBackward0>)

In [67]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2009,   -inf,   -inf,   -inf],
        [0.1809, 0.1300,   -inf,   -inf],
        [0.2137, 0.2182, 0.1984,   -inf],
        [0.2475, 0.3959, 0.1855, 0.1711]], grad_fn=<MaskedFillBackward0>)

In [68]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5127, 0.4873, 0.0000, 0.0000],
        [0.3345, 0.3360, 0.3294, 0.0000],
        [0.2484, 0.2881, 0.2334, 0.2301]], grad_fn=<SoftmaxBackward0>)

In [69]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [70]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6589, 0.0000],
        [0.0000, 0.5762, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

In [71]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [72]:
batches.shape

torch.Size([2, 4, 8])

In [73]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    self.dropout = Dropout( dropout )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context