#### We will write out some basic stuff relaled to Self-Attention which is one of the ingredients of Transformers 

Some references:
- https://arxiv.org/abs/2207.09238
- https://arxiv.org/pdf/2105.09121.pdf
- https://johnthickstun.com/docs/transformers.pdf
- https://peterbloem.nl/blog/transformers
- https://github.com/mlvu/worksheets/blob/master/Worksheet%205%2C%20Pytorch.ipynb
- https://github.com/pbloem/former/issues/4
- https://towardsdatascience.com/illustrated-guide-to-transformers-step-by-step-explanation-f74876522bc0
- 


A statistical model is autoregressive if it predicts future values based on past values. 
Except for inputs, cross-attention calculation is the same as self-attention.


Cross-attention combines asymmetrically two separate (KV) embedding sequences of same dimension, 
in contrast self-attention input is a single embedding sequence.

Read up:

- https://doordash.engineering/2022/09/08/evolving-doordashs-substitution-recommendations-algorithm/
- https://arxiv.org/pdf/2205.11728.pdf
- https://arxiv.org/pdf/2102.06156.pdf

In [4]:
# Seattle, 3 Apr 2023
# 3.10.9 Conda 

import torch
from torch import nn
import torch.nn.functional as F
from pytorch_transformers import *



b, t, k = 10, 20, 30 
# We’ll represent the input, a sequence of 
# t vectors of dimension k as a t by k matrix 𝐗. 
# Including a minibatch dimension b, gives us 
# an input tensor of size (b,t,k).

# t is max.length of sentence, k is dim. of inputs

# Masking plays an important role in the transformer. It serves two purposes:
# In the encoder and decoder: To zero attention outputs wherever there is just padding in the input sentences.
# In the decoder: To prevent the decoder ‘peaking’ ahead at the rest of the translated 
# sentence when predicting the next word.# 



# assume we have some tensor x with size (b, t, k)
x = torch.rand(b,t,k)

raw_weights = torch.bmm(x, x.transpose(1, 2))
#raw_weights = torch.einsum('xyz,xzy', x, x)
weights = F.softmax(raw_weights, dim=2)
raw_weights.shape, x.transpose(1, 2).shape




(torch.Size([10, 20, 20]), torch.Size([10, 30, 20]))

In [6]:
class SelfAttention(nn.Module):

    def __init__(self, k, heads=4, mask=False):
      
        super().__init__()
    
        assert k % heads == 0
        # the embedding dimension needs to be divisible by the number of heads.

        # heads > 1 is multi-head self attention 

        self.k, self.heads = k, heads

        # These compute the queries, keys and values for all
        # heads
        self.tokeys    = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues  = nn.Linear(k, k, bias=False)

	    # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):

        b, t, k = x.size()
        h = self.heads

        queries = self.toqueries(x)
        keys    = self.tokeys(x)   
        values  = self.tovalues(x)

        s = k // h

        keys    = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values  = values.view(b, t, h, s)

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # Option 1 
        queries = queries 
        keys    = keys

        # Get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # -- dot has size (b*h, t, t) containing raw weights
        
        # q and k gives weight 

        # scale the dot product
        dot = dot / (k ** (1/2))
	
        # normalize 
        dot = F.softmax(dot, dim=2)
        # - dot now contains row-wise normalized weights

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, s)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)


        # Option 2 (will make it faster!) 

        #dot = torch.einsum('bthe,bihe->bhti', queries, keys) / (k ** (1/2))
        #dot = F.softmax(dot, dim=-1)
        #out = torch.einsum('bhtd,bdhe->bthe', dot, values)
        #out = torch.einsum('bthe,khe->btk', out, self.unifyheads.weight.view(e,h,e)) 
        #return out + self.unifyheads.bias
    
        return self.unifyheads(out)



In [20]:
from torchtext.transforms import BERTTokenizer
VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
sequence_a = "This is a short sequence."
sequence_b = "This is a rather long sequence. It is at least longer than the sequence A."
tokenizer = BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, return_tokens=True)
encoded_sequence_a = tokenizer(sequence_a)
len(encoded_sequence_a)
encoded_sequence_a

100%|██████████| 232k/232k [00:00<00:00, 1.17MB/s]


['this', 'is', 'a', 'short', 'sequence', '.']