In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from models import BiagramLanguageModel , Head
import config

torch.manual_seed(1337)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "C:\Users\sachi\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\sachi\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "c:\users\sachi\onedrive\documents\github\envs\funchat\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "c:\users\sachi\onedrive\documents\github\envs\funchat\lib\site-packages\traitlets

<torch._C.Generator at 0x250c54e4930>

In [2]:
with open('text.txt', 'r', encoding='utf-8') as f:
    text = f.read()


chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


data = torch.tensor(encode(text))
n = int(0.9*len(text))
data_train = data[:n]
data_val = data[n:]


def get_batch(split):
    data = data_train if split == "train" else data_val
    idxs = torch.randint(len(data)-config.block_size, (config.batch_size,))
    x = torch.stack([data[i:i+config.block_size] for i in idxs])
    y = torch.stack([data[i+1:i+config.block_size+1] for i in idxs])
    x , y = x.to(config.device) , y.to(config.device)
    return x , y


In [3]:
model = BiagramLanguageModel(vocab_size).to(config.device)

In [4]:
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

0.021697 M parameters


## Training

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
eval_interval = 100
max_iters = 10000
eval_iters = 200

def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()



step 0: train loss 4.1587, val loss 4.1598
step 100: train loss 2.9403, val loss 2.9487
step 200: train loss 2.7542, val loss 2.7605
step 300: train loss 2.6176, val loss 2.6193
step 400: train loss 2.5121, val loss 2.5231
step 500: train loss 2.4573, val loss 2.4680
step 600: train loss 2.4344, val loss 2.4418
step 700: train loss 2.4144, val loss 2.4300
step 800: train loss 2.4035, val loss 2.4124
step 900: train loss 2.3958, val loss 2.4081
step 1000: train loss 2.3871, val loss 2.3980
step 1100: train loss 2.3770, val loss 2.3901
step 1200: train loss 2.3718, val loss 2.3907
step 1300: train loss 2.3666, val loss 2.3827
step 1400: train loss 2.3647, val loss 2.3836
step 1500: train loss 2.3616, val loss 2.3791
step 1600: train loss 2.3555, val loss 2.3777
step 1700: train loss 2.3524, val loss 2.3763
step 1800: train loss 2.3448, val loss 2.3695
step 1900: train loss 2.3401, val loss 2.3712
step 2000: train loss 2.3360, val loss 2.3697
step 2100: train loss 2.3366, val loss 2.3683


In [6]:
temp = torch.stack([data[50:80],data[80:110]]).to(config.device)


ans = model.generate(temp , 500)

In [7]:
ans = [decode(x.tolist()) for x in ans]

In [8]:
ans

[' me speak.\n\nAll:\nSpeak, speak.\n\nCETH:\nRorid owingh is sowr thad set bobe toe.\nSthr-and mealild\nhy ar highe us hathe.\nWar dilthoate awice my.\n\nHAER:\nAy onoug\nYowno, tof it he me milfllill, aes iree sen cin lat Het drovets, heen me nghhoulerans!\nel lind te lllliser cechiry:\nSupr aisspll, ye whe nes normopeeelaves\nMomy ll, dem thakeeo Windo whre eiingh wisti fourive wees ime st sousower; th\nhe kind thrupirf son; igis! muf thin inle ont ffaf Pre?\n\nWASo myr figuea!\n\nWied isad adsal this ghe thinin cme amar tey Ire ts I fr tho!\nMy',
 "\n\nFirst Citizen:\nYou are all reacre hinf ty ancle! wa het fere a nou th blad at netresopatrisicmo ignus Ceras thory thind nd ad yeplithur,\nSeciulllloule. But, hacod,\nEnke, if lon will tho asy tas sthis rgot.\n\nTEBRUSe lileti may her ches, thend dul dy ce fe Ges, ter perat hiroos Rlothu nom, I'ighh fithe, hine, amen,\nTillotois paly ow Eenaldo ald bore ango Citrer be toow.\n\nMy dor cturew dsinth.\n\nBUCELONTALUS:\nI thesim orwaen l