In [1]:
import torch

import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

import matplotlib.pyplot as plt

from src.mlx.models.gpt_mlx import GPT
from src.mlx.models.gpt_mlx import generate_attention_mask
from src.mlx.models.gpt_mlx import GPTParams
from src.mlx.utils.mlx_lm_dataset import LanguageModelingDataset

from src.tokenizers.tokenizer import Tokenizer

from tqdm.notebook import tqdm

In [2]:
tokenizer = Tokenizer.from_pretrained("tokenizers/tokenizer_ru_toxics/")

In [3]:
params = GPTParams(
    vocab_size=len(tokenizer.vocab),
    context_size=128,
    input_dim=312,
    query_dim=32,
    value_dim=32,
    feed_forward_hidden_dim=1024,
    n_heads=4,
    n_decoder_blocks=6,
)

BATCH_SIZE = 16
N_EPOCHS = 1
LEARNING_RATE = 2e-5
MASK = generate_attention_mask(params.context_size - 1)

In [4]:
model = GPT(params)

In [5]:
tokens = torch.load("preprocessed_datasets/ru_toxics_tokens")
tokens = mx.array(tokens.tolist())

In [6]:
dataset = LanguageModelingDataset(tokens)
dataloader = dataset.to_dataloader(batch_size=BATCH_SIZE)

In [10]:
def loss_func(model, inputs, labels):
    logits = model(inputs, MASK)
    logits = logits.reshape(-1, logits.shape[2])
    
    return mx.mean(nn.losses.cross_entropy(logits, labels.flatten()))

optimizer = optim.Adam(learning_rate=LEARNING_RATE)

loss_and_grad_fn = nn.value_and_grad(model, loss_func)

In [11]:
for epoch in range(N_EPOCHS):
    for batch in tqdm(dataloader, total=np.ceil(len(dataset) / BATCH_SIZE).astype(int)):
        inputs, labels = batch
        inputs, labels = mx.array(inputs), mx.array(labels)

        loss, grads = loss_and_grad_fn(model, inputs, labels)

        optimizer.update(model, grads)

  0%|          | 0/3322 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
for batch in dataloader:
    break

In [12]:
inputs, labels = batch.values()

inputs, labels = mx.array(inputs), mx.array(labels)

In [13]:
loss, grads = loss_and_grad_fn(model, inputs, labels)

In [14]:
loss

array(8.05253, dtype=float32)

In [15]:
grads

{'embeddings': {'weight': array([[0.000822256, 0.00189436, -0.00115391, ..., 0.000859065, 0.00152886, -0.000539912],
         [-4.74873e-05, 0.000133152, 0.000549505, ..., 7.5604e-05, -0.000456788, -0.000232249],
         [-0.00126738, 0.000199548, -0.00223851, ..., 0.00135388, 0.00260444, 0.000812456],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]], dtype=float32)},
 'positional_embeddings': array([[-0.000477559, -0.00180093, -0.000761468, ..., 0.000485165, 0.000498768, 4.24828e-05],
        [-0.000913079, -0.000492609, -0.00233857, ..., 0.000624657, 0.000340673, -0.000682636],
        [7.30869e-05, -0.000349676, -0.000795952, ..., -7.43272e-05, 0.000984964, -0.000116955],
        ...,
        [-0.00122993, 0.000861806, -1.87649e-05, ..., 0.000915424, -0.00019698, 0.000572759],
        [0.000179105, -0.000335839, 0.000451862, ..., 0.00012408, 0.000157587, 0.000638091],
        [0, 0, 0, ..., 0, 0, 0]], dtype=float32)