In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from src.torch.models.gpt import GPTParams, GPT
from src.torch.models.gpt import generate_attention_mask

from src.tokenizers.tokenizer import Tokenizer
from src.torch.utils.lm_dataset import LanguageModelingDataset

from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

DEVICE = "cpu"

In [5]:
tokenizer = Tokenizer.from_pretrained("tokenizers/tokenizer_ru_toxics/")

In [6]:
dataset = LanguageModelingDataset(torch.load("preprocessed_datasets/ru_toxics_tokens"))

In [7]:
params = GPTParams(
    vocab_size=len(tokenizer.vocab),
    context_size=dataset.tokens.shape[1],
    input_dim=312,
    query_dim=32,
    value_dim=32,
    feed_forward_hidden_dim=1024,
    n_heads=4,
    n_decoder_blocks=6,
    device=DEVICE
)

BATCH_SIZE = 16
N_EPOCHS = 10
LEARNING_RATE = 2e-5
MASK = generate_attention_mask(params.context_size - 1).to(DEVICE)

dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
model = GPT(params)

optimizer = Adam(lr=LEARNING_RATE, params=model.parameters())
loss_func = CrossEntropyLoss()

In [6]:
losses = []

for epoch in tqdm(range(N_EPOCHS)):
    for batch in tqdm(dataloader, desc=f"Epoch #{epoch}"):
        inputs, labels = batch.values()

        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(inputs, MASK)
        logits = logits.view(-1, logits.size(-1))

        loss = loss_func(logits, labels.flatten())
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.item())

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch #0:   0%|          | 0/3322 [00:00<?, ?it/s]

KeyboardInterrupt: 