In [9]:
# !conda install -y pytorch torchvision torchaudio -c pytorch-nightly
# !conda install -y mpmath
# !conda install -y cython 

Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.

Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/ratneshjamidar/miniconda3

  added / updated specs:
    - cython


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cython-0.29.35             |  py310h313beb8_0         2.1 MB
    ------------------------------------------------------------
                                           Total:         2.1 MB

The following NEW packages will be INSTALLED:

  cython             pkgs/main/osx-arm64::cython-0.29.35-py310h313beb8_0 



Downloading and Extracting Packages
                                     

# Reference: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [5]:
sentence = "Always stay curious my son"
vocab = {word: i for i, word in enumerate(sorted(sentence.split(" ")))}
vocab

{'Always': 0, 'curious': 1, 'my': 2, 'son': 3, 'stay': 4}

Dumb Tokenizer, for every word it will give an integer, let's create a tensor 

In [6]:
import torch
tensor = torch.tensor([vocab[word] for word in sentence.split(" ")])
tensor

tensor([0, 4, 1, 2, 3])

Since we have sentence vector, we can now prepare vector embedding, we will pick embedding dimension to be `32`
We will do random initialisation of embedding of size `5*32` , since our tensor size is `5`

In [25]:
torch.manual_seed(33)

#effectively a look table for each vector, it usually have size more than sentence length, equals to vocab size
embed = torch.nn.Embedding(5, 32)
embedded_sentence = embed(tensor).detach()
embedded_sentence.shape

torch.Size([5, 32])

Now we will setup key, query and value matrices W<sub>k</sub>, W<sub>q</sub> and W<sub>v</sub>

Dimensions of  W<sub>k</sub> and  W<sub>q</sub> is  d<sub>k</sub>*d ,  d<sub>q</sub>*d,  d<sub>v</sub>*d
where q==k and d = dimension of token vector


In [34]:
d = embedded_sentence.shape[1]
q, k, v = 36, 36, 48

W_query = torch.nn.Parameter(torch.rand(q,d))
W_key = torch.nn.Parameter(torch.rand(k,d))
W_value = torch.nn.Parameter(torch.rand(v,d))

print(W_query.shape)
print(W_key.shape)
print(W_value.shape)

torch.Size([36, 32])
torch.Size([36, 32])
torch.Size([48, 32])


In [71]:
queries = torch.matmul(W_query,embedded_sentence.T).T
keys = torch.matmul(W_query,embedded_sentence.T).T
values = torch.matmul(W_value,embedded_sentence.T).T

print(queries.shape)
print(keys.shape)

print(values.shape)

torch.Size([5, 36])
torch.Size([5, 36])
torch.Size([5, 48])


In [72]:
# unnormalize attention weights

omega =  queries.matmul(keys.T)
print(omega.shape)

torch.Size([5, 5])


In [52]:
import torch.nn.functional as F
attention_weights = F.softmax(omega/k**(0.5), dim=0)
print(attention_weights.shape)

torch.Size([5, 5])


In [51]:
context_vector = attention_weights.matmul(values)
print(context_vector.shape)

torch.Size([5, 48])


In [53]:
# Let's extend this to multihead attention now 

In [82]:
h = 3

embedded_sentence_multi_head = embedded_sentence.T.repeat(3,1,1)

print("embedded_sentence_multi_head shape {}".format(embedded_sentence_multi_head.shape) )
W_query_multi_head = torch.nn.Parameter(torch.rand(h, q, d))
W_key_multi_head = torch.nn.Parameter(torch.rand(h, k, d))
W_value_multi_head = torch.nn.Parameter(torch.rand(h, v, d))

keys_multi_head = torch.bmm(W_key_multi_head, embedded_sentence_multi_head)
queries_multi_head = torch.bmm(W_query_multi_head, embedded_sentence_multi_head)
values_multi_head = torch.bmm(W_value_multi_head,embedded_sentence_multi_head)
print("keys_multi_head shape {}".format(keys_multi_head.shape) )
print("queries_multi_head shape {}".format(queries_multi_head.shape) )
print("values_multi_head shape {}".format(values_multi_head.shape) )


omega_multi_head = torch.bmm(keys_multi_head.transpose(-2,-1),queries_multi_head)

print("omega_multi_head shape {}".format(omega_multi_head.shape) )

context_vector = torch.bmm(F.softmax(omega_multi_head/k**0.5),values_multi_head.transpose(-1,-2))
print("context_vector shape {}".format(context_vector.shape) )


embedded_sentence_multi_head shape torch.Size([3, 32, 5])
keys_multi_head shape torch.Size([3, 36, 5])
queries_multi_head shape torch.Size([3, 36, 5])
values_multi_head shape torch.Size([3, 48, 5])
omega_multi_head shape torch.Size([3, 5, 5])
context_vector shape torch.Size([3, 5, 48])


  context_vector = torch.bmm(F.softmax(omega_multi_head/k**0.5),values_multi_head.transpose(-1,-2))


Cross attention

In [98]:
# this is decoder input
embedded_sentence_2 = torch.rand(16, 32)

#embedded_sentence is encoder output 


W_key_cross_attention = torch.nn.Parameter(torch.rand(k, d))
W_value_cross_attention = torch.nn.Parameter(torch.rand(v, d))
W_query_cross_attention = torch.nn.Parameter(torch.rand(q, d))

keys_cross_attention = torch.matmul(W_key_cross_attention,embedded_sentence.T )
values_cross_attention = torch.matmul(W_value_cross_attention,embedded_sentence.T )

query_cross_attention =  torch.matmul(W_query_cross_attention,embedded_sentence_2.T )
print("query_cross_attention shape {}".format(query_cross_attention.shape) )

context_vector_cross_attention = query_cross_attention.T.matmul(F.softmax(keys_cross_attention.matmul(values_cross_attention.T)/k**0.5, dim=0))
print("context_vector_cross_attention shape {}".format(context_vector_cross_attention.shape) )


query_cross_attention shape torch.Size([36, 16])
context_vector_cross_attention shape torch.Size([16, 48])
