In [1]:
import torch
import sys
import math
import os
sys.path.append('../')  

from Classes.tokenizer import Tokenizer as T
from Classes.myGPT import Model
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == 'cuda':
    torch.backends.cudnn.benchmark = True
    
data_path = f'data/tokenized_inputs/'
runs_path = f'runs/'

block_size = 128
batch_size = 32
n_heads = 6
n_layers = 10
d_model = 768
dropout = 0.2
learning_rate = 3e-4
epochs = 1
eval_iters = 25
vocab_size = 15_000
max_iters = 100_000
dff = n_heads * 4

static_attributes = {
    'vocab_size': vocab_size,
    'n_heads': n_heads,
    'n_layers': n_layers,
    'device': device,
    'd_model': d_model,
    'batch_size': batch_size,
    'epochs': epochs,
    'eval_iters': eval_iters,
    'learning_rate': learning_rate,
    'dropout': dropout,
    'block_size': block_size,
    'dff': dff,
}

# FOR GENERATOR
def make_batches(data, block_size, batch_size):
    data_len = len(data)
    last_ix = 0
    
    while last_ix + batch_size * block_size <= data_len:
        batch_X = []
        batch_Y = []
        
        for _ in range(batch_size):
            X = data[last_ix: last_ix + block_size]
            Y = data[last_ix + 1: last_ix + block_size + 1]
            batch_X.append(X)
            batch_Y.append(Y)
            last_ix += block_size
        yield torch.stack(batch_X), torch.stack(batch_Y)

#FOR GENERATOR
def estimate_loss(m, train, val, block_size, batch_size, eval_iters):
    def calculate_loss(data):
        l = []
        counter = 0
        for x, y in make_batches(data, block_size, batch_size):
            if counter >= eval_iters:
                break
            x, y = x.to(device), y.to(device)
            _, loss = m(x, y)
            l.append(loss.item())
            counter += 1
        return sum(l) / len(l) if len(l) > 0 else 0.0
    
    m.eval()
    with torch.no_grad():
        train_loss = calculate_loss(train)
        val_loss = calculate_loss(val)
    m.train()
    return train_loss, val_loss

def calculate_total_batches(data, block_size, batch_size):
    return len(data) // (block_size * batch_size)

# Custom learning rate scheduler
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    coeff = 0.5 * (1.0 + math.cos(math.pi * it / max_iters))
    return learning_rate * coeff

gradient_accumulation_steps = 4  # used to simulate larger batch sizes
learning_rate = 5e-4  # max learning rate
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
warmup_iters = 1_000  # how many steps to warm up for

m = Model(vocab_size=vocab_size, 
          block_size=block_size,
          dropout=dropout,
          dff=dff,
          n_heads=n_heads, 
          d_model=d_model,
          n_layers=n_layers,).to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
n_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"Model has {n_params:,} trainable parameters")
t = T()
n_files = 0
val = torch.load(f'{data_path}/val.pt')
files = os.listdir(data_path) 

Model has 47,153,544 trainable parameters


In [2]:
writer = SummaryWriter()
steps = 1
for file in files:
    if file.startswith('tns'):
        train = torch.load(data_path + file)
        total_batches = calculate_total_batches(train, block_size, batch_size)
        train_dl = make_batches(train, block_size, batch_size)

        for epoch, (Xb, Yb) in enumerate(tqdm(train_dl, total=total_batches)):

            for param_group in optimizer.param_groups:
                param_group['lr'] = get_lr(steps)

            Xb, Yb = Xb.to(device), Yb.to(device)
            logits, loss = m(Xb, Yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss/train', loss, steps)
            steps += 1

            if (epoch+1) % 100 == 0:
                _, val_loss = estimate_loss(m, train, val, block_size, batch_size, eval_iters)
                writer.add_scalar('Loss/val', val_loss, steps)

        train_loss, val_loss = estimate_loss(m, train, val, block_size, batch_size, eval_iters)

        if steps >= max_iters:
            break

# save torch model
torch.save(m.state_dict())

100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 2.814. Val Loss: 2.806


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 2.328. Val Loss: 2.428


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 2.161. Val Loss: 2.265


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 2.008. Val Loss: 2.165


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.939. Val Loss: 2.091


100%|██████████| 1227/1227 [07:54<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.896. Val Loss: 2.043


100%|██████████| 1226/1226 [07:51<00:00,  2.60it/s]


Epoch: 1226 | Train Loss: 1.829. Val Loss: 2.006


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.806. Val Loss: 1.967


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.781. Val Loss: 1.948


100%|██████████| 226/226 [01:25<00:00,  2.63it/s]


Epoch: 226 | Train Loss: 1.575. Val Loss: 1.939


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.713. Val Loss: 1.914


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.736. Val Loss: 1.901


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.719. Val Loss: 1.875


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.702. Val Loss: 1.863


100%|██████████| 1227/1227 [07:54<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.689. Val Loss: 1.844


100%|██████████| 1228/1228 [07:54<00:00,  2.59it/s]


Epoch: 1228 | Train Loss: 1.679. Val Loss: 1.828


100%|██████████| 1227/1227 [07:54<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.641. Val Loss: 1.821


100%|██████████| 1227/1227 [07:56<00:00,  2.58it/s]


Epoch: 1227 | Train Loss: 1.632. Val Loss: 1.806


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.604. Val Loss: 1.788


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.595. Val Loss: 1.776


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.600. Val Loss: 1.766


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.605. Val Loss: 1.759


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.552. Val Loss: 1.751


100%|██████████| 1228/1228 [07:53<00:00,  2.59it/s]


Epoch: 1228 | Train Loss: 1.559. Val Loss: 1.740


100%|██████████| 1228/1228 [07:54<00:00,  2.59it/s]


Epoch: 1228 | Train Loss: 1.577. Val Loss: 1.732


100%|██████████| 1228/1228 [07:54<00:00,  2.59it/s]


Epoch: 1228 | Train Loss: 1.532. Val Loss: 1.719


100%|██████████| 1228/1228 [07:53<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.518. Val Loss: 1.717


100%|██████████| 1227/1227 [07:57<00:00,  2.57it/s]


Epoch: 1227 | Train Loss: 1.532. Val Loss: 1.708


100%|██████████| 1227/1227 [07:54<00:00,  2.58it/s]


Epoch: 1227 | Train Loss: 1.527. Val Loss: 1.696


100%|██████████| 1227/1227 [07:54<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.519. Val Loss: 1.689


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.536. Val Loss: 1.684


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.482. Val Loss: 1.674


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.485. Val Loss: 1.668


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.494. Val Loss: 1.660


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.472. Val Loss: 1.656


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.479. Val Loss: 1.649


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.485. Val Loss: 1.642


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.506. Val Loss: 1.638


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.494. Val Loss: 1.631


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.468. Val Loss: 1.622


100%|██████████| 1227/1227 [07:53<00:00,  2.59it/s]


Epoch: 1227 | Train Loss: 1.462. Val Loss: 1.614


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.470. Val Loss: 1.609


100%|██████████| 1226/1226 [07:52<00:00,  2.59it/s]


Epoch: 1226 | Train Loss: 1.484. Val Loss: 1.606


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.443. Val Loss: 1.602


100%|██████████| 1226/1226 [07:52<00:00,  2.59it/s]


Epoch: 1226 | Train Loss: 1.455. Val Loss: 1.596


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.473. Val Loss: 1.589


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.451. Val Loss: 1.587


100%|██████████| 1227/1227 [07:52<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.488. Val Loss: 1.579


100%|██████████| 1226/1226 [07:50<00:00,  2.60it/s]


Epoch: 1226 | Train Loss: 1.425. Val Loss: 1.573


100%|██████████| 1228/1228 [07:51<00:00,  2.61it/s]


Epoch: 1228 | Train Loss: 1.424. Val Loss: 1.571


100%|██████████| 1228/1228 [07:51<00:00,  2.61it/s]


Epoch: 1228 | Train Loss: 1.444. Val Loss: 1.566


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.472. Val Loss: 1.557


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.439. Val Loss: 1.553


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.407. Val Loss: 1.550


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.439. Val Loss: 1.546


100%|██████████| 1228/1228 [07:51<00:00,  2.61it/s]


Epoch: 1228 | Train Loss: 1.440. Val Loss: 1.541


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.436. Val Loss: 1.538


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.424. Val Loss: 1.533


100%|██████████| 1226/1226 [07:50<00:00,  2.60it/s]


Epoch: 1226 | Train Loss: 1.436. Val Loss: 1.528


100%|██████████| 1226/1226 [07:50<00:00,  2.60it/s]


Epoch: 1226 | Train Loss: 1.443. Val Loss: 1.526


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.425. Val Loss: 1.523


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.458. Val Loss: 1.520


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.413. Val Loss: 1.516


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.471. Val Loss: 1.514


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.442. Val Loss: 1.511


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.440. Val Loss: 1.507


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.445. Val Loss: 1.504


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.477. Val Loss: 1.501


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.436. Val Loss: 1.500


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.448. Val Loss: 1.497


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.458. Val Loss: 1.495


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.500. Val Loss: 1.493


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.447. Val Loss: 1.492


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.443. Val Loss: 1.491


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.480. Val Loss: 1.490


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.445. Val Loss: 1.489


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.464. Val Loss: 1.489


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.471. Val Loss: 1.488


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.476. Val Loss: 1.488


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.503. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.457. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.475. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.489. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.509. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.489. Val Loss: 1.487


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.453. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.447. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.493. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.468. Val Loss: 1.487


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.441. Val Loss: 1.487


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.451. Val Loss: 1.488


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.440. Val Loss: 1.488


100%|██████████| 1227/1227 [07:51<00:00,  2.60it/s]


Epoch: 1227 | Train Loss: 1.494. Val Loss: 1.489


100%|██████████| 1228/1228 [07:51<00:00,  2.61it/s]


Epoch: 1228 | Train Loss: 1.444. Val Loss: 1.489


100%|██████████| 1228/1228 [07:52<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.429. Val Loss: 1.490


100%|██████████| 1228/1228 [07:51<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.425. Val Loss: 1.491


100%|██████████| 1228/1228 [07:53<00:00,  2.60it/s]


Epoch: 1228 | Train Loss: 1.414. Val Loss: 1.492


100%|██████████| 1226/1226 [07:50<00:00,  2.60it/s]


Epoch: 1226 | Train Loss: 1.455. Val Loss: 1.492


100%|██████████| 1229/1229 [07:51<00:00,  2.61it/s]


Epoch: 1229 | Train Loss: 1.417. Val Loss: 1.494


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.404. Val Loss: 1.494


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.414. Val Loss: 1.496


100%|██████████| 1227/1227 [07:50<00:00,  2.61it/s]


Epoch: 1227 | Train Loss: 1.447. Val Loss: 1.498


 24%|██▎       | 291/1227 [01:48<05:47,  2.69it/s]


KeyboardInterrupt: 

In [6]:
seed_text = "Once upon a time"
tokenized_seed = t.encode(seed_text,False,False)  # Make sure to use your actual tokenization method
tokenized_seed = torch.tensor(tokenized_seed).unsqueeze(0).to(device)  # Add a batch dimension

# Generate 1000 new tokens
generated_tokens = m.generate(tokenized_seed, max_new_tokens=1000)
decoded_text = t.decode(generated_tokens[0].tolist())
decoded_text[256:]

'⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ ly ⁇  parked with ⁇  Sarah ⁇  A lady ⁇  her mom ⁇  and Sarah ⁇ s mom ⁇ .  ⁇ There is no park right here, ⁇  Sarah ⁇ s little girl  ⁇ Ok, ⁇  "But why do you need to be the first to be the first to follow the orders?" Sarah ⁇ s mom ⁇ s eyes lit up ⁇  "Of course! I ⁇ ll help you get the right number  ⁇  to follow the way he is in such fun and safe". Sarah smiled and said "That\'s the way I\'ve to follow the laws of the park!". <|endoftext|> One day, a little boat named Bobby went for a harbor. Bobby was a very independent boat. He loved to go on new water and play with the other boats. Bobby saw a big boat named  ⁇ room, and he wanted to be friends too. Bobby said to  ⁇ room, "Look at my new toy! I\'m big and I can play with your boat!"  ⁇ room said, "Okay, let\'s play together!" But Bobby did not want to play with  ⁇ room, so he was having too much fun.  ⁇ room asked Bobby