In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [3]:
import sys

In [4]:
sys.path.append(os.path.abspath('..'))

# 1. Translate the English training set into Russian

In [6]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange

In [9]:
parts = []
for split in ['train', 'dev', 'test']:
    df = pd.read_csv(f'../data/russian_data/{split}.tsv', sep='\t')
    df['split'] = split
    parts.append(df)

In [10]:
detox_ru = pd.concat(parts, ignore_index=True)

In [11]:
detox_ru

Unnamed: 0,idx,toxic_comment,neutral_comment,toxicity_score,dataset,toxic,confidence_toxic,is_match,confidence_is_match,split
0,6142,"и,чё,блядь где этот херой был до этого со свои...","Ну и где этот герой был,со своими доказательст...",0.999007,ru,False,0.9999,True,0.9985,train
1,9210,ебанько из какого ебонария тебя выпустили???😂😂😂,Откуда ты взялся такой?,0.999326,ru,False,0.9999,True,0.9972,train
2,8692,"херну всякую пишут,из-за этого лайка.долбоебизм.","Чушь всякую пишут, из- за этого лайка.",0.998891,ru,False,0.9999,True,0.9964,train
3,10814,в гсвг за такие сапоги пиздюлей получил бы от ...,В ГСВГ за такие сапоги наказали бы сослуживцы,0.998883,ru,False,0.9999,True,0.9950,train
4,7915,какой дебил эту придумывает 🤣,Кто это придумывает?,0.921756,ru,False,0.9999,True,0.9950,train
...,...,...,...,...,...,...,...,...,...,...
7053,6522,"дебил,за решотку тебя и лишить всего!",Открыть бы на тебя судебное дело,0.995274,ru_test,False,0.9987,True,0.9777,test
7054,17248,попытка вы*бнуться? не вы одни не пьете. ждете...,Почему вы хвастаетесь? не вы одни не употребля...,0.998892,ru,False,0.9997,True,0.9800,test
7055,11330,вы нытики заколебали.едь в чехию.скажешь за до...,"Езжайте в Чехию и потом будете говорить , про ...",0.947868,ru,False,0.9999,True,0.9166,test
7056,67,в глазах чужого государства они конечно красав...,В глазах другого государства они конечно краса...,0.994028,ru,False,0.9341,True,0.9727,test


In [12]:
detox_ru.split.value_counts()

train    5058
dev      1000
test     1000
Name: split, dtype: int64

In [13]:
pd.options.display.max_colwidth = 200

In [14]:
detox_ru.shape

(7058, 10)

In [15]:
import nltk
nltk.download('punkt')
from nltk import sent_tokenize

[nltk_data] Downloading package punkt to /home/dale/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [16]:
import re

def split_by_symbol(text, symbol=',', max_len=400):
    if len(text) <= max_len:
        return [text]
    chunks = re.split(symbol, text)
    if len(chunks) <= 1:
        return [text]
    result = [chunks[0]]
    for chunk in chunks[1:]:
        result.append(symbol)
        result.append(chunk)
    return result

def join_texts(texts, max_len=400):
    result = []
    prev_text = ''
    for text in texts:
        if len(text) + len(prev_text) > max_len:
            result.append(prev_text)
            prev_text = text
        else:
            prev_text = prev_text + text
    result.append(prev_text)
    return result

def hard_split(text, max_len=300):
    parts = list(sent_tokenize(text))
    result = []
    for part in parts:
        chunks = [part]
        for symbol in [',', '-', ' ']:
            chunks = [c2 for c in chunks for c2 in split_by_symbol(c, symbol, max_len=max_len)]
        result.extend(chunks)
    result = join_texts(result, max_len=max_len)
    return result

In [23]:
'''
How to obtain a fresh SID:
* go to translate.yandex.ru
* open the "network" panel of the developers console
* enter any text in the translation form
* find the request to "https://translate.yandex.net/api/v1/tr.json/translate" and copy its first parameter ("id")
'''

import requests

SID = 'd893eb46.629f444a.4b7f3984.74722d74657874-1-0'

def translate_yandex(search_str, direction='en-ru', full_response=False):
    try:
        url = f'https://translate.yandex.net/api/v1/tr.json/translate?id={SID}&srv=tr-text&lang={direction}&reason=auto&format=text'

        post_header = {}
        post_header['Accept'] = '*/*'
        post_header['Accept-Encoding'] = 'gzip, deflate'
        post_header['Accept-Language'] = 'en-US,en;q=0.9'
        post_header['Cache-Control'] = 'no-cache'
        post_header['Connection'] = 'keep-alive'
        post_header['Content-Type'] = 'application/x-www-form-urlencoded'
        post_header['Host'] = 'translate.yandex.com'
        post_header['Referer'] = 'https://translate.yandex.com/'
        post_header['User-Agent'] = 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 YaBrowser/21.8.2.383 Yowser/2.5 Safari/537.36'


        data_payload = {'text': search_str, 'options': '4'}

        resp = requests.post(url, headers=post_header, data=data_payload).json()
        if full_response:
            return resp

        if resp.get('message') == 'The text size exceeds the maximum':
            parts = hard_split(search_str)
            if len(parts) > 1:
                return 200, ' '.join([translate_yandex(part, dir=dir)[1] for part in parts])

        return resp['code'], resp['text'][0]
    except Exception as e:
        print(resp)
        return 0, ''

In [24]:
translate_yandex('go to hell')

(200, 'иди к черту')

In [25]:
detox_ru.toxic_comment[:5].tolist()

['и,чё,блядь где этот херой был до этого со своими доказательствами?',
 'ебанько из какого ебонария тебя выпустили???😂😂😂',
 'херну всякую пишут,из-за этого лайка.долбоебизм.',
 'в гсвг за такие сапоги пиздюлей получил бы от сослуживцев',
 'какой дебил эту придумывает 🤣']

In [26]:
[translate_yandex(t, 'ru-en')[1] for t in detox_ru.toxic_comment[:5]]

['and, what, where the fuck was this dick before with his evidence?',
 'what fucking hospital did they let you out of???😂😂😂',
 'they write all sorts of bullshit, because of this like.motherfucking.',
 'in gsvg , I would have received a pussy from my colleagues for such boots',
 'what kind of moron comes up with this one 🤣']

In [27]:
[translate_yandex(t, 'ru-en')[1] for t in detox_ru.neutral_comment[:5]]

['Well, where was this hero, with his evidence?',
 'Where did you come from?',
 'They write all sorts of nonsense, because of this like.',
 'In the GSVG, colleagues would be punished for such boots',
 "Who's coming up with this?"]

In [29]:
toxic_en = []
neutral_en = []
for i in trange(0, len(detox_ru)):
    toxic_en.append(translate_yandex(detox_ru.toxic_comment[i], 'ru-en')[1])
    neutral_en.append(translate_yandex(detox_ru.neutral_comment[i], 'ru-en')[1])

  0%|          | 0/7058 [00:00<?, ?it/s]

In [30]:
detox_ru['toxic_en'] = toxic_en
detox_ru['neutral_en'] = neutral_en

In [31]:
detox_ru.sample(5)

Unnamed: 0,idx,toxic_comment,neutral_comment,toxicity_score,dataset,toxic,confidence_toxic,is_match,confidence_is_match,split,toxic_en,neutral_en
4481,16602,дима петровский пидор ты гнойный,дима петровский не хороший человек,0.999302,ru,False,0.9999,True,0.9915,train,Dima Petrovsky You're a purulent faggot,Dima Petrovsky is not a good person
1373,7266,уроды....он один из лучших полководцев в истории...жаль,Жалко одного из лучших полководцев в истории,0.986603,ru_test,False,0.9996,True,0.9379,train,freaks....he is one of the best generals in history...sorry,I feel sorry for one of the best generals in history
1239,8643,зачем с таким жить тварь. расстрелять,Зачем с таким жить?,0.99934,ru,False,0.9955,True,0.9069,train,why live with such a creature. shoot,Why live with this?
6975,16053,"ни хуя себе чучело, а где главный?????",А где главный?,0.948019,ru,False,0.9999,True,0.96,test,"not a fucking scarecrow, but where is the main one?????",And where is the main one?
6631,12592,о долбоёб проснулся иди на укол тебе чекають,о проснулся иди на укол тебе чекають,0.999194,ru,False,0.9911,True,0.9794,test,"oh, the fucker woke up, go for an injection to check you",oh woke up go for a shot you check out


In [32]:
detox_ru.to_csv('detox_ru2en_yandex.tsv', index=None, sep='\t')

# Start from here

In [34]:
(detox_ru.toxic_en == detox_ru.neutral_en).mean()

0.006942476622272598

In [35]:
from textdistance import levenshtein

In [37]:
detox_ru['edit_distance_ru'] = [levenshtein.distance(*row) for row in detox_ru[['toxic_comment', 'neutral_comment']].values]
detox_ru['edit_distance_en'] = [levenshtein.distance(*row) for row in detox_ru[['toxic_en', 'neutral_en']].values]

detox_ru['edit_sim_ru'] = [levenshtein.normalized_similarity(*row) for row in detox_ru[['toxic_comment', 'neutral_comment']].values]
detox_ru['edit_sim_en'] = [levenshtein.normalized_similarity(*row) for row in detox_ru[['toxic_en', 'neutral_en']].values]

In [38]:
detox_ru.describe()

Unnamed: 0,idx,toxicity_score,confidence_toxic,confidence_is_match,edit_distance_ru,edit_distance_en,edit_sim_ru,edit_sim_en
count,7058.0,7058.0,7058.0,7058.0,7058.0,7058.0,7058.0,7058.0
mean,8509.176821,0.984075,0.984749,0.972116,22.059365,29.011193,0.639245,0.599037
std,6270.499049,0.030714,0.023933,0.024514,15.346834,18.704621,0.202165,0.202235
min,1.0,0.800721,0.9003,0.9,0.0,0.0,0.030303,0.020408
25%,3351.25,0.98622,0.9799,0.9601,11.0,15.0,0.510204,0.448276
50%,6819.0,0.996124,0.9972,0.9811,18.0,25.0,0.677966,0.62069
75%,13514.0,0.999306,0.9999,0.9906,29.0,38.0,0.8,0.761905
max,21360.0,0.999356,0.9999,0.9999,132.0,172.0,1.0,1.0


In [40]:
detox_ru2en = detox_ru

In [41]:
detox_ru2en.to_csv('detox_ru2en_yandex.tsv', sep='\t', index=None)

# 2. Train the Russian model 

In [4]:
import pandas as pd

In [42]:
detox_ru2en = pd.read_csv('detox_ru2en_yandex.tsv', sep='\t')

In [43]:
from datasets import Dataset, DatasetDict

In [44]:
from sklearn.model_selection import train_test_split

Filter text pairs by similarity to escape translation artifacts

In [48]:
detox_ru2en_filtered = detox_ru2en[
    ((detox_ru2en.edit_distance_en >= detox_ru2en.edit_distance_ru.quantile(0.01)) 
    & (detox_ru2en.edit_distance_en <= detox_ru2en.edit_distance_ru.quantile(0.99)) 
    & (detox_ru2en.edit_sim_en >= detox_ru2en.edit_sim_ru.quantile(0.01)) 
    & (detox_ru2en.edit_sim_en <= detox_ru2en.edit_sim_ru.quantile(0.99))
    )
]

print(detox_ru2en.shape)
print(detox_ru2en_filtered.shape)

(7058, 16)
(6739, 16)


In [54]:
train = detox_ru2en_filtered[detox_ru2en_filtered.split=='train']
val = detox_ru2en_filtered[detox_ru2en_filtered.split=='dev']
test = detox_ru2en_filtered[detox_ru2en_filtered.split=='test']

In [55]:
raw_data = DatasetDict({
    'train': Dataset.from_dict({'text': train.toxic_en, 'target': train.neutral_en}),
    'dev': Dataset.from_dict({'text': val.toxic_en, 'target': val.neutral_en}),
})
raw_data

DatasetDict({
    train: Dataset({
        features: ['text', 'target'],
        num_rows: 4825
    })
    dev: Dataset({
        features: ['text', 'target'],
        num_rows: 961
    })
})

### Compute and evaluate baselines

In [51]:
from sacrebleu import CHRF
chrfpp = CHRF(word_order=2)

Baseline chrf++: 60% if not change the texts. 

In [57]:
chrfpp.corpus_score(val.toxic_en.tolist(), [val.neutral_en.tolist()]).score

64.26783177670117

A baseline that removes bad words

In [58]:
from nltk import wordpunct_tokenize
from collections import Counter

def detokenize(text):
    for symbol in ",.?!'":
        text = text.replace(' ' + symbol, symbol)
    return text

class Remover:
    def __init__(self, ratio_threshold=2):
        self.ratio_threshold = ratio_threshold
    def fit(self, x, y):
        self.x_count = Counter(w.lower() for text in x for w in wordpunct_tokenize(text))
        self.y_count = Counter(w.lower() for text in y for w in wordpunct_tokenize(text))
    def predict(self, x):
        results = []
        for text in x:
            words = []
            for w in wordpunct_tokenize(text):
                key = w.lower()
                if (self.x_count[key] + 1) / (self.y_count[key] + 1) > self.ratio_threshold:
                    continue
                words.append(w)
            results.append(detokenize(' '.join(words)))
        return results

In [60]:
remover = Remover(2.0)
remover.fit(train.toxic_en, train.neutral_en)

In [61]:
chrfpp.corpus_score(remover.predict(val.toxic_en), [val.neutral_en.tolist()]).score

64.26938732564818

A simple word-based translation

In [62]:
from collections import deque
from itertools import product
from tqdm.auto import tqdm

# https://johnlekberg.com/blog/2020-10-25-seq-align.html


def needleman_wunsch(x, y, sim=None, verbose=False):
    """Run the Needleman-Wunsch algorithm on two sequences.

    x, y -- sequences.

    Code based on pseudocode in Section 3 of:

    Naveed, Tahir; Siddiqui, Imitaz Saeed; Ahmed, Shaftab.
    "Parallel Needleman-Wunsch Algorithm for Grid." n.d.
    https://upload.wikimedia.org/wikipedia/en/c/c4/ParallelNeedlemanAlgorithm.pdf
    """
    N, M = len(x), len(y)
    if sim is None:
        s = lambda a, b: int(a == b)
    else:
        s = sim

    DIAG = -1, -1
    LEFT = -1, 0
    UP = 0, -1

    # Create tables F and Ptr
    F = {}
    Ptr = {}

    F[-1, -1] = 0
    for i in range(N):
        F[i, -1] = -i
    for j in range(M):
        F[-1, j] = -j

    option_Ptr = DIAG, LEFT, UP
    for i, j in product(range(N), range(M)):
        option_F = (
            F[i - 1, j - 1] + s(x[i], y[j]),
            F[i - 1, j] - 1,
            F[i, j - 1] - 1,
        )
        F[i, j], Ptr[i, j] = max(zip(option_F, option_Ptr))

    # Work backwards from (N - 1, M - 1) to (0, 0)
    # to find the best alignment.
    alignment = deque()
    i, j = N - 1, M - 1
    if verbose:
        tq = tqdm(total=max(N, M))
        
    while i >= 0 and j >= 0:
        direction = Ptr[i, j]
        if direction == DIAG:
            element = i, j
        elif direction == LEFT:
            element = i, None
        elif direction == UP:
            element = None, j
        alignment.appendleft(element)
        di, dj = direction
        i, j = i + di, j + dj
    while i >= 0:
        alignment.appendleft((i, None))
        i -= 1
    while j >= 0:
        alignment.appendleft((None, j))
        j -= 1

    return list(alignment)

In [63]:
from collections import Counter, defaultdict
import nltk
from tqdm.auto import tqdm, trange
import numpy as np


class TextReplacer:
    def __init__(self, max_n=3, smooth_n=10, min_n=10, min_p=0.01):
        self.max_n = max_n
        self.smooth_n = smooth_n
        self.min_n = min_n
        self.min_p = min_p
        
        self.replace_proba = {}
        self.replaced_tuples = {}
        
    def tokenize(self, text):
        return nltk.wordpunct_tokenize('_bos_ ' + text + ' _eos_')
    
    def detokenize(self, text):
        text = text.strip()
        for symbol in '.,?!':
            text = text.replace(' ' + symbol, symbol)
        if text.startswith('_bos_'):
            text = text[5:]
        if text.endswith('_eos_'):
            text = text[:-5]
        return text.strip()
    
    def fit(self, x_train, y_train):
        raw_counts = Counter()
        replace_counts = Counter()
        
        for i in trange(len(x_train)):
            xx, yy = x_train[i], y_train[i]
            xx, yy = self.tokenize(xx), self.tokenize(yy)
            alignment = needleman_wunsch(xx, yy)
            ixx, iyy = list(zip(*alignment))
            for gram_size in range(1, self.max_n + 1):
                for start in range(len(ixx) - gram_size + 1):
                    xgram = tuple([xx[c] for c in ixx[start: start + gram_size] if c is not None])
                    ygram = tuple([yy[c] for c in iyy[start: start + gram_size] if c is not None])
                    if xgram:
                        xg, yg = ' '.join([''] + list(xgram) + ['']), ' '.join([''] + list(ygram) + [''])
                        raw_counts[xg] += 1
                        if xgram != ygram:
                            replace_counts[(xg, yg)] += 1
    
        self.replace_proba = defaultdict(list)
        self.replaced_tuples = dict()

        for pair, n_sub in replace_counts.most_common():
            if n_sub >= self.min_n:
                xx, yy = pair
                pr = n_sub / (self.smooth_n + raw_counts[xx])
                if pr >= self.min_p:
                    self.replace_proba[xx].append([yy, pr])
                    self.replaced_tuples[tuple(xx.strip().split())] = raw_counts[xx]

        for k, v in self.replace_proba.items():
            tot = sum(p for r, p in v)
            if tot < 1:
                v.append([k, 1 - tot])
        
        return self

    def transform_one(self, text, n_out=None, temperature=None):
        xx = self.tokenize(text)
        found_grams = []
        for gram_size in range(1, self.max_n + 1):
            for start in range(len(xx) - gram_size + 1):
                xgram = tuple([c for c in xx[start: start + gram_size] if c is not None])
                if xgram and xgram in self.replaced_tuples:
                    found_grams.append((xgram, self.replaced_tuples[xgram], len(xgram)))
        found_grams = sorted(found_grams, key=lambda x: (x[2], x[1]), reverse=True)
        
        results = []
        for i in range(n_out or 1):
            untext = ' '.join([''] + xx + [''])
            for gram, gn, gl in found_grams:
                gram_text = ' '.join([''] + list(gram) + [''])
                reps, ww = zip(*self.replace_proba[gram_text])
                if not temperature:
                    chosen_rep = list(reps)[np.argmax(ww)]
                else: # chose randomly
                    weights = [w ** (1 / temperature) for w in ww]
                    chosen_rep = random.choices(list(reps), weights=weights)[0]
                untext = untext.replace(gram_text, chosen_rep)
            results.append(self.detokenize(untext))
        if not n_out:
            return results[0]
        return results
    
    def transform(self, texts, n_out=None, temperature=None):
        return [self.transform_one(text, n_out=n_out, temperature=temperature) for text in tqdm(texts)]

In [64]:
replacer = TextReplacer(min_n=3)
replacer.fit(train.toxic_en.tolist(), train.neutral_en.tolist())

  0%|          | 0/4825 [00:00<?, ?it/s]

<__main__.TextReplacer at 0x7f2acd0e8c10>

In [65]:
chrfpp.corpus_score(replacer.transform(val.toxic_en), [val.neutral_en.tolist()]).score

  0%|          | 0/961 [00:00<?, ?it/s]

64.33395440698264

### Prepare test methods

In [66]:
def paraphrase(
    text, model, tokenizer, n=None, max_length="auto", beams=5,
):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(
        model.device
    )

    if max_length == "auto":
        max_length = inputs.shape[1] + 10

    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=10.0,
        max_length=max_length,
        min_length=int(0.5 * max_length),
        num_beams=beams,
        #forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return texts[0]
    return texts[0]

In [68]:
with open('../data/english_data/test_toxic_parallel.txt', 'r') as f:
    test_inputs = [line.strip() for line in f.readlines()]
    
test_inputs[:5]

['. or the loud ass one - thousand ton beast roaring towards you howling its horn .',
 'mandated  and " right fucking now " would be good .',
 '* neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .',
 '* well shit , cunt shot himself .',
 "&gt i wouldn 't care how ignorant you are in you weren 't pretending to know shit ."]

## Fine-tune mBART

In [69]:
base_model = 'facebook/mbart-large-50'

In [70]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [71]:
model = AutoModelForSeq2SeqLM.from_pretrained(base_model)# .cuda()
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [72]:
prefix = ""

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, padding=True)
    labels = tokenizer(examples["target"], padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [73]:
tok_data = raw_data.map(preprocess_function, batched=True)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [74]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [75]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="/home/dale/models/detox-parallel/translate-ru2en-full-mbart",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1, # 8 is too much 
    weight_decay=1e-5,
    max_steps=10_000,
    learning_rate=1e-5,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=1,
    eval_steps=500, 
    save_steps=500,
    logging_steps=500,
    load_best_model_at_end=True,
    # trying to save memory: see https://huggingface.co/docs/transformers/performance
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    gradient_accumulation_steps=1,
)

In [76]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tok_data["train"],
    eval_dataset=tok_data["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


In [77]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: target, text. If target, text are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 4825
  Num Epochs = 17
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 10000
  args.max_grad_norm,


Step,Training Loss,Validation Loss
500,2.5715,0.384305
1000,0.4143,0.348551
1500,0.3502,0.350092
2000,0.2878,0.378403
2500,0.2463,0.376747
3000,0.2008,0.374015
3500,0.1613,0.405448
4000,0.1382,0.4394
4500,0.1153,0.459208
5000,0.0972,0.493408


The following columns in the evaluation set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: target, text. If target, text are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 961
  Batch size = 1
Saving model checkpoint to /home/dale/models/detox-parallel/translate-ru2en-full-mbart/checkpoint-500
Configuration saved in /home/dale/models/detox-parallel/translate-ru2en-full-mbart/checkpoint-500/config.json
Model weights saved in /home/dale/models/detox-parallel/translate-ru2en-full-mbart/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /home/dale/models/detox-parallel/translate-ru2en-full-mbart/checkpoint-500/tokenizer_config.json
Special tokens file saved in /home/dale/models/detox-parallel/translate-ru2en-full-mbart/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set don't have a corresponding ar

TrainOutput(global_step=10000, training_loss=0.2542056949615479, metrics={'train_runtime': 4164.9986, 'train_samples_per_second': 19.208, 'train_steps_per_second': 2.401, 'total_flos': 1.0560774671204352e+16, 'train_loss': 0.2542056949615479, 'epoch': 16.56})

In [51]:
from tqdm.auto import tqdm, trange

In [78]:
preds = []
model.eval()
for text in tqdm(val.toxic_en):
    with torch.inference_mode():
        out = tokenizer.decode(
            model.generate(**tokenizer(text, return_tensors='pt').to(model.device), num_beams=5, max_length=256)[0], 
            skip_special_tokens=True,
        )
        preds.append(out)
        
print(chrfpp.corpus_score(preds, [val.neutral_en.tolist()]).score)

  0%|          | 0/961 [00:00<?, ?it/s]

65.12601361802126


In [79]:
val.toxic_en[:5]

5058                                                                                         what kind of scum from the duma offered old people pasta
5059                                                                                                               pizdobol you don't touch my mother
5060                                                                                                                   lucky idiots also justify them
5061                       and this one climbs into deputies, well, if not an alcoholic, then a faggot would be a disgrace to sing his fucking songs.
5062    the creatures are not people let her go she will be alive and if you don't need her don't start a dog I love dogs there is a dog in the house
Name: toxic_en, dtype: object

In [80]:
preds[:5]

['What kind of person from the duma offered old people pasta',
 "You don't touch my mother",
 'Lucky people also justify them',
 'and this one climbs into deputies, well, if not an alcoholic, then it would be a disgrace to sing his songs.',
 "People let her go she will be alive and if you don't need her don't start a dog I love dogs there is a dog in the house"]

In [81]:
test_outputs = [paraphrase(text, model, tokenizer) for text in tqdm(test_inputs)]

  0%|          | 0/671 [00:00<?, ?it/s]

In [82]:
with open('../results/translate-train_yandex-full-mbart/results_en.txt', 'w') as f:
    for text in test_outputs:
        f.write(text+'\n')

### Fine-tune mBART with BOTH English (original) and Russian (translated) data

In [83]:
raw_data = DatasetDict({
    'train': Dataset.from_dict({
        'text': train.toxic_en.tolist() + train.toxic_comment.tolist(), 
        'target': train.neutral_en.tolist() + train.neutral_comment.tolist()}),
    'dev': Dataset.from_dict({'text': val.toxic_en, 'target': val.neutral_en}),
})
raw_data

DatasetDict({
    train: Dataset({
        features: ['text', 'target'],
        num_rows: 9650
    })
    dev: Dataset({
        features: ['text', 'target'],
        num_rows: 961
    })
})

In [95]:
import random
random.choice(raw_data['train'])

{'text': "don't lie, miserable, 100 thousand came to white (not a fact! yellow house) in Washington!",
 'target': '100,000 people came to the White House in Washington.'}

In [96]:
base_model = 'facebook/mbart-large-50'

In [97]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(base_model)# .cuda()
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [99]:
prefix = ""

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, padding=True)
    labels = tokenizer(examples["target"], padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [100]:
tok_data = raw_data.map(preprocess_function, batched=True)

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [101]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [102]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="/home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1, # 8 is too much 
    weight_decay=1e-5,
    max_steps=10_000,
    learning_rate=1e-5,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=1,
    eval_steps=500, 
    save_steps=500,
    logging_steps=500,
    load_best_model_at_end=True,
    # trying to save memory: see https://huggingface.co/docs/transformers/performance
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    gradient_accumulation_steps=1,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [103]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tok_data["train"],
    eval_dataset=tok_data["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


In [104]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: target, text. If target, text are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 9650
  Num Epochs = 9
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 10000
  args.max_grad_norm,


Step,Training Loss,Validation Loss
500,2.7405,0.399959
1000,0.4602,0.362033
1500,0.3777,0.344698
2000,0.3401,0.337635
2500,0.325,0.340251
3000,0.2673,0.341592
3500,0.2657,0.344094
4000,0.2283,0.36349
4500,0.2159,0.35397
5000,0.198,0.37944


The following columns in the evaluation set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: target, text. If target, text are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 961
  Batch size = 1
Saving model checkpoint to /home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart/checkpoint-500
Configuration saved in /home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart/checkpoint-500/config.json
Model weights saved in /home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart/checkpoint-500/tokenizer_config.json
Special tokens file saved in /home/dale/models/detox-parallel/translate-ru2en_yandex-full_bilingual-mbart/checkpoint-500/special_toke

TrainOutput(global_step=10000, training_loss=0.33632232246398924, metrics={'train_runtime': 4184.5054, 'train_samples_per_second': 19.118, 'train_steps_per_second': 2.39, 'total_flos': 1.0677778515296256e+16, 'train_loss': 0.33632232246398924, 'epoch': 8.29})

In [105]:
from tqdm.auto import tqdm, trange

In [106]:
preds = []
model.eval()
for text in tqdm(val.toxic_en):
    with torch.inference_mode():
        out = tokenizer.decode(
            model.generate(**tokenizer(text, return_tensors='pt').to(model.device), num_beams=5, max_length=256)[0], 
            skip_special_tokens=True,
        )
        preds.append(out)

  0%|          | 0/961 [00:00<?, ?it/s]

In [107]:
print(chrfpp.corpus_score(preds, [val.neutral_en.tolist()]).score)

65.54680919764111


In [108]:
val.toxic_en[:5]

5058                                                                                         what kind of scum from the duma offered old people pasta
5059                                                                                                               pizdobol you don't touch my mother
5060                                                                                                                   lucky idiots also justify them
5061                       and this one climbs into deputies, well, if not an alcoholic, then a faggot would be a disgrace to sing his fucking songs.
5062    the creatures are not people let her go she will be alive and if you don't need her don't start a dog I love dogs there is a dog in the house
Name: toxic_en, dtype: object

In [109]:
preds[:5]

['what kind of person from the duma offered old people pasta',
 "You don't touch my mother",
 'Lucky people also justify their actions',
 'And this one climbs into deputies, well, if not an alcoholic, then it would be a disgrace to sing his songs',
 "Let her go she will be alive and if you don't need her don't start a dog I love dogs there is a dog in the house"]

In [110]:
test_outputs = [paraphrase(text, model, tokenizer) for text in tqdm(test_inputs)]

  0%|          | 0/671 [00:00<?, ?it/s]

In [111]:
with open('../results/translate-train_yandex-full_bilingual-mbart/results_en.txt', 'w') as f:
    for text in test_outputs:
        f.write(text+'\n')

# 3. Evaluation results

```
cd /home/dale/projects/paradetox2/evaluation_detox
python metric.py --inputs /home/dale/projects/multilingual_detox/data/english_data/test_toxic_parallel.txt \
    --preds /home/dale/projects/multilingual_detox/results/translate-train_yandex-full-mbart/results_en.txt \
    --cola_classifier_path /home/dale/models/cola_classifier_fairseq \
    --wieting_model_path /home/dale/models/wieting_similarity/sim.pt \
    --wieting_tokenizer_path /home/dale/models/wieting_similarity/sim.sp.30k.model \
    --batch_size 32
cat results.md
```

| Model | ACC | EMB_SIM | SIM | CharPPL | TokenPPL | FL | GM | J | BLEU |
| ----- | --- | ------- | --- | ------- | -------- | -- | -- | - | ---- |
results_en.txt|0.6528|0.8848|0.8660|6.3930|146.1058|0.8823|0.0000|0.4681|0.7305|

```
cd /home/dale/projects/paradetox2/evaluation_detox
python metric.py --inputs /home/dale/projects/multilingual_detox/data/english_data/test_toxic_parallel.txt \
    --preds /home/dale/projects/multilingual_detox/results/translate-train_yandex-full_bilingual-mbart/results_en.txt \
    --cola_classifier_path /home/dale/models/cola_classifier_fairseq \
    --wieting_model_path /home/dale/models/wieting_similarity/sim.pt \
    --wieting_tokenizer_path /home/dale/models/wieting_similarity/sim.sp.30k.model \
    --batch_size 32
cat results.md
```

| Model | ACC | EMB_SIM | SIM | CharPPL | TokenPPL | FL | GM | J | BLEU |
| ----- | --- | ------- | --- | ------- | -------- | -- | -- | - | ---- |
results_en.txt|0.7765|0.8657|0.8229|5.9643|97.8173|0.9031|6.4201|0.5566|0.6886|