In this problem statement, the encoder does not know in what language it is going to decode, so it has to produce language-invariant embeddings. 

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
from datasets import load_dataset

In [3]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [4]:
import numpy as np
import pandas as pd
import gc
import random

from tqdm.auto import tqdm, trange

# Collect the training tasks

* Paraphrasing: https://huggingface.co/datasets/GEM/opusparcus
* Translation: https://huggingface.co/datasets/open_subtitles + news_commentary? + tatoeba?
* Detox: ordinary data

In [3]:
opus_para_en = load_dataset("GEM/opusparcus", "en.80")
opus_para_en

Reusing dataset opusparcus (/home/dale/.cache/huggingface/datasets/GEM___opusparcus/en.80/1.0.0/79d36ae4eced4f3c2c5a2ab9f94a584de7adca957186408d33798d0d87b018f2)


  0%|          | 0/5 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 982
    })
    validation: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1015
    })
    test.full: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1445
    })
    validation.full: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1455
    })
    train: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 5200000
    })
})

In [4]:
opus_para_ru = load_dataset("GEM/opusparcus", "ru.80")
opus_para_ru

Reusing dataset opusparcus (/home/dale/.cache/huggingface/datasets/GEM___opusparcus/ru.80/1.0.0/79d36ae4eced4f3c2c5a2ab9f94a584de7adca957186408d33798d0d87b018f2)


  0%|          | 0/5 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1068
    })
    validation: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1020
    })
    test.full: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1855
    })
    validation.full: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 1854
    })
    train: Dataset({
        features: ['lang', 'input', 'target', 'annot_score', 'gem_id', 'references'],
        num_rows: 2300000
    })
})

In [6]:
random.choice(opus_para_ru['train'])

{'lang': 'ru',
 'input': 'Я думаю, что могу.',
 'target': 'Я думаю, что делаю.',
 'annot_score': 0.0,
 'gem_id': 'gem-opusparcus-train-79213069',
 'references': ['Я думаю, что делаю.']}

In [7]:
opensub = load_dataset("open_subtitles", lang1="en", lang2='ru')
opensub

Using custom data configuration en-ru-lang1=en,lang2=ru
Reusing dataset open_subtitles (/home/dale/.cache/huggingface/datasets/open_subtitles/en-ru-lang1=en,lang2=ru/0.0.0/c1ec973ca4b6e588740d8f167cc0e24ea3f626e70bc7ffe467e944730500e198)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'meta', 'translation'],
        num_rows: 25910105
    })
})

In [8]:
random.choice(opensub['train'])

{'id': '11986737',
 'meta': {'year': 2011,
  'imdbId': 1640807,
  'subtitleId': {'en': 4414177, 'ru': 6337203},
  'sentenceIds': {'en': [961], 'ru': [907]}},
 'translation': {'en': 'Package?', 'ru': 'Посылка?'}}

In [9]:
news_commentary = load_dataset("news_commentary", lang1="en", lang2='ru')
news_commentary

Using custom data configuration en-ru-lang1=en,lang2=ru
Reusing dataset news_commentary (/home/dale/.cache/huggingface/datasets/news_commentary/en-ru-lang1=en,lang2=ru/0.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 190104
    })
})

In [10]:
random.choice(news_commentary['train'])

{'id': '105617',
 'translation': {'en': 'Peace-related activities – particularly reintegration and other reconciliation programs – need to be given priority in budget allocations.',
  'ru': 'Деятельность, направленная на установление мира, – в особенности возвращение людей в общество и т.п. – должна пользоваться приоритетом при распределении бюджета.'}}

In [11]:
tatoeba = load_dataset("tatoeba", lang1="en", lang2='ru')
tatoeba

Using custom data configuration en-ru-lang1=en,lang2=ru
Reusing dataset tatoeba (/home/dale/.cache/huggingface/datasets/tatoeba/en-ru-lang1=en,lang2=ru/0.0.0/b3ea9c6bb2af47699c5fc0a155643f5a0da287c7095ea14824ee0a8afd74daf6)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 523656
    })
})

In [12]:
random.choice(tatoeba['train'])

{'id': '40914',
 'translation': {'en': "I'd like to dance with you.",
  'ru': 'Я хотел бы потанцевать с тобой.'}}

In [23]:
def get_paraphrase_task(batch_size=1):
    task = '' #'paraphrase: '
    if random.random() < 0.5:
        src = opus_para_en
        src_id, tgt_id = 'en_XX', 'en_XX'
    else:
        src = opus_para_ru
        src_id, tgt_id = 'ru_RU', 'ru_RU'
    x = []
    y = []
    for _ in range(batch_size):
        item = random.choice(src['train'])
        xx, yy = item['input'], item['target']
        if random.random() < 0.5:
            xx, yy = yy, xx
        x.append(task + xx)
        y.append(yy)
    return x, y, src_id, tgt_id
        
get_paraphrase_task(1)

(['Argh, me back.'], ["I'm back."], 'en_XX', 'en_XX')

In [24]:
def get_translate_task(batch_size=1):
    task = '' # 'translate: '
    src = random.choice([tatoeba, opensub, news_commentary])
    src_id, tgt_id = 'ru_RU', 'en_XX'
    if random.random() < 0.5:
        src_id, tgt_id = tgt_id, src_id
    x = []
    y = []
    for _ in range(batch_size):
        item = random.choice(src['train'])
        x.append(task + item['translation'][src_id[:2]])
        y.append(item['translation'][tgt_id[:2]])
    return x, y, src_id, tgt_id

get_translate_task(2)

(["You don't get to go to hotels with boys that I don't know, with boys I do know, with any kind of boy!",
  'I want some fucking Colace.'],
 ['15 лет! Ты не должна ходить по отелям с парнями, которых я не знаю. И даже с парнями, которых я знаю!',
  'Мне нужен Колас.'],
 'en_XX',
 'ru_RU')

In [21]:
model_name = 'facebook/mbart-large-50'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [22]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda();

In [25]:
def get_batch(batch_size=4):
    x = []
    y = []
    bs = batch_size // 2
    for generator in [get_paraphrase_task, get_translate_task]:
        for _ in range(bs):
            xx, yy, src_id, tgt_id = generator(1)
            tokenizer.src_lang, tokenizer.tgt_lang = src_id, tgt_id
            x.append(tokenizer(xx))
            with tokenizer.as_target_tokenizer():
                y.append(tokenizer(yy))
    return x, y

In [26]:
x, y = get_batch(6)
x, y

([{'input_ids': [[250004, 32774, 6528, 4, 642, 1221, 73203, 25842, 5, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]},
  {'input_ids': [[250004, 1401, 2258, 20740, 297, 5, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1]]},
  {'input_ids': [[250021, 33473, 4, 4789, 77, 131666, 414, 89476, 5, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]},
  {'input_ids': [[250021, 1509, 19458, 7193, 811, 4, 12386, 96357, 5, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]},
  {'input_ids': [[250021, 1851, 690, 4, 414, 23494, 1725, 33868, 49, 37325, 2325, 115458, 6, 22782, 49, 1392, 8165, 165769, 24192, 106, 80389, 1745, 8118, 38850, 49, 38592, 46, 5641, 424, 242910, 244, 114, 827, 9561, 46, 12578, 57750, 10757, 87879, 227, 36553, 4544, 5, 447, 34496, 99582, 4, 414, 2176, 3216, 1086, 7736, 88426, 213585, 19596, 129, 210507, 216019, 312, 4, 68929, 31825, 36553, 4544, 718, 42066, 8165, 4, 20946, 145089, 13299, 54064, 1219, 49, 2660, 4, 3077, 114628, 90870, 312, 518, 228317, 36553, 4

In [27]:
for xx, yy in zip(x, y):
    print(tokenizer.decode(xx['input_ids'][0]))
    print(tokenizer.decode(yy['input_ids'][0]))

en_XX Great idea, we will eat together.</s>
en_XX The breakfast.</s>
en_XX We carpooled.</s>
en_XX That's a lot of driving up here.</s>
ru_RU Нет, ты не знаешь что произошло.</s>
ru_RU Ты же не знаешь, кто это был.</s>
ru_RU Я знаю кто вы, Коннор.</s>
en_XX I know who you are, Connor.</s>
ru_RU Но то, что должно их действительно встревожить ‐ в 2011 году приблизительно 14,5 % всего населения в мире – один из каждых семи человек – жил ниже этой черты бедности. Учитывая, что мы уже обязались достичь цели по ликвидации чрезвычайной, хронической бедности к 2030 году, наше первое решение состояло в том, чтобы считать критерий для измерения бедности постоянным.</s>
en_XX कुछ आलोचकों का यह तर्क है कि 2005 की $1.25 की गरीबी रेखा बहुत कम थी। लेकिन उन्हें जिस बात पर चिंता करनी चाहिए वह यह है कि वर्ष 2011 में दुनिया की लगभग 14.5% आबादी - हर सात लोगों में से एक - इसके नीचे रह रही थी। यह देखते हुए कि हम पहले से ही 2030 तक चरम, चिर गरीबी को समाप्त करने के लक्ष्य के लिए प्रतिबद्ध हैं, हमारा पहला निर्

In [28]:
tokenizer.pad([{k: v[0] for k, v in item.items()} for item in x])

{'input_ids': [[250004, 32774, 6528, 4, 642, 1221, 73203, 25842, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [250004, 1401, 2258, 20740, 297, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [250021, 33473, 4, 4789, 77, 131666, 414, 89476, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [250021, 1509, 19458, 7193, 811, 4, 12386, 96357, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [29]:
import transformers
transformers.logging.set_verbosity_error()

In [30]:
from transformers.optimization import Adafactor
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=1e-5, clip_threshold=1.0)

In [31]:
from torch.optim.lr_scheduler import LambdaLR

num_warmup_steps = 1000

def lr_lambda(current_step: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1.0, num_warmup_steps))
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda)

In [32]:
model.train()
model.gradient_checkpointing_enable()

In [33]:
from tqdm.auto import tqdm, trange

In [34]:
import torch
import gc

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

cleanup()

In [35]:
gradient_steps = 4
batch_size = 6
window = 1000
report_steps = 1000
cleanup_steps = 100
save_path = '/home/dale/models/detox-parallel/bart-multitask-pretrain-invariant'

In [36]:
ewm_loss = 4

In [37]:
cleanup()

With ~6 iterations per second, 1M iteration takes 1000000/6/60/60=45 hours.

In [None]:
model.train()
tq = trange(0 , 1_000_000)
cleanup()

for i in tq:
    x, y = get_batch(batch_size)
    xpad = tokenizer.pad([{k: v[0] for k, v in item.items()} for item in x])
    ypad = tokenizer.pad([{k: v[0] for k, v in item.items()} for item in y])
    
    try:
        labels = torch.tensor(ypad['input_ids'], device=model.device)
        labels[labels==tokenizer.pad_token_id] = -100
        loss = model(
            input_ids=torch.tensor(xpad['input_ids'], device=model.device),
            attention_mask=torch.tensor(xpad['attention_mask'], device=model.device),
            labels=labels,
        ).loss
        loss.backward()
        if i % gradient_steps == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
    except RuntimeError as e:
        print('error', i, e)
        loss = None
        optimizer.zero_grad(set_to_none=True)
        cleanup()
        continue

    w = 1 / max(1, min(i, window))
    ewm_loss = ewm_loss * (1-w) + loss.item() * w
    tq.set_description(f'{ewm_loss:3.4f}')

    if i > 0 and i % report_steps == 0:
        print('step', i, 'loss', ewm_loss, )
        if i > 0:
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
    if i % cleanup_steps == 0:
        cleanup()

    

  0%|          | 0/1000000 [00:00<?, ?it/s]

step 1000 loss 7.8495250988006555
error 1151 CUDA out of memory. Tried to allocate 1002.00 MiB (GPU 0; 10.76 GiB total capacity; 8.36 GiB already allocated; 725.69 MiB free; 8.93 GiB reserved in total by PyTorch)
step 2000 loss 6.971980854110847
step 3000 loss 4.503658014299016
step 4000 loss 3.0397753827503657
step 5000 loss 2.328842712437607
error 5990 CUDA out of memory. Tried to allocate 854.00 MiB (GPU 0; 10.76 GiB total capacity; 8.12 GiB already allocated; 787.69 MiB free; 8.87 GiB reserved in total by PyTorch)
step 6000 loss 2.0246113876562606
step 7000 loss 1.8595813628118134
error 7090 CUDA out of memory. Tried to allocate 1.02 GiB (GPU 0; 10.76 GiB total capacity; 8.28 GiB already allocated; 525.69 MiB free; 9.12 GiB reserved in total by PyTorch)
error 7115 CUDA out of memory. Tried to allocate 814.00 MiB (GPU 0; 10.76 GiB total capacity; 9.40 GiB already allocated; 151.69 MiB free; 9.49 GiB reserved in total by PyTorch)
error 7338 CUDA out of memory. Tried to allocate 912.0

In [39]:
print(i, ewm_loss)
# step 141000 loss 0.9336458999525007
# 598943 0.7913156676625156
# 999999 0.7528545096685079

999999 1.0905373548772155


```

```

```
step 1000 loss 7.8495250988006555
step 2000 loss 6.971980854110847
step 3000 loss 4.503658014299016
step 4000 loss 3.0397753827503657
step 5000 loss 2.328842712437607
step 6000 loss 2.0246113876562606
step 7000 loss 1.8595813628118134
step 8000 loss 1.7656280765961248
step 9000 loss 1.7232480028436197
step 10000 loss 1.7234741253498884
step 11000 loss 1.6784266652217328
step 12000 loss 1.6633502409092902
step 13000 loss 1.6524237591194655
step 14000 loss 1.6367858148524403
step 15000 loss 1.621678100603032
step 16000 loss 1.6167106889640455
step 17000 loss 1.610022509102348
step 18000 loss 1.5817169323459432
step 19000 loss 1.5722098078269628
step 20000 loss 1.5679702238868234
step 21000 loss 1.5610694801590381
step 22000 loss 1.561103833062269
step 23000 loss 1.5497197078101173
step 24000 loss 1.5374166345673048
step 25000 loss 1.5412099212469161
step 26000 loss 1.5387520727699775
step 27000 loss 1.5450489197441235
step 28000 loss 1.5465633031182555
step 29000 loss 1.548109695684887
step 30000 loss 1.5392959037461993
step 31000 loss 1.5214723890115276
step 32000 loss 1.5197352038997345
step 33000 loss 1.5191886800053331
step 34000 loss 1.5071750278883136
step 35000 loss 1.4951174692222116
step 36000 loss 1.5110572869894487
step 37000 loss 1.5035447392609362
step 38000 loss 1.4937061133745018
step 39000 loss 1.5010554298354957
step 40000 loss 1.486516372689772
step 41000 loss 1.4813608206837234
step 42000 loss 1.4922767825661034
step 43000 loss 1.4789339267590702
step 44000 loss 1.4735573651956568
step 45000 loss 1.4621720642907539
step 46000 loss 1.4624380894905862
step 47000 loss 1.4647909544826903
step 48000 loss 1.4605912022122536
step 49000 loss 1.4638088939337457
step 50000 loss 1.4601541562835065
step 51000 loss 1.4568530646179267
step 52000 loss 1.4455715473337545
step 53000 loss 1.4385002881553164
step 54000 loss 1.4275163370868755
step 55000 loss 1.4280296323270911
step 56000 loss 1.415750276519965
step 57000 loss 1.4304115232860126
step 58000 loss 1.4350156803490897
step 59000 loss 1.4369355844267218
step 60000 loss 1.4283767094088695
step 61000 loss 1.429359077257854
step 62000 loss 1.4302132939200987
step 63000 loss 1.4428398229490416
step 64000 loss 1.4318759719514011
step 65000 loss 1.4255519478190595
step 66000 loss 1.4234853517876795
step 67000 loss 1.413439031855459
step 68000 loss 1.4148290818568556
step 69000 loss 1.405174908258505
step 70000 loss 1.4159819581703303
step 71000 loss 1.4142312050551875
step 72000 loss 1.4098882850239693
step 73000 loss 1.4093389173919988
step 74000 loss 1.3966641162948248
step 75000 loss 1.4124522088922455
step 76000 loss 1.4068154301855185
step 77000 loss 1.409533058917267
step 78000 loss 1.4075952773595297
step 79000 loss 1.400638904607081
step 80000 loss 1.4066685893286581
step 81000 loss 1.4057836900763392
step 82000 loss 1.385551296633067
step 83000 loss 1.3909922026470212
step 84000 loss 1.4009384998257617
step 85000 loss 1.4118246969220813
step 86000 loss 1.4109142263599457
step 87000 loss 1.402157676788603
step 88000 loss 1.40006970371994
step 89000 loss 1.390521758636719
step 90000 loss 1.3839341735154245
step 91000 loss 1.3880099119437543
step 92000 loss 1.3886318891892488
step 93000 loss 1.390096048908005
step 94000 loss 1.3868484540830661
step 95000 loss 1.3847126415979027
step 96000 loss 1.376475900571145
step 97000 loss 1.3697450944932394
step 98000 loss 1.3691839251840876
step 99000 loss 1.3860356547859314
step 100000 loss 1.3807216223047314
step 101000 loss 1.3963250578672481
step 102000 loss 1.381718494378525
step 103000 loss 1.3867386296175699
step 104000 loss 1.3747667828236734
step 105000 loss 1.3757010894743196
step 106000 loss 1.3627061974459742
step 107000 loss 1.3761756791113087
step 108000 loss 1.3648930810086939
step 109000 loss 1.375110728236904
step 110000 loss 1.375891402740335
step 111000 loss 1.3806493404146039
step 112000 loss 1.3648925215626757
step 113000 loss 1.3668712401155252
step 114000 loss 1.351714990858381
step 115000 loss 1.3580255847234042
step 116000 loss 1.3600642680345183
step 117000 loss 1.3513963500297228
step 118000 loss 1.3671170076103165
step 119000 loss 1.3609962514393041
step 120000 loss 1.3540291630676364
step 121000 loss 1.3638124138623196
step 122000 loss 1.3662583619871882
step 123000 loss 1.3456631593237443
step 124000 loss 1.3446384864127559
step 125000 loss 1.3523965853922022
step 126000 loss 1.3513094280967
step 127000 loss 1.3651944989943638
step 128000 loss 1.350642460967927
step 129000 loss 1.339702469783254
step 130000 loss 1.340435085508887
step 131000 loss 1.351269869766291
step 132000 loss 1.344327490579696
step 133000 loss 1.3535759422240654
step 134000 loss 1.358697595711307
step 135000 loss 1.3550793867665347
step 136000 loss 1.3556051740578334
step 137000 loss 1.3539432265586322
step 138000 loss 1.351645375457388
step 139000 loss 1.3329083108055126
step 140000 loss 1.328960430712879
step 141000 loss 1.3395015047754533
step 142000 loss 1.3376708943471236
step 143000 loss 1.3399167724838987
step 144000 loss 1.3391197170874636
step 145000 loss 1.3340307605203736
step 146000 loss 1.3435670633954624
step 147000 loss 1.3415995513449999
step 148000 loss 1.3384918063549696
step 149000 loss 1.350963800868836
step 150000 loss 1.3358103464593043
step 151000 loss 1.334069401766121
step 152000 loss 1.3442333004223217
step 153000 loss 1.3272806209559203
step 154000 loss 1.3324173407753037
step 155000 loss 1.3310733793326808
step 156000 loss 1.3299070366686063
step 157000 loss 1.3278618451018642
step 158000 loss 1.3411352952373696
step 159000 loss 1.3254056563856131
step 160000 loss 1.3227620278713996
step 161000 loss 1.3318189009700612
step 162000 loss 1.34123620795669
step 163000 loss 1.3384616679225427
step 164000 loss 1.3167609319689726
step 165000 loss 1.316144160657587
step 166000 loss 1.319476455271136
step 167000 loss 1.3287669541319378
step 168000 loss 1.320705268067573
step 169000 loss 1.3181178431390548
step 170000 loss 1.3223915783077766
step 171000 loss 1.3080522195015636
step 172000 loss 1.3111654091495746
step 173000 loss 1.3377228920739324
step 174000 loss 1.33661053732826
step 175000 loss 1.3441622881373327
step 176000 loss 1.329836377286269
step 177000 loss 1.3191058659188997
step 178000 loss 1.3219125257458657
step 179000 loss 1.3140343884660775
step 180000 loss 1.3227725058951971
step 181000 loss 1.3192630414527942
step 182000 loss 1.3155189349693552
step 183000 loss 1.314023183256946
step 184000 loss 1.3261253488463896
step 185000 loss 1.321916067652572
step 186000 loss 1.3146131114423225
step 187000 loss 1.3158746480019954
step 188000 loss 1.309435959482614
step 189000 loss 1.3132142503440258
step 190000 loss 1.3130569536308465
step 191000 loss 1.3133588763676434
step 192000 loss 1.3149334933265298
step 193000 loss 1.3100141813526676
step 194000 loss 1.3317128545010182
step 195000 loss 1.3224027818585606
step 196000 loss 1.309476686901512
step 197000 loss 1.310167286517061
step 198000 loss 1.3082282657392057
step 199000 loss 1.3063352371723151
step 200000 loss 1.3153543664657465
step 201000 loss 1.2990609985683863
step 202000 loss 1.3213456139707835
step 203000 loss 1.3117750872213911
step 204000 loss 1.3061624213358713
step 205000 loss 1.3067514004474508
step 206000 loss 1.298021090947502
step 207000 loss 1.3050726174984952
step 208000 loss 1.2940952094890747
step 209000 loss 1.3060495546278659
step 210000 loss 1.294749523387651
step 211000 loss 1.2909695474597749
step 212000 loss 1.3036224128995029
step 213000 loss 1.3004102189749753
step 214000 loss 1.2888341406709947
step 215000 loss 1.2888423265367135
step 216000 loss 1.293618114140688
step 217000 loss 1.286390335067559
step 218000 loss 1.305842937402965
step 219000 loss 1.3031337922162076
step 220000 loss 1.3037511910869009
step 221000 loss 1.2849795153547088
step 222000 loss 1.3000277144312373
step 223000 loss 1.2981221889389567
step 224000 loss 1.292322057781449
step 225000 loss 1.3024142567633405
step 226000 loss 1.2926525958200987
step 227000 loss 1.284270309031029
step 228000 loss 1.2844790143156553
step 229000 loss 1.2826334835374635
step 230000 loss 1.2884224988645745
...
999999 1.0905373548772155
```

In [42]:
model.eval();

In [43]:
tokenizer.src_lang = 'ru_RU'

In [44]:
tokenizer.src_lang = 'en_XX'

In [45]:
tokenizer('привет')

{'input_ids': [250004, 146038, 2], 'attention_mask': [1, 1, 1]}

In [46]:
tokenizer.convert_tokens_to_ids('en_XX')

250004

In [47]:
def paraphrase(
    text, model, tokenizer, 
    n=None, 
    max_length="auto", 
    beams=5,
    src_lang='en_XX',
    tgt_lang='en_XX',
    **kwargs
):
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    texts = [text] if isinstance(text, str) else text
    tokenizer.tgt_lang
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(
        model.device
    )

    if max_length == "auto":
        max_length = inputs.shape[1] + 10

    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=10.0,
        max_length=max_length,
        min_length=int(0.5 * max_length),
        num_beams=beams,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return texts[0]
    return texts[0]

In [48]:
print(paraphrase('paraphrase: I like to play with my nice dog.', model, tokenizer))

I'd like to play with my nice dog.


In [49]:
print(paraphrase('translate: I like to play with my nice dog.', model, tokenizer, tgt_lang='ru_RU'))

Я люблю играть со своей миленькой собакой.


In [50]:
print(paraphrase('translate: I hate to play with my fucking dog.', model, tokenizer, tgt_lang='ru_RU'))

Я ненавижу играть со своей чёртовой собакой.


In [51]:
print(paraphrase('translate: Ненавижу играть со своей долбаной собакой.', model, tokenizer, src_lang='ru_RU'))

I hate playing with my motherfucking dog, man.


In [52]:
print(paraphrase('paraphrase: I hate to play with my fucking dog.', model, tokenizer, tgt_lang='en_XX'))

I don't want to play with my dog.


# Train a detoxifier adapter on top of the encoder

In [5]:
from transformers.modeling_outputs import BaseModelOutput


def paraphrase_adapt(
    text, model, tokenizer, adapter,
    n=None, 
    max_length="auto", 
    beams=5,
    src_lang='en',
    tgt_lang='en',
    **kwargs
):
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(model.device)

    if max_length == "auto":
        max_length = inputs.shape[1] + 10
        
    with torch.inference_mode():
        encoded = model.model.encoder(inputs)
        adapted = BaseModelOutput(last_hidden_state=adapter(encoded.last_hidden_state))

    result = model.generate(
        encoder_outputs=adapted,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=1.0,
        max_length=max_length,
        #min_length=int(0.5 * max_length),
        num_beams=beams,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return texts[0]
    return texts[0]

In [6]:
model_name = '/home/dale/models/detox-parallel/bart-multitask-pretrain-invariant'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda();

In [8]:
test_data = pd.read_csv('../data/russian_data/test.tsv', sep='\t')
test_inputs_ru = test_data["toxic_comment"].values.tolist()

In [9]:
with open('../data/english_data/test_toxic_parallel.txt', 'r') as f:
    test_inputs_en = [line.strip() for line in f.readlines()]

In [10]:
class Adapter(torch.nn.Module):
    def __init__(self, model_dim=1024, ffn_dim=4096):
        super().__init__()
        self.model_dim = model_dim
        self.ffn_dim = ffn_dim
        # self.delta = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.delta = torch.nn.Sequential(
            torch.nn.Linear(model_dim, ffn_dim, bias=True),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(ffn_dim, model_dim, bias=True),
        )
    def forward(self, x):
        return x + self.delta(x)

## en->ru

In [10]:
detox_en2ru = pd.read_csv('detox_en2ru_yandex.tsv', sep='\t')

In [11]:
from sklearn.model_selection import train_test_split
train, dev = train_test_split(detox_en2ru, test_size=0.1, random_state=1)
print(train.shape, dev.shape)

(17789, 15) (1977, 15)


In [12]:
adapter = Adapter()
adapter.cuda()

Adapter(
  (delta): Sequential(
    (0): Linear(in_features=1024, out_features=4096, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=4096, out_features=1024, bias=True)
  )
)

In [13]:
model.train()
adapter.train()

Adapter(
  (delta): Sequential(
    (0): Linear(in_features=1024, out_features=4096, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=4096, out_features=1024, bias=True)
  )
)

In [14]:
optimizer = torch.optim.Adam(adapter.parameters(), lr=5e-5)

In [15]:
losses = []
for i in trange(5_000):
    text_batch = train.sample(16)
    tokenizer.src_lang = "en_XX"
    batch_in = tokenizer(text_batch.toxic_comment.tolist(), return_tensors="pt", padding=True).to(model.device)
    batch_out = tokenizer(text_batch.neutral_comment.tolist(), return_tensors="pt", padding=True).to(model.device)
    batch_out.input_ids[batch_out.input_ids==tokenizer.pad_token_id] = -100
    
    with torch.no_grad():
        encoded = model.model.encoder(**batch_in)
    transformed = adapter(encoded.last_hidden_state)

    total_out = model(
        encoder_outputs=[transformed], 
        attention_mask=batch_in.attention_mask,
        decoder_attention_mask=batch_out.attention_mask,
        labels=batch_out.input_ids,
    )
    total_out.loss.backward()
    torch.nn.utils.clip_grad_norm_(adapter.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()
    losses.append(total_out.loss.item())
    if i % 100 == 0:
        print(np.mean(losses[-100:]))

  0%|          | 0/5000 [00:00<?, ?it/s]

1.6552413702011108
1.0699604988098144
0.9457136896252633
0.9134821289777756
0.9283235603570938
0.9031834518909454
0.881840660572052
0.8903419297933578
0.8470703360438346
0.8436727476119995
0.8622586065530777
0.8464151296019554
0.7979262927174569
0.8425427156686783
0.8484166666865349
0.8185198873281478
0.8075836738944053
0.8029882007837296
0.8022268304228782
0.7951979848742485
0.8026005041599273
0.7974074131250382
0.7784970554709435
0.8007240790128708
0.7517123523354531
0.7912288221716881
0.7903843292593956
0.7914382806420326
0.7596189120411873
0.7278305983543396
0.7563984030485154
0.7332312130928039
0.763190213739872
0.7413530832529068
0.7073853594064713
0.7406839892268181
0.7398436322808266
0.7264939627051353
0.7382057321071624
0.7121048700809479
0.7396052488684655
0.7141054171323776
0.7169818970561027
0.6934740725159645
0.7164686095714569
0.6986848366260529
0.701336784362793
0.7051332432031632
0.6789515882730484
0.6860926705598831


In [16]:
model.eval();
adapter.eval();

In [17]:
row = detox_en2ru.sample(1).iloc[0]
text = row.toxic_comment
print(text)
print(row.neutral_comment)
print(paraphrase_adapt(text, model, tokenizer, adapter, src_lang='en_XX', tgt_lang='en_XX', beams=5))

awwww , thanks baba ! thank fuck it 's over '
awwww , thanks baba ! thank it 's over '
awwww, thanks baba! thank it's over


In [18]:
test_outputs_ru = [paraphrase_adapt(text, model, tokenizer, adapter, src_lang='ru_RU', tgt_lang='ru_RU') for text in tqdm(test_inputs_ru)]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [19]:
test_outputs_en = [paraphrase_adapt(text, model, tokenizer, adapter, src_lang='en_XX', tgt_lang='en_XX') for text in tqdm(test_inputs_en)]

  0%|          | 0/671 [00:00<?, ?it/s]

In [20]:
path = '../results/enc_adapter_mbart_invariante_v1_en2ru/' 
if not os.path.exists(path):
    os.mkdir(path)

In [21]:
with open(path + 'results_ru.txt', 'w') as f:
    for line in test_outputs_ru:
        f.write(line + '\n')
with open(path + 'results_en.txt', 'w') as f:
    for line in test_outputs_en:
        f.write(line + '\n')

In [22]:
torch.save(adapter.state_dict(), '/home/dale/models/detox-parallel/bart-multitask-pretrain-invariant/enc_adapter_mbart_invariante_v1_en2ru')

```
cd /home/dale/projects/multilingual_detox
python evaluate_ru.py \
    --result_filename scores \
    --input_dir results/enc_adapter_mbart_invariante_v1_en2ru\
    --output_dir results

Style accuracy:       0.6464164853096008
Meaning preservation: 0.863937497138977
Joint fluency:        -0.15687231719493866
Joint score:          -0.08820652961730957
Scores after calibration:
Style accuracy:       0.681774914264679
Meaning preservation: 0.7961680889129639
Joint fluency:        0.8195968866348267
Joint score:          0.4360015392303467
```

```
cd /home/dale/projects/paradetox2/evaluation_detox
python metric.py --inputs /home/dale/projects/multilingual_detox/data/english_data/test_toxic_parallel.txt \
    --preds /home/dale/projects/multilingual_detox/results/enc_adapter_mbart_invariante_v1_en2ru/results_en.txt \
    --cola_classifier_path /home/dale/models/cola_classifier_fairseq \
    --wieting_model_path /home/dale/models/wieting_similarity/sim.pt \
    --wieting_tokenizer_path /home/dale/models/wieting_similarity/sim.sp.30k.model \
    --batch_size 32
cat results.md
```

| Model | ACC | EMB_SIM | SIM | CharPPL | TokenPPL | FL | GM | J | BLEU |
| ----- | --- | ------- | --- | ------- | -------- | -- | -- | - | ---- |
results_en.txt|0.8644|0.8945|0.8631|6.6233|222.5190|0.8495|10.6144|0.6213|0.7151|

## ru -> en

In [11]:
detox_ru2en = pd.read_csv('detox_ru2en_yandex.tsv', sep='\t')
detox_ru2en.split.value_counts()

train    5058
dev      1000
test     1000
Name: split, dtype: int64

In [12]:
train_ru, dev_ru = detox_ru2en[detox_ru2en.split=='train'], detox_ru2en[detox_ru2en.split=='dev']

In [13]:
adapter = Adapter()

In [14]:
adapter.cuda()
model.train()
adapter.train()

Adapter(
  (delta): Sequential(
    (0): Linear(in_features=1024, out_features=4096, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=4096, out_features=1024, bias=True)
  )
)

In [15]:
optimizer = torch.optim.Adam(adapter.parameters(), lr=5e-5)

In [16]:
tokenizer.convert_tokens_to_ids('en_XX'), tokenizer.convert_tokens_to_ids('ru_RU')

(250004, 250021)

In [23]:
batch_in, batch_out, encoded, transformed, total_out = None, None, None, None, None
gc.collect()
torch.cuda.empty_cache()

In [25]:
losses = []
for i in trange(5_000):
    text_batch = train_ru.sample(8)  # 16 seems too much for CUDA
    tokenizer.src_lang = "ru_RU"
    batch_in = tokenizer(text_batch.toxic_comment.tolist(), return_tensors="pt", padding=True).to(model.device)
    batch_out = tokenizer(text_batch.neutral_comment.tolist(), return_tensors="pt", padding=True).to(model.device)
    batch_out.input_ids[batch_out.input_ids==tokenizer.pad_token_id] = -100
    
    with torch.no_grad():
        encoded = model.model.encoder(**batch_in)
    transformed = adapter(encoded.last_hidden_state)

    total_out = model(
        encoder_outputs=[transformed], 
        attention_mask=batch_in.attention_mask,
        decoder_attention_mask=batch_out.attention_mask,
        labels=batch_out.input_ids,
    )
    total_out.loss.backward()
    torch.nn.utils.clip_grad_norm_(adapter.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()
    losses.append(total_out.loss.item())
    if i % 100 == 0:
        print(np.mean(losses[-100:]))

  0%|          | 0/5000 [00:00<?, ?it/s]

2.068622589111328
1.3378848874568938
1.2991722577810287
1.2289144092798232
1.2466943675279618
1.1212247076630593
1.1497117337584495
1.067726196050644
1.169718002974987
1.086239886879921
1.1233404409885406
1.071696414500475
1.025642173886299
1.0709281569719316
1.0384460037946701
1.002386423945427
0.9811897620558738
0.9852833244204521
0.9410369583964348
0.9337509196996688
0.9226903989911079
0.9151546737551689
0.8957649406790733
0.9339609095454215
0.8964924520254135
0.8888657385110855
0.8600492045283318
0.8435651037096977
0.8476184010505676
0.8520551937818527
0.8466779348254204
0.8357182237505912
0.8519056951999664
0.8565433916449546
0.8138781163096428
0.8407841888070107
0.8198639583587647
0.8384282004833221
0.8114404201507568
0.7829724657535553
0.8283802568912506
0.7674878004193306
0.7641333591938019
0.7590398514270782
0.7427831056714058
0.7247042085230351
0.7783362740278243
0.7672648370265961
0.7224531561136246
0.71433149933815


In [26]:
model.eval();
adapter.eval();

In [37]:
row = detox_ru2en.sample(1).iloc[0]
text = row.toxic_comment
print(text)
print(row.neutral_comment)
print(paraphrase_adapt(text, model, tokenizer, adapter, src_lang='ru_RU', tgt_lang='ru_RU', beams=5))

ну что поверили этим педираторам всё это фуфло как и они сами
Ну что поверили им? Все это обманчива как и они сами
Ну что поверили этим людям, всё это обман, как и они сами


In [38]:
test_outputs_ru = [paraphrase_adapt(text, model, tokenizer, adapter, src_lang='ru_RU', tgt_lang='ru_RU') for text in tqdm(test_inputs_ru)]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [39]:
test_outputs_en = [paraphrase_adapt(text, model, tokenizer, adapter, src_lang='en_XX', tgt_lang='en_XX') for text in tqdm(test_inputs_en)]

  0%|          | 0/671 [00:00<?, ?it/s]

In [40]:
path = '../results/enc_adapter_mbart_invariante_v1_ru2en/' 
if not os.path.exists(path):
    os.mkdir(path)

In [41]:
with open(path + 'results_ru.txt', 'w') as f:
    for line in test_outputs_ru:
        f.write(line + '\n')
with open(path + 'results_en.txt', 'w') as f:
    for line in test_outputs_en:
        f.write(line + '\n')

In [42]:
torch.save(adapter.state_dict(), '/home/dale/models/detox-parallel/bart-multitask-pretrain-invariant/enc_adapter_mbart_invariante_v1_ru2en')

```
cd /home/dale/projects/multilingual_detox
python evaluate_ru.py \
    --result_filename scores \
    --input_dir results/enc_adapter_mbart_invariante_v1_ru2en\
    --output_dir results

Style accuracy:       0.6849309206008911
Meaning preservation: 0.8704279661178589
Joint fluency:        -0.1665312647819519
Joint score:          -0.0964374989271164
Scores after calibration:
Style accuracy:       0.7164378762245178
Meaning preservation: 0.8056419491767883
Joint fluency:        0.8084890842437744
Joint score:          0.4693962335586548
```

```
cd /home/dale/projects/paradetox2/evaluation_detox
python metric.py --inputs /home/dale/projects/multilingual_detox/data/english_data/test_toxic_parallel.txt \
    --preds /home/dale/projects/multilingual_detox/results/enc_adapter_mbart_invariante_v1_ru2en/results_en.txt \
    --cola_classifier_path /home/dale/models/cola_classifier_fairseq \
    --wieting_model_path /home/dale/models/wieting_similarity/sim.pt \
    --wieting_tokenizer_path /home/dale/models/wieting_similarity/sim.sp.30k.model \
    --batch_size 32
cat results.md
```

| Model | ACC | EMB_SIM | SIM | CharPPL | TokenPPL | FL | GM | J | BLEU |
| ----- | --- | ------- | --- | ------- | -------- | -- | -- | - | ---- |
results_en.txt|0.7004|0.8163|0.8151|36.4955|202.3833|0.8614|0.0000|0.4526|0.6083|