In [2]:
# Install dependencies if not already available
!pip install gputil psutil -q

In [3]:
import GPUtil
import platform
import psutil

def show_hardware_info():
    print("=== 🧠 CPU INFORMATION ===")
    print(f"System: {platform.system()} {platform.release()}")
    print(f"Processor: {platform.processor()}")
    print(f"Architecture: {platform.machine()}")
    print(f"CPU Cores: {psutil.cpu_count(logical=False)} (Physical), {psutil.cpu_count(logical=True)} (Logical)")
    print(f"Total RAM: {round(psutil.virtual_memory().total / (1024**3), 2)} GB")

    print("\n=== ⚡ GPU INFORMATION ===")
    gpus = GPUtil.getGPUs()
    if not gpus:
        print("No GPU detected.")
    else:
        for gpu in gpus:
            print(f"GPU ID: {gpu.id}")
            print(f"Name: {gpu.name}")
            print(f"Driver Version: {gpu.driver}")
            print(f"Total Memory: {gpu.memoryTotal} MB")
            print(f"Used Memory: {gpu.memoryUsed} MB")
            print(f"Free Memory: {gpu.memoryFree} MB")
            print(f"GPU Load: {gpu.load * 100:.1f}%")
            print("-" * 40)

show_hardware_info()


=== 🧠 CPU INFORMATION ===
System: Linux 6.8.0-85-generic
Processor: x86_64
Architecture: x86_64
CPU Cores: 32 (Physical), 64 (Logical)
Total RAM: 251.75 GB

=== ⚡ GPU INFORMATION ===
GPU ID: 0
Name: NVIDIA GeForce GTX 1080 Ti
Driver Version: 580.95.05
Total Memory: 11264.0 MB
Used Memory: 3.0 MB
Free Memory: 11164.0 MB
GPU Load: 0.0%
----------------------------------------


In [5]:
# If you already have these, skip. Pin to stable versions.
!pip install "transformers>=4.45.0" "accelerate>=0.34.0" "peft>=0.12.0" "datasets>=2.20.0" "trl>=0.9.6" "bitsandbytes>=0.44.1"

Collecting transformers>=4.45.0
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m258.7 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting accelerate>=0.34.0
  Downloading accelerate-1.11.0-py3-none-any.whl.metadata (19 kB)
Collecting peft>=0.12.0
  Downloading peft-0.17.1-py3-none-any.whl.metadata (14 kB)
Collecting datasets>=2.20.0
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting trl>=0.9.6
  Downloading trl-0.24.0-py3-none-any.whl.metadata (11 kB)
Collecting bitsandbytes>=0.44.1
  Downloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers>=4.45.0)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers>=4.45.0)
  Downloading regex-2025.10.23-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.man

In [8]:
import torch

# 1) Basic environment info
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    dev = torch.cuda.current_device()
    name = torch.cuda.get_device_name(dev)
    cap  = torch.cuda.get_device_capability(dev)
    print(f"GPU: {name} | SM capability: {cap} | Total VRAM: {torch.cuda.get_device_properties(dev).total_memory/1024**3:.2f} GB")

# 2) BitsAndBytes availability check
try:
    import bitsandbytes as bnb  # noqa
    BNB_IMPORTED = True
except Exception as e:
    BNB_IMPORTED = False
    print("bitsandbytes import failed:", repr(e))

sm_major, sm_minor = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
ALLOW_4BIT = bool(torch.cuda.is_available() and BNB_IMPORTED and sm_major >= 6)

# 3) Model choice and conservative defaults for your 1080 Ti (11 GB)
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"   # we will load later
SEQ_LEN = 512
DTYPE = torch.float16        # Pascal cards do not support bfloat16
USE_4BIT = True              # we will fall back if ALLOW_4BIT is False
GRAD_ACCUM = 16
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

print("\n=== CONFIG SUMMARY ===")
print("MODEL_ID:", MODEL_ID)
print("SEQ_LEN:", SEQ_LEN)
print("DTYPE:", DTYPE)
print("BNB imported:", BNB_IMPORTED)
print("ALLOW_4BIT:", ALLOW_4BIT)
print("USE_4BIT (requested):", USE_4BIT)
print("LoRA r/alpha/dropout:", LORA_R, LORA_ALPHA, LORA_DROPOUT)
print("Target modules:", TARGET_MODULES)


PyTorch: 2.4.0+cu121
CUDA available: True
GPU: NVIDIA GeForce GTX 1080 Ti | SM capability: (6, 1) | Total VRAM: 10.90 GB

=== CONFIG SUMMARY ===
MODEL_ID: meta-llama/Llama-3.2-3B-Instruct
SEQ_LEN: 512
DTYPE: torch.float16
BNB imported: True
ALLOW_4BIT: True
USE_4BIT (requested): True
LoRA r/alpha/dropout: 16 32 0.05
Target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']


In [11]:
# Step 3 — HF login and dataset smoke test
import os
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"  # keep things quiet

from huggingface_hub import login, whoami

# Use an env var if you have set one, else you will be prompted in the notebook.
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
if token:
    login(token=token, add_to_git_credential=True)
else:
    print("No HF token in env. You will be prompted. Create one at https://huggingface.co/settings/tokens")
    login(add_to_git_credential=True)

print("Authenticated as:", whoami().get("name") or whoami().get("email") or "unknown")

from datasets import load_dataset

DATASET_ID = "ai4privacy/pii-masking-200k"

# Load a tiny slice to keep RAM light for now
ds = load_dataset(DATASET_ID, split="train[:200]")
print(ds)

# Peek at one record so we know the field names
ex = ds[0]
print("Example keys:", list(ex.keys()))
for k, v in ex.items():
    s = str(v)
    print(f"- {k}: {s[:200]}{'...' if len(s) > 200 else ''}")


No HF token in env. You will be prompted. Create one at https://huggingface.co/settings/tokens


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Authenticated as: chinu-codes


README.md: 0.00B [00:00, ?B/s]

english_pii_43k.jsonl:   0%|          | 0.00/73.8M [00:00<?, ?B/s]

french_pii_62k.jsonl:   0%|          | 0.00/116M [00:00<?, ?B/s]

german_pii_52k.jsonl:   0%|          | 0.00/97.8M [00:00<?, ?B/s]

italian_pii_50k.jsonl:   0%|          | 0.00/93.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/209261 [00:00<?, ? examples/s]

Dataset({
    features: ['source_text', 'target_text', 'privacy_mask', 'span_labels', 'mbert_text_tokens', 'mbert_bio_labels', 'id', 'language', 'set'],
    num_rows: 200
})
Example keys: ['source_text', 'target_text', 'privacy_mask', 'span_labels', 'mbert_text_tokens', 'mbert_bio_labels', 'id', 'language', 'set']
- source_text: A student's assessment was found on device bearing IMEI: 06-184755-866851-3. The document falls under the various topics discussed in our Optimization curriculum. Can you please collect it?
- target_text: A student's assessment was found on device bearing IMEI: [PHONEIMEI]. The document falls under the various topics discussed in our [JOBAREA] curriculum. Can you please collect it?
- privacy_mask: [{'value': '06-184755-866851-3', 'start': 57, 'end': 75, 'label': 'PHONEIMEI'}, {'value': 'Optimization', 'start': 138, 'end': 150, 'label': 'JOBAREA'}]
- span_labels: [[0, 57, "O"], [57, 75, "PHONEIMEI"], [75, 138, "O"], [138, 150, "JOBAREA"], [150, 189, "O"]]
- mber

In [13]:
import re
from transformers import pipeline

# Reuse 'tokenizer' and 'base_model' already loaded

SYSTEM_V2 = (
    "You redact personal or secret information from user text. "
    "Return the SAME text but with only the sensitive VALUES replaced by placeholders. "
    "Do NOT change surrounding words like 'IMEI', 'Email', 'Phone', labels, or punctuation. "
    "Allowed placeholders: [NAME_1], [EMAIL_1], [PHONE_1], [ADDRESS_1], [DOB_1], [SSN_1], [CARD_1], "
    "[IP_1], [USERNAME_1], [PASSWORD_1], [APIKEY_1], and dataset-style tags like [PHONEIMEI]. "
    "Output ONLY the redacted text between <safe> and </safe>. No other text."
)

def make_prompt_v2(text: str) -> str:
    # No <<< >>> delimiters; keep it minimal and deterministic
    messages = [
        {"role": "system", "content": SYSTEM_V2},
        {"role": "user", "content": text},
        {"role": "assistant", "content": "<safe>"}  # bias start of the block
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

SAFE_BLOCK_RE = re.compile(r"<safe>(.*?)</safe>", flags=re.DOTALL)

def extract_safe(generated: str) -> str:
    m = SAFE_BLOCK_RE.search(generated)
    if m:
        return m.group(1).strip()
    # fallback: keep everything after first <safe> if closing tag missing
    if "<safe>" in generated:
        return generated.split("<safe>", 1)[1].strip()
    return generated.strip()

gen_v2 = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=tokenizer,
    device_map="auto",
)

def redact_infer_v2(text: str, max_new_tokens=96) -> str:
    prompt = make_prompt_v2(text)
    out = gen_v2(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,     # deterministic
        top_p=1.0,
        pad_token_id=tokenizer.eos_token_id,
        return_full_text=False
    )[0]["generated_text"]
    # Hard stop at the closing tag if present
    if "</safe>" in out:
        out = out.split("</safe>", 1)[0]
    # Ensure we only return the inside of the block
    return extract_safe(f"<safe>{out}</safe>")

# Test again with the same example
sample = ds[0]["source_text"]
print("SOURCE:", sample)
masked_v2 = redact_infer_v2(sample)
print("\nPREDICTED:")
print(f"<safe>{masked_v2}</safe>")


Device set to use cuda:0


SOURCE: A student's assessment was found on device bearing IMEI: 06-184755-866851-3. The document falls under the various topics discussed in our Optimization curriculum. Can you please collect it?

PREDICTED:
<safe>A student's assessment was found on device bearing [PHONEIMEI]. The document falls under the various topics discussed in our Optimization curriculum. Can you please collect it?</safe>


In [14]:
# Step 5 — Prepare prompt→answer strings for supervised fine-tuning (no training yet)
from datasets import load_dataset

# We reuse the tokenizer from earlier. If not present, re-create it quickly.
try:
    tokenizer
except NameError:
    from transformers import AutoTokenizer
    MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    tokenizer.model_max_length = 512  # keep in sync with your SEQ_LEN knob

DATASET_ID = "ai4privacy/pii-masking-200k"

# Load a manageable slice, then keep only English
raw = load_dataset(DATASET_ID, split="train[:3000]")
raw_en = raw.filter(lambda x: (x.get("language") or "").startswith("en")).select(range(min(2000, len(raw))))

SYSTEM = (
    "You redact personal or secret information from user text. "
    "Return the SAME text but with only the sensitive VALUES replaced by placeholders. "
    "Do not change surrounding words like 'IMEI', 'Email', 'Phone', or punctuation. "
    "Allowed placeholders include dataset labels such as [PHONEIMEI], [EMAIL], etc. "
    "Output ONLY the redacted text between <safe> and </safe>."
)

def format_chat(input_text: str, target_masked: str) -> str:
    # The assistant message includes the exact desired output between tags
    messages = [
        {"role": "system", "content": SYSTEM},
        {"role": "user", "content": input_text},
        {"role": "assistant", "content": f"<safe>{target_masked}</safe>"},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

formatted = raw_en.map(
    lambda ex: {"sft_text": format_chat(ex["source_text"], ex["target_text"])},
    desc="Formatting to chat template",
)

# Quick sanity: show two samples (truncated)
print(formatted)
print("\nSample 1 (truncated):\n", formatted[0]["sft_text"][:600])
print("\nSample 2 (truncated):\n", formatted[1]["sft_text"][:600])

# Token length check on a subset so we know how tight 512 is
subset = formatted.select(range(min(50, len(formatted))))
lens = subset.map(lambda ex: {"len": len(tokenizer(ex["sft_text"], truncation=False)["input_ids"])})
avg_len = sum(r["len"] for r in lens) / len(lens)
max_len = max(r["len"] for r in lens)
print(f"\nAvg tokens over {len(lens)} samples: {avg_len:.1f} | Max: {max_len} | SEQ_LEN limit = {tokenizer.model_max_length}")


Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Formatting to chat template:   0%|          | 0/2000 [00:00<?, ? examples/s]

Dataset({
    features: ['source_text', 'target_text', 'privacy_mask', 'span_labels', 'mbert_text_tokens', 'mbert_bio_labels', 'id', 'language', 'set', 'sft_text'],
    num_rows: 2000
})

Sample 1 (truncated):
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Oct 2025

You redact personal or secret information from user text. Return the SAME text but with only the sensitive VALUES replaced by placeholders. Do not change surrounding words like 'IMEI', 'Email', 'Phone', or punctuation. Allowed placeholders include dataset labels such as [PHONEIMEI], [EMAIL], etc. Output ONLY the redacted text between <safe> and </safe>.<|eot_id|><|start_header_id|>user<|end_header_id|>

A student's assessment was found on device bearing IMEI: 06-184755-86

Sample 2 (truncated):
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Oct 2025

You redact personal or secret informatio

Map:   0%|          | 0/50 [00:00<?, ? examples/s]


Avg tokens over 50 samples: 197.3 | Max: 258 | SEQ_LEN limit = 512


In [16]:
# Step 6 — LoRA wrap (no training yet)
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Reuse knobs from earlier cells, or set safe defaults
try:
    LORA_R
except NameError:
    LORA_R = 16
try:
    LORA_ALPHA
except NameError:
    LORA_ALPHA = 32
try:
    LORA_DROPOUT
except NameError:
    LORA_DROPOUT = 0.05
try:
    TARGET_MODULES
except NameError:
    TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

assert "base_model" in globals(), "Please run the model loading cell first to create `base_model`."

# Make sure caching is off for checkpointing
base_model.config.use_cache = False

# If loaded in 4-bit/8-bit, prep for k-bit training, which also toggles grad checkpointing
is_kbit = bool(getattr(base_model, "is_loaded_in_4bit", False) or getattr(base_model, "is_loaded_in_8bit", False))
if is_kbit:
    base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
else:
    # Fallback: still enable checkpointing
    base_model.gradient_checkpointing_enable()

# Define LoRA
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=TARGET_MODULES,
    task_type="CAUSAL_LM",
)

# Wrap
lora_model = get_peft_model(base_model, lora_cfg)

# Report trainable params
lora_model.print_trainable_parameters()

# Quick VRAM check
if torch.cuda.is_available():
    print(f"VRAM allocated after LoRA wrap: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

trainable params: 9,175,040 || all params: 3,221,924,864 || trainable%: 0.2848
VRAM allocated after LoRA wrap: 2.87 GB


In [19]:
# Step 7 — Tiny SFT using vanilla Trainer (labels only after "<safe>")
import os, gc, torch
from typing import List, Dict, Any
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from datasets import DatasetDict

assert "lora_model" in globals(), "LoRA model not found. Please run the LoRA wrap cell first."
assert "tokenizer" in globals(), "Tokenizer not found. Please load tokenizer."
assert "formatted" in globals() and "sft_text" in formatted.column_names, "Run the formatting step to create 'sft_text'."

# 1) Find the token index of the response template
RESP_TEMPLATE = "<safe>"
resp_ids = tokenizer(RESP_TEMPLATE, add_special_tokens=False)["input_ids"]

def find_sublist(haystack: List[int], needle: List[int]) -> int:
    """Return start index of 'needle' in 'haystack', or -1 if not found."""
    if not needle or len(needle) > len(haystack):
        return -1
    # simple scan; fast enough for our batch sizes
    for i in range(len(haystack) - len(needle) + 1):
        if haystack[i:i+len(needle)] == needle:
            return i
    return -1

# 2) Tokenise and build labels with -100 before the response start
def tok_and_mask(example: Dict[str, Any]) -> Dict[str, Any]:
    enc = tokenizer(
        example["sft_text"],
        truncation=True,
        max_length=tokenizer.model_max_length,
        add_special_tokens=False,
    )
    ids = enc["input_ids"]
    start = find_sublist(ids, resp_ids)
    labels = [-100] * len(ids)
    if start != -1:
        # put loss on everything from the start of "<safe>" onwards
        labels[start:] = ids[start:]
    enc["labels"] = labels
    return enc

tokenised = formatted.map(tok_and_mask, remove_columns=[c for c in formatted.column_names if c != "sft_text"])
# Simple split
splits = tokenised.train_test_split(test_size=0.1, seed=42)
dsd = DatasetDict(train=splits["train"], test=splits["test"])
print(dsd)

# 3) Collator that pads input_ids, attention_mask, and labels together
class CausalPadCollator(DataCollatorWithPadding):
    def __call__(self, features):
        labels = [f.pop("labels") for f in features]
        batch = super().__call__(features)  # pads input_ids and attention_mask
        # pad labels to same length with -100
        max_len = batch["input_ids"].shape[1]
        padded = []
        for lab in labels:
            if len(lab) < max_len:
                lab = lab + [-100] * (max_len - len(lab))
            else:
                lab = lab[:max_len]
            padded.append(lab)
        batch["labels"] = torch.tensor(padded, dtype=torch.long)


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sft_text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1800
    })
    test: Dataset({
        features: ['sft_text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})


In [21]:
# Step 7c — Train for 1 epoch with minimal, version-safe TrainingArguments
import os, gc, torch
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding

assert "dsd" in globals(), "The tokenised DatasetDict 'dsd' is missing. Please re-run the dataset formatting cell."
assert "lora_model" in globals(), "LoRA model not found. Please run the LoRA wrap cell."
assert "tokenizer" in globals(), "Tokenizer not found. Please load tokenizer."

# Collator that pads input_ids, attention_mask, and labels together
class CausalPadCollator(DataCollatorWithPadding):
    def __call__(self, features):
        labels = [f.pop("labels") for f in features]
        batch = super().__call__(features)  # pads input_ids and attention_mask
        # pad labels to same length with -100
        max_len = batch["input_ids"].shape[1]
        padded = []
        for lab in labels:
            if len(lab) < max_len:
                lab = lab + [-100] * (max_len - len(lab))
            else:
                lab = lab[:max_len]
            padded.append(lab)
        batch["labels"] = torch.tensor(padded, dtype=torch.long)
        return batch

collator = CausalPadCollator(tokenizer=tokenizer, padding=True)

OUTPUT_DIR = "./outputs/safe-prompt-3b-lora"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Keep args minimal for 4.57.1 compatibility (no evaluation_strategy here)
args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    fp16=True,
    optim="adamw_torch",   # safest option with your stack
    report_to=[],
)

trainer = Trainer(
    model=lora_model,
    args=args,
    train_dataset=dsd["train"],
    eval_dataset=dsd["test"],
    tokenizer=tokenizer,
    data_collator=collator,
)

print("Starting training…")
train_out = trainer.train()
print("Training complete.\n")

# Explicit evaluation (since we did not set an eval strategy)
eval_metrics = trainer.evaluate()
print("Eval:", eval_metrics)

# Save adapters and tokenizer
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)

# Cleanup
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"VRAM now: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
print("Adapters saved to:", OUTPUT_DIR)


  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


Starting training…


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss
10,0.5814
20,0.2867
30,0.2097
40,0.1292
50,0.124
60,0.1037
70,0.0954
80,0.0806
90,0.0761
100,0.0645


Training complete.



Eval: {'eval_loss': nan, 'eval_runtime': 75.2314, 'eval_samples_per_second': 2.658, 'eval_steps_per_second': 2.658, 'epoch': 1.0}
VRAM now: 2.94 GB
Adapters saved to: ./outputs/safe-prompt-3b-lora


In [22]:
# Step 8 — Reload adapters from disk and test inference with leak checks
import os, re, gc, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from peft import PeftModel

# Knobs
MODEL_ID   = "meta-llama/Llama-3.2-3B-Instruct"
ADAPTER_DIR = "./outputs/safe-prompt-3b-lora"
SEQ_LEN    = 512
DTYPE      = torch.float16

# 1) Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.model_max_length = SEQ_LEN

# 2) Load base model (4-bit if possible)
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

bnb_config = None
ALLOW_4BIT = False
try:
    import bitsandbytes as bnb  # noqa
    sm_major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
    ALLOW_4BIT = bool(torch.cuda.is_available() and sm_major >= 6)
except Exception:
    ALLOW_4BIT = False

if ALLOW_4BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=DTYPE,
    )

kwargs = dict(
    torch_dtype=DTYPE,
    device_map="auto",
    trust_remote_code=False,
)
if bnb_config is not None:
    kwargs["quantization_config"] = bnb_config

base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
base_model.config.use_cache = False

# 3) Attach trained adapters
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model.eval()

# 4) Build generator
gen = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)

# 5) Prompt and parser with strict tag-only contract
SYSTEM = (
    "You redact personal or secret information from user text. "
    "Return the SAME text but with only the sensitive VALUES replaced by placeholders. "
    "Do NOT change surrounding words like 'IMEI', 'Email', 'Phone', labels, or punctuation. "
    "Allowed placeholders include dataset-style tags like [EMAIL], [PHONEIMEI], etc. "
    "Output ONLY the redacted text between <safe> and </safe>. No other text."
)

def make_prompt(text: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM},
        {"role": "user", "content": text},
        {"role": "assistant", "content": "<safe>"},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def redact(text: str, max_new_tokens=128) -> str:
    prompt = make_prompt(text)
    out = gen(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,            # deterministic
        pad_token_id=tokenizer.eos_token_id,
        return_full_text=False
    )[0]["generated_text"]
    # Trim at the first closing tag if present
    if "</safe>" in out:
        out = out.split("</safe>", 1)[0]
    return out.strip()

# 6) Light leak checks with conservative fallback
EMAIL_RE = re.compile(r"\b[^\s@]+@[^\s@]+\.[^\s@]+\b")
PHONE_RE = re.compile(r"\b(?:\+?\d{1,3}[\s.\-]?)?(?:\(?\d{3}\)?[\s.\-]?)?\d{3}[\s.\-]?\d{4}\b")

def postcheck(output_text: str) -> str:
    fixed = output_text
    # Basic masks if the model leaked values
    fixed = EMAIL_RE.sub("[EMAIL]", fixed)
    fixed = PHONE_RE.sub("[PHONE]", fixed)
    return fixed

def redact_safe(text: str) -> str:
    masked = redact(text)
    masked = postcheck(masked)
    return f"<safe>{masked}</safe>"

# 7) Try a few samples (uses 'ds' from earlier; fallback to load one)
try:
    ds
except NameError:
    from datasets import load_dataset
    ds = load_dataset("ai4privacy/pii-masking-200k", split="train[:3]")

for i in range(min(3, len(ds))):
    src = ds[i]["source_text"]
    print(f"\nSOURCE {i+1}:\n", src)
    print("\nPREDICTED:")
    print(redact_safe(src))


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



SOURCE 1:
 A student's assessment was found on device bearing IMEI: 06-184755-866851-3. The document falls under the various topics discussed in our Optimization curriculum. Can you please collect it?

PREDICTED:
<safe><safe>A student's assessment was found on device bearing IMEI: [IMEI]. The document falls under the various topics discussed in our Optimization curriculum. Can you please collect it?</safe>

SOURCE 2:
 Dear Omer, as per our records, your license 78B5R2MVFAHJ48500 is still registered in our records for access to the educational tools. Please feedback on it's operability.

PREDICTED:
<safe><safe>Dear [FIRSTNAME], as per our records, your license [VEHICLEVRM] is still registered in our records for access to the educational tools. Please feedback on it's operability.</safe>

SOURCE 3:
 Kattie could you please share your recomndations about vegetarian diet for 72 old Intersex person with 158centimeters?

PREDICTED:
<safe><safe>Kattie could you please share your recomndation

In [24]:
# Step 9 — Redaction tester for arbitrary inputs
import re
from dataclasses import dataclass
from typing import List, Tuple, Callable
from transformers import pipeline

SAFE_OPEN = "<safe>"
SAFE_CLOSE = "</safe>"

# Reuse tokenizer/model already in memory. If 'gen' is missing, rebuild it.
try:
    gen
except NameError:
    gen = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")

SYSTEM_TEST = (
    "You redact personal or secret information from user text. "
    "Return the SAME text but replace only sensitive VALUES with placeholders. "
    "Do not change surrounding words like 'IMEI', 'Email', 'Phone', or punctuation. "
    "Allowed placeholders include dataset-style tags like [EMAIL], [PHONEIMEI], [FIRSTNAME], etc. "
    "Output ONLY the redacted text between <safe> and </safe>. No other text."
)

@dataclass
class Detector:
    name: str
    pattern: re.Pattern
    placeholder: str
    post: Callable[[str], str] = lambda s: s
    
def make_prompt_v3(text: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_TEST},
        {"role": "user", "content": text},
        {"role": "assistant", "content": SAFE_OPEN},  # model will continue after <safe>
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def strip_one_prefix(s: str, prefix: str) -> str:
    return s[len(prefix):] if s.startswith(prefix) else s

def redact_raw(text: str, max_new_tokens: int = 128) -> str:
    prompt = make_prompt_v3(text)
    out = gen(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        return_full_text=False
    )[0]["generated_text"]
    # Cut at first closing tag if present
    if SAFE_CLOSE in out:
        out = out.split(SAFE_CLOSE, 1)[0]
    # Remove any leading <safe> that the model may echo
    out = strip_one_prefix(out.strip(), SAFE_OPEN)
    # Final cleanup: if model still echoed tags inside, strip all
    out = out.replace(SAFE_OPEN, "").replace(SAFE_CLOSE, "").strip()
    return out

# A tiny starter set; extend as needed
DETECTORS: List[Detector] = [
    Detector("email", re.compile(r"\b[^\s@]+@[^\s@]+\.[^\s@]+\b"), "[EMAIL]"),
    Detector("phone", re.compile(r"\b(?:\+?\d{1,3}[\s.\-]?)?(?:\(?\d{3}\)?[\s.\-]?)?\d{3}[\s.\-]?\d{4}\b"), "[PHONE]"),
    # Detector("ipv4", re.compile(r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b"), "[IP]"),
    # Detector("ssn_us", re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN]"),
]

def apply_validators(text: str, mode: str = "warn") -> Tuple[str, List[str]]:
    """
    mode = "off" | "warn" | "enforce"
    returns: (possibly_modified_text, list_of_detector_names_that_matched)
    """
    hits = []
    out = text
    for d in DETECTORS:
        if d.pattern.search(out):
            hits.append(d.name)
            if mode == "enforce":
                out = d.pattern.sub(d.placeholder, out)
    return out, hits
    
# Example integration with your redact function:
def redact_safe_prompt(text: str, max_new_tokens: int = 128, validate_mode: str = "enforce") -> str:
    masked = redact_raw(text, max_new_tokens=max_new_tokens)  # your LLM output inside the <safe> block
    masked, hits = apply_validators(masked, mode=validate_mode)
    # Optional: log hits to a counter or print in dev
    # if hits: print(f"[validator hits] {hits}")
    return masked

# Try your own prompts by editing the list below
tests = [
    "Hi, I am Vishal Shinde. Email me at vishal@example.com and call +1 415 555 0199.",
    "Card 4111 1111 1111 1111 expires 12/26, CVV 123.",
    "Server at 10.0.0.5 uses API key sk_live_ABC123XYZ456 for uploads.",
    "IMEI: 06-184755-866851-3 belongs to this device.",
    "My SSN is 123-45-6789. Do not store it.",
    "Send the package to 221B Baker Street, London.",
]

for i, t in enumerate(tests, 1):
    print(f"\nINPUT {i}: {t}")
    print("OUTPUT:", redact_safe_prompt(t))



INPUT 1: Hi, I am Vishal Shinde. Email me at vishal@example.com and call +1 415 555 0199.


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


OUTPUT: Hi, I am [FIRSTNAME] [LASTNAME]. Email me at [EMAIL] and call [PHONEIMEI].

INPUT 2: Card 4111 1111 1111 1111 expires 12/26, CVV 123.
OUTPUT: Card [MASKEDNUMBER] [MASKEDNUMBER] [MASKEDNUMBER] [MASKEDNUMBER] expires [DATE], CVV [MASKEDNUMBER].

INPUT 3: Server at 10.0.0.5 uses API key sk_live_ABC123XYZ456 for uploads.
OUTPUT: Server at [IP] uses API key [PASSWORD] for uploads.

INPUT 4: IMEI: 06-184755-866851-3 belongs to this device.
OUTPUT: IMEI: [IMEI] belongs to this device.

INPUT 5: My SSN is 123-45-6789. Do not store it.
OUTPUT: My SSN is [SSN]. Do not store it.

INPUT 6: Send the package to 221B Baker Street, London.
OUTPUT: Send the package to [BUILDINGNUMBER] [STREET], [CITY].
