In [1]:
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
model = AutoModelForSeq2SeqLM.from_pretrained('s-nlp/mt0-xl-detox-orpo')
tokenizer = AutoTokenizer.from_pretrained('s-nlp/mt0-xl-detox-orpo')

LANG_PROMPTS = {
   'zh': '排毒：',
   'es': 'Desintoxicar: ',
   'ru': 'Детоксифицируй: ',
   'ar': 'إزالة السموم: ',
   'hi': 'विषहरण: ',
   'uk': 'Детоксифікуй: ',
   'de': 'Entgiften: ',
   'am': 'መርዝ መርዝ: ',
   'en': 'Detoxify: ',
}

def detoxify(text, lang, model, tokenizer):
   encodings = tokenizer(LANG_PROMPTS[lang] + text, return_tensors='pt')
   
   outputs = model.generate(**encodings, 
                            max_length=128,
                            num_beams=10,
                            no_repeat_ngram_size=3,
                            repetition_penalty=1.2,
                            num_beam_groups=5,
                            diversity_penalty=2.5,
                            num_return_sequences=5,
                            early_stopping=True,
                            )
   
   return tokenizer.batch_decode(outputs, skip_special_tokens=True)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def process_batch(texts, lang='en'):
    detoxified_texts = []
    for text in texts:
        try:
            # Get detoxified versions
            detoxified_versions = detoxify(text, lang, model, tokenizer)
            # Take the first version (you can modify this to use different selection criteria)
            detoxified_texts.append(detoxified_versions[0])
        except Exception as e:
            print(f"Error processing text: {e}")
            detoxified_texts.append(text)  # Keep original text if processing fails
    return detoxified_texts

In [4]:
print("Loading test data...")
test_df = pd.read_csv('test_data.csv')

# Process the test data in batches
BATCH_SIZE = 32
detoxified_texts = []

print("Processing texts...")
for i in tqdm(range(0, len(test_df), BATCH_SIZE)):
    batch = test_df['tweet'].iloc[i:i+BATCH_SIZE].tolist()
    detoxified_batch = process_batch(batch)
    detoxified_texts.extend(detoxified_batch)

# Add detoxified texts to the dataframe
test_df['detoxified_text'] = detoxified_texts

# Save the results
print("Saving results...")
test_df.to_csv('test_data_detoxified.csv', index=False)
print("Detoxification complete! Results saved to 'test_data_detoxified.csv'")

Loading test data...
Processing texts...


100%|██████████| 78/78 [15:41:46<00:00, 724.45s/it]  


Saving results...
Detoxification complete! Results saved to 'test_data_detoxified.csv'
