In [12]:
!curl https://raw.githubusercontent.com/karpathy/ng-video-lecture/refs/heads/master/input.txt -o input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1089k  100 1089k    0     0  2484k      0 --:--:-- --:--:-- --:--:-- 2504k


In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [13]:
with open('input.txt', 'r') as f:
    text = f.read()

In [16]:
print("Total characters" , (len(text)))

Total characters 1115394


In [19]:
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [None]:
# These are all the characters that the model will know. This is the vocab of the model
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [24]:
# Tokenization 
# A integer representation of the data. We will map our vocab to integer representations
stoi = {char:i for i,char in enumerate(chars)}
itos = {i:char for i,char in enumerate(chars)}

def encode(s)->list:
    return [stoi[c] for c in s]
def decode(listofint)->list:
    return [itos[i] for i in listofint]

In [25]:
encode("Welcome to AI3")

[35, 43, 50, 41, 53, 51, 43, 1, 58, 53, 1, 13, 21, 9]

In [26]:
decode([35, 43, 50, 41, 53, 51, 43, 1, 58, 53, 1, 13, 21, 9])

['W', 'e', 'l', 'c', 'o', 'm', 'e', ' ', 't', 'o', ' ', 'A', 'I', '3']

In [28]:
print(''.join(decode([35, 43, 50, 41, 53, 51, 43, 1, 58, 53, 1, 13, 21, 9])))

Welcome to AI3


In [29]:
# For encoding openai uses tiktoken and google uses sentencepiece

In [39]:
data = torch.tensor(encode(text),dtype=torch.long)
print(data.shape,data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [40]:
n = int(0.9*len(data))

In [42]:
train_data = data[:n]
val_data = data[n:]

In [43]:
block_size = 8


In [44]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [46]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:block_size+i] for i in ix])
    y = torch.stack([data[i+1:block_size+i+1] for i in ix])
    return x,y

xb,yb = get_batch("train")
print("inputs")
print(xb.shape)
print(xb)
print("targets")
print(yb.shape)
print(yb)



inputs
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [57]:
class BigramLangModel(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.token_embd_table = nn.Embedding(vocab_size,vocab_size)
    
    def forward(self,idx,targets=None):
        logits = self.token_embd_table(idx) # (Batch,Block - Time,Channels-hold the unnormalized prob vals)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
        return logits,loss

    def generate(self,idx,max_new_tokens):
        # idx is array of idx of shape(B,T)
        for i in range(max_new_tokens):
            logits,loss = self(idx)
            # concerned with only the last pred of the entire block
            # contains only the last value of the sequence of each batch
            logits = logits[:,-1,:] #(B,C)
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs,num_samples=1) #(B,1)
            # Adding the predictions to the end one char at a time (B,T+1)
            idx = torch.cat((idx,idx_next),dim=1) 
        return idx
m = BigramLangModel()
logits,loss = m(xb,yb)
print(logits.shape)
print(loss)
input = torch.zeros((1,1),dtype=torch.long)
max_new_tokens = 100
output = m.generate(input,max_new_tokens=max_new_tokens)[0].tolist()
print(''.join(decode(output)))


torch.Size([32, 65])
tensor(4.6351, grad_fn=<NllLossBackward0>)

VuWj;XdeaV,JwK,JugYjGpFkIQ-!xu3pphoQoX.GzW JmcpeHFRl!NKFn,NE,kAAd
mtBGXdwGntRDx3f!iJE;kgEEedw3SuOMn!


In [59]:
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

In [62]:
batch_size = 32
for steps in range(1000000000000):
    xb,yb = get_batch('train')
    logits,loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

KeyboardInterrupt: 

In [63]:
input = torch.zeros((100,1),dtype=torch.long)
max_new_tokens = 1000
output = m.generate(input,max_new_tokens=max_new_tokens)[0].tolist()
print(''.join(decode(output)))


Theaicaloo ourowanef oud h'l ll pacquthisafan:
O:
LO fetilalpamyoff isththeasth! wifang ay womus Hom!

F es, thowir ged f n mn,
Thodyoowemom, w man, coflo.
I 'd ns nd. anoout t,
Fr wigrel Cove t my g ouce,
An arpissan CHus the hoiarsslacke y; r I y coby nds;
S:
Wancuar t e th y bestre,
Who y bllene s,
n.

RKINofer, ano m meayos tth I lein y senot t wn
punftr ve'd, ay th teldustr har IO,
CEESe harr he t IETheawindust cre sw ICrin
By chor sour, f;

UTok ties caroul.
T: h urddshen, bead myst a ICAnd his:
We a Foutou ayond suld y s fomord hehmessagdr tend wane ad s fas!
MIIE:
HUpowouruber I gsthest pengimuthe bes! sing.

Th bis on pof ll wigh dowil sh hareas?
QUK:
I bas ll thou he p amon hin sedallloulin.

'd ses by utat:
Wh we t thouo in wiveseit o tse, ombe;
Tharep kilo, kainttheain k!
WERor hayowre acurkeatsshanad llato st
HI, t w pild, har he g wice urathofondyoesthe aty; IN:
The t meno hy s.
S:
Tu, tore Se h to ce Bundous tasthiono thar w;
Th ce.

STo:

D whewarsushenowengouldgheseth

In [64]:
input = torch.zeros((1,1),dtype=torch.long)
max_new_tokens = 500
output = m.generate(input,max_new_tokens=max_new_tokens)[0].tolist()
print(''.join(decode(output)))


WIEL:
ARDWer w gadsand:
sel, mepr, cerixt ck-bat t Weresisiloun D:
Tasein, lodot m f'dingan, h w:

Himmer f fo ther touthil.
topo, grobalchentlolerd the the; thar wof fiennd ort. nk may, thechy, m tyontu
S:
DUSar hind aisas, tanoureleswiomathard;
ADIt th, ht. wio ar
ica re wh d f.
yshoy wotanelo My ork
Yopan al ord in scut,
re thastrent et m icengenk rove indourn angimboonckeathor
IZWAn beavene ackis ve'd, ove ourea aned yougote.
Fiks m!'s?
ARMy amyom s h wanesou.
CO, senthat iofaveas rim,
the--
