In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
import sys
sys.path.append(os.path.abspath('..'))

# Augment the English training set

In [22]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange

In [3]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [4]:
model_name = '/home/dale/models/detox-parallel/mbart_5000_EN'
tokenizer = AutoTokenizer.from_pretrained('facebook/mbart-large-50')
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cuda:0')

In [88]:
def paraphrase(
    text, model, tokenizer, 
    n=None, 
    max_length="auto", 
    min_length='auto',
    beams=5,
    repetition_penalty=16.0,
):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(
        model.device
    )

    if max_length == "auto":
        max_length = int(inputs.shape[1] * 1.2 + 10 ) 
    if min_length == 'auto':
        for i in range(inputs.shape[1]):
            min_length = i + 1
            if (inputs[:, i] == tokenizer.eos_token_id).any().item():
                break
    min_length = int(min_length * 0.5 + 1) 

    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=repetition_penalty,
        max_length=max_length,
        min_length=min_length,
        num_beams=beams,
    )
    results = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return results[0]
    return results

In [83]:
paraphrase('fuck this', model, tokenizer)

"I don't like this."

In [8]:
twitter_unmarked = pd.read_csv('/home/dale/data/toxic_corpora/en-parallel/input_twitter_unmarked.csv')
print(twitter_unmarked.shape)
print(twitter_unmarked.columns.tolist())

(82691, 6)
['Unnamed: 0', 'sentence', 'dataset', 'toxicity_score', 'iteration', 'length']


In [9]:
jigsaw_unmarked = pd.read_csv('/home/dale/data/toxic_corpora/en-parallel/input_jigsaw_unmarked.csv')
print(jigsaw_unmarked.shape)
print(jigsaw_unmarked.columns.tolist())

(97968, 6)
['Unnamed: 0', 'sentence', 'dataset', 'toxicity_score', 'iteration', 'length']


In [10]:
reddit_unmarked = pd.read_csv('/home/dale/data/toxic_corpora/en-parallel/input_reddit_unmarked.csv')
print(reddit_unmarked.shape)
print(reddit_unmarked.columns.tolist())

(232347, 6)
['Unnamed: 0', 'sentence', 'dataset', 'toxicity_score', 'iteration', 'length']


In [11]:
all_unmarked = pd.concat([twitter_unmarked, jigsaw_unmarked, reddit_unmarked]).reset_index(drop=True)
print(all_unmarked.shape)
all_unmarked.groupby('dataset').toxicity_score.describe()

(413006, 6)


Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
jigsaw,97968.0,0.984946,0.030152,0.800043,0.987282,0.99657,0.998868,0.999677
reddit,232347.0,0.991976,0.016702,0.800127,0.992619,0.996761,0.998475,0.99967
twitter,82691.0,0.991451,0.018476,0.800028,0.992899,0.996755,0.998302,0.999658


In [15]:
detox_en_train = pd.read_csv('../data/english_data/en_data.csv', sep='\t')

In [16]:
with open('../data/english_data/test_toxic_parallel.txt', 'r') as f:
    test_inputs = [line.strip() for line in f.readlines()]

In [17]:
already = set(test_inputs).union(set(detox_en_train.toxic_comment))
print(len(already))

12598


In [18]:
all_unmarked = all_unmarked[all_unmarked.sentence.apply(lambda x: x not in already)]
print(all_unmarked.shape)

(412540, 6)


In [20]:
unmarked_balanced = all_unmarked.groupby('dataset').sample(70_000, random_state=1)
print(unmarked_balanced.shape)

(210000, 6)


In [21]:
unmarked_balanced.sample(3)

Unnamed: 0.1,Unnamed: 0,sentence,dataset,toxicity_score,iteration,length
152519,26474,meeting with this lame duck .,jigsaw,0.99477,,5
17365,83386,shit our spring break next week so imma be bac...,twitter,0.973974,,14
143822,49367,step away from the computer and the ridiculous...,jigsaw,0.984254,,18


In [48]:
def detokenize(text):
    for symbol in ",.?'":
        text = text.replace(' ' + symbol, symbol)
    return text

In [51]:
bs = 16

In [89]:
detoxified_unmarked = []
inputs = unmarked_balanced.sentence.tolist()
for i in trange(0, len(inputs), bs):
    try:
        results = paraphrase(inputs[i:i+bs], model, tokenizer, beams=5)
    except Exception as e:
        print(e)
        results = [None for t in inputs[i:i+bs]]
    detoxified_unmarked.extend(results)

  0%|          | 0/13125 [00:00<?, ?it/s]

In [92]:
unmarked_balanced['neutral'] = detoxified_unmarked

In [94]:
pd.options.display.max_colwidth = 300

In [97]:
unmarked_balanced.sample(3)

Unnamed: 0.1,Unnamed: 0,sentence,dataset,toxicity_score,iteration,length,neutral
158871,77849,"he boobytrapped his apartment , to kill anyone coming in .",jigsaw,0.926426,,9,He boobytrapped his apartment to kill anyone coming in.
133699,100197,"perhaps the people of taif saw that this fellow was a con artist , thief and predator .",jigsaw,0.983709,,16,perhaps the people of taif saw that this fellow was a thief and predator.
137827,3917,"and not , au revoir , do call again and leave more crap .",jigsaw,0.993696,,11,Do not call again and leave more problems.


In [98]:
unmarked_balanced[['sentence', 'neutral', 'dataset']].to_csv('detox_en_augmented.tsv', index=None, sep='\t')

# Translate the English augmented data to Russian

In [12]:
import pandas as pd
from tqdm.auto import tqdm, trange

In [10]:
detox_en_augmented = pd.read_csv('detox_en_augmented.tsv', sep='\t')

In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# model_name = "facebook/wmt19-en-ru"
model_name = 'Helsinki-NLP/opus-mt-en-ru'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda();

In [4]:
def detokenize(text):
    for symbol in ",.?'":
        text = text.replace(' ' + symbol, symbol)
    return text

In [6]:
from nltk.tokenize import sent_tokenize


def detokenize(text):
    for symbol in ",.?'!":
        text = text.replace(' ' + symbol, symbol)
    return text


def translate(texts, model, tokenizer, num_beams=5, max_length='auto', repetition_penalty=16.0, **kwargs):
    sentences = []
    sent_sizes = []
    for text in texts:
        sents = sent_tokenize(detokenize(text))
        sentences.extend(sents)
        sent_sizes.append(len(sents))
    inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
    if max_length == 'auto': 
        max_length = int(10 + inputs.input_ids.shape[1] * 1.5)
    with torch.no_grad():
        outputs = model.generate(**inputs.to(model.device), num_beams=num_beams, repetition_penalty=repetition_penalty, **kwargs)
    out_sents = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
    out_texts = []
    cursor = 0
    for size in sent_sizes:
        out_texts.append(' '.join(out_sents[cursor: cursor+size]))
        cursor += size
    return out_texts

In [8]:
batch_size = 16 

In [19]:
detox_en_augmented.head(10)

Unnamed: 0,sentence,neutral,dataset
0,the demonization of the china and the chinese ...,the demonization of the china and the chinese ...,jigsaw
1,hatred of those white supremacist nazis was pr...,hatred of those white supremacist nazis was pr...,jigsaw
2,"if a person is black , deadly force is general...","If a person is black, deadly force is generall...",jigsaw
3,what a shitty upbringing that would be .,what a bad upbringing that would be.,jigsaw
4,trudeau is nothing but corruptions .,Trudeau is nothing but corruptions.,jigsaw
5,and dont give me that bs about those priests b...,and dont give me that bs about those priests b...,jigsaw
6,theyve been feeding this rubbish to their citi...,theyve been feeding this to their citizens sin...,jigsaw
7,the reason white people have that advantage is...,the reason white people have that advantage is...,jigsaw
8,thats why i donr blame trump for the present m...,That's why I don't blame Trump for the present...,jigsaw
9,one thing harper was not is a lapdog that crav...,one thing harper was not is a lapdog that crav...,jigsaw


In [None]:
toxic_ru = []
neutral_ru = []
for i in trange(0, len(detox_en_augmented), batch_size):
    toxic_ru.extend(translate(detox_en_augmented.sentence[i:i+batch_size].tolist(), model, tokenizer))
    neutral_ru.extend(translate(detox_en_augmented.neutral[i:i+batch_size].tolist(), model, tokenizer))

  0%|          | 0/13125 [00:00<?, ?it/s]

In [22]:
detox_en_augmented['toxic_ru'] = toxic_ru
detox_en_augmented['neutral_ru'] = neutral_ru

In [24]:
detox_en_augmented.to_csv('detox_en_augmented.tsv', index=None, sep='\t')

# Prepare the training data

In [3]:
import pandas as pd
from tqdm.auto import tqdm, trange

In [4]:
detox_en_augmented = pd.read_csv('detox_en_augmented.tsv', sep='\t')

In [5]:
detox_en2ru = pd.read_csv('detox_en2ru.tsv', sep='\t')

In [6]:
detox_en2ru.sample(3)

Unnamed: 0,idx,toxic_comment,neutral_comment,toxicity_score,dataset,toxic,confidence_toxic,is_match,confidence_is_match,toxic_ru,neutral_ru,edit_distance_en,edit_distance_ru,edit_sim_en,edit_sim_ru,accuracy,similarity,fluency,joint
9885,69927,no one gives a fuck about the music on ur ipod,No one cares about the music on your iPod,0.900609,twitter,False,0.9968,True,0.9777,Никого не волнует музыка на айподе.,Никого не волнует музыка на твоем iPod,14,10,0.695652,0.736842,0.479064,0.68084,0.521709,0.170164
15999,9271,i think we have better things to care about in...,I think we have better things to care about in...,0.998004,reddit,False,0.9997,True,0.9824,"Думаю, в Париже есть вещи поважнее, чем показы...","Думаю, в Париже есть вещи поважнее, чем показы...",5,7,0.938272,0.887097,0.997856,0.934951,1.0,0.932946
4559,22998,it was directed towards callin me a bitch,it was directed towards callin me a bad name,0.991726,twitter,False,0.9896,True,0.9727,"Он был направлен на то, чтобы называть меня су...","Он был направлен на то, чтобы называть меня пл...",7,13,0.840909,0.775862,0.961407,0.775528,1.0,0.745598


In [7]:
detox_en_augmented.sample(3)

Unnamed: 0,sentence,neutral,dataset,toxic_ru,neutral_ru
169718,"don 't get too close bitch , tuh !","don 't get too close, tuh!",twitter,"Не подбирайся слишком близко, сука!","Не подходи слишком близко, Тихо!"
196111,in bed with my pussy,I am in love with this woman.,twitter,в постели с моей киской,Я люблю эту женщину.
36806,the name mt mckinley sucked !,the name mt mckinley sucked!,jigsaw,имя МакКинли было отстойным!,имя МакКинли было отстойным!


In [8]:
from textdistance import levenshtein

In [9]:
detox_en_augmented['edit_distance_ru'] = [levenshtein.distance(*row) for row in detox_en_augmented[['toxic_ru', 'neutral_ru']].values]
detox_en_augmented['edit_sim_ru'] = [levenshtein.normalized_similarity(*row) for row in detox_en_augmented[['toxic_ru', 'neutral_ru']].values]

In [10]:
detox_en_augmented['edit_distance_en'] = [levenshtein.distance(*row) for row in detox_en_augmented[['sentence', 'neutral']].values]
detox_en_augmented['edit_sim_en'] = [levenshtein.normalized_similarity(*row) for row in detox_en_augmented[['sentence', 'neutral']].values]

In [11]:
detox_en2ru.describe()

Unnamed: 0,idx,toxicity_score,confidence_toxic,confidence_is_match,edit_distance_en,edit_distance_ru,edit_sim_en,edit_sim_ru,accuracy,similarity,fluency,joint
count,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0,19766.0
mean,78035.783416,0.990181,0.984174,0.972585,15.705757,21.46327,0.696492,0.60278,0.838157,0.711838,0.869159,0.525728
std,60839.351686,0.019532,0.025847,0.025094,10.989535,14.26659,0.186762,0.209683,0.250159,0.216294,0.172396,0.261383
min,7.0,0.800983,0.5237,0.8,0.0,0.0,0.0,0.0,0.103904,0.0,0.0,0.0
25%,31839.0,0.991065,0.9811,0.9611,8.0,11.0,0.590909,0.451613,0.793939,0.604774,0.788633,0.309918
50%,63895.0,0.996133,0.9941,0.9805,12.0,18.0,0.741935,0.626667,0.974691,0.764827,0.944624,0.554619
75%,101255.25,0.998179,0.9985,0.9911,20.0,28.0,0.83871,0.765957,0.994927,0.873708,1.0,0.745741
max,238836.0,0.999647,1.0,0.9999,105.0,271.0,1.0,1.0,0.999766,1.0,1.0,0.999527


In [12]:
detox_en_augmented.describe()

Unnamed: 0,edit_distance_ru,edit_sim_ru,edit_distance_en,edit_sim_en
count,210000.0,210000.0,210000.0,210000.0
mean,17.198533,0.663594,11.559924,0.766713
std,12.668022,0.221797,8.207094,0.166556
min,0.0,0.0,0.0,0.0
25%,8.0,0.510204,6.0,0.684211
50%,15.0,0.681818,10.0,0.806452
75%,24.0,0.828571,15.0,0.887097
max,260.0,1.0,245.0,1.0


In [13]:
detox_en2ru_filtered = detox_en2ru[
    (detox_en2ru.edit_distance_ru >= detox_en2ru.edit_distance_en.quantile(0.01)) 
    & (detox_en2ru.edit_distance_ru <= detox_en2ru.edit_distance_en.quantile(0.99)) 
    & (detox_en2ru.edit_sim_ru >= detox_en2ru.edit_sim_en.quantile(0.01)) 
    & (detox_en2ru.edit_sim_ru <= detox_en2ru.edit_sim_en.quantile(0.99))
]

print(detox_en2ru.shape)
print(detox_en2ru_filtered.shape)

(19766, 19)
(18239, 19)


In [14]:
detox_augmented_filtered = detox_en_augmented[
    (detox_en_augmented.edit_distance_ru >= detox_en2ru.edit_distance_en.quantile(0.01)) 
    & (detox_en_augmented.edit_distance_ru <= detox_en2ru.edit_distance_en.quantile(0.99)) 
    & (detox_en_augmented.edit_sim_ru >= detox_en2ru.edit_sim_en.quantile(0.01)) 
    & (detox_en_augmented.edit_sim_ru <= detox_en2ru.edit_sim_en.quantile(0.99))
]

print(detox_en_augmented.shape)
print(detox_augmented_filtered.shape)

(210000, 9)
(178112, 9)


In [15]:
pd.options.display.max_colwidth = 300

In [16]:
detox_augmented_filtered.sample(3)

Unnamed: 0,sentence,neutral,dataset,toxic_ru,neutral_ru,edit_distance_ru,edit_sim_ru,edit_distance_en,edit_sim_en
170589,it 's bitch follow me pleeease ?,follow me pleeease?,twitter,Это сучка за моей плейезой?,Следуешь за мной?,18,0.333333,13,0.59375
88547,"i get the idea of goodie bags or whatever to an extent , but entire fucking gifts ! ?","i get the idea of goodie bags or whatever to an extent, but entire gifts!?",reddit,"Я понимаю, что такое шикарные сумки или в какой-то мере. Но целые чертовы подарки! ?","Я понимаю, что такое шикарные сумки или вроде того. Но целые подарки! ?",20,0.761905,11,0.870588
73805,they build and program that shit .,They build and program that.,reddit,Они строят и программируют это дерьмо.,Они строят и программируют это.,7,0.815789,7,0.794118


In [21]:
from sklearn.model_selection import train_test_split
train, val = train_test_split(detox_en2ru_filtered, random_state=1, test_size=500)

In [19]:
from sacrebleu import CHRF
chrfpp = CHRF(word_order=2)

In [22]:
chrfpp.corpus_score(val.toxic_ru.tolist(), [val.neutral_ru.tolist()]).score

59.6204444270353

In [17]:
def paraphrase(
    text, model, tokenizer, n=None, max_length="auto", beams=5,
):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(
        model.device
    )

    if max_length == "auto":
        max_length = inputs.shape[1] + 10

    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=10.0,
        max_length=max_length,
        min_length=int(0.5 * max_length),
        num_beams=beams,
        #forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return texts[0]
    return texts[0]

In [18]:
test_data = pd.read_csv('../data/russian_data/test.tsv', sep='\t')
test_inputs = test_data["toxic_comment"].values.tolist()

# Train a model

In [24]:
from datasets import Dataset, DatasetDict

Add special `<s>` token to indicate augmented data (it is not used otherwise)

In [84]:
raw_data = DatasetDict({
    'train_en': Dataset.from_dict({
        'text': train.toxic_comment.tolist() + ['<s>' + t for t in detox_augmented_filtered.sentence.tolist()], 
        'target': train.neutral_comment.tolist() + ['<s>' + t for t in detox_augmented_filtered.neutral.tolist()], 
    }),
    'train_ru': Dataset.from_dict({
        'text': train.toxic_ru.tolist() + ['<s>' + t for t in detox_augmented_filtered.toxic_ru.tolist() ], 
        'target': train.neutral_ru.tolist()  + ['<s>' + t for t in detox_augmented_filtered.neutral_ru.tolist()], 
    }),
    'train_ru_clean': Dataset.from_dict({
        'text': train.toxic_ru.tolist(),
        'target': train.neutral_ru.tolist(),
    }),
    'train': Dataset.from_dict({
        'text': train.toxic_ru.tolist() + train.toxic_comment.tolist() \
            + ['<s>' + t for t in detox_augmented_filtered.toxic_ru.tolist() + detox_augmented_filtered.sentence.tolist()], 
        'target': train.neutral_ru.tolist() + train.neutral_comment.tolist() \
            + ['<s>' + t for t in detox_augmented_filtered.neutral_ru.tolist() + detox_augmented_filtered.neutral.tolist()], 
    }),
    'dev': Dataset.from_dict({'text': val.toxic_ru, 'target': val.neutral_ru}),
})
raw_data

DatasetDict({
    train_en: Dataset({
        features: ['text', 'target'],
        num_rows: 195851
    })
    train_ru: Dataset({
        features: ['text', 'target'],
        num_rows: 195851
    })
    train_ru_clean: Dataset({
        features: ['text', 'target'],
        num_rows: 17739
    })
    train: Dataset({
        features: ['text', 'target'],
        num_rows: 391702
    })
    dev: Dataset({
        features: ['text', 'target'],
        num_rows: 500
    })
})

In [85]:
raw_data['train'][-1]

{'text': '<s>she might not make it me fuck you nursey',
 'target': '<s>she might not make it me mess with you nursey'}

In [74]:
base_model = 'facebook/mbart-large-50'

In [32]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [33]:
model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [56]:
prefix = ""

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, padding=True)
    labels = tokenizer(examples["target"], padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [86]:
tok_data = raw_data.map(preprocess_function, batched=True)

  0%|          | 0/196 [00:00<?, ?ba/s]

  0%|          | 0/196 [00:00<?, ?ba/s]

  0%|          | 0/18 [00:00<?, ?ba/s]

  0%|          | 0/392 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [62]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [66]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="/home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1, # 8 is too much 
    weight_decay=1e-5,
    max_steps=50_000,
    learning_rate=1e-5,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    eval_steps=1000, 
    save_steps=1000,
    logging_steps=1000,
    load_best_model_at_end=True,
    # trying to save memory: see https://huggingface.co/docs/transformers/performance
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    gradient_accumulation_steps=1,
)

In [67]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tok_data["train"],
    eval_dataset=tok_data["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


About 4 hours of training

In [75]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: text, target. If text, target are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 391702
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 50000
  args.max_grad_norm,


Step,Training Loss,Validation Loss
1000,1.27,1.082111
2000,0.3434,0.490913
3000,0.3151,0.455586
4000,0.304,0.448654
5000,0.2978,0.4316
6000,0.2959,0.417865
7000,0.2934,0.413381
8000,0.2898,0.410315
9000,0.2889,0.403366
10000,0.2771,0.399524


The following columns in the evaluation set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: text, target. If text, target are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 1
Saving model checkpoint to /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart/checkpoint-1000
Configuration saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart/checkpoint-1000/config.json
Model weights saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart/checkpoint-1000/tokenizer_config.json
Special tokens file saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart/checkpoint-1000/special_tokens_map.jso

TrainOutput(global_step=50000, training_loss=0.28522207397460936, metrics={'train_runtime': 18125.3899, 'train_samples_per_second': 22.068, 'train_steps_per_second': 2.759, 'total_flos': 3.915323299735142e+16, 'train_loss': 0.28522207397460936, 'epoch': 1.02})

In [76]:
from tqdm.auto import tqdm, trange

In [77]:
preds = []
model.eval()
for text in tqdm(val.toxic_ru):
    with torch.inference_mode():
        out = tokenizer.decode(
            model.generate(**tokenizer(text, return_tensors='pt').to(model.device), num_beams=5, max_length=256)[0], 
            skip_special_tokens=True,
        )
        preds.append(out)

  0%|          | 0/500 [00:00<?, ?it/s]

61.4175 was the previous model score. With augmentation, it is higher: 61.9785. 

In [78]:
print(chrfpp.corpus_score(preds, [val.neutral_ru.tolist()]).score)

61.97856261830022


In [79]:
test_outputs = [paraphrase(text, model, tokenizer) for text in tqdm(test_inputs)]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [80]:
p = '../results/translate-train-full_augmented_bilingual-mbart/'
os.makedirs(p)

In [81]:
with open(p + 'results_ru.txt', 'w') as f:
    for text in test_outputs:
        f.write(text+'\n')

```
python evaluate_ru.py \
    --result_filename scores \
    --input_dir results/translate-train-full_augmented_bilingual-mbart \
    --output_dir results
```

```
Style accuracy:       0.47033900022506714
Meaning preservation: 0.8848435282707214
Joint fluency:        -0.10494334995746613
Joint score:          -0.04591362178325653
Scores after calibration:
Style accuracy:       0.5233051180839539
Meaning preservation: 0.8273530006408691
Joint fluency:        0.879315197467804
Joint score:          0.36825796961784363
```

In [82]:
training_args = Seq2SeqTrainingArguments(
    output_dir="/home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart-finetune",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1, # 8 is too much 
    weight_decay=1e-5,
    max_steps=5_000,
    learning_rate=1e-5,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    eval_steps=1000, 
    save_steps=1000,
    logging_steps=1000,
    load_best_model_at_end=True,
    # trying to save memory: see https://huggingface.co/docs/transformers/performance
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    gradient_accumulation_steps=1,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [88]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tok_data["train_ru_clean"],
    eval_dataset=tok_data["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


In [89]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: text, target. If text, target are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 17739
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 5000


Step,Training Loss,Validation Loss
1000,0.3454,0.344882
2000,0.3416,0.336609
3000,0.2833,0.340101
4000,0.2674,0.335528
5000,0.2469,0.34545


  args.max_grad_norm,
The following columns in the evaluation set don't have a corresponding argument in `MBartForConditionalGeneration.forward` and have been ignored: text, target. If text, target are not expected by `MBartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 1
Saving model checkpoint to /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart-finetune/checkpoint-1000
Configuration saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart-finetune/checkpoint-1000/config.json
Model weights saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart-finetune/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in /home/dale/models/detox-parallel/translate-en2ru-full_aug_bilingual-mbart-finetune/checkpoint-1000/tokenizer_config.json
Special tokens file saved in /home/dale/models/detox-parallel/translate-en2ru-full_

TrainOutput(global_step=5000, training_loss=0.2969224395751953, metrics={'train_runtime': 1972.0811, 'train_samples_per_second': 20.283, 'train_steps_per_second': 2.535, 'total_flos': 5273016131518464.0, 'train_loss': 0.2969224395751953, 'epoch': 2.25})

In [90]:
preds = []
model.eval()
for text in tqdm(val.toxic_ru):
    with torch.inference_mode():
        out = tokenizer.decode(
            model.generate(**tokenizer(text, return_tensors='pt').to(model.device), num_beams=5, max_length=256)[0], 
            skip_special_tokens=True,
        )
        preds.append(out)

  0%|          | 0/500 [00:00<?, ?it/s]

In [91]:
print(chrfpp.corpus_score(preds, [val.neutral_ru.tolist()]).score)

62.257692944801725


In [92]:
test_outputs = [paraphrase(text, model, tokenizer) for text in tqdm(test_inputs)]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [93]:
p = '../results/translate-train-full_augmented_bilingual-mbart-finetune/'
os.makedirs(p)

In [94]:
with open(p + 'results_ru.txt', 'w') as f:
    for text in test_outputs:
        f.write(text+'\n')

```
python evaluate_ru.py \
    --result_filename scores \
    --input_dir results/translate-train-full_augmented_bilingual-mbart-finetune \
    --output_dir results
```

```
Style accuracy:       0.5113834142684937
Meaning preservation: 0.8692513704299927
Joint fluency:        -0.11415145546197891
Joint score:          -0.04862915724515915
Scores after calibration:
Style accuracy:       0.5602450370788574
Meaning preservation: 0.8039894700050354
Joint fluency:        0.8687258362770081
Joint score:          0.3829019069671631
```