In [1]:
# Cell 1 — Install required packages
# We use transformers, datasets, accelerate, peft, evaluate. This may take a few minutes.
!pip install -q "transformers>=4.30" datasets accelerate peft bitsandbytes evaluate sentencepiece safetensors


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Cell 2 — imports & config
import os, json, random, math, gc, sys
from getpass import getpass
import pandas as pd
import numpy as np
import torch

print("Python:", sys.version.split()[0])
print("Torch:", torch.__version__, "CUDA:", torch.cuda.is_available())
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# -- Config (tweak for your runtime)
CSV_PATH = "reddit_preprocessed_minimal.csv"   # put dataset in working dir or use upload cell below
SAMPLE_N = 20000           # safe default sample for Colab; increase if you have more VRAM
VAL_N = 4000               # validation size
MODEL_BASE = "google/flan-t5-base"   # try this first; fallback to flan-t5-small if OOM
FALLBACK_SMALL = "google/flan-t5-small"
OUTPUT_DIR = "flan_t5_sna_adapter"
MAX_SOURCE_LENGTH = 256
MAX_TARGET_LENGTH = 64
BATCH_SIZE = 4            # per device batch; reduce to 2 if you get OOM
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


Python: 3.12.12
Torch: 2.8.0+cu126 CUDA: True
Using device: cuda


<torch._C.Generator at 0x79cab34a3130>

In [4]:
# Cell 3 — upload UI if needed (Colab)
if not os.path.exists(CSV_PATH):
    try:
        from google.colab import files
        print(f"{CSV_PATH} not found. Upload the CSV now.")
        uploaded = files.upload()
        for fn in uploaded.keys():
            os.rename(fn, CSV_PATH)
            print("Saved", fn, "->", CSV_PATH)
    except Exception:
        raise FileNotFoundError(f"{CSV_PATH} not found. Upload it or change CSV_PATH.")
else:
    print("Found dataset:", CSV_PATH)


Found dataset: reddit_preprocessed_minimal.csv


In [5]:
# Cell 4 — load CSV and preprocess
df = pd.read_csv(CSV_PATH)
print("Dataset rows:", len(df))
print("Columns (preview):", df.columns.tolist()[:40])

# detect text column heuristics
text_col = None
for c in df.columns:
    if any(k in c.lower() for k in ["body","text","post","content","comment","message","title"]):
        text_col = c; break
if text_col is None:
    raise RuntimeError("Could not detect text column — edit CSV or set text_col manually.")

print("Using text column:", text_col)

# basic safe preprocessing
import re
def preprocess_text(s):
    if pd.isna(s): return ""
    s = str(s)
    s = s.replace("\n"," ").replace("\r"," ")
    s = re.sub(r"http\S+", " ", s)
    s = re.sub(r"/u/\S+", " ", s)
    s = re.sub(r"@\w+", " ", s)
    s = re.sub(r"[^\\x00-\\x7f]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

df["_sna_text"] = df[text_col].astype(str).apply(preprocess_text)
print("Example text:", df["_sna_text"].iloc[0][:200])


Dataset rows: 910086
Columns (preview): ['subreddit', 'body', 'controversiality', 'score', 'lang', 'text_length', 'word_count', 'avg_word_len']
Using text column: body
Example text: Y [SPOILERS [NO SPOILERS P [ f f f [ f I f P [ f ? = f f


In [6]:
# Cell 5 — create/refine labels (weak labeling)
def improved_weak_label(text):
    s = (text or "").lower()
    patterns = [
        r"\bsex\b", r"\bnude\b", r"\bnsfw\b", r"\bporn\b", r"\bxxx\b", r"\bunderage\b", r"\bchild\b",
        r"\brape\b", r"\bkill\b", r"\bmurder\b", r"\bi want to die\b", r"\bkill myself\b", r"\bsuicide\b",
        r"\b(cocaine|heroin|meth|xanax|fentanyl|drug)\b", r"\bsend nudes\b", r"\bsend pics\b", r"\bprivate chat\b"
    ]
    for p in patterns:
        if re.search(p, s):
            return 1
    # obfuscated checks
    if re.search(r"p[\W_]*0?r[\W_]*n|n[\W_]*u[\W_]*d|s[\W_]*3[\W_]*x", s):
        return 1
    return 0

# If dataset already has a label column, normalize it; otherwise create weak labels
label_candidates = [c for c in df.columns if any(k in c.lower() for k in ["label","sensitive","target","is_sensitive","y","class"])]
if label_candidates:
    src = label_candidates[0]
    if df[src].dtype == object:
        df["_sna_label"] = df[src].astype(str).str.lower().isin(["1","true","yes","sensitive","explicit","adult"]).astype(int)
    else:
        vals = sorted(df[src].dropna().unique())
        if set(vals) <= {0,1}:
            df["_sna_label"] = df[src].astype(int)
        else:
            df["_sna_label"] = (df[src] > df[src].median()).astype(int)
else:
    df["_sna_label"] = df["_sna_text"].apply(improved_weak_label).astype(int)

print("Label counts:", df["_sna_label"].value_counts().to_dict())
# if no positives, relax rules or ask for manual labels
if df["_sna_label"].sum() == 0:
    print("No positives found — relaxing rules and adding loose matches.")
    loose_kw = ["porn","sex","nude","suicide","kill","child","drug", "underage", "teen"]
    added = 0
    for i, row in df.iterrows():
        if df.at[i, "_sna_label"]==0:
            s = row["_sna_text"].lower()
            if any(k in s for k in loose_kw):
                df.at[i, "_sna_label"] = 1
                added += 1
    print("Added", added, "weak positives. New counts:", df["_sna_label"].value_counts().to_dict())

if df["_sna_label"].sum()==0:
    raise RuntimeError("Still zero positive examples. Provide labeled data or sample rows with expected positives for tuning.")


Label counts: {0: 910086}
No positives found — relaxing rules and adding loose matches.
Added 327 weak positives. New counts: {0: 909759, 1: 327}


In [7]:
# Cell 6 — Build JSONL train/val for SFT (prompt -> JSON output)
# We'll build simple instruction prompts and target outputs (JSON string).
from sklearn.model_selection import train_test_split

# build prompt function
def build_prompt(text):
    return (f"Instruction:\nDecide whether the following post is appropriate for children. "
            "Return a JSON with keys: label ('SENSITIVE' or 'SAFE'), score (0.0-1.0), reason (short).\n\n"
            f"Input: {text}\n\nOutput:")

def build_target(label, reason=""):
    lbls = {1: "SENSITIVE", 0: "SAFE"}
    s = {"label": lbls[int(label)], "score": 1.0 if int(label)==1 else 0.0, "reason": reason or ""}
    return json.dumps(s, ensure_ascii=False)

# create dataset for SFT
df_sft = df[["_sna_text","_sna_label"]].rename(columns={"_sna_text":"text","_sna_label":"label"}).sample(frac=1, random_state=SEED)
# sample for Colab
if SAMPLE_N and SAMPLE_N < len(df_sft):
    df_sft = df_sft.sample(n=SAMPLE_N, random_state=SEED).reset_index(drop=True)
print("SFT dataset rows:", len(df_sft))

# prepare train/val split
train_df, val_df = train_test_split(df_sft, test_size=VAL_N if VAL_N < len(df_sft) else 0.15, random_state=SEED, stratify=df_sft["label"])
print("Train:", len(train_df), "Val:", len(val_df))

# write jsonl files
os.makedirs("sft_data", exist_ok=True)
train_path = "sft_data/train.jsonl"
val_path = "sft_data/val.jsonl"

with open(train_path, "w", encoding="utf-8") as fo:
    for _, r in train_df.iterrows():
        p = build_prompt(r["text"])
        t = build_target(r["label"])
        fo.write(json.dumps({"prompt": p, "response": t}, ensure_ascii=False) + "\n")

with open(val_path, "w", encoding="utf-8") as fo:
    for _, r in val_df.iterrows():
        p = build_prompt(r["text"])
        t = build_target(r["label"])
        fo.write(json.dumps({"prompt": p, "response": t}, ensure_ascii=False) + "\n")

print("Wrote", train_path, "and", val_path)


SFT dataset rows: 20000
Train: 16000 Val: 4000
Wrote sft_data/train.jsonl and sft_data/val.jsonl


In [8]:
# Cell 7 — datasets + tokenizer
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer

# try tokenizer for chosen model; fallback if fails
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)
    model_name_used = MODEL_BASE
except Exception as e:
    print("Primary tokenizer failed:", e)
    tokenizer = AutoTokenizer.from_pretrained(FALLBACK_SMALL)
    model_name_used = FALLBACK_SMALL

print("Tokenizer/model selected:", model_name_used)

# Load the jsonl files created above
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        rows = [json.loads(l) for l in f]
    return rows

train_raw = load_jsonl(train_path)
val_raw = load_jsonl(val_path)

# Convert to HF datasets with input_ids / labels (tokenized targets)
from transformers import DataCollatorForSeq2Seq

def preprocess_sft(batch):
    # batch: dict with 'prompt' and 'response'
    prompts = batch["prompt"]
    targets = batch["response"]
    model_inputs = tokenizer(prompts, max_length=MAX_SOURCE_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
    # set -100 for padding tokens on labels
    label_ids = labels["input_ids"]
    for i, lbl in enumerate(label_ids):
        label_ids[i] = [l if l != tokenizer.pad_token_id else -100 for l in lbl]
    model_inputs["labels"] = label_ids
    return model_inputs

from datasets import Dataset
train_ds = Dataset.from_list(train_raw)
val_ds = Dataset.from_list(val_raw)

# tokenization with batched map
train_tok = train_ds.map(preprocess_sft, batched=True, batch_size=128, remove_columns=train_ds.column_names)
val_tok   = val_ds.map(preprocess_sft, batched=True, batch_size=128, remove_columns=val_ds.column_names)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=None, label_pad_token_id=-100)
print("Prepared tokenized datasets:", len(train_tok), len(val_tok))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Tokenizer/model selected: google/flan-t5-base


Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Prepared tokenized datasets: 16000 4000


In [15]:
# ===== REPLACEMENT Cell 8 (attach missing generation_config and create Trainer safely) =====
import traceback
from transformers import TrainingArguments, Seq2SeqTrainer

print("=== Robust Trainer creation (attach generation_config if missing) ===")
print("Device:", DEVICE)

# assume `model`, `tokenizer`, `train_tok`, `val_tok`, `data_collator` exist from previous cells
# if `model`/`base_model` not present, you need to run the model-load cell again.

# Build minimal TrainingArguments (very widely supported)
try:
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        logging_steps=100,
        save_strategy="epoch",
        fp16=(DEVICE == "cuda"),
        learning_rate=LEARNING_RATE,
        push_to_hub=False
    )
    print("TrainingArguments constructed.")
except Exception as e:
    print("Failed to construct TrainingArguments with full kwargs, trying minimal fallback...", e)
    # last-resort minimal args
    training_args = TrainingArguments(output_dir=OUTPUT_DIR, per_device_train_batch_size=1)
    print("Minimal TrainingArguments constructed.")

# Attach generation_config if missing (some transformers versions expect it)
if not hasattr(training_args, "generation_config"):
    try:
        setattr(training_args, "generation_config", None)
        print("Attached training_args.generation_config = None (compatibility shim).")
    except Exception as e:
        print("Could not attach generation_config:", e)

# Create Seq2SeqTrainer without predict_with_generate kwarg (compat-safe)
try:
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_tok,
        eval_dataset=val_tok,
        tokenizer=tokenizer,
        data_collator=data_collator,
        # compute_metrics can be added if you want eval metrics during training
    )
    print("Seq2SeqTrainer created successfully.")
except Exception as e:
    print("Seq2SeqTrainer creation failed. Traceback:")
    traceback.print_exc()
    raise

print("Trainer ready. You can now run: trainer.train()")


  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


=== Robust Trainer creation (attach generation_config if missing) ===
Device: cuda
TrainingArguments constructed.
Attached training_args.generation_config = None (compatibility shim).
Seq2SeqTrainer created successfully.
Trainer ready. You can now run: trainer.train()


In [16]:
# Cell 9 — Train and save the PEFT adapter / model
import gc, os, time

print("Starting training. This may take some time.")
start = time.time()
train_result = trainer.train()
end = time.time()
print(f"Training completed in {(end-start)/60:.2f} minutes. Trainer return: {train_result}")

# Save final model (PEFT-aware) and tokenizer
print("Saving model and tokenizer to", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)  # saves peft adapter + base model pointers as appropriate
tokenizer.save_pretrained(OUTPUT_DIR)

# optional: save trainer state
trainer.state.save_to_json(os.path.join(OUTPUT_DIR, "trainer_state.json"))

# cleanup
gc.collect()
try:
    import torch
    torch.cuda.empty_cache()
except Exception:
    pass

print("Model & tokenizer saved.")


Starting training. This may take some time.


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mam-sc-u4cse23271[0m ([33mam-sc-u4cse23271-amrita-vishwa-vidhyapeetham[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
100,0.0
200,0.0
300,0.0
400,0.0
500,0.0
600,0.0
700,0.0
800,0.0
900,0.0
1000,0.0


Training completed in 56.06 minutes. Trainer return: TrainOutput(global_step=12000, training_loss=0.0, metrics={'train_runtime': 3361.9411, 'train_samples_per_second': 14.277, 'train_steps_per_second': 3.569, 'total_flos': 1.6564636090368e+16, 'train_loss': 0.0, 'epoch': 3.0})
Saving model and tokenizer to flan_t5_sna_adapter
Model & tokenizer saved.


In [19]:
# ===== Replacement Cell 10: Robust generation + parsing (re-run for N_EVAL examples) =====
import re, json, torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np

# ensure gen_model/gen_tokenizer exist (loaded from OUTPUT_DIR or trainer.model)
try:
    gen_tokenizer
    gen_model
except NameError:
    try:
        gen_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
        gen_model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
        if torch.cuda.is_available(): gen_model.to("cuda")
    except Exception:
        gen_model = trainer.model
        gen_tokenizer = tokenizer
        if torch.cuda.is_available():
            try: gen_model.to("cuda")
            except: pass

# improved parser that handles JSON or plain tokens
def parse_flexible_output(s):
    if not s or not isinstance(s, str):
        return None
    s = s.strip()
    # 1) try to find JSON object first
    m = re.search(r"\{.*\}", s, flags=re.S)
    if m:
        txt = m.group(0)
        try:
            return json.loads(txt)
        except Exception:
            try:
                txt2 = txt.replace("'", '"')
                txt2 = re.sub(r",\s*}", "}", txt2)
                return json.loads(txt2)
            except Exception:
                pass
    # 2) look for label words anywhere
    u = s.lower()
    if "sensitive" in u:
        # try to capture a numeric score if present
        score_m = re.search(r"score[^0-9\-\.]{0,6}([0-9]*\.?[0-9]+)", u)
        if score_m:
            try:
                sc = float(score_m.group(1))
                sc = max(0.0, min(1.0, sc))
            except:
                sc = 1.0
        else:
            sc = 1.0
        return {"label":"SENSITIVE","score":sc,"reason":u[:120]}
    if "safe" in u:
        score_m = re.search(r"score[^0-9\-\.]{0,6}([0-9]*\.?[0-9]+)", u)
        if score_m:
            try:
                sc = float(score_m.group(1)); sc = max(0.0,min(1.0,sc))
            except: sc = 0.0
        else:
            sc = 0.0
        return {"label":"SAFE","score":sc,"reason":u[:120]}
    # 3) check for single-token labels like "1" or "0"
    if re.fullmatch(r"[01]", s):
        return {"label":"SENSITIVE" if s.strip()=="1" else "SAFE", "score": 1.0 if s.strip()=="1" else 0.0, "reason": ""}
    # nothing found
    return None

# generate on a fixed number of validation examples
N_EVAL = min(len(val_raw), 500)   # keep same limit as before
print("Generating on", N_EVAL, "examples")

gen_model.eval()
gen_texts = []
parsed_list = []
true_labels = []

max_gen_len = MAX_TARGET_LENGTH if "MAX_TARGET_LENGTH" in globals() else 64
for i in tqdm(range(N_EVAL)):
    prompt = val_raw[i]["prompt"]
    true_j = json.loads(val_raw[i]["response"])
    true_label = 1 if true_j["label"].upper().startswith("S") else 0
    true_labels.append(true_label)

    inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=MAX_SOURCE_LENGTH)
    if torch.cuda.is_available():
        inputs = {k:v.to("cuda") for k,v in inputs.items()}
    try:
        out_ids = gen_model.generate(**inputs, max_length=max_gen_len, do_sample=False, num_beams=2)
        out_text = gen_tokenizer.decode(out_ids[0], skip_special_tokens=True)
    except Exception as e:
        # fallback: CPU generation
        try:
            inputs_cpu = {k:v.cpu() for k,v in inputs.items()}
            out_ids = gen_model.generate(**inputs_cpu, max_length=max_gen_len, do_sample=False, num_beams=2)
            out_text = gen_tokenizer.decode(out_ids[0], skip_special_tokens=True)
        except Exception:
            out_text = ""

    gen_texts.append(out_text)
    parsed = parse_flexible_output(out_text)
    parsed_list.append(parsed)

print("Generation + parsing done. Parsed count:", sum(1 for p in parsed_list if p is not None))


Generating on 500 examples


  0%|          | 0/500 [00:00<?, ?it/s]

Generation + parsing done. Parsed count: 498


In [20]:
# ===== Replacement Cell 11: Metrics and example display (robust) =====
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np

# derive predicted labels and scores from parsed_list and gen_texts
pred_labels = []
pred_scores = []
for parsed, txt in zip(parsed_list, gen_texts):
    if parsed is None:
        # as a last resort, check raw text tokens for label words
        u = (txt or "").lower()
        if "sensitive" in u:
            pred_labels.append(1); pred_scores.append(1.0)
        elif "safe" in u:
            pred_labels.append(0); pred_scores.append(0.0)
        else:
            pred_labels.append(0); pred_scores.append(0.0)
    else:
        lbl = 1 if parsed.get("label","").upper().startswith("S") else 0
        pred_labels.append(lbl)
        try:
            score = float(parsed.get("score", 1.0 if lbl==1 else 0.0))
            score = max(0.0, min(1.0, score))
        except:
            score = 1.0 if lbl==1 else 0.0
        pred_scores.append(score)

y_true = np.array(true_labels)
y_pred = np.array(pred_labels)
y_score = np.array(pred_scores)

print("Counts: true positives:", int(y_true.sum()), "predicted positives:", int(y_pred.sum()))

# safe metric printing with zero_division handled
print("\nClassification report:")
print(classification_report(y_true, y_pred, digits=4, zero_division=0))

acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)
cm = confusion_matrix(y_true, y_pred)
print(f"Accuracy: {acc:.4f}  Precision: {prec:.4f}  Recall: {rec:.4f}  F1: {f1:.4f}")
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# show example triples (true -> pred) and generated output
examples_to_show = min(6, len(gen_texts))
print("\nExample model outputs (true_label -> predicted_label) and text:")
shown = 0
for i in range(len(gen_texts)):
    if shown >= examples_to_show:
        break
    # prioritize showing cases where prediction != truth, otherwise show first few
    if y_pred[i] != y_true[i] or shown < 2:
        print("----")
        print("Index:", i)
        print("True label:", y_true[i], " Pred label:", y_pred[i], " Score:", round(y_score[i],3))
        print("Prompt (truncated):", val_raw[i]["prompt"][:300].replace("\n"," "))
        print("Model output:", gen_texts[i])
        print("Parsed:", parsed_list[i])
        shown += 1


Counts: true positives: 500 predicted positives: 498

Classification report:
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     1.0000    0.9960    0.9980       500

    accuracy                         0.9960       500
   macro avg     0.5000    0.4980    0.4990       500
weighted avg     1.0000    0.9960    0.9980       500

Accuracy: 0.9960  Precision: 1.0000  Recall: 0.9960  F1: 0.9980
Confusion matrix (rows=true, cols=pred):
 [[  0   0]
 [  2 498]]

Example model outputs (true_label -> predicted_label) and text:
----
Index: 0
True label: 1  Pred label: 1  Score: 1.0
Prompt (truncated): Instruction: Decide whether the following post is appropriate for children. Return a JSON with keys: label ('SENSITIVE' or 'SAFE'), score (0.0-1.0), reason (short).  Input: W 3000  Output:
Model output: SENSITIVE
Parsed: {'label': 'SENSITIVE', 'score': 1.0, 'reason': 'sensitive'}
----
Index: 1
True label: 1  Pred label: 1  Sc

In [None]:
# Cell 12 — Write sna_predict_llm.py helper for SNA team
helper_code = f'''
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import re, json, torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("{OUTPUT_DIR}")
model = AutoModelForSeq2SeqLM.from_pretrained("{OUTPUT_DIR}")
if DEVICE == "cuda":
    model = model.to("cuda")
gen = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if DEVICE=="cuda" else -1)

def parse_json_output(s):
    import re, json
    m = re.search(r"\\{{.*\\}}", s, flags=re.S)
    if not m:
        return None
    try:
        return json.loads(m.group(0))
    except:
        try:
            txt = m.group(0).replace("'", '"')
            txt = re.sub(r",\\s*}", "}", txt)
            return json.loads(txt)
        except:
            return None

def predict_text_prob(text):
    prompt = ("Instruction:\\nDecide whether the following post is appropriate for children. "
              "Return a JSON with keys: label ('SENSITIVE' or 'SAFE'), score (0.0-1.0), reason (short).\\n\\n"
              f"Input: {{text}}\\n\\nOutput:")
    out = gen(prompt.format(text=text), max_length=64, do_sample=False)[0]["generated_text"]
    parsed = parse_json_output(out)
    if parsed is None:
        return 0.0, 1.0
    lbl = 1 if parsed.get("label","").upper().startswith("S") else 0
    score = float(parsed.get("score", 1.0 if lbl==1 else 0.0))
    return score, 1.0-score

def compute_node_risk(node_id, recent_post_risks, graph_feats, strike_count, alpha=1.0, beta=1.5):
    C_i = max(recent_post_risks) if recent_post_risks else 0.0
    I_i = graph_feats.get(node_id, {}).get("pagerank", 0.0)
    H_i = min(5, strike_count)
    R_i = C_i * (1 + alpha * I_i) * (1 + beta * H_i)
    return float(R_i)
'''
open("sna_predict_llm.py","w").write(helper_code)
print("Wrote sna_predict_llm.py — use predict_text_prob() and compute_node_risk() from it.")
