# Imports

In [1]:
import warnings

warnings.filterwarnings('ignore')

In [2]:
import transformers
import datasets
import random
import numpy as np
import pandas as pd

from typing import Callable, Any, Dict
from pathlib import Path
from datasets import load_dataset, load_metric
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from IPython.display import display, HTML

transformers.manual_seed(432)

In [3]:
%load_ext tensorboard

# Loading data

In [8]:
base_path = Path('..').resolve()
data_path = base_path / Path('data/raw/filtered.tsv')
model_cktp_path = base_path / 'models' / 'pretrained.pt'

In [9]:
AMOUNT_OF_PAIRS = 50000
EPOCHS = 15
VAL_RATIO = 0.1
MAX_LENGTH = 75
MIN_TOX = 0.75
MODEL_CHECKPOINT = "t5-small"
SEED=432

In [10]:
transformers.set_seed(42)
raw_data = pd.read_csv(data_path, sep='\t', index_col=False)
raw_data = raw_data[raw_data.columns[1:]]
raw_data.head()

Unnamed: 0,reference,translation,similarity,lenght_diff,ref_tox,trn_tox
0,"If Alkar is flooding her with psychic waste, t...","if Alkar floods her with her mental waste, it ...",0.785171,0.010309,0.014195,0.981983
1,Now you're getting nasty.,you're becoming disgusting.,0.749687,0.071429,0.065473,0.999039
2,"Well, we could spare your life, for one.","well, we can spare your life.",0.919051,0.268293,0.213313,0.985068
3,"Ah! Monkey, you've got to snap out of it.","monkey, you have to wake up.",0.664333,0.309524,0.053362,0.994215
4,I've got orders to put her down.,I have orders to kill her.,0.726639,0.181818,0.009402,0.999348


In [11]:
raw_data = raw_data[
    (raw_data['ref_tox'] >= MIN_TOX) &
    (raw_data['trn_tox'] <= 1 - MIN_TOX)
]

In [12]:
raw_data['id'] = pd.RangeIndex(0, len(raw_data))
train_split, val_split = train_test_split(
    range(raw_data[raw_data['id'] < AMOUNT_OF_PAIRS]['id'].max() + 1),
    test_size=VAL_RATIO,
    random_state=SEED
)
train_dataframe = raw_data[raw_data['id'].isin(train_split)]
val_dataframe = raw_data[raw_data['id'].isin(val_split)]

# Dataset

In [13]:
ModelTokenizer = transformers.models.t5.tokenization_t5_fast.T5TokenizerFast
model_tokenizer: ModelTokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

In [14]:
class DeToxificationDataset(Dataset):
    def __init__(self,
                 dataframe: pd.DataFrame,
                 tokenizer: ModelTokenizer,
                 max_length: int):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.reference = self.dataframe['reference'].values
        self.translation = self.dataframe['translation'].values

    def __getitem__(self, index) -> Dict[str, Any]:
        inputs = self.reference[index]
        targets = self.translation[index]
        model_inputs = self.tokenizer.__call__(inputs, max_length=self.max_length, truncation=True)
        labels = self.tokenizer.__call__(targets, max_length=self.max_length, truncation=True)    
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def __len__(self) -> int:
        return len(self.dataframe)

In [15]:
train_dataset = DeToxificationDataset(train_dataframe, model_tokenizer, MAX_LENGTH)
val_dataset = DeToxificationDataset(val_dataframe, model_tokenizer, MAX_LENGTH)

In [16]:
train_dataset[0]

{'input_ids': [27, 31, 51, 59, 3, 13366, 43, 3, 9, 861, 233, 3, 233, 4065, 8, 337, 6472, 9311, 38, 140, 113, 31, 7, 3, 13366, 67, 5, 301, 233, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [27, 31, 51, 59, 352, 12, 8885, 1082, 28, 3, 9, 6472, 9311, 24, 656, 135, 67, 5, 1]}

In [18]:
def random_sample(some_dataset: DeToxificationDataset, tokenizer: ModelTokenizer):
    idx = np.random.randint(0, len(some_dataset))
    model_inputs = some_dataset[idx]
    ref = tokenizer.batch_decode(
        model_inputs['input_ids'], 
        skip_special_tokens=True
    )
    trn = tokenizer.batch_decode(
        model_inputs['labels'], 
        skip_special_tokens=True
    )
    print(' '.join(ref))
    print(' '.join(trn))

random_sample(train_dataset, train_dataset.tokenizer)

Death . 
 he ' s dead . 


In [19]:
len(train_dataset), len(val_dataset)

(45000, 5000)

# Fine-tuning the model

In [20]:
metric = load_metric("sacrebleu")

In [21]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)

In [22]:
# defining the parameters for training
batch_size = 25
args = transformers.Seq2SeqTrainingArguments(
    f"{MODEL_CHECKPOINT}-finetuned-de-toxification",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=EPOCHS,
    predict_with_generate=True,
    fp16=True,
    report_to=['tensorboard'],
    seed=SEED,
)

In [23]:
data_collator = transformers.DataCollatorForSeq2Seq(model_tokenizer, model=model)

In [24]:
# simple postprocessing for text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

# compute metrics function to pass to trainer
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = model_tokenizer.batch_decode(preds, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, model_tokenizer.pad_token_id)
    decoded_labels = model_tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != model_tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [25]:
# instead of writing train loop we will use Seq2SeqTrainer
trainer = transformers.Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=model_tokenizer,
    compute_metrics=compute_metrics
)

In [22]:
trainer.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,2.1255,1.714027,23.6079,13.2658
2,1.8366,1.656125,24.5168,13.2156
3,1.773,1.627491,24.742,13.2044
4,1.7528,1.610212,24.9668,13.1194
5,1.7205,1.595325,25.2218,13.1142
6,1.7064,1.585188,25.3431,13.0936
7,1.6948,1.577094,25.4473,13.0886
8,1.6816,1.570604,25.5223,13.1
9,1.6722,1.565133,25.6639,13.0674
10,1.6634,1.562667,25.7071,13.0354


TrainOutput(global_step=13500, training_loss=1.7195449806495948, metrics={'train_runtime': 3867.0013, 'train_samples_per_second': 174.554, 'train_steps_per_second': 3.491, 'total_flos': 9233126075596800.0, 'train_loss': 1.7195449806495948, 'epoch': 15.0})

In [23]:
# saving model
trainer.save_model(model_cktp_path)

In [27]:
# loading the model and run inference for it
ModelType = transformers.models.t5.modeling_t5.T5ForConditionalGeneration
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_cktp_path)
model.eval()
model.config.use_cache = False

In [28]:
def de_toxification(model: ModelType, inference_request: str, tokenizer: ModelTokenizer) -> str:
    input_ids = tokenizer(inference_request, return_tensors="pt").input_ids
    outputs = model.generate(input_ids=input_ids)
    return tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0)

In [33]:
inference_request = 'Bob is stupid bastard!'
de_tox = de_toxification(model, inference_request, model_tokenizer)
print(de_tox)

Bob is a bad guy!
