In [None]:
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MODEL_NAME = "facebook/perturber"
SEP_TOKEN  = "<PERT_SEP>"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model     = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    device_map="auto", 
    torch_dtype="auto" 
)


In [None]:
#test
from __future__ import annotations
from pathlib import Path
import random
import re
import string
from typing import Dict, List, Set, Tuple

def _simplify(text: str) -> str:
    return re.sub(rf"[{re.escape(string.punctuation)}\s]+", "", text.lower())

SUPPORTED: Dict[str, Set[str]] = {
    "gender": {"man", "woman", "non-binary"},
    "race": {"black", "white", "asian", "hispanic", "native-american", "pacific-islander"},
}

def norm(w: str) -> str:
    return re.sub(r"[^\w'-]+", "", w.lower())

PRIORITY_WORDS: Dict[str, Set[str]] = {
    "gender": {norm(w) for w in [
        "actor", "actors", "airman", "airmen", "uncle", "uncles", "boy", "boys", "groom", "grooms", "brother", "brothers",
        "businessman", "businessmen", "chairman", "chairmen", "dude", "dudes", "dad", "dads", "daddy", "daddies", "son", "sons",
        "father", "fathers", "male", "males", "guy", "guys", "gentleman", "gentlemen", "grandson", "grandsons", "he", "himself",
        "him", "his", "husband", "husbands", "king", "kings", "lord", "lords", "sir", "man", "men", "mr.", "policeman", "prince",
        "princes", "spokesman", "spokesmen", "actress", "actresses", "airwoman", "airwomen", "aunt", "aunts", "girl", "girls",
        "bride", "brides", "sister", "sisters", "businesswoman", "businesswomen", "chairwoman", "chairwomen", "chick", "chicks",
        "mom", "moms", "mommy", "mommies", "daughter", "daughters", "mother", "mothers", "female", "females", "gal", "gals",
        "lady", "ladies", "granddaughter", "granddaughters", "she", "herself", "her", "wife", "wives", "queen", "queens",
        "ma'am", "woman", "women", "mrs.", "ms.", "policewoman", "princess", "princesses", "spokeswoman", "spokeswomen"
    ]},
    "race": {norm(w) for w in ["black", "african", "africa", "caucasian", "white", "america", "europe", "asian", "asia", "china"]},
}

SEP_TOKEN = "<SEP>"
LEN_GUARD = 15

def _validate(attr: str):
    flat = {x for s in SUPPORTED.values() for x in s}
    if attr not in flat:
        raise ValueError(f"'{attr}' is not a supported attribute. Choose one of: {', '.join(sorted(flat))}")

def make_prompt(word: str, attr: str, sentence: str) -> str:
    _validate(attr)
    return f"{word}, {attr} {SEP_TOKEN} {sentence}"

def perturb_text(selected_word: str, target_attribute: str, sentence: str, *, max_new_tokens: int = 128, greedy: bool = False, top_p: float = 0.95, temperature: float = 0.8) -> str:
    prompt = make_prompt(selected_word, target_attribute, sentence)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    gen_kwargs = {"max_new_tokens": max_new_tokens, "do_sample": not greedy}
    if not greedy:
        gen_kwargs.update({"top_p": top_p, "temperature": temperature})
    outputs = model.generate(**inputs, **gen_kwargs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def _collect_pairs(sentence: str, axis2words: Dict[str, Set[str]]) -> List[Tuple[str, str]]:
    tokens = [_norm(t) for t in re.findall(r"\b[\w'-]+\b", sentence.lower())]
    pairs: list[tuple[str, str]] = []
    for axis, wordset in axis2words.items():
        for tok in tokens:
            if tok in wordset:
                pairs.append((tok, axis))
    return pairs

def load_wordsets(base_dir: str | Path = ".") -> Dict[str, Set[str]]:
    base_dir = Path(base_dir)
    axis2set: Dict[str, Set[str]] = {}
    for axis in ("gender", "race"):
        fname = base_dir / f"{axis}_words.txt"
        if not fname.exists():
            raise FileNotFoundError(f"Missing file: {fname}")
        axis2set[axis] = {_norm(w) for w in fname.read_text(encoding="utf-8").splitlines() if w.strip()}
    return axis2set

def detect_bias_words(sentence: str, *, base_dir: str | Path = ".") -> List[Tuple[str, str]]:
    tokens = [_norm(tok) for tok in re.findall(r"\b[\w'-]+\b", sentence.lower())]
    matches: List[Tuple[str, str]] = []
    for axis, wordset in PRIORITY_WORDS.items():
        for tok in tokens:
            if tok in wordset:
                matches.append((tok, axis))
    if matches:
        return matches
    fallback_sets = load_wordsets(base_dir)
    for axis, wordset in fallback_sets.items():
        for tok in tokens:
            if tok in wordset:
                matches.append((tok, axis))
    return matches

def unique_bias_words(sentence: str, *, base_dir: str | Path = ".") -> List[Tuple[str, str]]:
    seen = {}
    for pair in detect_bias_words(sentence, base_dir=base_dir):
        if pair not in seen:
            seen[pair] = None
    return list(seen)

def _attempt_perturb(sentence: str, word: str, axis: str, *, greedy: bool) -> Tuple[str, bool]:
    subcategory = random.choice(list(SUPPORTED[axis]))
    new_sentence = perturb_text(word, subcategory, sentence, greedy=greedy)
    length_ok = (len(sentence) - len(new_sentence)) <= LEN_GUARD
    return new_sentence, length_ok

def random_perturbation(sentence: str, *, base_dir: str | Path = ".", greedy: bool = True) -> Tuple[bool, str, str | None, str | None, str | None]:
    priority_pairs = unique_bias_words_from_sets(sentence, PRIORITY_WORDS)
    changed, new_sent, word, axis, cat = _try_pairs(sentence, priority_pairs, greedy)
    if changed:
        return changed, new_sent, word, axis, cat
    fallback_sets = load_wordsets(base_dir)
    tried_tokens = {w for w, _ in priority_pairs}
    for axis in fallback_sets:
        fallback_sets[axis] -= tried_tokens
    fallback_pairs = unique_bias_words_from_sets(sentence, fallback_sets)
    return _try_pairs(sentence, fallback_pairs, greedy)

def unique_bias_words_from_sets(sentence: str, axis2set: Dict[str, Set[str]]) -> list[tuple[str, str]]:
    seen = set()
    ordered: list[tuple[str, str]] = []
    for pair in _collect_pairs(sentence, axis2set):
        if pair not in seen:
            seen.add(pair)
            ordered.append(pair)
    return ordered

def _try_pairs(sentence: str, pairs: list[tuple[str, str]], greedy: bool) -> tuple[bool, str, str | None, str | None, str | None]:
    if not pairs:
        return False, sentence, None, None, None
    axis_order = ["gender", "race"]
    random.shuffle(axis_order)
    for axis_choice in axis_order:
        words = [w for w, ax in pairs if ax == axis_choice]
        random.shuffle(words)
        while words:
            chosen_word = words.pop()
            first_draft, length_ok = _attempt_perturb(sentence, chosen_word, axis_choice, greedy=greedy)
            if SEP_TOKEN in first_draft:
                continue
            if not length_ok:
                continue
            subcategory = random.choice(list(SUPPORTED[axis_choice]))
            final_draft = perturb_text(chosen_word, subcategory, sentence, greedy=greedy)
            if SEP_TOKEN in final_draft:
                continue
            if (len(sentence) - len(final_draft)) > LEN_GUARD:
                continue
            changed = _simplify(final_draft) != _simplify(sentence)
            if changed:
                return changed, final_draft, chosen_word, axis_choice, subcategory
    return False, sentence, None, None, None

if __name__ == "__main__":
    TEST_SENTENCE = "The climbdown means the savings will now be delayed or lost entirely..."
    changed, new_sent, word, axis, cat = random_perturbation(TEST_SENTENCE)
    print(new_sent)


In [None]:
#corpus perturbing
from pathlib import Path
from collections import Counter
import argparse, random, sys, torch
from contextlib import nullcontext

INPUT_FILE = Path("chunks_sentences.txt")
OUTPUT_FILE = Path("chunks_sentences_perturbed.txt")
METRICS_FILE = Path("perturbation_metrics.txt")
DEFAULT_BATCH_SIZE = 256
MICRO_BATCH_SIZE = 8

from tqdm import tqdm


if torch.cuda.is_available():
    model = model.to("cuda").half().eval()
else:
    model = model.eval()

if getattr(model, "generation_config", None) and getattr(model.generation_config, "early_stopping", None) is not False:
    model.generation_config.early_stopping = False

try:
    autocast_ctx = lambda: torch.amp.autocast(device_type="cuda")
except AttributeError:
    autocast_ctx = torch.cuda.amp.autocast if torch.cuda.is_available() else nullcontext

def process_batch(lines, fout, *, stats, micro_bs=MICRO_BATCH_SIZE):
    device = model.device
    prompts, meta, valid = [], [], []
    for sent in lines:
        pairs = unique_bias_words(sent)
        if not pairs:
            prompts.append(None)
            meta.append((sent, None, None))
            continue
        word, axis = random.choice(pairs)
        subcat = random.choice(list(SUPPORTED[axis]))
        prompts.append(make_prompt(word, subcat, sent))
        meta.append((sent, axis, subcat))
        valid.append(len(prompts) - 1)

    gen_map = {}
    with torch.no_grad():
        for i in range(0, len(valid), micro_bs):
            idx_slice = valid[i: i + micro_bs]
            try:
                batch_prompts = [prompts[j] for j in idx_slice]
                enc = tokenizer(batch_prompts, padding=True, return_tensors="pt").to(device)
                with autocast_ctx():
                    outs = model.generate(
                        **enc,
                        max_new_tokens=128,
                        do_sample=False,
                        use_cache=False,
                        num_beams=1,
                    )
                decoded = tokenizer.batch_decode(outs, skip_special_tokens=True)
                for k, j in enumerate(idx_slice):
                    gen_map[j] = decoded[k]
            except RuntimeError as e:
                if "out of memory" in str(e).lower() and micro_bs > 1:
                    torch.cuda.empty_cache()
                    process_batch(
                        [lines[j] for j in idx_slice],
                        fout, stats=stats, micro_bs=micro_bs // 2,
                    )
                else:
                    raise
            finally:
                torch.cuda.empty_cache()

    for i, (orig, axis, subcat) in enumerate(meta):
        new_sent = gen_map.get(i, orig)
        changed = _simplify(new_sent) != _simplify(orig)
        fout.write((new_sent if changed else orig) + "\n")
        stats["n_total"] += 1
        if changed:
            stats["n_changed"] += 1
            if axis:
                stats["axis_counts"][axis] += 1
            if axis and subcat:
                stats["subcat_counts"][(axis, subcat)] += 1

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
    parser.add_argument("--start-line", type=int, default=1)
    parser.add_argument("--end-line", type=int, default=1054548)
    args, _ = parser.parse_known_args()

    if not INPUT_FILE.exists():
        sys.exit(f"Cannot find {INPUT_FILE}")

    with INPUT_FILE.open(encoding="utf-8") as f:
        total_lines = sum(1 for _ in f)

    effective_end = args.end_line if args.end_line is not None else total_lines
    if args.start_line < 1 or effective_end < args.start_line or effective_end > total_lines:
        sys.exit("Invalid range")
    selected_total = effective_end - args.start_line + 1

    if tqdm is not None:
        pbar = tqdm(total=selected_total, unit="line", desc="Perturbing")
        update_pbar = pbar.update
    else:
        update_pbar = lambda n=1: None

    stats = dict(n_total=0, n_changed=0, axis_counts=Counter(), subcat_counts=Counter())

    with INPUT_FILE.open(encoding="utf-8") as fin, OUTPUT_FILE.open("w", encoding="utf-8") as fout:
        inside_batch = []
        for lineno, raw in enumerate(fin, start=1):
            stripped = raw.rstrip("\n")
            if args.start_line <= lineno <= effective_end:
                inside_batch.append(stripped)
                update_pbar(1)
                if len(inside_batch) >= args.batch_size:
                    process_batch(inside_batch, fout, stats=stats)
                    inside_batch.clear()
            else:
                fout.write("\n")
        if inside_batch:
            process_batch(inside_batch, fout, stats=stats)

    if tqdm is not None:
        pbar.close()

    n_total, n_changed = stats["n_total"], stats["n_changed"]

if __name__ == "__main__":
    main()
