In [None]:
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AdamW, get_scheduler

from src.classification import SequenceClassificationTask
from src.dataloaders import MultitaskMetabatchSampler
from src.models import MultitaskModel
from src.preprocessing import Preprocessor, NLIPreprocess
from src.ru import preprocess_danetqa, preprocess_parus
from src.tasks import Tasks
from src.tokenizers import TokenizerConfig

In [None]:
rsg = "russian_super_glue"
cfg = TokenizerConfig(max_length=512)
encoder_path = "DeepPavlov/rubert-base-cased"

In [None]:
tasks = Tasks([
    SequenceClassificationTask(
        name="danetqa",
        dataset_dict=load_dataset(rsg, "danetqa"),
        preprocessor=Preprocessor([preprocess_danetqa]),
        tokenizer_config=cfg,
    ),
    SequenceClassificationTask(
        name="parus",
        dataset_dict=load_dataset(rsg, "parus"),
        preprocessor=Preprocessor([preprocess_parus]),
        tokenizer_config=cfg,
    ),
    SequenceClassificationTask(
        name="terra",
        dataset_dict=load_dataset(rsg, "terra"),
        preprocessor=Preprocessor([NLIPreprocess()]),
        tokenizer_config=cfg,
    ),
], encoder_path)

In [None]:
train_sampler = MultitaskMetabatchSampler(tasks.data, "train", batch_size=12)

In [None]:
model = MultitaskModel(encoder_path, tasks.heads)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 1
num_training_steps = num_epochs * len(train_sampler)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.train()

In [None]:
progress_bar = tqdm(range(num_training_steps))
for epoch_num in range(num_epochs):
    for meta_batch in train_sampler:
        loss = 0
        for name, data in meta_batch.items():
            data.to(device)
            outputs = model.forward(name, only_head = {"labels",}, **data)
            loss += outputs.loss
        print(f"Training loss: {loss}")
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)