In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import time
import random
import math
PATH = "models/pavan_gpt_100k_1.91.bin"
LOAD_MODEL = False


torch.manual_seed(1)

inf = torch.inf
context_length = 256 # No of tokens
model_dim = 128 # dimension of the model -> residual stream
n_layers = 6 # no of layers
n_heads = 0 # No of attention heads for layer # TODO
head_dim = 128
vocab_size = 65
learning_rate = 3e-4
max_iters = 5000
eval_iters = 100
batch_size = 32 #Takes 27k iters

lower_triangular_matrix = torch.tensor([[1 if i<=j else -torch.inf for i in range(context_length)] for j in range(context_length)]).float()

def tokenise(str: str):
    return torch.tensor([char_map[i] for i in str])

def decode(tokens: list[str]):
    return ''.join([reverse_char_map[i] for i in tokens])

file = open("tiny_shakesphere.txt", "r")
full_data = file.read()

vocab = list(sorted((set(full_data))))

char_map = {vocab[i]: i for i in range(len(vocab))}
reverse_char_map = {char_map[i] : i for i in char_map}
full_data = tokenise(full_data)

total_datapoints  = full_data.shape[0]

training_data : list[int] = full_data[:int(total_datapoints*0.9)]
validation_data = full_data[int(total_datapoints*0.9):total_datapoints]


def sample_data(split: str = "train"): # With replacement
    data = training_data if split == 'train' else validation_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = sample_data(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    
    return out["train"], out['val']


class Layer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        

class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.key = nn.Linear(model_dim, head_dim)
        self.query = nn.Linear(model_dim, head_dim)
        self.value = nn.Linear(model_dim, head_dim)
        self.proj = nn.Linear(head_dim, model_dim)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, idx):
        key = self.key(idx) # (batch, context_length, head_dim)
        query = self.query(idx)
        value = self.value(idx) # (batch, context_length, head_dim)

        attention = (query@torch.transpose(key,1,2))/(math.sqrt(head_dim)) # (batch, context_length, context_length)

        attention = torch.tril(attention)

        attention = attention.masked_fill(attention == 0, -inf)

        attention = F.softmax(attention,-1) # probs along context_length sum to 1

        attention_value = attention@value  # (batch, context_length, head_dim)

        attention_value = self.proj(attention_value)  # (batch, context_length, model_dim)
        return self.dropout(attention_value)
    

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(model_dim, 4*model_dim), nn.ReLU(), nn.Linear(4*model_dim, model_dim))
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, idx):
        logits = self.layers(idx)
        return self.dropout(logits)

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length, model_dim)
        self.attention_layes = nn.ModuleList([AttentionHead() for i in range(n_layers)])
        self.mlp_layers = nn.ModuleList([MLP() for i in range(n_layers)])
        self.unembed_layer = nn.Linear(model_dim,vocab_size)

        self.total_parameters = sum([p.numel() for p in self.parameters()])
        print(f"Model has {self.total_parameters//1000}k params")


    def forward(self, idx, targets = None):
        # idx -> [1,2,0,3..] (batch, context_length)

        # for p in range(idx.shape[0]):
        #     print([decode(idx[p].tolist()), decode(targets[p].tolist())])

        input_sequence_length = idx.shape[-1]

        residual_stream = self.token_embedding(idx)  # (batch, context_length, model_dim)
        residual_stream = residual_stream + self.pos_embedding(torch.tensor([i for i in range(input_sequence_length)])) # Pos embedding will be # (context_length, model_dim)
        
        for i in range(n_layers):
            residual_stream = residual_stream + self.attention_layes[i](residual_stream)
            residual_stream = residual_stream + self.mlp_layers[i](residual_stream)

        residual_stream = self.unembed_layer(residual_stream) # (batch, context_length, vocab_size)
        if targets is None:
            return residual_stream
        (x,y,z) = residual_stream.shape
        loss = F.cross_entropy(residual_stream.view(x*y,z), targets.view(x*y))
        return residual_stream, loss
    

model = Transformer()


if LOAD_MODEL:
    model = Transformer()
    model.load_state_dict(torch.load(PATH))
    model.eval()

train_loss,val_loss = estimate_loss()
print(f"Initial training loss: {train_loss}, val loss: {val_loss}")

loss_value = []
val_loss_value = []
iters = []
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
step_value = max_iters/20
start_time = time.time()
for iter in range(max_iters):
    X,Y= sample_data() # (B, context_length)
    logits, loss = model(X, Y)  # (B, context_length, vocab_size)
    if iter%step_value ==0:
        train_loss,val_loss = estimate_loss()
        iters.append(iter)
        loss_value.append(train_loss)
        val_loss_value.append(val_loss)
        print(f"iter:{iter} training loss: {train_loss}, val loss: {val_loss}")

    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
end_time = time.time()
print(f"Took {end_time-start_time}s for {max_iters} epochs")

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(iters,loss_value, color='blue', label="Training")
plt.plot(iters, val_loss_value, "red", label = "validation")
plt.legend()
plt.show()

Model has 1236k params


KeyboardInterrupt: 

In [23]:
def generate_text(input: str):
    max_tokens = 1000
    input_tokens = tokenise(input)
    print(input, end='')
    
    for i in range(max_tokens):
        now = model(input_tokens.unsqueeze(0))[0][-1]
        now = F.softmax(now, dim= 0)
        token = torch.multinomial(now,1).item()
        input_tokens = torch.tensor(input_tokens.tolist() + [token])
        text = decode([token])
        print(text, end='')
        input_tokens = input_tokens[-context_length:]
                

generate_text("you")  

your lord;
From your from my faulters.

herredness on my and bears:
Bell though trel'd topil'd cortain not swould or
Littleman tor lass of the strongs hone;
Now and answer.

First Secondier: well, by crabelandamp'd you not? forse?

Epioram,
God bell, the alse your which obth'ts you,
More for his comforthersomenly die,
Cought mother or it, Bukchil'd he grosery?'

DUKE VINCENTIO:
No. I, to not you be many.

POMPEY:
Go, homon is be Coase have you cbare,
Or led, true and edgrille, my orrackly to age,
Werel twhen not say, not to hour flitter'd gary
Thou lie From For
chento worse chride pherinon
Is thee of a' thine: he chountie! Speet to ie at are our than and reclorine,
Andwelcome have plectate.'
Chown, palys tekunctio, my indid,
Why sidon flow, of honds with that
Lord himself; and be his ear the Ciriall:
Go this must sir?

GAUCESSTUS:
Then thee; and crovedles tence hate,
And seech them, this know
za, sir, sagancinmon'd tend, till ques.

DUKE VINCENTIO:
Low not not and evilt her,
I'll ats i

In [11]:
generate_text(" ")  

 whis bens,
Whis thirgar tink for and by eee acaing ifs fall andoy thow And wommith to theame?

Rome hatheut?

BENVO:
Mill tak. And Romparn the the dell sor thisild,
Heake lus nowell theur, who say anto arier this Waite,
And arte the orrest their flor eneque?

PRICICK:
I you
Tord, brees Marred thou ancel.

'lance; cas and it misess worsly fis.

EDWABES:
Nure muel
Tooder:
I her. Ray kide mall bour gis, ase, and is is us cady--y the hese we younor he it slooks hade to meicke:
'botion nut
ETarids whath. That she stre homse him,
Yo migh creast! liftotord your is hive:
I will agarin I as beake; and in you ar poner,
But in that and thaul aurt bretle fore I propursed
CYodes noth are my as him feal,
No wregns, and your thall gend eeford.

PAPURIANUS:
Whe wo worth
So kim you faul the thee but the bing
Kirne bee!-Hay forcher's as naye his
Whicl say both'd reid bive; and the thickel geet
Aw o
So mor's the proubstervicarting Buse.

Ply axtid ye twill firfaires irvettlal yere;
Away, where itrances!

In [22]:
a = tokenise("alik").unsqueeze(0)
output = model(a)
F.cross_entropy(output.resize(4,65), F.one_hot(tokenise("like").unsqueeze(0), vocab_size).resize(4, 65).float())

tensor(2.3337, grad_fn=<DivBackward1>)

In [22]:
PATH = "models/pavan_gpt_1M_1.48.bin"
torch.save(model.state_dict(), PATH)
