In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [2]:
# let's look at the first 500 characters
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [3]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
print(chars)
vocab_size = len(chars)
# print(''.join(chars))
print(vocab_size)


# We can tradeoff between the size of the vocabulary(code book size) and the length of the sequences.

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [4]:
# Tokenization: convert characters to integers and vice versa
# Character level language model

# Example tokenization methods: SentencePiece, Byte Pair Encoding (BPE), WordPiece

# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)} # character:integer mapping
itos = {i: ch for i, ch in enumerate(chars)} # integer:character mapping
# encoder: take a string, output a list of integers
def encode(s): return [stoi[c] for c in s]
# decoder: take a list of integers, output a string
def decode(l): return ''.join([itos[i] for i in l])


print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [5]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)

# 1115394 is total number of characters in the dataset
print(data.shape, data.dtype)
# the 200 characters we looked at earier will to the GPT look like this
print(data[:200])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59])


In [8]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
# block size or context length is the length of the input sequence, how many tokens
# we feed to the model at once
block_size = 8

print(train_data[:block_size+1]) # tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

# These are the first 9 characters in the training data. When we sample chunk of data like this,
# this has multiple examples or samples packed into it because all these characters follow each other.
# And when we plug it into the transformer, we are going to simultaneously train it to make prediction
# at one of these positions.
# We can think of this as a sliding window over the data and sample the next character as the target.
# So in the chunk of 9 characters, there are 8 individual examples in there.

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])


In [12]:
x = train_data[:block_size]
y = train_data[1:block_size+1] # Offset by one, the target is the next character in the sequence

print(f"x: {x} \ny: {y}")



# These are the 8 examples hidden in the chunk of 9 characters that we sampled from the training set.
# We train on all the examples here with context between 1 all the way upto the context of block_size i.e. 8.
# t can be thought of as the time dimension, the position in the sequence or the length of the input sequence.
# t can go from 0 to block_size-1.
for t in range(block_size):
    # print(t) [0, 1, 2, 3, 4, 5, 6, 7]
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

x: tensor([18, 47, 56, 57, 58,  1, 15, 47]) 
y: tensor([47, 56, 57, 58,  1, 15, 47, 58])
when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [15]:
torch.manual_seed(1337)
batch_size = 4  # how many independent sequences or examples will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data

    # Generate 4(batch_size) random integers(starting indices) between 0 and len(data) - block_size.
    ix = torch.randint(len(data) - block_size, (batch_size,))   # shape (batch_size,)
    x = torch.stack([data[i:i+block_size] for i in ix], dim=0)  # (batch_size, block_size)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix], dim=0)  # (batch_size, block_size)

    # sequences_x = []
    # sequences_y = []
    # for i in ix:
    #     seq_x = data[i:i+block_size+1]
    #     sequences_x.append(seq_x)

    #     seq_y = data[i+1:i+block_size+1]
    #     sequences_y.append(seq_y)
    # x = torch.stack(sequences_x, dim=0) # (batch_size, block_size+1)
    # y = torch.stack(sequences_y, dim=0) # (batch_size, block_size)

    return x, y


xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('---------------------------------------------------')

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
---------------------------------------------------
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 

In [16]:
print(xb)  # our input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


In [None]:
print(xb)  # our input to the transformer

In [36]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

embedding_size = 32  # each character is represented by a vector of size embedding_size
block_size = 8  # context length, how many characters we feed to the model at once


class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size, embedding_size=embedding_size):
        super().__init__()
        # Each character is represented by a vector of size embedding_size.
        # So in our case we have a vocabulary of size vocab_size = 65 and each character
        # is represented by a vector of size embedding_size = 32.
        # In this case, the token embedding table is a square matrix of size (vocab_size, embedding_size).
        # Each row of this matrix corresponds to a character in the vocabulary with its embedding vector.

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_size)

        # each position from 0 to block_size-1 will also have a corresponding embedding vector
        self.positional_embedding_table = nn.Embedding(block_size, embedding_size)  # positional embeddings

        self.lm_head = nn.Linear(embedding_size, vocab_size)  # output layer to map embeddings to logits

    def forward(self, idx, targets=None):

        # idx is of shape (B, T) where B is the batch size and T is the block size, here (4, 8)
        # idx and targets are both (B,T) tensor of integers
        # So each of the 8 characters (integer correspoding to that character) for every example
        # will index into the token embedding table to get their corresponding embedding vectors.
        # So every single integer in our input sequence is going to refer to the embedding table
        # and is going to pluck out a row from the embedding table corresponding to that integer
        # i.e it's index and every integer will be replaced by its corresponding embedding vector.

        B, T = idx.shape

        token_embeddings = self.token_embedding_table(idx)  # (B,T,C) = (4, 8, 32)

        # All the integers from 0 to block_size-1 or T - 1 will get embedded through the
        # positional embedding table to get and pluck out a row or embedding vector
        # corresponding to that integer.
        positional_embeddings = self.positional_embedding_table(torch.arange(T))  # (T, C) = (8, 32)

        # now x holds not just the token identities but the positions at which the tokens occur
        x = token_embeddings + positional_embeddings  # (B, T, C) = (4, 8, 32)
        # x = token_embeddings

        logits = self.lm_head(x)  # (B, T, vocab_size) = (4, 8, 65)

        if targets is None:
            loss = None
        else:
            B, T, vocab_size  = logits.shape
            logits = logits.view(B*T, vocab_size) # (32, 65)
            targets = targets.view(B*T)  # (32,)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):

            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]

            # get the predictions
            logits, loss = self(idx_cond) # logits is (B, T, C)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)


# kick off generation with character corresponding to 0 which is a 'new line character'
idx = torch.zeros((1, 1), dtype=torch.long) # 1 batch, 1 context length
generated = m.generate(idx=idx, max_new_tokens=100)
print(generated.shape) # [1, 101] 1 for batch, we generate until context length becomes 100, 101 tokens (1 initial + 100 generated)
print(decode(generated[0].tolist()))  # decode the generated sequence

torch.Size([256, 65])
tensor(4.4649, grad_fn=<NllLossBackward0>)
torch.Size([1, 101])

SmVddomOVWydk3'BrlK
QduWGGHPmiu&UXBlIHTZ'yfsDuEtqWPUlOZt&-lV&qBohwN.l;N3z:miimwvg,gAo3EPN3hOw$!VyTuE


In [37]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)


batch_size = 32
for steps in range(1000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.569643020629883


In [38]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


IAthape f gidom:
n, t nrd v t:
Q oto t hieret wouan r chersod  sesthande pKs s the s bleaunthisid EN
WES:
Te hee woce ld g wot.
An damedrpl ph an
AKe hameg thy cavind trkelyovesyave
TAKoude uth veillaBartga;
$lenu inonofete mrave

Mvatlo mis'ray
H t?

Cpe wirs watheanthinerense athandst.
AO-pit chesellongoEx theroidlos tery I uc ithe methene,a hd de Wear at yt be hite phanim mpasig' an.

rathan?
 llloc

ASesipp wilele ba eyo t nll tond pthealders as pde llaldsthed,
AE lei're f ML
Cxhalis ton y b


In [27]:
torch.arange(8)

tensor([0, 1, 2, 3, 4, 5, 6, 7])

###  The Mathematical Trick of Self Attention

In [58]:
torch.manual_seed(1337)

B, T, C = 4, 5, 7 # batch, time, channels

x = torch.randint(10, (B, T, C)).float()
# x = torch.ones((B, T, C))

# We want to compute the attention scores(for now just the average) for each token in the sequence
# with respect to every other token preceeding it and not the ones that come after it.
# We can do this by taking the average of the token embeddings for each token in the sequence
# upto the current token and not the ones that come after it.
# This is a very simple example of self attention.
# So say for token 5 in sequence of 8 tokens, we want to average the embeddings of tokens
# i,e channels of 0, 1, 2, 3, 4 and 5 and that will be the attention score for token 5.
print(x.shape)
print(x[1])


# Version 1
scores = torch.zeros((B, T, C))  # initialize scores tensor to hold the attention scores
for b in range(B):  # batch dimension
    for t in range(T):  # time dimension

        # shape (t, C), t will go from 0 to T-1 i.e 0 to 4 in this case, (1, 7) (2, 7) (3, 7) (4, 7) (5, 7)
        prev = x[b, :t+1, :]  # take all the tokens upto and including the current token
        attention_score = torch.mean(prev, dim=0)  # average the embeddings across the time dimension | shape (7,)
        # print(attention_score.shape) # (C,) # (7,)
        scores[b, t] = attention_score  # store the attention score in the scores tensor

print(scores.shape)  # (B, T, C) = (4, 5, 7)
scores[1]

torch.Size([4, 5, 7])
tensor([[9., 9., 3., 9., 8., 4., 8.],
        [3., 5., 4., 3., 3., 8., 9.],
        [1., 6., 8., 6., 0., 0., 3.],
        [4., 3., 8., 6., 6., 1., 7.],
        [6., 8., 6., 5., 2., 4., 4.]])
torch.Size([4, 5, 7])


tensor([[9.0000, 9.0000, 3.0000, 9.0000, 8.0000, 4.0000, 8.0000],
        [6.0000, 7.0000, 3.5000, 6.0000, 5.5000, 6.0000, 8.5000],
        [4.3333, 6.6667, 5.0000, 6.0000, 3.6667, 4.0000, 6.6667],
        [4.2500, 5.7500, 5.7500, 6.0000, 4.2500, 3.2500, 6.7500],
        [4.6000, 6.2000, 5.8000, 5.8000, 3.8000, 3.4000, 6.2000]])

In [71]:
# Version 2

# Doing the same thing with Matrix Multiplication
# We can do the same thing with matrix multiplication.
# We can take the average of all the tokens upto and including the current token
# by multiplying the token embeddings with a triangular matrix of ones.
# This triangular matrix will have ones in the upper triangle and zeros in the lower triangle.
# This will ensure that we only take the tokens upto and including the current token
triangular_matrix = torch.tril(torch.ones((T, T)))  # shape (T, T) = (5, 5)
print(triangular_matrix)
scores = triangular_matrix @ x
print(scores.shape)  # (B, T, C) = (4, 5, 7)
print(x[1])
print(scores[1])  # this gives us the sum of all the tokens upto and including the current token
# But we want the average, so we divide by the number of tokens

# print(torch.arange(1, T+1)) # tensor([1, 2, 3, 4, 5])
scores = scores / torch.arange(1, T+1).view(1, T, 1)  # shape (1, T, 1) = (1, 5, 1)
print(scores[1])  # this gives us the average of all the tokens upto and including the current token

# OR
triangular_matrix = triangular_matrix / triangular_matrix.sum(dim=1, keepdim=True)  # normalize the triangular matrix
print(triangular_matrix) # each row sums to 1
scores = triangular_matrix @ x  # now this is the average of all the tokens upto and including the current token
# (T, T) @ (B, T, C) = (B, T, C)
print(scores[1])

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])
torch.Size([4, 5, 7])
tensor([[9., 9., 3., 9., 8., 4., 8.],
        [3., 5., 4., 3., 3., 8., 9.],
        [1., 6., 8., 6., 0., 0., 3.],
        [4., 3., 8., 6., 6., 1., 7.],
        [6., 8., 6., 5., 2., 4., 4.]])
tensor([[ 9.,  9.,  3.,  9.,  8.,  4.,  8.],
        [12., 14.,  7., 12., 11., 12., 17.],
        [13., 20., 15., 18., 11., 12., 20.],
        [17., 23., 23., 24., 17., 13., 27.],
        [23., 31., 29., 29., 19., 17., 31.]])
tensor([[9.0000, 9.0000, 3.0000, 9.0000, 8.0000, 4.0000, 8.0000],
        [6.0000, 7.0000, 3.5000, 6.0000, 5.5000, 6.0000, 8.5000],
        [4.3333, 6.6667, 5.0000, 6.0000, 3.6667, 4.0000, 6.6667],
        [4.2500, 5.7500, 5.7500, 6.0000, 4.2500, 3.2500, 6.7500],
        [4.6000, 6.2000, 5.8000, 5.8000, 3.8000, 3.4000, 6.2000]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.00

In [74]:
# Version 3

# Using softmax to compute the attention scores
# We can also use softmax to compute the attention scores.

triangular_matrix = torch.tril(torch.ones((T, T)))  # shape (T, T) = (5, 5)
wei = torch.zeros((T, T))  # initialize weights tensor to hold the attention scores

# In the wei tensor, we will set the upper triangle to -inf. The places where we have 0 in the
# triangular matrix, we will set the corresponding positions in the wei tensor to -inf.
#  So that when we apply softmax, the softmax will ignore those positions and assign 0 probability to them.
wei = wei.masked_fill(triangular_matrix == 0, float('-inf'))  # set the upper triangle to -inf
print(wei)

wei = F.softmax(wei, dim=-1)  # apply softmax to get the attention scores
print(wei)

scores = wei @ x  # (T, T) @ (B, T, C) = (B, T, C)
print(x[1])
print(scores[1])  # this gives us the average of all the tokens upto and including the current token

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])
tensor([[9., 9., 3., 9., 8., 4., 8.],
        [3., 5., 4., 3., 3., 8., 9.],
        [1., 6., 8., 6., 0., 0., 3.],
        [4., 3., 8., 6., 6., 1., 7.],
        [6., 8., 6., 5., 2., 4., 4.]])
tensor([[9.0000, 9.0000, 3.0000, 9.0000, 8.0000, 4.0000, 8.0000],
        [6.0000, 7.0000, 3.5000, 6.0000, 5.5000, 6.0000, 8.5000],
        [4.3333, 6.6667, 5.0000, 6.0000, 3.6667, 4.0000, 6.6667],
        [4.2500, 5.7500, 5.7500, 6.0000, 4.2500, 3.2500, 6.7500],
        [4.6000, 6.2000, 5.8000, 5.8000, 3.8000, 3.4000, 6.2000]])


In [81]:
# Single Head Self Attention


B, T, C = 4, 8, 32  # batch, time, channels

x = torch.randint(10, (B, T, C)).float() # (4, 8, 32)

# Each 32 dimensional token emits 3 vectors K, Q, V
# K is the key vector, Q is the query vector and V is the value vector.


# You can kind of think as K, Q, V as projection the input token embedding which initially
# is of size C(32) or in 32 dimensional space and projecting it to a lower dimensional space like 16.
key = nn.Linear(C, 16, bias=False)  # key projection
query = nn.Linear(C, 16, bias=False)  # query projection
value = nn.Linear(C, 16, bias=False)  # value projection

k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
v = value(x)  # (B, T, 16)

wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)


# Scaled Dot Product Attention
# To prevent the dot products from growing too large, we scale them by the square root of the key dimension.
# This helps to stabilize the gradients during training.
scale = 16 ** 0.5
wei = wei / scale

# Kinda forget about B dimension for now, we will just focus on T dimension.
# In Q, each row is a query vector for a token with 16 dimensions.
# In K_transpose, each column is a key vector with 16 dimensions for a token.
# In V, each row is a value vector for a token with 16 dimensions.

# So when we do q @ k.transpose(-2, -1), we are computing the dot product between
# each query vector and each key vector for every token in the sequence.
# So the first row of wei of shape(T, T) will be the dot product of the query vector of the first token
# with all the key vectors of all the tokens in the sequence. So the first row of wei will
# contain the attention scores for the first token with respect to all the tokens in the sequence
# and tell us how much attention the first token should pay to all the tokens in the sequence.
# Similarly the second row tells us how much attention the second token should pay to all the
# tokens in the sequence.
# And this happens for all the examples in the batch simultaneously.


# Causal self attention is a mechanism that allows each token in the sequence to attend to all the tokens
# that come before it and not the ones that come after it.
# So we want to set the upper triangle of the wei tensor to -inf so that when we apply softmax,
# the softmax will ignore those positions and assign 0 weights to them.
tril = torch.tril(torch.ones((T, T)))  # shape (T, T) = (8, 8)
wei = wei.masked_fill(tril == 0, float('-inf'))  # set the upper triangle to -inf
print(wei[0])


# And now what we do is take this wei tensor and apply softmax to it.
wei = F.softmax(wei, dim=-1)  # (B, T, T)


out = wei @ v # (B, T, T) @ (B, T, 16) = (B, T, 16)
# Now we have the attention scores for each token in the sequence with respect to all the tokens in the sequence.
# Now we can use these attention scores to compute the output of the self attention layer.
# Say for the first token, the first row of wei tells us how much attention first token has to pay to
# each token in the sequence. We take each score in the row multiply it with the corresponding value vector
# for which the score was computed and add them up to get the output for the first token.
# This will be the output of the self attention layer for the first token.
# Say scores in the first row of wei are [0.1, 0.2, 0.3, 0.4, 0.0, 0.0, 0.0, 0.0]
# And the value vectors for the tokens are: v1, v2, v3, v4, v5, v6, v7, v8 (each is a vector of size 16)
# So the output for the first token will be:
# output = 0.1 * v1 + 0.2 * v2 + 0.3 * v3 + 0.4 * v4 + 0.0 * v5 + 0.0 * v6 + 0.0 * v7 + 0.0 * v8
# This will be the output of the self attention layer for the first token and the first row of out will be
# the output and the enhanced representation of the first token.
# This is exactly what this matrix multiplication does between wei and v does.


tensor([[ -0.5330,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf],
        [-14.1926,  -9.7073,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf],
        [  6.8117,   1.9836,   9.7450,     -inf,     -inf,     -inf,     -inf,
             -inf],
        [ -3.5694,  -6.2294,  -6.2758,  -6.3100,     -inf,     -inf,     -inf,
             -inf],
        [  6.8332,   9.4518,   8.6119,   8.3535,   8.5837,     -inf,     -inf,
             -inf],
        [  3.4085,  -2.4052,  -0.9684,  -3.2428,   0.2396,  -0.5170,     -inf,
             -inf],
        [ -0.5364,  -4.5374,   0.6402,  -4.6938,  -5.0907,   0.4215, -10.4793,
             -inf],
        [ -1.0625,   3.4254,   7.2073,   1.1701,   3.9532,   4.1309,   4.3983,
           2.1042]], grad_fn=<SelectBackward0>)


## Attention Matrix A (4x4)

Each row contains attention weights for one token over all tokens (including itself).

Let:
A = 
\[
\begin{bmatrix}
0.1 & 0.2 & 0.3 & 0.4 \\
0.3 & 0.3 & 0.2 & 0.2 \\
0.25 & 0.25 & 0.25 & 0.25 \\
0.4 & 0.3 & 0.2 & 0.1 \\
\end{bmatrix}
\]

## Value Matrix V (4x10)

Each row is a value vector of dimension 10:

V =
\[
\begin{bmatrix}
1 & 0 & 1 & 0 & 1 & 0 & 1 & 0 & 1 & 0 \\
0 & 1 & 0 & 1 & 0 & 1 & 0 & 1 & 0 & 1 \\
1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
\end{bmatrix}
\]

## Compute Output Z = A × V (4x10)

Let’s compute row 0 of Z explicitly:

Z₀ = 0.1 * V₀ + 0.2 * V₁ + 0.3 * V₂ + 0.4 * V₃

Break it down:

- 0.1 * V₀ = [0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0]
- 0.2 * V₁ = [0, 0.2, 0, 0.2, 0, 0.2, 0, 0.2, 0, 0.2]
- 0.3 * V₂ = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
- 0.4 * V₃ = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Now sum them up:

Z₀ = 
[

0.1 + 0   + 0.3 + 0,  
0   + 0.2 + 0.3 + 0,  
0.1 + 0   + 0.3 + 0,  
0   + 0.2 + 0.3 + 0,  
0.1 + 0   + 0.3 + 0,  
0   + 0.2 + 0.3 + 0,  
0.1 + 0   + 0.3 + 0,  
0   + 0.2 + 0.3 + 0,  
0.1 + 0   + 0.3 + 0,  
0   + 0.2 + 0.3 + 0  

]

Z₀ =
[0.4, 0.5, 0.4, 0.5, 0.4, 0.5, 0.4, 0.5, 0.4, 0.5]

Repeat the same for Z₁, Z₂, Z₃ to get the full matrix Z ∈ ℝ⁴ˣ¹⁰.


In [83]:
# Multi Head Self Attention
# In practice, we use multiple heads of self attention to capture different aspects of the input sequence.
# And then we concatenate the outputs of all the heads and pass it through a linear layer to get the final output.

# Say we have 2 heads of self attention, we will have 2 sets of K, Q, V vectors.
num_heads = 2
d_model = 32  # dimension of the input token embeddings
d_k = d_model // num_heads  # dimension of each head, here 16

B, T, C = 4, 8, d_model  # batch, time, channels

x = torch.randint(10, (B, T, C)).float()  # (4, 8, 32)

# Each head will have its own set of linear layers for K, Q, V
key_1 = nn.Linear(C, d_k, bias=False)    # key projection for head 1
query_1 = nn.Linear(C, d_k, bias=False)  # query projection for head 1
value_1 = nn.Linear(C, d_k, bias=False)  # value projection for head 1

k_1 = key_1(x)  # (B, T, 16)
q_1 = query_1(x)  # (B, T, 16)
v_1 = value_1(x)  # (B, T, 16)
wei_1 = q_1 @ k_1.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)
# Scaled Dot Product Attention for head 1
scale = d_k ** 0.5
wei_1 = wei_1 / scale  # (B, T, T)
# Causal self attention for head 1
mask_1 = torch.tril(torch.ones(T, T), diagonal=1).bool()  # upper triangular matrix
wei_1 = wei_1.masked_fill(mask_1, float('-inf'))  # apply mask
wei_1 = wei_1.softmax(dim=-1)  # (B, T, T)
# Now we can compute the output for head 1
out_1 = wei_1 @ v_1  # (B, T, T) @ (B, T, 16) = (B, T, 16)


# Now we do the same for head 2

key_2 = nn.Linear(C, d_k, bias=False)    # key projection for head 2
query_2 = nn.Linear(C, d_k, bias=False)  # query projection for head 2
value_2 = nn.Linear(C, d_k, bias=False)  # value projection for head 2

k_2 = key_2(x)  # (B, T, 16)
q_2 = query_2(x)  # (B, T, 16)
v_2 = value_2(x)  # (B, T, 16)

wei_2 = q_2 @ k_2.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)
# Scaled Dot Product Attention for head 2
scale = d_k ** 0.5
wei_2 = wei_2 / scale  # (B, T, T)
# Causal self attention for head 2
mask_2 = torch.tril(torch.ones(T, T), diagonal=1).bool()  # upper triangular matrix
wei_2 = wei_2.masked_fill(mask_2, float('-inf'))  # apply mask
wei_2 = wei_2.softmax(dim=-1)  # (B, T, T)
# Now we can compute the output for head 2
out_2 = wei_2 @ v_2  # (B, T, T) @ (B, T, 16) = (B, T, 16)

print(out_1.shape)  # (4, 8, 16)
print(out_2.shape)  # (4, 8, 16)

# Now we concatenate the outputs of both heads
out = torch.cat((out_1, out_2), dim=-1)  # (B, T, 16*2) = (4, 8, 32)
# And we pass it through a linear layer to get the final output

# Each token in the sequence of size 32 is separately passsed through a linear layer
lm_head = nn.Linear(d_model, d_model)  # output layer to map embeddings to logits
final_output = lm_head(out)  # (B, T, C) = (4, 8, 32)
print(final_output.shape)  # (4, 8, 32)


torch.Size([4, 8, 16])
torch.Size([4, 8, 16])
torch.Size([4, 8, 32])


In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # 64   # how many independent sequences will we process in parallel?
block_size = 32  # 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3 # 10-4
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64 # 384
# d_model = n_embd  # dimension of the model, same as the embedding size
# d_k = 64/4 = 16  # dimension of each head, here 16
n_head = 4
n_layer = 4 # 6
dropout = 0.0 # 0.2
# ------------

# send the network to the GPU if available
device = 'mps' if torch.backends.mps.is_available() else (
    'cuda' if torch.cuda.is_available() else 'cpu')
print("Using device = " + device)
if device == 'cpu':
    print("WARNING: Using CPU will cause slower train times")
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# encoder: take a string, output a list of integers
def encode(s): return [stoi[c] for c in s]
# decoder: take a list of integers, output a string
def decode(l): return ''.join([itos[i] for i in l])


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # block_size is the maximum context length
        # this is not the parameter of the module
        self.register_buffer('tril', torch.tril(
            torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x)  # (B,T,C)
        # compute attention scores ("affinities")
        # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(
            self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C)
        out = wei @ v  # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    # This is on a per token basis, so it does not have any parameters that depend on the sequence length.
    # It is applied to each token independently. So each token independently goes through the
    # feedforward network.

    # ? WE PRODUCE THE PROBABILITY FOR EVERY SINGLE TOKEN THAT MIGHT COME NEXT
    # ? AND WE DO THIS AT EVERY POINT IN TIME OF THAT TRANSFORMER  i.e FOR EVERY SINGLE
    # ? TOKEN IN THE SEQUENCE AND WE DO THIS IN PARALLEL FOR ALL THE TOKENS IN THE SEQUENCE
    # ? SO WE HAVE A SEPARATE FEEDFORWARD NETWORK FOR EVERY SINGLE TOKEN IN THE SEQUENCE

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size) # this is where tokens attend and communicate with each other
        self.ffwd = FeedFoward(n_embd)  # this is where tokens think individually about the information they gathered in attention step.
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Residual connection around self-attention and feedforward

        # pre-norm formulation
        norm_x1 = self.ln1(x)      # layer norm before self-attention
        x = x + self.sa(norm_x1)   # residual connection around self-attention

        norm_x2 = self.ln2(x)      # layer norm before feedforward
        x = x + self.ffwd(norm_x2) # residual connection around feedforward
        return x

# super simple bigram model


class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # ? WE PRODUCE THE PROBABILITY FOR EVERY SINGLE TOKEN THAT MIGHT COME NEXT
        # ? AND WE DO THIS AT EVERY POINT IN TIME OF THAT TRANSFORMER  i.e FOR EVERY SINGLE
        # ? TOKEN IN THE SEQUENCE AND WE DO THIS IN PARALLEL FOR ALL THE TOKENS IN THE SEQUENCE
        # ? SO WE HAVE A SEPARATE FEEDFORWARD NETWORK FOR EVERY SINGLE TOKEN IN THE SEQUENCE.

        # ? SO THIS lm_head IS A LINEAR LAYER THAT MAPS THE EMBEDDING OF EACH TOKEN
        # ? TO THE LOGITS FOR THE NEXT TOKEN IN THE SEQUENCE AND THIS IS DONE IN PARALLEL FOR ALL
        # ? THE TOKENS IN THE SEQUENCE AND self.lm_head THUS RUNS ON EVERY SINGLE TOKEN IN THE SEQUENCE
        # ? SEPARATELY AND IN PARALLEL.

        # ? WE ARE PREDICTING THE NEXT TOKEN FOR EACH TOKEN IN THE SEQUENCE
        # ? SO THIS IS A MULTI-HEAD CAUSAL SELF-ATTENTION MECHANISM THAT ALLOWS
        # ? EACH TOKEN TO ATTEND TO ALL OTHER TOKENS IN THE SEQUENCE SIMULTANEOUSLY.

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(
            torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)

        # ? If the input was (B, T) indicies, then at every single (B, T) we calculate logits(vocab_size)
        # ? for what token comes next in the sequence. Thus et every single token in the sequence
        # ? we are predicting the next token in the sequence and this is done in parallel for all the tokens
        # ? in the sequence.
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # (B*T, 65)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(
            f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

Using device = mps
0.209729 M parameters
step 0: train loss 4.4116, val loss 4.4022
step 500: train loss 2.3137, val loss 2.3145
step 1000: train loss 2.1030, val loss 2.1295
step 1500: train loss 1.9671, val loss 2.0329
step 2000: train loss 1.8799, val loss 1.9676
step 2500: train loss 1.8212, val loss 1.9456
step 3000: train loss 1.7750, val loss 1.9170
step 3500: train loss 1.7498, val loss 1.8924
step 4000: train loss 1.7191, val loss 1.8581
step 4500: train loss 1.6964, val loss 1.8502
step 4999: train loss 1.6633, val loss 1.8268

Forsul presenchelders, else dofenievey:
All know I heaver the horselted, my charn blones,
Beath Lord, as a simpery imbuag hold
by serveing shalk, as leht thy
who croes fears it of sulb.

HENRY BOLINGBRARET:
Say live, I mosty frishond not slead 'tis proging.

QUEEN MARGAn:
Why, they foull lead,
I agailn art your had cried men
Is thoughmy that prongmanoroud Volnow,
And say shall rachaly all was had;
Thought from shap thy, as it sirring a tasts latesse my

In [None]:
class CrossAttentionHead(nn.Module):
    """One head of cross-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output):
        B, T, C = x.shape
        # k,v from encoder output, q from decoder input
        k = self.key(encoder_output)    # (B,S,C) where S is encoder sequence length
        q = self.query(x)               # (B,T,C)
        v = self.value(encoder_output)  # (B,S,C)

        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B,T,S)
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # weighted aggregation of values
        out = wei @ v  # (B,T,C)
        return out

class MultiHeadCrossAttention(nn.Module):
    """Multiple heads of cross-attention in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([CrossAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output):
        out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class DecoderBlock(nn.Module):
    """Transformer decoder block: self-attention, cross-attention, and computation"""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)  # self-attention
        self.ca = MultiHeadCrossAttention(n_head, head_size)  # cross-attention
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)

    def forward(self, x, encoder_output):
        # self-attention with residual
        x = x + self.sa(self.ln1(x))
        # cross-attention with residual
        x = x + self.ca(self.ln2(x), encoder_output)
        # feedforward with residual
        x = x + self.ffwd(self.ln3(x))
        return x

class EncoderBlock(nn.Module):
    """Transformer encoder block: self-attention and computation"""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TransformerModel(nn.Module):
    """Full Transformer with encoder and decoder"""

    def __init__(self, src_vocab_size, tgt_vocab_size):
        super().__init__()
        # Encoder side
        self.encoder_token_embedding = nn.Embedding(src_vocab_size, n_embd)
        self.encoder_position_embedding = nn.Embedding(block_size, n_embd)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.encoder_ln_f = nn.LayerNorm(n_embd)

        # Decoder side
        self.decoder_token_embedding = nn.Embedding(tgt_vocab_size, n_embd)
        self.decoder_position_embedding = nn.Embedding(block_size, n_embd)
        self.decoder_blocks = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.decoder_ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, tgt_vocab_size)

    def encode(self, src_idx):
        B, T = src_idx.shape
        # Get embeddings
        tok_emb = self.encoder_token_embedding(src_idx)
        pos_emb = self.encoder_position_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Pass through encoder blocks
        for block in self.encoder_blocks:
            x = block(x)
        return self.encoder_ln_f(x)

    def decode(self, tgt_idx, encoder_output):
        B, T = tgt_idx.shape
        # Get embeddings
        tok_emb = self.decoder_token_embedding(tgt_idx)
        pos_emb = self.decoder_position_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Pass through decoder blocks
        for block in self.decoder_blocks:
            x = block(x, encoder_output)
        x = self.decoder_ln_f(x)

        # Get logits
        logits = self.lm_head(x)
        return logits

    def forward(self, src_idx, tgt_idx, targets=None):
        # Encode source sequence
        encoder_output = self.encode(src_idx)

        # Decode target sequence
        logits = self.decode(tgt_idx, encoder_output)

        # If we have targets, compute loss
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, src_idx, max_new_tokens, temperature=1.0):
        """Generate a target sequence given a source sequence"""
        self.eval()
        with torch.no_grad():
            # First encode the source sequence
            encoder_output = self.encode(src_idx)

            # Start with just a BOS token
            B = src_idx.shape[0]
            tgt_idx = torch.zeros((B, 1), dtype=torch.long, device=device)

            # Generate one token at a time
            for _ in range(max_new_tokens):
                # Get predictions for next token
                logits = self.decode(tgt_idx, encoder_output)
                logits = logits[:, -1, :] / temperature  # focus on last token
                probs = F.softmax(logits, dim=-1)

                # Sample from the distribution
                next_token = torch.multinomial(probs, num_samples=1)

                # Append to the sequence
                tgt_idx = torch.cat([tgt_idx, next_token], dim=1)

        return tgt_idx

# Example usage for translation task:
# src_vocab_size = len(source_language_chars)
# tgt_vocab_size = len(target_language_chars)
# model = TransformerModel(src_vocab_size, tgt_vocab_size).to(device)
#
# # Training:
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# src_batch, tgt_batch = get_translation_batch('train')
# logits, loss = model(src_batch, tgt_batch[:, :-1], tgt_batch[:, 1:])
#
# # Generation:
# src_text = "Hello"  # example source text
# src_ids = torch.tensor([encode_source(src_text)], device=device)
# generated_ids = model.generate(src_ids, max_new_tokens=50)
# generated_text = decode_target(generated_ids[0].tolist())

In [None]:
class TransformerEncoder(nn.Module):
    """Transformer Encoder that processes the input sequence"""

    def __init__(self, n_embd, n_head, n_layer, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)

        # Key difference from decoder: No causal mask in encoder attention
        # Create encoder-specific attention block that allows bidirectional attention
        self.blocks = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, idx):
        B, T = idx.shape

        # Get embeddings
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        x = self.dropout(tok_emb + pos_emb)

        # Pass through encoder blocks
        for block in self.blocks:
            x = block(x)

        return self.ln_f(x)

class EncoderBlock(nn.Module):
    """Transformer Encoder block with bidirectional attention"""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.attention = EncoderMultiHeadAttention(n_head, head_size)
        self.ff = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Use pre-norm formulation like in the decoder
        norm_x = self.ln1(x)
        x = x + self.attention(norm_x)
        norm_x = self.ln2(x)
        x = x + self.ff(norm_x)
        return x

class EncoderMultiHeadAttention(nn.Module):
    """Multi-head attention for encoder - allows bidirectional attention"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([EncoderHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class EncoderHead(nn.Module):
    """Single head of encoder self-attention - no causal mask"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # Compute attention scores - no causal mask
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v
        return out

class TransformerDecoderBlock(nn.Module):
    """Transformer Decoder block with both self-attention and cross-attention"""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        # Self attention (causal, like before)
        self.sa = MultiHeadAttention(n_head, head_size)
        # Cross attention to encoder outputs
        self.ca = CrossAttention(n_head, head_size)
        # Feed forward
        self.ff = FeedFoward(n_embd)
        # Layer norms
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)

    def forward(self, x, encoder_out):
        # Self attention
        norm_x = self.ln1(x)
        x = x + self.sa(norm_x)
        # Cross attention
        norm_x = self.ln2(x)
        x = x + self.ca(norm_x, encoder_out)
        # Feed forward
        norm_x = self.ln3(x)
        x = x + self.ff(norm_x)
        return x

class CrossAttention(nn.Module):
    """Cross-attention module to attend to encoder outputs"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([CrossAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_out):
        out = torch.cat([h(x, encoder_out) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class CrossAttentionHead(nn.Module):
    """Single head of cross-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_out):
        # x is from decoder, encoder_out is from encoder
        B, T, C = x.shape

        # Project encoder outputs to key and value
        k = self.key(encoder_out)
        v = self.value(encoder_out)
        # Project decoder state to query
        q = self.query(x)

        # Compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v
        return out

class FullTransformer(nn.Module):
    """Complete Transformer with both encoder and decoder"""

    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = TransformerEncoder(n_embd, n_head, n_layer//2)  # Use half the layers for encoder

        # Decoder components
        self.decoder_token_embedding = nn.Embedding(vocab_size, n_embd)
        self.decoder_position_embedding = nn.Embedding(block_size, n_embd)
        self.decoder_blocks = nn.ModuleList([
            TransformerDecoderBlock(n_embd, n_head) for _ in range(n_layer//2)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, enc_idx, dec_idx, targets=None):
        B, T = dec_idx.shape

        # Run through encoder
        encoder_out = self.encoder(enc_idx)

        # Decoder embeddings
        tok_emb = self.decoder_token_embedding(dec_idx)
        pos_emb = self.decoder_position_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Run through decoder blocks
        for block in self.decoder_blocks:
            x = block(x, encoder_out)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, enc_idx, max_new_tokens, temperature=1.0):
        # First encode the input sequence
        encoder_out = self.encoder(enc_idx)

        # Start with just the start token for decoder
        dec_idx = torch.zeros((1, 1), dtype=torch.long, device=device)

        # Auto-regressively generate tokens
        for _ in range(max_new_tokens):
            # Crop context if needed
            dec_idx_cond = dec_idx[:, -block_size:]

            # Forward pass through decoder
            tok_emb = self.decoder_token_embedding(dec_idx_cond)
            pos_emb = self.decoder_position_embedding(torch.arange(dec_idx_cond.size(1), device=device))
            x = tok_emb + pos_emb

            # Run through decoder blocks
            for block in self.decoder_blocks:
                x = block(x, encoder_out)

            x = self.ln_f(x)
            logits = self.lm_head(x)

            # Focus on last time step
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # Sample
            idx_next = torch.multinomial(probs, num_samples=1)
            dec_idx = torch.cat((dec_idx, idx_next), dim=1)

        return dec_idx