In [1]:
import sys
sys.path.append('..')

In [2]:
from minbpe import RegexTokenizer

tokenizer = RegexTokenizer()
tokenizer_path = "../output/tokenizer/darija_tokenizer.model"
tokenizer.load(model_file=tokenizer_path)


def get_vocab_size(tokenizer: RegexTokenizer) -> int:
    vocab = tokenizer.vocab
    special_tokens = tokenizer.special_tokens

    return len(vocab) + len(special_tokens)

In [3]:
tokenizer.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

## Create the model

In [4]:
import torch
torch.manual_seed(3647)

<torch._C.Generator at 0x7bfabe5417d0>

In [None]:
from transformer.pedro_model import GPTLanguageModel
from transformer import BASE_CONFIG, selConfig

selConfig('gpt2-medium (355M)')

block_size = BASE_CONFIG['context_length']
n_embd = BASE_CONFIG['emb_dim']
n_head = BASE_CONFIG['n_heads']
n_layer = BASE_CONFIG['n_layers']
dropout = BASE_CONFIG['dropout']
batch_size = 2
vocab_size = get_vocab_size(tokenizer)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = GPTLanguageModel(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
    device=device
).to(device)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

{'vocab_size': 50257, 'context_length': 1024, 'dropout': 0.2, 'qkv_bias': True, 'emb_dim': 1024, 'n_layers': 24, 'n_heads': 16}
336.885774 M parameters


In [6]:
model

GPTLanguageModel(
  (token_embedding_table): Embedding(16398, 1024)
  (position_embedding_table): Embedding(1024, 1024)
  (blocks): Sequential(
    (0): Block(
      (self_attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-15): 16 x Head(
            (key): Linear(in_features=1024, out_features=64, bias=False)
            (query): Linear(in_features=1024, out_features=64, bias=False)
            (value): Linear(in_features=1024, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (projection): Linear(in_features=1024, out_features=1024, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feed_forward): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU()
          (2): Linear(in_features=4096, out_features=1024, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (layer_norm_1

## Load the data

In [7]:
import numpy as np

data_path = "../output/encoded_data/encoded_atlaset.npy"
data = np.load(data_path, mmap_mode='r')
print('Data shape:', data.shape)

Data shape: (452560,)


In [8]:
split_index = int(0.9*len(data))
split_index

407304

## Helper functions

In [9]:
from typing import Tuple


def get_batch(split: str) -> Tuple[torch.Tensor, torch.Tensor]:
    # generate a small batch of data of inputs x and targets y
    if split == 'train':
        start_index = 0
        end_index = split_index
    else:
        start_index = split_index
        end_index = len(data)

    index = torch.randint(start_index, end_index - block_size, (batch_size,))
    x_batch, y_batch = [], []
    for i in index:
        x_batch.append(data[i:i+block_size])
        y_batch.append(data[i+1:i+block_size+1])

    x_batch = np.array(x_batch)
    y_batch = np.array(y_batch)

    x_batch = torch.tensor(x_batch, dtype=torch.long).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.long).to(device)

    return x_batch, y_batch

In [10]:
from typing import Dict


@torch.no_grad()
def estimate_loss() -> Dict:
    output = {}
    eval_iters = 1000
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[k] = loss.item()
        output[split] = losses.mean()
    model.train()
    return output

In [11]:
def save_checkpoint(
    model: GPTLanguageModel,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    file_path: str = "checkpoint.pth"
) -> None:
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, file_path)

## Training

In [None]:
torch.cuda.empty_cache()

: 

In [26]:
from tqdm import tqdm

torch.set_float32_matmul_precision('high')

gradient_accumulation_steps = 8
eval_interval = 1000
save_interval = 10000

# equivalent to len(data) - block_size
total_data_to_process = split_index - block_size
total_data_to_process_in_batches = total_data_to_process // batch_size

learning_rate = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

batches_processed = 0
train_losses, val_losses = [], []
optimizer.zero_grad(set_to_none=True)
for i in tqdm(
    iterable=range(0, total_data_to_process, batch_size),
    desc="Processing",
    total=total_data_to_process_in_batches
):
    # Load a batch of data
    x_batch, y_batch = [], []
    for j in range(i, i+batch_size):
        x_batch.append(data[j:j+block_size])
        y_batch.append(data[j+1:j+block_size+1])

    x_batch = np.array(x_batch)
    y_batch = np.array(y_batch)

    x_batch = torch.tensor(x_batch, dtype=torch.long).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.long).to(device)

    # Forward pass
    logits, loss = model(x_batch, y_batch)
    loss /= gradient_accumulation_steps
    loss.backward()

    # Gradient accumulation
    batches_processed += 1
    if batches_processed % gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    # Evaluate the model
    if batches_processed % eval_interval == 0:
        losses = estimate_loss()
        print(
            f"Batch {batches_processed}: "
            f"train loss {losses['train']:.4f}, "
            f"val loss {losses['val']:.4f}"
        )
        train_losses.append(losses['train'])
        val_losses.append(losses['val'])

    # Save the model
    if batches_processed % save_interval == 0:
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            epoch=batches_processed,
            loss=loss.item(),
            file_path=f"../output/pre_training/run_11/checkpoint_{batches_processed}.pth"
        )

if batches_processed % gradient_accumulation_steps != 0:
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

Processing:   0%|          | 1002/203140 [02:46<559:35:04,  9.97s/it] 

Batch 1000: train loss 8.8364, val loss 8.9887


Processing:   1%|          | 2002/203140 [05:34<552:26:26,  9.89s/it] 

Batch 2000: train loss 9.2697, val loss 9.4190


Processing:   1%|▏         | 3002/203140 [08:20<550:37:53,  9.90s/it] 

Batch 3000: train loss 9.7622, val loss 9.9576


Processing:   2%|▏         | 4002/203140 [11:08<551:04:53,  9.96s/it] 

Batch 4000: train loss 9.9918, val loss 10.2262


Processing:   2%|▏         | 5002/203140 [13:55<544:31:08,  9.89s/it] 

Batch 5000: train loss 9.9077, val loss 10.1444


Processing:   3%|▎         | 6002/203140 [16:42<546:07:57,  9.97s/it] 

Batch 6000: train loss 10.1379, val loss 10.3038


Processing:   3%|▎         | 7002/203140 [19:30<541:14:58,  9.93s/it] 

Batch 7000: train loss 10.1546, val loss 10.3824


Processing:   4%|▍         | 8002/203140 [22:16<533:35:21,  9.84s/it]

Batch 8000: train loss 10.3148, val loss 10.5464


Processing:   4%|▍         | 9002/203140 [25:04<537:58:17,  9.98s/it]

Batch 9000: train loss 10.4091, val loss 10.6450


Processing:   5%|▍         | 9999/203140 [26:51<5:41:34,  9.42it/s]  

Batch 10000: train loss 10.1710, val loss 10.3744


Processing:   5%|▌         | 11002/203140 [30:39<524:33:09,  9.83s/it]

Batch 11000: train loss 10.8984, val loss 11.1518


Processing:   6%|▌         | 12002/203140 [33:24<521:12:24,  9.82s/it]

Batch 12000: train loss 10.5420, val loss 10.7526


Processing:   6%|▋         | 13002/203140 [36:11<524:19:14,  9.93s/it]

Batch 13000: train loss 10.7335, val loss 11.0229


Processing:   7%|▋         | 14002/203140 [38:58<524:23:29,  9.98s/it]

Batch 14000: train loss 10.7016, val loss 11.0130


Processing:   7%|▋         | 15002/203140 [41:45<515:46:33,  9.87s/it]

Batch 15000: train loss 10.3583, val loss 10.6376


Processing:   8%|▊         | 16002/203140 [44:33<511:16:14,  9.84s/it]

Batch 16000: train loss 10.5111, val loss 10.7844


Processing:   8%|▊         | 17002/203140 [47:19<515:35:54,  9.97s/it]

Batch 17000: train loss 10.9179, val loss 11.2288


Processing:   9%|▉         | 18002/203140 [50:07<513:14:54,  9.98s/it]

Batch 18000: train loss 10.6125, val loss 10.9614


Processing:   9%|▉         | 19002/203140 [52:54<502:27:36,  9.82s/it]

Batch 19000: train loss 11.1777, val loss 11.6232


Processing:  10%|▉         | 19999/203140 [54:39<5:19:53,  9.54it/s]  

Batch 20000: train loss 10.7890, val loss 11.0736


Processing:  10%|█         | 21002/203140 [58:27<503:08:17,  9.94s/it]

Batch 21000: train loss 11.2303, val loss 11.6890


Processing:  11%|█         | 22002/203140 [1:01:14<502:53:27,  9.99s/it]

Batch 22000: train loss 11.2662, val loss 11.6514


Processing:  11%|█▏        | 23002/203140 [1:04:02<501:05:27, 10.01s/it]

Batch 23000: train loss 11.1808, val loss 11.5100


Processing:  12%|█▏        | 24002/203140 [1:06:50<492:46:51,  9.90s/it]

Batch 24000: train loss 11.2231, val loss 11.6648


Processing:  12%|█▏        | 25002/203140 [1:09:38<492:32:14,  9.95s/it]

Batch 25000: train loss 11.0690, val loss 11.3552


Processing:  13%|█▎        | 26002/203140 [1:12:25<487:12:06,  9.90s/it]

Batch 26000: train loss 11.0955, val loss 11.4315


Processing:  13%|█▎        | 27002/203140 [1:15:12<487:39:40,  9.97s/it]

Batch 27000: train loss 11.2047, val loss 11.6925


Processing:  14%|█▍        | 28002/203140 [1:17:59<477:11:34,  9.81s/it]

Batch 28000: train loss 11.1704, val loss 11.4702


Processing:  14%|█▍        | 29002/203140 [1:20:45<479:34:31,  9.91s/it]

Batch 29000: train loss 11.2261, val loss 11.6258


Processing:  15%|█▍        | 29999/203140 [1:22:30<5:07:02,  9.40it/s]  

Batch 30000: train loss 10.8345, val loss 11.2438


Processing:  15%|█▌        | 31002/203140 [1:26:20<479:28:25, 10.03s/it]

Batch 31000: train loss 10.8668, val loss 11.2571


Processing:  16%|█▌        | 32002/203140 [1:29:08<468:42:53,  9.86s/it]

Batch 32000: train loss 11.2964, val loss 11.6316


Processing:  16%|█▌        | 33002/203140 [1:31:54<462:23:06,  9.78s/it]

Batch 33000: train loss 11.2474, val loss 11.6119


Processing:  17%|█▋        | 34002/203140 [1:34:40<464:21:47,  9.88s/it]

Batch 34000: train loss 11.4634, val loss 11.7738


Processing:  17%|█▋        | 35002/203140 [1:37:25<457:22:38,  9.79s/it]

Batch 35000: train loss 11.6599, val loss 12.1790


Processing:  18%|█▊        | 36002/203140 [1:40:11<461:56:34,  9.95s/it]

Batch 36000: train loss 11.8006, val loss 12.3035


Processing:  18%|█▊        | 37002/203140 [1:42:59<460:07:14,  9.97s/it]

Batch 37000: train loss 11.7089, val loss 12.0829


Processing:  19%|█▊        | 38002/203140 [1:45:47<459:46:28, 10.02s/it]

Batch 38000: train loss 11.5856, val loss 12.0529


Processing:  19%|█▉        | 39002/203140 [1:48:35<457:08:04, 10.03s/it]

Batch 39000: train loss 12.1000, val loss 12.5853


Processing:  20%|█▉        | 39999/203140 [1:50:22<4:47:49,  9.45it/s]  

Batch 40000: train loss 11.2447, val loss 11.7151


Processing:  20%|██        | 41002/203140 [1:54:11<451:40:24, 10.03s/it]

Batch 41000: train loss 12.0029, val loss 12.5644


Processing:  21%|██        | 42002/203140 [1:56:58<442:12:27,  9.88s/it]

Batch 42000: train loss 12.0068, val loss 12.5688


Processing:  21%|██        | 43002/203140 [1:59:46<437:49:41,  9.84s/it]

Batch 43000: train loss 12.0169, val loss 12.5965


Processing:  22%|██▏       | 44002/203140 [2:02:32<437:59:24,  9.91s/it]

Batch 44000: train loss 12.0426, val loss 12.5831


Processing:  22%|██▏       | 45002/203140 [2:05:17<430:31:26,  9.80s/it]

Batch 45000: train loss 11.7593, val loss 12.3518


Processing:  23%|██▎       | 46002/203140 [2:08:03<427:41:47,  9.80s/it]

Batch 46000: train loss 11.9505, val loss 12.4214


Processing:  23%|██▎       | 47002/203140 [2:10:48<425:22:29,  9.81s/it]

Batch 47000: train loss 12.4336, val loss 13.0183


Processing:  24%|██▎       | 48002/203140 [2:13:34<423:04:08,  9.82s/it]

Batch 48000: train loss 12.5088, val loss 13.1105


Processing:  24%|██▍       | 49002/203140 [2:16:22<428:26:36, 10.01s/it]

Batch 49000: train loss 12.0178, val loss 12.5876


Processing:  25%|██▍       | 49999/203140 [2:18:08<4:30:11,  9.45it/s]  

Batch 50000: train loss 11.9929, val loss 12.5964


Processing:  25%|██▌       | 51002/203140 [2:21:56<422:44:33, 10.00s/it]

Batch 51000: train loss 12.5582, val loss 13.2869


Processing:  26%|██▌       | 52002/203140 [2:24:43<414:21:52,  9.87s/it]

Batch 52000: train loss 12.0332, val loss 12.5240


Processing:  26%|██▌       | 53002/203140 [2:27:29<408:00:19,  9.78s/it]

Batch 53000: train loss 12.2188, val loss 12.8130


Processing:  27%|██▋       | 54002/203140 [2:30:16<413:57:57,  9.99s/it]

Batch 54000: train loss 12.8692, val loss 13.5555


Processing:  27%|██▋       | 55002/203140 [2:33:03<407:50:55,  9.91s/it]

Batch 55000: train loss 12.2992, val loss 12.9438


Processing:  28%|██▊       | 56002/203140 [2:35:51<410:19:39, 10.04s/it]

Batch 56000: train loss 12.5071, val loss 13.2518


Processing:  28%|██▊       | 57002/203140 [2:38:39<403:52:14,  9.95s/it]

Batch 57000: train loss 11.7964, val loss 12.5267


Processing:  29%|██▊       | 58002/203140 [2:41:24<393:57:36,  9.77s/it]

Batch 58000: train loss 12.4159, val loss 12.9738


Processing:  29%|██▉       | 59002/203140 [2:44:11<396:42:21,  9.91s/it]

Batch 59000: train loss 12.3507, val loss 12.9438


Processing:  30%|██▉       | 59999/203140 [2:45:57<4:12:39,  9.44it/s]  

Batch 60000: train loss 12.0350, val loss 12.6536


Processing:  30%|███       | 61002/203140 [2:49:46<395:48:31, 10.02s/it]

Batch 61000: train loss 12.4404, val loss 12.3951


Processing:  31%|███       | 62002/203140 [2:52:34<386:47:05,  9.87s/it]

Batch 62000: train loss 12.1945, val loss 12.2200


Processing:  31%|███       | 63002/203140 [2:55:21<388:26:28,  9.98s/it]

Batch 63000: train loss 12.3456, val loss 12.2884


Processing:  32%|███▏      | 64002/203140 [2:58:08<381:24:01,  9.87s/it]

Batch 64000: train loss 12.0875, val loss 11.9454


Processing:  32%|███▏      | 65002/203140 [3:00:54<379:38:55,  9.89s/it]

Batch 65000: train loss 12.5839, val loss 12.4020


Processing:  32%|███▏      | 66002/203140 [3:03:41<379:10:54,  9.95s/it]

Batch 66000: train loss 12.4237, val loss 12.1079


Processing:  33%|███▎      | 67002/203140 [3:06:29<379:06:36, 10.03s/it]

Batch 67000: train loss 12.1588, val loss 11.8565


Processing:  33%|███▎      | 68002/203140 [3:09:18<378:02:34, 10.07s/it]

Batch 68000: train loss 12.6163, val loss 12.3701


Processing:  34%|███▍      | 69002/203140 [3:12:08<373:30:03, 10.02s/it]

Batch 69000: train loss 12.5054, val loss 12.1844


Processing:  34%|███▍      | 69999/203140 [3:13:54<3:55:14,  9.43it/s]  

Batch 70000: train loss 11.8696, val loss 11.5072


Processing:  35%|███▍      | 71002/203140 [3:17:42<361:37:23,  9.85s/it]

Batch 71000: train loss 12.1058, val loss 11.7372


Processing:  35%|███▌      | 72002/203140 [3:20:27<357:27:18,  9.81s/it]

Batch 72000: train loss 11.9178, val loss 11.4090


Processing:  36%|███▌      | 73002/203140 [3:23:14<356:44:04,  9.87s/it]

Batch 73000: train loss 12.2435, val loss 11.9496


Processing:  36%|███▋      | 73999/203140 [3:25:17<5:58:16,  6.01it/s]  


KeyboardInterrupt: 

In [27]:
input_tokens = tokenizer.encode("Hola, como ")
input_tokens = torch.tensor(
    input_tokens, dtype=torch.long).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    output = model.generate(input_tokens=input_tokens, max_new_tokens=50)

print(tokenizer.decode(output[0].tolist()))

Hola, como 14, debemos votar a otros productos agrícola?
 Si haber tomado la vez más parte avanzada y año 69, presentada por haber tomado la Sra. Aparicio Sánchez por la iniciativa de él.
    –    – Aparicio S


In [28]:
save_checkpoint(
            model=model,
            optimizer=optimizer,
            epoch=batches_processed,
            loss=loss.item(),
            file_path=f"../output/pre_training/run_11/checkpoint_{batches_processed}.pth"
        )