In [1]:
# if you want to use cuda, you can specify the ID of the device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [2]:
import pandas as pd
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch
from tqdm.auto import tqdm, trange
import gc

def cleanup():
    """
    A helpful function to clean all cached batches.
    """
    gc.collect()
    torch.cuda.empty_cache()

Reading the inputs

In [3]:
df = pd.read_csv('../../data/input/dev.tsv', sep='\t')
toxic_inputs = df['toxic_comment'].tolist()

Loading the model. For the baseline we used `t5_model` from Sberbank-AI

In [7]:
base_model_name = 'sberbank-ai/ruT5-base'
model_name = 'SkolkovoInstitute/ruT5-base-detox'

In [8]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [9]:
model.cuda();

Paraphrasing preparation with small example

In [10]:
def paraphrase(text, model, n=None, max_length='auto', temperature=0.0, beams=3):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(model.device)
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10
    result = model.generate(
        inputs, 
        num_return_sequences=n or 1, 
        do_sample=False, 
        temperature=temperature, 
        repetition_penalty=3.0, 
        max_length=max_length,
        bad_words_ids=[[2]],  # unk
        num_beams=beams,
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]
    if not n and isinstance(text, str):
        return texts[0]
    return texts

In [11]:
print(paraphrase(['Дмитрий вы ебанулись, уже все выложено'], model, temperature=50.0, beams=10))

['Дмитрий вы с ума сошли, уже все выложено']


The inference

In [28]:
para_results = []
problematic_batch = [] #if something goes wrong you can track such bathces
batch_size = 8

for i in tqdm(range(0, len(toxic_inputs), batch_size)):
    batch = [sentence for sentence in toxic_inputs[i:i + batch_size]]
    try:
        para_results.extend(paraphrase(batch, model, temperature=0.0))
    except Exception as e:
        print(i)
        para_results.append(toxic_inputs[i:i + batch_size])

HBox(children=(IntProgress(value=0), HTML(value='')))




Saving the results

In [29]:
with open('../../data/output/t5_base_10000_dev.txt', 'w') as file:
    file.writelines([sentence+'\n' for sentence in para_results])