In [None]:
import argparse
import json
import torch
import logging
from typing import List, Union, Tuple
from tqdm import tqdm
from transformers import (
    M2M100ForConditionalGeneration,
    NllbTokenizerFast,
    BartTokenizerFast,
    T5TokenizerFast,
    BartForConditionalGeneration,
    T5ForConditionalGeneration,
    PreTrainedTokenizerFast,
    PreTrainedModel,
)


def detoxify_batch(
    texts: List[str],
    model: Union[BartForConditionalGeneration, T5ForConditionalGeneration],
    tokenizer: PreTrainedTokenizerFast,
    batch_size: int = 32,
) -> List[str]:
    """
    Detoxify a batch of texts.

    Args:
        texts (List[str]): The list of texts to detoxify.
        model (Union[BartForConditionalGeneration, T5ForConditionalGeneration]): The detoxification model.
        tokenizer (PreTrainedTokenizerFast): The tokenizer for the detoxification model.
        batch_size (int, optional): The batch size for detoxification. Defaults to 32.

    Returns:
        List[str]: The detoxified texts.
    """
    detoxified = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Detoxifying"):
        batch = texts[i : i + batch_size]
        batch_detoxified = model.generate(
            **tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
            ,max_new_tokens=60, do_sample=False, temperature=None, top_p=None,
        )
        detoxified.extend(
            tokenizer.decode(tokens, skip_special_tokens=True)
            for tokens in batch_detoxified
        )
    return detoxified

In [None]:
with open("../data/test_toxic_parallel.txt", "r") as f:
    test_en = f.read().split("\n")

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

poc_path = "/models/bart_synth/checkpoint-2780/"
orig_path ="/home/models/bart_orig/checkpoint-2780/"
bart_paradetox = "s-nlp/bart-base-detox"

for path, name in zip([poc_path, orig_path], ["poc", "orig"]):
    model = AutoModelForSeq2SeqLM.from_pretrained(path).cuda().eval()
    tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

    detoxed = detoxify_batch(test_en, model=model, tokenizer=tokenizer, batch_size=128)

    with open(f"../experiments/detoxed_{name}.txt", "w") as f:
        f.write("\n".join(detoxed))

In [None]:
!python ../src/utils/evaluate.py \
    --source_list ../data/test_toxic_parallel.txt \
    --references_list ../data/test_neutral_parallel.txt \
    --input_path /results/generated/Meta-Llama-3-8B-Instruct-abliterated-v3.5_10shot_t08_p09.txt \
    --output_dir /results/generated.results