In [6]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import time
import random
import math

inf = torch.inf
context_length = 8 # No of tokens
model_dim = 64 # dimension of the model -> residual stream
n_layers = 2 # no of layers
n_heads = 2 # No of attention heads for layer
head_dim = 8
vocab_size = 65
learning_rate = 0.01
max_iters = 100

lower_triangular_matrix = torch.tensor([[1 if i<=j else -torch.inf for i in range(7)] for j in range(7)]).float()



def sample_training_data():
    X = training_data[:,:-1]
    Y = training_data[:,1:] 
    return X,Y

torch.no_grad()
def get_val_loss():
    X = validation_data[:,:-1]
    Y = validation_data[:,1:]
    _, loss = model(X, Y)  # (B, context_length, vocab_size)

    return loss.item()


class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.key = nn.Linear(model_dim, model_dim)
        self.query = nn.Linear(model_dim, model_dim)
        self.value = nn.Linear(model_dim, model_dim)
        self.relu = nn.ReLU()
    
    def forward(self, idx):
        # idx -> (batch, context_length, model_dim)

        # (x,y,_) = idx.shape
        # for i in range(x):
        #     for j in range(y):
        #         for k in range(j):
        #             if j!=k:
        #                 idx[i][j]+=idx[i][k]
        #         if j>=1:
        #             idx[i][j]/=(j)
        # return idx

        # return lower_triangular_matrix@idx

        key = self.key(idx) # (batch, context_length, head_dim)
        query = self.query(idx)
        value = self.value(idx) # (batch, context_length, head_dim)

        attention = query@torch.transpose(key,1,2) # (batch, context_length, context_length)
        attention =torch.tril(attention)
        print("attention",attention.shape, attention[0,:])


        attention = attention*lower_triangular_matrix

        print("attention",attention.shape, attention[0,:])

        attention = F.softmax((attention/math.sqrt(model_dim)),1) # probs along context_length sum to 1
        print("attention with softmax", attention[0,:])

        return attention@value  # (batch, context_length, model_dim)
    

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(model_dim, model_dim), nn.Linear(model_dim, model_dim))
        self.relu = nn.ReLU()
    
    def forward(self, idx):
        logits = self.layers(idx)
        return self.relu(logits)



class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length-1, model_dim)
        self.attention_head = AttentionHead()
        
        self.mlp = MLP()

        self.unembed_layer = nn.Linear(model_dim,vocab_size)

    def forward(self, idx, targets = None):
        # idx -> [1,2,0,3..] (batch, context_length)


        residual_stream = self.token_embedding(idx)  # (batch, context_length, model_dim)
        residual_stream = residual_stream + self.pos_embedding(torch.tensor([i for i in range(context_length-1)])) # Pos embedding will be # (context_length, model_dim)
        
        residual_stream= residual_stream + self.attention_head(residual_stream)

        residual_stream= residual_stream + self.mlp(residual_stream) 

        residual_stream = self.unembed_layer(residual_stream) # (batch, context_length, vocab_size)
        if targets is None:
            return residual_stream
        (x,y,z) = residual_stream.shape
        loss = F.cross_entropy(residual_stream.resize(x*y,z), F.one_hot(targets, vocab_size).resize(x*y, z).float())
        return residual_stream, loss

def tokenise(str: str):
    return torch.tensor([char_map[i] for i in str])


file = open("tiny_shakesphere.txt", "r")
full_data = file.read()

vocab = list(sorted((set(full_data))))

char_map = {vocab[i]: i for i in range(len(vocab))}
full_data = tokenise(full_data)

full_data = full_data[:len(full_data)- len(full_data)%context_length].reshape(-1,context_length) # Make it multiple of context length

random.shuffle(full_data)
total_datapoints  = full_data.shape[0]//10
training_data = full_data[:int(total_datapoints*0.6), :]
validation_data = full_data[int(total_datapoints*0.6):total_datapoints,:]



model = Transformer()
print(get_val_loss())
loss_value = []
val_loss_value = []
iters = []
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
step_value = max_iters/20
start_time = time.time()
for iter in range(max_iters):
    X,Y= sample_training_data() # (B, context_length)
    logits, loss = model(X, Y)  # (B, context_length, vocab_size)
    if iter%step_value ==0:
        model.eval()
        with torch.no_grad():
            val_loss = get_val_loss()
            iters.append(iter)
            loss_value.append(loss.item())
            val_loss_value.append(val_loss)
            print(f"iter:{iter} training loss: {loss.item()}, val loss: {val_loss}")
        model.train()

    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
end_time = time.time()
print(f"Took {end_time-start_time}s for {max_iters} epochs")

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(iters,loss_value, color='blue', label="Training")
plt.plot(iters, val_loss_value, "red", label = "validation")
plt.legend()
plt.show()



        

attention torch.Size([5577, 7, 7]) tensor([[ 0.0000,  1.7787, 10.0693,  0.0000,  0.6208,  2.4768,  0.0000],
        [11.5818,  7.5673,  1.6659,  0.0000,  8.9626, 14.3081,  7.1469],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.7706,  0.0000,  0.0000],
        [ 0.0000,  1.2150,  2.9298,  4.1431,  0.0000,  3.0354,  0.0000],
        [ 4.5180,  0.0000,  0.0000,  0.0000,  0.0000,  6.6173,  7.0408],
        [ 0.0000,  0.0000,  6.4178,  0.0000,  6.5771,  1.3985,  1.7506],
        [ 0.3864,  1.9618,  1.4119,  0.0000,  0.0000,  0.0000,  4.9324]],
       grad_fn=<SliceBackward0>)
attention torch.Size([5577, 7, 7]) tensor([[ 0.0000,    -inf,    -inf,     nan,    -inf,    -inf,     nan],
        [11.5818,  7.5673,    -inf,     nan,    -inf,    -inf,    -inf],
        [ 0.0000,  0.0000,  0.0000,     nan,    -inf,     nan,     nan],
        [ 0.0000,  1.2150,  2.9298,  4.1431,     nan,    -inf,     nan],
        [ 4.5180,  0.0000,  0.0000,  0.0000,  0.0000,    -inf,    -inf],
        [ 0.0000,  0



attention torch.Size([8365, 7, 7]) tensor([[nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan]], grad_fn=<SliceBackward0>)
attention torch.Size([8365, 7, 7]) tensor([[nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan]], grad_fn=<SliceBackward0>)
attention with softmax tensor([[nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan]

KeyboardInterrupt: 

In [None]:
torch.softmax(attn,1)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5077, 0.4923, 0.0000],
        [0.4229, 0.3701, 0.2071]])

In [12]:
datapoints

139424

In [47]:
full_data.shape[1]

8

In [16]:
data[:500]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor"

In [17]:
type(data)

str

In [23]:
list(sorted((set(data))))

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']