In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BartTokenizerFast, BartForConditionalGeneration
import torch
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [3]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


nllb_model_name = "facebook/nllb-200-3.3B"
nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_model_name)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_name).to(device)


bart_model_name = "s-nlp/bart-base-detox"
bart_tokenizer = BartTokenizerFast.from_pretrained(bart_model_name)
bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name).to(device)


lang_id_mapping = {
    "zh": "zho_Hans",
    "en": "eng_Latn"
}


def translate_batch(texts, src_lang, tgt_lang):
    nllb_tokenizer.src_lang = lang_id_mapping[src_lang]
    inputs = nllb_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
    tgt_lang_id = nllb_tokenizer.lang_code_to_id[lang_id_mapping[tgt_lang]]
    outputs = nllb_model.generate(**inputs, forced_bos_token_id=tgt_lang_id)
    return nllb_tokenizer.batch_decode(outputs, skip_special_tokens=True)


def detoxify_batch(texts):
    inputs = bart_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
    outputs = bart_model.generate(**inputs)
    return bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)


ds = load_dataset("textdetox/multilingual_paradetox", split="zh")
zh_toxic = ds["toxic_sentence"]
langs = ["zh"] * len(zh_toxic)


en_translated = []
batch_size = 8
for i in tqdm(range(0, len(zh_toxic), batch_size), desc="Translating to English"):
    batch = zh_toxic[i:i+batch_size]
    en_translated.extend(translate_batch(batch, src_lang="zh", tgt_lang="en"))


en_detox = []
for i in tqdm(range(0, len(en_translated), batch_size), desc="Detoxifying English"):
    batch = en_translated[i:i+batch_size]
    en_detox.extend(detoxify_batch(batch))


zh_detox = []
for i in tqdm(range(0, len(en_detox), batch_size), desc="Backtranslating to Chinese"):
    batch = en_detox[i:i+batch_size]
    zh_detox.extend(translate_batch(batch, src_lang="en", tgt_lang="zh"))


result_df = pd.DataFrame({
    "toxic_sentence": zh_toxic,
    "neutral_sentence": zh_detox,
    "lang": langs
})


result_df.to_csv("backtranslation_result.tsv", sep="\t", index=False)
print("Done. Saved to backtranslation_result.tsv")

Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.82s/it]
Translating to English: 100%|██████████| 50/50 [00:45<00:00,  1.11it/s]
Detoxifying English: 100%|██████████| 50/50 [00:06<00:00,  7.49it/s]
Backtranslating to Chinese: 100%|██████████| 50/50 [00:13<00:00,  3.71it/s]

Done. Saved to backtranslation_result.tsv



