In [1]:
import torch

import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

import matplotlib.pyplot as plt

from src.mlx.models.gpt_mlx import GPT
from src.mlx.models.gpt_mlx import generate_attention_mask
from src.mlx.models.gpt_mlx import GPTParams
from src.mlx.utils.mlx_lm_dataset import LanguageModelingDataset

from src.tokenizers.tokenizer import Tokenizer

from tqdm.notebook import tqdm

In [2]:
tokenizer = Tokenizer.from_pretrained("tokenizers/tokenizer_ru_toxics/")

In [3]:
params = GPTParams(
    vocab_size=len(tokenizer.vocab),
    context_size=128,
    input_dim=312,
    query_dim=32,
    value_dim=32,
    feed_forward_hidden_dim=1024,
    n_heads=4,
    n_decoder_blocks=6,
)

BATCH_SIZE = 16
N_EPOCHS = 1
LEARNING_RATE = 2e-5
MASK = generate_attention_mask(params.context_size - 1)

In [4]:
model = GPT(params)

In [5]:
tokens = torch.load("preprocessed_datasets/ru_toxics_tokens")
tokens = mx.array(tokens.tolist())

In [6]:
dataset = LanguageModelingDataset(tokens)
dataloader = dataset.to_dataloader(batch_size=BATCH_SIZE)

In [7]:
def loss_func(model, inputs, labels):
    logits = model(inputs, MASK)
    logits = logits.reshape(-1, logits.shape[2])
    
    return mx.mean(nn.losses.cross_entropy(logits, labels.flatten()))

optimizer = optim.Adam(learning_rate=LEARNING_RATE)

loss_and_grad_fn = nn.value_and_grad(model, loss_func)

In [11]:
for epoch in range(N_EPOCHS):
    for batch in tqdm(dataloader, total=np.ceil(len(dataset) / BATCH_SIZE).astype(int)):
        inputs, labels = batch
        inputs, labels = mx.array(inputs), mx.array(labels)

        loss, grads = loss_and_grad_fn(model, inputs, labels)

        optimizer.update(model, grads)

  0%|          | 0/3322 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
for batch in dataloader:
    break

In [10]:
inputs, labels = batch

inputs, labels = mx.array(inputs), mx.array(labels)

In [11]:
loss, grads = loss_and_grad_fn(model, inputs, labels)

In [12]:
loss

array(7.9926, dtype=float32)