# Test Transformer & Trainer

In this notebook, I will test my implementation of Transformer and Trainer

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.lion.lion import Lion
from src.transformer.transformer import Transformer
from src.training.trainer import Trainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
from datasets import load_dataset

ds = load_dataset("Qilex/TinyStories_10M")

In [3]:
text_data = ds['train']['text']

In [4]:
from tqdm import tqdm
from transformers import T5TokenizerFast

tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [5]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=2048):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = self.tokenizer.encode(
            self.texts[idx],
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
            padding='max_length'
        ).squeeze(0)

        src = tokens[:-1]
        tgt = tokens[1:]

        src_mask = (src != self.tokenizer.pad_token_id).float()

        return {
            "src": src,
            "tgt": tgt,
            "src_mask": src_mask,
        }

In [6]:
text_dataset = TextDataset(texts=text_data, tokenizer=tokenizer)

In [7]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    src = pad_sequence([item["src"] for item in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
    tgt = pad_sequence([item["tgt"] for item in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
    src_mask = pad_sequence([item["src_mask"] for item in batch], batch_first=True, padding_value=0)

    return {
        "src": src,
        "tgt": tgt,
        "src_mask": src_mask,
    }

batch_size = 2
data_loader = DataLoader(text_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [8]:
for batch in data_loader:
    print(batch)
    print(batch['src'].shape)
    print(batch['tgt'].shape)
    print(batch['src_mask'].shape)
    break

{'src': tensor([[1447, 1286,    3,  ...,    0,    0,    0],
        [1447, 1286,    3,  ...,    0,    0,    0]]), 'tgt': tensor([[1286,    3,    9,  ...,    0,    0,    0],
        [1286,    3,    9,  ...,    0,    0,    0]]), 'src_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]])}
torch.Size([2, 2047])
torch.Size([2, 2047])
torch.Size([2, 2047])


In [9]:
from torch.utils.data import random_split

split_ratio = 0.8
size = len(text_dataset)
train_size = int(size * split_ratio)
eval_size = size - train_size

train_dataset, eval_dataset = random_split(text_dataset, [train_size, eval_size])

In [10]:
from src.training.metrics import compute_metrics

model = Transformer(
    num_layers=6,
    d_model=512,
    num_heads=8,
    d_ff=256,
    input_dim=tokenizer.vocab_size,
    output_dim=tokenizer.vocab_size,
    max_len=2048,
)

trainer = Trainer(
    model=model,
    num_epochs=100,
    batch_size=4,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_fn=F.cross_entropy,
    metrics_fn=compute_metrics,
    optimizer=Lion,
    optimizer_kwargs=dict(lr=1e-4),
    device=device,
    collate_fn=collate_fn,
)

In [11]:
trainer.train()

Epoch 1/100:   0%|          | 0/12461 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8188x128 and 256x512)