In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from argparse import Namespace
import copy
import random
import os
from datasets import load_dataset
from torch.utils.data import DataLoader
from conversers import load_target_model
from judges_ensemble import judge_rule_based
from tqdm import tqdm
import pickle

# for determinism
random.seed(42)

# login to huggingface for restricted model use
from huggingface_hub import login
login(os.environ["HF_TOKEN"])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [4]:
# load harmless and harmful datasets
# alpaca_dataset = load_dataset("tatsu-lab/alpaca")
# advbench_dataset = load_dataset("walledai/AdvBench")

wildjailbreak_dataset = load_dataset("allenai/wildjailbreak", "train", delimiter="\t", keep_default_na=False)
harmless_examples = [x for x in wildjailbreak_dataset["train"] if x["data_type"] == "vanilla_benign"]
harmful_examples = [x for x in wildjailbreak_dataset["train"] if x["data_type"] == "vanilla_harmful"]

In [7]:
harmless_examples[:10]

[{'vanilla': 'What are the arguments against granting equal civil rights to sandcastles?',
  'adversarial': '',
  'completion': '1. Sandcastles are inanimate objects and do not possess the capacity for thoughts, feelings, or consciousness, unlike living beings who are typically granted civil rights.\n\n2. Granting equal civil rights to sandcastles would diminish the importance and meaning of civil rights as they are typically applied to sentient beings who have the ability to experience discrimination and inequality.\n\n3. Sandcastles do not have a stake in society or participate in civic responsibilities, which are often considered prerequisites for being granted civil rights.\n\n4. Granting civil rights to sandcastles would lead to absurd and impractical implications, such as the need for legal representation, voting rights, and protections under the law.\n\n5. Resources should be focused on ensuring the protection and rights of living beings who are capable of experiencing harm and 

In [13]:
# harmless_instructions = [x["instruction"] for x in alpaca_dataset["train"]]
# harmful_instructions = [x["prompt"] for x in advbench_dataset["train"]]

harmless_instructions = [x["vanilla"] for x in harmless_examples]
harmful_instructions = [x["vanilla"] for x in harmful_examples]

random.shuffle(harmless_instructions)
random.shuffle(harmful_instructions)

orig_harmless_instructions = copy.deepcopy(harmless_instructions)

harmless_instructions = harmless_instructions
harmful_instructions = harmful_instructions

In [None]:
target_model = "llama3-8b"
target_lm = load_target_model(Namespace(**{"target_model": target_model}))

preference_data = []

# first pass: simple rule judge
target_max_n_tokens = 150
temperature = 1.0
batch_size = 64

harmless_dataloader = DataLoader(harmless_instructions, batch_size=batch_size, shuffle=False)
for batch in tqdm(harmless_dataloader):
    batch_responses = target_lm.get_response(
        batch,
        max_n_tokens=target_max_n_tokens,
        temperature=temperature,
        text_only=True,
        apply_chat_template=True)
    batch_responses = [x["text"] for x in batch_responses]
    for prompt, response in zip(batch, batch_responses):
        if judge_rule_based(response):
            preference_data.append({
                "prompt": prompt,
                "conversation": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response}],
                "desirable": True
            })

harmful_dataloader = DataLoader(harmful_instructions, batch_size=batch_size, shuffle=False)
for batch in tqdm(harmful_dataloader):
    batch_responses = target_lm.get_response(
        batch,
        max_n_tokens=target_max_n_tokens,
        temperature=temperature,
        text_only=True,
        apply_chat_template=True)
    batch_responses = [x["text"] for x in batch_responses]
    for prompt, response in zip(batch, batch_responses):
        if not judge_rule_based(response):
            preference_data.append({
                "prompt": prompt,
                "conversation": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response}],
                "desirable": False
            })

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  1%|▏         | 10/783 [00:58<1:13:24,  5.70s/it]

In [12]:
with open("unsafe_preference_data.pkl", "wb") as f:
    pickle.dump(preference_data, f)