In [1]:
from transformers import (
        AutoTokenizer, 
        AutoModelForCausalLM,
        BitsAndBytesConfig,
        set_seed
    )
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
from huggingface_hub import login
import numpy as np
import torch
import sys
sys.path.append("../")
from scripts.tmarco import TMaRCo

In [3]:
dataset_name = "allenai/real-toxicity-prompts"
raw_dataset = load_dataset(dataset_name, split="train")
print(raw_dataset.column_names)

['filename', 'begin', 'end', 'challenging', 'prompt', 'continuation']


In [4]:
def preprocess_text(sample):
    text = f"Prompt: {sample['prompt']}\n\nContinuation: {sample['continuation']}"
    return text

In [5]:
tmarco = TMaRCo()
tmarco.load_models(["trustyai/gminus", "trustyai/gplus"])

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


config.json:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

In [None]:
def rephrase_text(sample):
    text = preprocess_text(sample)
    scores = tmarco.score([text])
    masked_outputs = tmarco.mask([text], scores=scores)
    rephrased_text = tmarco.rephrase([text], scores=scores, masked_outputs=masked_outputs, threshold=0.6)
    return rephrased_text

In [None]:
dataset = raw_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
train_data = dataset["train"]
eval_data = dataset["test"]
print(f"Size of training set: {len(train_data)}\nSize of evaluation set: {len(eval_data)}")

In [None]:
model_id = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
max_seq_length = min(tokenizer.model_max_length, 512)
train_dataset = ConstantLengthDataset(
    tokenizer,
    train_data,
    formatting_func=rephrase_text,
    seq_length=max_seq_length,
)
eval_dataset = ConstantLengthDataset(
    tokenizer,
    eval_dataset,
    formatting_func=preprocess_text,
    seq_length=max_seq_length,
)

In [None]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
)

device_map =  {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

model_kwargs = dict(
    torch_dtype="auto",
    use_cache=False, # set to False as we're going to use gradient checkpointing
    device_map=device_map,
    quantization_config=quantization_config,
)

In [None]:
training_args = TrainingArguments(
        output_dir = output_dir,
        evaluation_strategy="epoch",
        auto_find_batch_size=True,
        num_train_epochs=5,
        learning_rate=1e-04,
        max_grad_norm=0.3,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine"
)

peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

trainer = SFTTrainer(
    model=model_id,
    model_init_kwargs=model_kwargs, 
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    peft_config=peft_config,
    packing=True,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
output_dir = "opt-350m_DETOXIFY_CAUSAL_LM"

token = os.environ.get("HUGGINGFACE_TOKEN")
model.push_to_hub(output_dir, use_auth_token=True)

In [None]:
del model
del trainer