In [1]:
import argparse
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import torch
from sklearn.model_selection import train_test_split
from transformers import (MT5ForConditionalGeneration, MT5Tokenizer, Trainer,
                          TrainingArguments)

from utils import (ToxicDataset, fix_tokenizer, load_data, load_only_russian,
                   set_random_seed)
import pandas as pd

In [2]:
model = MT5ForConditionalGeneration.from_pretrained('google/mt5-base')
model = model.to(torch.device('cuda:0'))
tokenizer = MT5Tokenizer.from_pretrained('google/mt5-base')

In [3]:
class ToxicDataset(torch.utils.data.Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer):

        # assert part in ["ru", "en", "both"]
        self.data = data
        self.tokenizer = tokenizer

    def __getitem__(self, idx):

        source = self.tokenizer(
            self.data.iloc[idx].en_toxic_comment,
            max_length=60,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        target = self.tokenizer(
            self.data.iloc[idx].neutral_comment,
            max_length=60,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        source["labels"] = target["input_ids"]

        return {k: v.squeeze(0) for k, v in source.items()}

    def __len__(self):
        return self.data.shape[0]

In [4]:
set_random_seed(42)
data = pd.read_csv('data/english_data/jigsaw_twitter_reddit_7_translated.tsv', sep='\t')
train_ids, valid_ids = train_test_split(data.index.values, test_size=0.05, random_state=42)
train_part = data[['en_toxic_comment', 'neutral_comment']].loc[train_ids]
valid_part = data[['en_toxic_comment', 'neutral_comment']].loc[valid_ids]

trainset = ToxicDataset(train_part, tokenizer)
valset = ToxicDataset(valid_part, tokenizer)

In [5]:
train_args = TrainingArguments(
    output_dir='cross_lingual_mt5',
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="steps",
    logging_steps=100,
    max_steps=10_000,
    learning_rate=3e-5,
    seed=42,
    save_strategy="no"
)

In [6]:
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=trainset,
    eval_dataset=valset,
    tokenizer=tokenizer
)

max_steps is given, it will override any value given in num_train_epochs


In [7]:
trainer.train()

***** Running training *****
  Num examples = 18777
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 10000
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
