<a href="https://colab.research.google.com/github/shitkov/courses/blob/master/transformers/transformers_shitkov_02_T5_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!wget https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/data/input/train.tsv
!wget https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/data/input/dev.tsv
!wget https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/data/input/test.tsv

!wget https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/evaluation/ru_detoxification_evaluation.py
!wget https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/evaluation/ru_detoxification_metrics.py

!pip install transformers sentencepiece

In [57]:
import gc

import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

from typing import Tuple, List, Dict, Union

from tqdm.auto import tqdm, trange

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import T5ForConditionalGeneration, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel

from ru_detoxification_evaluation import load_model
from ru_detoxification_metrics import evaluate_style
from ru_detoxification_metrics import evaluate_cosine_similarity
from ru_detoxification_metrics import evaluate_cola_relative

# Data

In [3]:
df = pd.read_csv('train.tsv', sep='\t', index_col='index')
df = df.fillna('')

In [4]:
df_train_toxic = []
df_train_neutral = []

for index, row in df.iterrows():
    references = row[['neutral_comment1', 'neutral_comment2', 'neutral_comment3']].tolist()
    
    for reference in references:
        if len(reference) > 0:
            df_train_toxic.append(row['toxic_comment'])
            df_train_neutral.append(reference)
        else:
            break

In [5]:
df = pd.DataFrame({
        'toxic_comment': df_train_toxic,
        'neutral_comment': df_train_neutral
    })

df = shuffle(df)

In [6]:
class PairsDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        assert idx < len(self.x['input_ids'])
        item = {key: val[idx] for key, val in self.x.items()}
        item['decoder_attention_mask'] = self.y['attention_mask'][idx]
        item['labels'] = self.y['input_ids'][idx]
        return item
    
    @property
    def n(self):
        return len(self.x['input_ids'])

    def __len__(self):
        return self.n # * 2

In [7]:
from typing import List, Dict, Union

class DataCollatorWithPadding:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            features,
            padding=True,
        )
        ybatch = self.tokenizer.pad(
            {'input_ids': batch['labels'], 'attention_mask': batch['decoder_attention_mask']},
            padding=True,
        ) 
        batch['labels'] = ybatch['input_ids']
        batch['decoder_attention_mask'] = ybatch['attention_mask']
        
        return {k: torch.tensor(v) for k, v in batch.items()}

# Utils

In [8]:
def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

In [9]:
def evaluate_model(model, test_dataloader):
    num = 0
    den = 0

    for batch in test_dataloader:
        with torch.no_grad():
            loss = model(**{k: v.to(model.device) for k, v in batch.items()}).loss
            num += len(batch) * loss.item()
            den += len(batch)
    val_loss = num / den
    return val_loss

In [10]:
def train_loop(
    model, train_dataloader, val_dataloader, 
    max_epochs=30, 
    max_steps=1_000, 
    lr=3e-5,
    gradient_accumulation_steps=1, 
    cleanup_step=100,
    report_step=300,
    window=100,
):
    cleanup()
    optimizer = torch.optim.Adam(params = [p for p in model.parameters() if p.requires_grad], lr=lr)

    ewm_loss = 0
    step = 0
    model.train()

    for epoch in trange(max_epochs):
        print(step, max_steps)
        if step >= max_steps:
            break
        tq = tqdm(train_dataloader)
        for i, batch in enumerate(tq):
            try:
                batch['labels'][batch['labels']==0] = -100
                loss = model(**{k: v.to(model.device) for k, v in batch.items()}).loss
                loss.backward()
            except Exception as e:
                print('error on step', i, e)
                loss = None
                cleanup()
                continue
            if i and i % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                step += 1
                if step >= max_steps:
                    break

            if i % cleanup_step == 0:
                cleanup()

            w = 1 / min(i+1, window)
            ewm_loss = ewm_loss * (1-w) + loss.item() * w
            tq.set_description(f'loss: {ewm_loss:4.4f}')

            if (i and i % report_step == 0 or i == len(train_dataloader)-1)  and val_dataloader is not None:
                model.eval()
                eval_loss = evaluate_model(model, val_dataloader)
                model.train()
                print(f'epoch {epoch}, step {i}/{step}: train loss: {ewm_loss:4.4f}  val loss: {eval_loss:4.4f}')
                
            if step % 1000 == 0:
                model.save_pretrained(f't5_base_{dname}_{steps}')
        
    cleanup()

In [11]:
def train_model(x, y, model_name, test_size=0.1, batch_size=32, **kwargs):
    model = T5ForConditionalGeneration.from_pretrained(model_name).cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    x1, x2, y1, y2 = train_test_split(x, y, test_size=test_size, random_state=42)
    train_dataset = PairsDataset(tokenizer(x1), tokenizer(y1))
    test_dataset = PairsDataset(tokenizer(x2), tokenizer(y2))
    
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)
    val_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)

    train_loop(model, train_dataloader, val_dataloader, **kwargs)
    return model

# Model

In [12]:
model_name = 'sberbank-ai/ruT5-base'

In [13]:
datasets = {
    'train': df
}

# Train

In [None]:
for steps in [300, 1000, 3000, 10000]:
    for dname, d in datasets.items():
        print(f'\n\n\n  {dname}  {steps} \n=====================\n\n')
        model = train_model(d['toxic_comment'].tolist(), d['neutral_comment'].tolist(), model_name=model_name, batch_size=16, max_epochs=1000, max_steps=steps)
        model.save_pretrained(f't5_base_{dname}_{steps}')

In [18]:
cleanup()

## Inference

In [15]:
model_name = '/content/t5_base_train_10000'

In [21]:
base_model_name = 'sberbank-ai/ruT5-base'

In [22]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

In [19]:
model = T5ForConditionalGeneration.from_pretrained(model_name, local_files_only=True)

In [None]:
model.cuda();

In [23]:
def paraphrase(text, model, n=None, max_length='auto', temperature=0.0, beams=3):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(model.device)
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10
    result = model.generate(
        inputs, 
        num_return_sequences=n or 1, 
        do_sample=False, 
        temperature=temperature, 
        repetition_penalty=3.0, 
        max_length=max_length,
        bad_words_ids=[[2]],  # unk
        num_beams=beams,
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]
    if not n and isinstance(text, str):
        return texts[0]
    return texts

In [25]:
df = pd.read_csv('dev.tsv', sep='\t')
toxic_inputs = df['toxic_comment'].tolist()

In [59]:
neutral_references = []
for index, row in df.iterrows():
    neutral_references.append([row['neutral_comment1'], row['neutral_comment2'], row['neutral_comment3']])

In [None]:
para_results = []
problematic_batch = [] #if something goes wrong you can track such bathces
batch_size = 8

for i in tqdm(range(0, len(toxic_inputs), batch_size)):
    batch = [sentence for sentence in toxic_inputs[i:i + batch_size]]
    try:
        para_results.extend(paraphrase(batch, model, temperature=0.0))
    except Exception as e:
        print(i)
        para_results.append(toxic_inputs[i:i + batch_size])

In [27]:
with open('t5_base_3000_dev.txt', 'w') as file:
    file.writelines([sentence+'\n' for sentence in para_results])

In [28]:
cleanup()

## Evaluate

In [33]:
use_cuda = True

In [39]:
with open('t5_base_3000_dev.txt', 'r') as file:
    preds = [line.rstrip() for line in file]

### Style Transfer Accuracy (STA)

In [34]:
style_model, style_tokenizer = load_model('SkolkovoInstitute/russian_toxicity_classifier', use_cuda=use_cuda)

Downloading:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/712M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/585 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.40M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [41]:
accuracy = evaluate_style(
    model = style_model,
    tokenizer = style_tokenizer,
    texts = preds,
    target_label=0,  # 1 is toxic, 0 is neutral
    batch_size=32, 
    verbose=True
)

  0%|          | 0/25 [00:00<?, ?it/s]

In [43]:
print(f'Style transfer accuracy (STA):  {np.mean(accuracy)}')

Style transfer accuracy (STA):  0.6835970282554626


In [45]:
cleanup()

### Meaning Preservation Score (SIM)

In [46]:
meaning_model, meaning_tokenizer = load_model('cointegrated/LaBSE-en-ru', use_cuda=use_cuda, model_class=AutoModel)

Downloading:   0%|          | 0.00/806 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/516M [00:00<?, ?B/s]

Some weights of the model checkpoint at cointegrated/LaBSE-en-ru were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/521k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [47]:
similarity = evaluate_cosine_similarity(
    model = meaning_model,
    tokenizer = meaning_tokenizer,
    original_texts = toxic_inputs,
    rewritten_texts = preds,
    batch_size=32,
    verbose=True,
    )

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

In [48]:
print(f'Meaning preservation (SIM):  {np.mean(similarity)}')

Meaning preservation (SIM):  0.799630880355835


In [50]:
cleanup()

### Fluency score (FL)

In [51]:
cola_model, cola_tolenizer = load_model('SkolkovoInstitute/rubert-base-corruption-detector', use_cuda=use_cuda)

Downloading:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/712M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/508 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.62M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [52]:
fluency = evaluate_cola_relative(
    model = cola_model,
    tokenizer = cola_tolenizer,
    original_texts = toxic_inputs,
    rewritten_texts = preds,
    target_label=1,
    batch_size=32,
    verbose=True
)

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

In [53]:
print(f'Fluency score (FL):  {np.mean(fluency)}')

Fluency score (FL):  0.7631889581680298


In [54]:
cleanup()

### Joint score (J)

In [55]:
joint = accuracy * similarity * fluency

In [56]:
print(f'Joint score (J):   {np.mean(joint)}')

Joint score (J):   0.414192795753479


In [63]:
for p in preds:
    if type(p) != str:
        print(p)