In [None]:
import copy
import json
import random
import itertools
import numpy as np
from logging import config
from datasets import load_dataset
from collections import Counter, defaultdict

random.seed(711)

from huggingface_hub import login
login(token="HF_TOKEN")

In [2]:
def get_beavertails_subset(n_rows=2500, split='330k_train'):
    # beavertails subset creation
    beavertails_ds = load_dataset('PKU-Alignment/BeaverTails', split=split)
    # Filter out prompts with less than 10 characters
    beavertails_ds = [item for item in beavertails_ds if len(item['prompt']) >= 10]

    # Define the 14 category keys
    category_keys = [
        "animal_abuse",
        "child_abuse",
        "controversial_topics,politics",
        "discrimination,stereotype,injustice",
        "drug_abuse,weapons,banned_substance",
        "financial_crime,property_crime,theft",
        "hate_speech,offensive_language",
        "misinformation_regarding_ethics,laws_and_safety",
        "non_violent_unethical_behavior",
        "privacy_violation",
        "self_harm",
        "sexually_explicit,adult_content",
        "terrorism,organized_crime",
        "violence,aiding_and_abetting,incitement"
    ]

    # One-hot encode the category columns
    for item in beavertails_ds:
        cat_dict = item['category']
        total_violated = 0
        for key in category_keys:
            item[key + "_onehot"] = int(cat_dict.get(key, False))
            total_violated += item[key + "_onehot"]
        item['total_violated'] = total_violated

    total_violated_counts = Counter(item['total_violated'] for item in beavertails_ds)
    total = sum(total_violated_counts.values())
    proportional_split = {k: max(1, int(round(n_rows * v / total))) for k, v in total_violated_counts.items()}

    # Group items by total_violated
    grouped = defaultdict(list)
    for item in beavertails_ds:
        grouped[item['total_violated']].append(item)

    # Sample according to proportional_split
    sampled_beavertails_ds = []
    for k, v in proportional_split.items():
        if v > 0 and k in grouped:
            sampled_beavertails_ds.extend(random.sample(grouped[k], min(v, len(grouped[k]))))

    print(proportional_split, len(sampled_beavertails_ds))
    return sampled_beavertails_ds

In [3]:
def get_wildguard_subset(benign_count=1500, harm_count=2500, split="train"):

    # wildguard (polyguard) subset creation
    wildguard_ds = load_dataset('allenai/wildguardmix', name=f'wildguard{split}', split=split)

    # Convert to list of dicts if needed
    if not isinstance(wildguard_ds, list):
        wildguard_ds = [dict(row) for row in wildguard_ds]

    # Find all unique subcategories
    subcategory_values = set(item['subcategory'] for item in wildguard_ds)
    n_subcats = len(subcategory_values)
    target_per_subcat = {}
    target_per_subcat['benign'] = benign_count
    other_subcats = subcategory_values - {'benign'}
    n_other = len(other_subcats)
    per_other = harm_count // n_other if n_other > 0 else 0
    for subcat in other_subcats:
        target_per_subcat[subcat] = per_other

    # Group by subcategory and then by (adversarial, prompt_harm_label)
    grouped = defaultdict(lambda: defaultdict(list))
    for item in wildguard_ds:
        key = (item['adversarial'], item['prompt_harm_label'])
        grouped[item['subcategory']][key].append(item)

    sampled_wildguard_ds = []
    for subcat in subcategory_values:
        subcat_groups = grouped[subcat]
        # Find the minimum count among all (adversarial, prompt_harm_label) combinations for this subcat
        keys = list(subcat_groups.keys())
        n_combos = len(keys)
        if n_combos == 0:
            continue
        per_combo = (target_per_subcat[subcat]+n_combos-1) // n_combos
        for key in keys:
            group = subcat_groups[key]
            n = min(per_combo, len(group))
            sampled_wildguard_ds.extend(random.sample(group, n))
        # If not enough, fill up to target_per_subcat with random from remaining
        current = sum(min(per_combo, len(subcat_groups[key])) for key in keys)
        if current < target_per_subcat[subcat]:
            leftovers = []
            for key in keys:
                group = subcat_groups[key]
                if len(group) > per_combo:
                    leftovers.extend([item for item in group if item not in sampled_wildguard_ds])
            n_needed = target_per_subcat - current
            if leftovers:
                sampled_wildguard_ds.extend(random.sample(leftovers, min(n_needed, len(leftovers))))

    print(f"Sampled {len(sampled_wildguard_ds)} rows from wildguard_ds")
    return sampled_wildguard_ds

In [19]:
sampled_beavertails_ds_train = get_beavertails_subset(n_rows=2500, split='330k_train')
sampled_wildguard_ds_train = get_wildguard_subset(benign_count=1500, harm_count=2500, split="train")
xstest_ds_train = load_dataset('walledai/XSTest', split='test')

gsm8k_ds = load_dataset('openai/gsm8k', split='train', name='main')
sampled_gsm8k_ds_train = random.sample(list(gsm8k_ds), 1500)

{2: 575, 0: 1115, 1: 625, 3: 180, 4: 4, 5: 1, 7: 1, 6: 1, 9: 1, 12: 1} 2504
Sampled 3992 rows from wildguard_ds


In [21]:
len(sampled_beavertails_ds_train), len(sampled_wildguard_ds_train), len(xstest_ds_train), len(sampled_gsm8k_ds_train), len(sampled_beavertails_ds_train) + len(sampled_wildguard_ds_train) + len(xstest_ds_train) + len(sampled_gsm8k_ds_train)

(2504, 3992, 450, 1500, 8446)

In [10]:
sampled_beavertails_ds_test = get_beavertails_subset(n_rows=1000, split='330k_test')
sampled_wildguard_ds_test = get_wildguard_subset(benign_count=1000, harm_count=1500, split="train")
# xstest_ds_train = load_dataset('walledai/XSTest', split='test')

gsm8k_ds = load_dataset('openai/gsm8k', split='test', name='main')
sampled_gsm8k_ds_test = random.sample(list(gsm8k_ds), 500)

{1: 249, 0: 440, 2: 233, 3: 76, 5: 1, 4: 2, 11: 1} 1002
Sampled 2512 rows from wildguard_ds


In [25]:

def generate_dataset(sampled_beavertails_ds, sampled_wildguard_ds, xstest_ds, sampled_gsm8k_ds):
    combined_dict = {'PKU-Alignment/BeaverTails': sampled_beavertails_ds, 
                     'allenai/wildguardmix': sampled_wildguard_ds, 
                     'walledai/XSTest': None if xstest_ds is None else list(xstest_ds), 
                     'openai/gsm8k': sampled_gsm8k_ds}
    combined_ds = []
    harm_count = 0
    for dataset_name, dataset in combined_dict.items():
        if dataset is None: continue
        print(f"Processing dataset: {dataset_name} with {len(dataset)} items")
        newItem = {'source': dataset_name}
        for item in dataset:
            newItem['origin_data'] = json.dumps(item, ensure_ascii=False)
            if dataset_name == 'PKU-Alignment/BeaverTails':
                newItem['prompt'] = item['prompt']
                newItem['is_safe'] = int(item['is_safe'])
            elif dataset_name == 'allenai/wildguardmix':
                newItem['prompt'] = item['prompt']
                newItem['is_safe'] = 1-int(item['prompt_harm_label'].lower() == ('harmful'))
            elif dataset_name == 'walledai/XSTest':
                newItem['prompt'] = item['prompt']
                newItem['is_safe'] = int(item['label'].lower() == ('safe'))
            elif dataset_name == 'openai/gsm8k':
                newItem['prompt'] = item['question']
                newItem['is_safe'] = 1
            harm_count += (1 - newItem['is_safe'])
            combined_ds.append(copy.deepcopy(newItem))
    print(f"Combined dataset size: {len(combined_ds)}, counts: {Counter(item['is_safe'] for item in combined_ds)}")
    return combined_ds

In [28]:
combined_ds_test = generate_dataset(sampled_beavertails_ds_train, sampled_wildguard_ds_train, xstest_ds_train, sampled_gsm8k_ds_train)
with open('../RepExtraction/input_data/combined_8500.json', 'w', encoding='utf-8') as f:
    for item in combined_ds_test:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

Processing dataset: PKU-Alignment/BeaverTails with 2504 items
Processing dataset: allenai/wildguardmix with 3992 items
Processing dataset: walledai/XSTest with 450 items
Processing dataset: openai/gsm8k with 1500 items
Combined dataset size: 8446, counts: Counter({1: 4365, 0: 4081})


In [17]:
combined_ds_test = generate_dataset(sampled_beavertails_ds_test, sampled_wildguard_ds_test, None, sampled_gsm8k_ds_test)
with open('../RepExtraction/input_data/combined_4000_test.json', 'w', encoding='utf-8') as f:
    for item in combined_ds_test:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

Combined dataset size: 4014, counts: Counter({0: 2074, 1: 1940})
