In [1]:
!pip install -Uqq torch
!pip install -Uqq numpy

In [54]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Raw Implementation Of GPT Like Model
### Download The Dataset

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-03-08 12:25:10--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-03-08 12:25:12 (1.79 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [5]:
with open('data/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

### Inspect The Data

In [5]:
len(text)

1115394

In [6]:
# first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



### Get The Vocabulary

In [7]:
# get set of all chars in the text and then get that as a sorted list
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

### Tokenize The Input
Since this is a character level language model, we'll just translate individual characters to integers.

Other tokenizers to look into:
1. SentencePiece (Google)
2. Tiktoken (OpenAI)

In [8]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # take a string and output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # take a list of integers and output a string

print(encode("Hello World!"))
print(decode(encode("Hello World!")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Hello World!


In [9]:
# stoi is a lookup table where key is the index and value is the character
type(stoi)

dict

In [9]:
# encode the dataset and get a tensor
# data type is int16 because our vocab size is only 65
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

### Create Training and Validation Splits

In [10]:
n = int(0.9*len(data)) # split 90% of data
train_data = data[:n] # first 90% is training data
val_data = data[n:] # rest is validation data

len(train_data), len(val_data)

(1003854, 111540)

### Create Batches Of Data To Train The Model

Sample random chunks of data from the training set. These chunks are of fixed max length.
In a chunk of 9 characters like `[18, 47, 56, 57, 58,  1, 15, 47, 58]` there are 8 examples for the model to train itself on like:
1. In the context of 18, 47 likely comes next.
2. In the context of 18 and 47, 56 likely comes next and so on.

This also helps the transformer network get used to seeing context length of 1 character upto the max context length.

In [11]:
block_size = 8 # max length of chunks
train_data[:block_size + 1] # first 9 chars in the training set

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [12]:
# x are the inputs to the transformer
x = train_data[:block_size]
# y is the next block
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1] # all chars of x upto t incl. t
    target = y[t]
    print(f'input: {context}\ttarget: {target}')

input: tensor([18])	target: 47
input: tensor([18, 47])	target: 56
input: tensor([18, 47, 56])	target: 57
input: tensor([18, 47, 56, 57])	target: 58
input: tensor([18, 47, 56, 57, 58])	target: 1
input: tensor([18, 47, 56, 57, 58,  1])	target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15])	target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47])	target: 58


In [13]:
# add batching to process multiple inputs simultaneously
batch_size = 4 # number of independent sequences to be processed parallely
block_size = 8 # max length of the context

def get_batch(split):
    # generate a small batch of inputs x and targets y
    data = train_data if split == 'train' else val_data
    # generate batch_size number of random offsets in the dataset
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # stack converts multiple rows into a list of rows
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x,y

In [15]:
# example batches
xb, yb = get_batch('train')
print(f'inputs: {xb}\ninputs_shape: {xb.shape}\n')
print(f'targets: {yb}\ntargets_shape: {yb.shape}\n')

inputs: tensor([[63,  1, 51, 43, 61,  5, 42,  1],
        [59,  1, 42, 43, 40, 39, 57, 43],
        [39,  1, 42, 59, 49, 43, 42, 53],
        [51, 47, 56, 58, 46,  8,  0,  0]])
inputs_shape: torch.Size([4, 8])

targets: tensor([[ 1, 51, 43, 61,  5, 42,  1, 46],
        [ 1, 42, 43, 40, 39, 57, 43,  1],
        [ 1, 42, 59, 49, 43, 42, 53, 51],
        [47, 56, 58, 46,  8,  0,  0, 34]])
targets_shape: torch.Size([4, 8])



In [16]:
# input and target mapping for batches
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'input: {context.tolist()}\ttarget: {target}')

input: [53]	target: 51
input: [53, 51]	target: 39
input: [53, 51, 39]	target: 52
input: [53, 51, 39, 52]	target: 11
input: [53, 51, 39, 52, 11]	target: 1
input: [53, 51, 39, 52, 11, 1]	target: 47
input: [53, 51, 39, 52, 11, 1, 47]	target: 44
input: [53, 51, 39, 52, 11, 1, 47, 44]	target: 1
input: [53]	target: 53
input: [53, 53]	target: 51
input: [53, 53, 51]	target: 10
input: [53, 53, 51, 10]	target: 0
input: [53, 53, 51, 10, 0]	target: 35
input: [53, 53, 51, 10, 0, 35]	target: 46
input: [53, 53, 51, 10, 0, 35, 46]	target: 39
input: [53, 53, 51, 10, 0, 35, 46, 39]	target: 58
input: [53]	target: 49
input: [53, 49]	target: 1
input: [53, 49, 1]	target: 47
input: [53, 49, 1, 47]	target: 52
input: [53, 49, 1, 47, 52]	target: 42
input: [53, 49, 1, 47, 52, 42]	target: 43
input: [53, 49, 1, 47, 52, 42, 43]	target: 43
input: [53, 49, 1, 47, 52, 42, 43, 43]	target: 42
input: [43]	target: 1
input: [43, 1]	target: 53
input: [43, 1, 53]	target: 58
input: [43, 1, 53, 58]	target: 46
input: [43, 1, 53

## Bigram Language Model As A Baseline Model

Right now we're only predicting what comes next based on just the individual identity of a single token. This is because the tokens aren't aware of each other. They can only see themselves. So we're only making predictions based on what the actual token is.

Notice that in the implementation of generate method even though we pass a sequence of characters as context, the Bigram model only looks at the last character in the sequence to make predictions for the next character. The generate method accepts a sequence as context to keep it general.

In [17]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # logits are basically the scores for the next character in a sequence
        # channel means all the possible tokens (here chars) you can have
        logits = self.token_embedding_table(idx) # (B, T, C) Batch, Time, Channel

        # loss function
        if targets is None:
            loss = None
        else:
            # logits need to be reshaped because cross_entropy expects channels as the second dimension
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # cross_entropy calculates loss a -log likelihood: -ln(char/65)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    # idx is the current context of some characters
    # generate function extends the input (B, T) to B by T+1, T+2 and so on...
    # and continues to do so for max_new_tokens
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # focus only on the last element in the time dimension
            logits = logits[:, -1, :]
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the probability distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

In [18]:
# example prediction
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(f'logits: {logits}\nloss: {loss}\n{logits.shape}')

logits: tensor([[-0.9729,  0.2434, -0.1356,  ...,  0.1614, -1.3162, -0.0710],
        [-1.3694,  0.4819, -0.6065,  ...,  1.4325,  2.7159, -0.2356],
        [ 2.1531, -1.6002,  0.9560,  ...,  0.2818,  1.5296,  0.7737],
        ...,
        [ 0.5601, -2.4942, -1.4284,  ...,  0.2002, -1.6565,  0.5220],
        [ 1.9211, -0.2959,  1.6537,  ..., -0.5056,  0.7793,  0.9916],
        [-0.1882,  0.4575,  0.3669,  ..., -0.2465, -1.4268, -0.8499]],
       grad_fn=<ViewBackward0>)
loss: 4.5496134757995605
torch.Size([32, 65])


In [19]:
# example generation
# B = 1 and T = 1 to kick off the generation
# 0 is also encoded as \n, which is a good place to start
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) #[0] to get a single batch


BL&xr
rwv!IepVIlbjFAHzCYJifXG&3&!a;pCZntuF;YjfRAuFAl;:lGJXTPYn!HM-- sWmJdL&NfcVJdF$nMvzbYjgvUpb'?vpC


### Training The Bigram Model

In [20]:
# creating a pytorch optimizer to get the gradients and update the parameters
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [21]:
batch_size = 32
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f'loss: {loss.item()}')

loss: 2.4108874797821045


In [22]:
# test generation
print(decode(m.generate(idx, max_new_tokens=400)[0].tolist()))


NGowerarisaxed, f I s m fouiethan ond mellchashe the THe orongendy'le, k, nk merp er q$ffffos, twidy
A:
Cousfress sefthe ma nd theaw'swate ghered,
I
Aur KETor byerin fame oGS:
lee.
NRFre:
I messen:
Cor D&
Cl
YWhind, her bee ted t, w trd Cotltixet ce t HENG me.
s worede of bor t tosesthesoicassw ay ir oukipserfonove; y Gise.
Wharar:
Whiro th.
W: ceandutous rtiffomy

VI wemofancat r,
Coen wher moner


# Implementing Self Attention
We want the tokens to be aware of each other. Specifically we want the current token to be aware of the tokens that have appeared before (and not the tokens that come after) and couple them. To do that we calculate the average of all the previous token which then acts as the summary of all the information before the current token which it can use to predict the next token. Keep in mind this approach loses out on a lot of spacial information about the previous tokens.

In [23]:
# consider an example sample
# B represents number of batchs
# T represents number of sequences of tokens in each batch
# C represents the number of token embeddings in each sequence
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
print(f'shape of x: {x.shape}\n')
print(f'x: {x}\n')
print(f'first batch is: {x[0]}')
B, T, C = x.shape
print(f'num sequences (T) in each batch: {T}\n')
print(f'num tokens (C) in each sequence: {C}\n')

shape of x: torch.Size([4, 8, 2])

x: tensor([[[-0.5064, -0.8704],
         [ 0.2256, -0.6402],
         [-0.6735,  1.1263],
         [-0.2469, -0.6209],
         [ 1.8725,  0.1073],
         [ 0.0131,  1.2031],
         [ 1.0011,  0.9082],
         [ 1.0427, -0.4253]],

        [[ 0.7450,  0.3972],
         [ 1.0596,  0.4509],
         [-1.7042,  0.4181],
         [ 0.9725, -0.7836],
         [ 0.2181, -0.6505],
         [-1.6790, -0.2189],
         [-0.5362,  0.4421],
         [-0.5073, -0.4883]],

        [[ 0.8481, -0.9877],
         [ 0.6018, -1.3321],
         [ 0.7131, -1.4039],
         [ 2.3594, -1.2647],
         [ 0.4230, -0.4270],
         [ 0.5923,  1.3385],
         [-0.0307,  0.2107],
         [-0.4659, -0.2931]],

        [[ 0.0439, -2.2378],
         [ 0.5821, -1.7240],
         [-2.1801,  0.0250],
         [-0.6911, -0.1005],
         [ 0.5991, -0.4966],
         [ 0.6118,  0.1392],
         [ 0.2605, -0.5143],
         [ 0.0079,  0.6442]]])

first batch is: tensor([[

In the above sample:
1. x is a list of batches of tokens from the input text. N(batches) = B. Each batch in our case is a random sample from the input text. This means x is the encoded form of a sequence of characters from somewhere in the text.
2. These sequences , in encoded form, are of fixed number of tokens and are layered one after the other in a batch. N(sequences in a batch) = T.
3. Each sequence is an ordered list of individual tokens that are present in the sequence. N(tokens in a sequence) = C. Therefore the dimention C contains the information about the position and value of a token in a randomly sampled sequence of tokenized text.

The cell below shows a batch from the input text to compare a batch from our random sample:

In [21]:
# a batch from the training set
print(f'batch from training input:\n{xb}\n')
print(f'shape of each batch: {xb.shape}\n')
T, C = xb.shape
print(f'num sequences (T) in each batch: {T}\n')
print(f'num tokens (C) in each sequence: {C}\n')

batch from training input:
tensor([[63,  1, 51, 43, 61,  5, 42,  1],
        [59,  1, 42, 43, 40, 39, 57, 43],
        [39,  1, 42, 59, 49, 43, 42, 53],
        [51, 47, 56, 58, 46,  8,  0,  0]])

shape of each batch: torch.Size([4, 8])

num sequences (T) in each batch: 4

num tokens (C) in each sequence: 8



### Self Attention V1
We'll calculate the average of all tokens in a sequence preceding the current token and including the current token to couple the context with the token. This is how a current token can become aware of the previous tokens in its sequence.

Keep in mind this way of communicating context is extremely lossy. We lose out on information about the spacial arrangement of all the tokens in that sequence, we can only see the average.

In [46]:
# calc average of prev tokens
# this implementation is inefficient
xbow = torch.zeros((B, T, C)) # bow means bag of words i.e. average of the prev tokens

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # everything upto the current token including the current token
        xbow[b, t] = torch.mean(xprev, 0)

print(f'xbow for the sample:\n{xbow}\n')

xbow for the sample:
tensor([[[-0.5064, -0.8704],
         [-0.1404, -0.7553],
         [-0.3181, -0.1281],
         [-0.3003, -0.2513],
         [ 0.1343, -0.1796],
         [ 0.1141,  0.0509],
         [ 0.2408,  0.1734],
         [ 0.3410,  0.0985]],

        [[ 0.7450,  0.3972],
         [ 0.9023,  0.4240],
         [ 0.0335,  0.4220],
         [ 0.2682,  0.1206],
         [ 0.2582, -0.0336],
         [-0.0647, -0.0645],
         [-0.1320,  0.0079],
         [-0.1789, -0.0541]],

        [[ 0.8481, -0.9877],
         [ 0.7249, -1.1599],
         [ 0.7210, -1.2412],
         [ 1.1306, -1.2471],
         [ 0.9891, -1.0831],
         [ 0.9229, -0.6795],
         [ 0.7867, -0.5523],
         [ 0.6301, -0.5199]],

        [[ 0.0439, -2.2378],
         [ 0.3130, -1.9809],
         [-0.5180, -1.3123],
         [-0.5613, -1.0093],
         [-0.3292, -0.9068],
         [-0.1724, -0.7324],
         [-0.1106, -0.7013],
         [-0.0957, -0.5331]]])



### Understanding What's Happening In The Above Loop

For each batch we are creating a new tensor xprev which stores all the previous sequences upto an including the current sequence(looping across T dim). The bow sequence (values in T dim) for each sequence will be the mean of the corresponding xprev.

Question: The mean is calculated only across the 0th dim. Does this mean that we are only calculating the average of the tokens in the same index for previous sequences? Is the context only moving vertically instead of horizontally?


In [47]:
example_xbow = torch.zeros((T, C))
first_batch = x[0]
print(f'first_batch:\n{first_batch}\n')

for t in range(T):
    xprev = first_batch[:t+1]
    print(f'sequences before and including current sequence:\n{xprev}\n')
    example_xbow[t] = torch.mean(xprev, 0) # take the mean of all the token sequences in the 0th dim

print(f'example xbow:\n{example_xbow}\n')

first_batch:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402],
        [-0.6735,  1.1263],
        [-0.2469, -0.6209],
        [ 1.8725,  0.1073],
        [ 0.0131,  1.2031],
        [ 1.0011,  0.9082],
        [ 1.0427, -0.4253]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402],
        [-0.6735,  1.1263]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402],
        [-0.6735,  1.1263],
        [-0.2469, -0.6209]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402],
        [-0.6735,  1.1263],
        [-0.2469, -0.6209],
        [ 1.8725,  0.1073]])

sequences before and including current sequence:
tensor([[-0.5064, -0.8704],
        [

In [48]:
# first batch and its corresponding xbow
x[0], xbow[0]
print(f'first batch:\n{x[0]}\n')
print(f'xbow for the first batch:\n{xbow[0]}\n')

first batch:
tensor([[-0.5064, -0.8704],
        [ 0.2256, -0.6402],
        [-0.6735,  1.1263],
        [-0.2469, -0.6209],
        [ 1.8725,  0.1073],
        [ 0.0131,  1.2031],
        [ 1.0011,  0.9082],
        [ 1.0427, -0.4253]])

xbow for the first batch:
tensor([[-0.5064, -0.8704],
        [-0.1404, -0.7553],
        [-0.3181, -0.1281],
        [-0.3003, -0.2513],
        [ 0.1343, -0.1796],
        [ 0.1141,  0.0509],
        [ 0.2408,  0.1734],
        [ 0.3410,  0.0985]])



### Using Matrix Multiplication To Make Average Calculation More Efficient

In [49]:
# matrix multiplication example

# we are multiplying row i of a with column i of b
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'matrix a:\n{a}')
print(f'matrix b:\n{b}')
print(f'matrix c:\n{c}')

matrix a:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
matrix b:
tensor([[5., 5.],
        [0., 6.],
        [6., 4.]])
matrix c:
tensor([[11., 15.],
        [11., 15.],
        [11., 15.]])


In [27]:
# getting a lower triangular matrix
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

### Performing MatMul Using A Lower Triangular Matrix

When we multiply by a tril matrix of ones then the resultant matrix is a matrix in which each element is a sum of only the elements in the previous and current rows(in the corresponding column). So we can use this to calculate mean by using MatMul which is much faster than looping. By dividing each row by its sum (in the 1th dim) and then performing MatMul we'll get xbow.

In [50]:
# performing multiplication with a tril matrix
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'matrix a:\n{a}')
print(f'matrix b:\n{b}')
print(f'matrix c:\n{c}')

matrix a:
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
matrix b:
tensor([[8., 5.],
        [6., 1.],
        [5., 9.]])
matrix c:
tensor([[ 8.,  5.],
        [14.,  6.],
        [19., 15.]])


In [28]:
# using a lower triangular matrix gives each element in c matrix as sum of prev elements upon matrix multiplication
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # normalize each row such that sum of all elements in a row is one
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b # in c each row will be the average of the previous rows
print(f'matrix a:\n{a}')
print(f'matrix b:\n{b}')
print(f'matrix c:\n{c}')

matrix a:
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
matrix b:
tensor([[3., 1.],
        [9., 5.],
        [1., 7.]])
matrix c:
tensor([[3.0000, 1.0000],
        [6.0000, 3.0000],
        [4.3333, 4.3333]])


## Self Attention V2

In [51]:
# version 2
# wei (short for weights) is the matrix we'll multiply our inputs with to get xbow
wei = torch.tril(torch.ones(T, T)) # T, T because we want to mul with T, C to get T, C
wei = wei / wei.sum(1, keepdim = True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [30]:
# pytorch will perform a batched MatMul in parallel
xbow2 = wei @ x # ((B), T, T) @ (B, T, C) -> (B, T, C) for each batch a T, T multiplies to a T, C
xbow2

tensor([[[-0.0613,  1.0447],
         [-0.4936,  0.2141],
         [-0.8303,  0.1645],
         [-0.4739,  0.1569],
         [-0.1582, -0.1843],
         [-0.0348, -0.0981],
         [ 0.1552, -0.1741],
         [ 0.2458, -0.2741]],

        [[ 1.6547,  0.5282],
         [ 0.2378,  0.5293],
         [ 0.3971,  0.9392],
         [ 0.3010,  0.7072],
         [ 0.2754,  0.6264],
         [ 0.3268,  0.7056],
         [ 0.1109,  0.5249],
         [ 0.2850,  0.5235]],

        [[ 1.0538,  0.1509],
         [ 0.9856, -0.1346],
         [ 0.4760,  0.4241],
         [ 0.1226,  0.6300],
         [-0.0399,  0.4964],
         [ 0.1006,  0.6127],
         [ 0.0724,  0.6003],
         [-0.0534,  0.5590]],

        [[ 1.1385,  0.4626],
         [-0.0193, -0.1204],
         [ 0.3171,  0.2542],
         [ 0.2987,  0.1163],
         [ 0.1737,  0.0299],
         [ 0.3123, -0.1032],
         [ 0.0927,  0.0549],
         [ 0.1127,  0.2565]]])

## Self Attention V3

We begin with wei being all zeros as a blank state to quantify how much of each token from the past do we want to aggregate. By performing the masked_fill operation we make sure that the tokens in the future aren't counted as they become zero when the softmax acts on them. Softmax converts all zeros to ones and -infs to zeros and then divides each row by the sum of the elements in that row. When we multiply this matrix to our input matrix we get our xbow.

In [53]:
# version 3 (using softmax)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
# perform a softmax across all elements in a row, i.e. convert 0 to 1 and -inf to 0
# and then normalize the row, i.e. divide by the sum of the elements in the row
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x

xbow3

tensor([[[-0.5064, -0.8704],
         [-0.1404, -0.7553],
         [-0.3181, -0.1281],
         [-0.3003, -0.2513],
         [ 0.1343, -0.1796],
         [ 0.1141,  0.0509],
         [ 0.2408,  0.1734],
         [ 0.3410,  0.0985]],

        [[ 0.7450,  0.3972],
         [ 0.9023,  0.4240],
         [ 0.0335,  0.4220],
         [ 0.2682,  0.1206],
         [ 0.2582, -0.0336],
         [-0.0647, -0.0645],
         [-0.1320,  0.0079],
         [-0.1789, -0.0541]],

        [[ 0.8481, -0.9877],
         [ 0.7249, -1.1599],
         [ 0.7210, -1.2412],
         [ 1.1306, -1.2471],
         [ 0.9891, -1.0831],
         [ 0.9229, -0.6795],
         [ 0.7867, -0.5523],
         [ 0.6301, -0.5199]],

        [[ 0.0439, -2.2378],
         [ 0.3130, -1.9809],
         [-0.5180, -1.3123],
         [-0.5613, -1.0093],
         [-0.3292, -0.9068],
         [-0.1724, -0.7324],
         [-0.1106, -0.7013],
         [-0.0957, -0.5331]]])

## Self Attention V4
Notice that `wei` has the same value for lower triangular elements in the same row. We don't want this to be uniform because different tokens find different tokens useful. This affinity should be data dependent. Self attention solves this problem by:

Every single token emits two vectors `key` and `query`. The `query` vector signifies "what am I looking for?" and `key` vector signifies "what do I contain?". To get the affinities between the tokens in a sequence is by taking the dot product between `keys` and `queries`. The `query` of a token dot products with all the `keys` of the other tokens and that dot product becomes `wei` i.e. the affinities.

This means that if the `key` and the `query` are aligned then they will interact more and their affinity for each other will be high in comparison to other tokens in the sequence.

Note:
This attention is called self attention because the keys, queries and the values, all come from the same source x.
In encoder-decoder transformers we can have keys from x but queries and values can come from a different source. This is called cross attention where we have nodes from a separate source which we'd like to pool information from.

### Questions
1. What is the significance of `head_size`?
2. Why use a Linear network to get head and value?

In [55]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# create a single head of self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False) # v is what x communicates, the value that is passed for aggregation

# all tokens produce a key and a query in parallel
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x)

# transpose last two dimensions to multiply
# (B, T, 16) @ (B, 16, T) -> (B, T, T)
# when all the queries dot product with all the keys, communication happens
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # multiply with sqrt of head_size to bring down variance to order of 1

print(f'key: {k}\nquery: {q}\n')
print(f'key_shape: {k.shape}\nquery_shape: {q.shape}\n')

tril = torch.tril(torch.ones(T, T))
# we have implemented a decoder block
# to make it an encoder block we simply remove this line
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ v

print(f'out: {out}\n')
print(f'out_shape: {out.shape}\n')

key: tensor([[[ 2.5851e-01, -4.1287e-01, -3.4331e-01,  7.8676e-01,  5.9700e-01,
          -1.9373e-01,  4.3467e-01,  7.6232e-01,  1.6628e-01, -1.0726e+00,
           4.7617e-01, -2.1293e-01, -1.3552e-01, -2.3179e-01, -4.1144e-01,
          -1.6659e-01],
         [-1.1676e+00,  7.9744e-03, -4.3502e-02, -8.9976e-02, -1.1000e+00,
          -1.0923e-01,  6.4384e-01, -4.8356e-02,  8.1363e-01,  4.7386e-01,
           7.5937e-01, -1.8870e-01, -3.2153e-01,  5.1348e-01,  1.6355e-01,
           6.3696e-01],
         [ 1.2313e-01, -8.0748e-02, -1.0561e-01,  1.0522e-01, -1.8357e-01,
          -4.5372e-01,  2.1923e-01,  6.7337e-01,  4.6693e-01,  1.5496e+00,
           5.2441e-01,  8.6864e-01,  4.0643e-01, -8.4416e-01, -8.0568e-02,
           7.1581e-01],
         [ 2.7439e-01, -8.5302e-02,  3.0850e-01,  2.1752e-01, -2.8407e-01,
           3.1418e-01, -5.2261e-01,  1.9404e-01, -2.0093e-01, -1.9599e-01,
           6.4786e-02,  1.6166e-01, -7.1450e-01, -8.9573e-01, -3.8610e-01,
           8.0427e-01],

In [56]:
wei[0]

tensor([[0.1509, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1041, 0.1390, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1052, 0.1119, 0.1155, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1041, 0.1725, 0.1721, 0.1845, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1234, 0.1602, 0.1658, 0.1935, 0.3870, 0.0000, 0.0000, 0.0000],
        [0.1304, 0.1850, 0.2660, 0.1928, 0.2238, 0.3054, 0.0000, 0.0000],
        [0.1776, 0.1325, 0.1601, 0.2673, 0.2169, 0.3796, 0.4267, 0.0000],
        [0.1042, 0.0989, 0.1205, 0.1618, 0.1723, 0.3150, 0.5733, 1.0000]],
       grad_fn=<SelectBackward0>)

In [71]:
first_random_batch = x[0]
print(f'first random batch:\n{first_random_batch}\n')

head_size = 16
key = nn.Linear(C, head_size, bias=False)

keys = key(first_random_batch)
print(f'keys for x:\n{keys}')

first random batch:
tensor([[ 1.1596, -1.0325,  0.5718, -0.5659, -1.1066, -1.1276,  1.0541, -0.7281,
          0.4765,  0.5276,  1.3571,  0.5979, -0.5872, -0.3202, -0.4962,  1.2043,
          1.5086, -1.5002,  2.3669,  1.5391,  0.0814, -0.0291,  1.0651, -1.0877,
          0.3377, -0.1240,  0.2149, -0.0827, -0.3799, -0.1128, -0.9547,  0.2862],
        [ 0.7990,  2.1681,  0.5425, -0.2722, -0.5659,  0.8444, -0.3682, -0.1785,
         -0.2294,  0.0187, -0.7815, -0.5654, -0.1671,  0.7232,  0.9447,  1.0631,
          1.0371,  0.0221,  1.6805, -0.7907,  1.2298, -0.2323,  0.2118,  1.0211,
          0.0134, -0.8806,  0.5978,  0.0731,  0.4269,  1.7898,  0.9360, -0.6859],
        [ 0.3657, -0.1974, -1.0244,  0.2563, -0.4931,  0.1098, -1.8298, -0.2591,
          0.1388, -1.2638, -0.4718, -0.7666, -0.0617,  0.4684,  0.3686, -1.5033,
         -0.2658, -0.3045,  0.1504,  0.6511, -0.6238,  0.5750, -0.0661, -0.3766,
         -1.5509,  2.1236, -0.2280,  1.7334, -0.7190,  0.4774, -0.6639,  0.7482],
     

### Multi Head Attention
We can apply multiple attentions in parallel and concatenate their results.