In [None]:
if False:
    !pip install --upgrade pip
    !pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu129
    !pip install "unsloth[cu121]"
    !pip install ipywidgets
    !pip install bitsandbytes
    !pip install orjson
    # !sudo dnf install python3.13-devel

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, dtype="auto", device_map="auto"
)

model.eval()

def chat(user_text, max_new_tokens=64, temperature=0.0, do_sample=False):
    messages = [{"role": "user","content": [{"type": "text", "text": user_text}]}]
    
    inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        tokenize=True,
        return_dict=True, 
        return_tensors="pt",
    ).to(model.device)
    
    with torch.no_grad():
        out = model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens,
            temperature=temperature, 
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id
        )
        
    return tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()

# print(chat("What is the capital of France?"))

In [None]:
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any

EPOCHS = 12
K = 3                    
C0 = 1.5                 
TRAIN_BS = 256           
TEST_BS  = 512           
MISMATCH_CAP = 8         
NEW_PROMPTS_PER_ROUND = 1
BATCH_WRITE = 4096       

PROMPT_EVAL_TMPL = """# Task:
{{prompt}}
# Output format:
Answer Yes or No as labels
# Prediction:
Text: {{text}}
Label:
"""

PROMPT_GRAD_TMPL = """I am trying to write a prompt to identify attack vectors.
My current prompt is:
"{{prompt}}"
However, this prompt incorrectly processes the following example:
{{error_string}}
Please explain why this prompt might process this example incorrectly.
Reason:
"""

PROMPT_UPDATE_TMPL = """I am trying to write a prompt to identify attack vectors.
My current prompt is: "{{prompt}}"
However, the problem with this prompt is: {{reason}}
Based on the above issue, Please write {{number_of_new_prompts}} different improved prompts.
Each prompt is wrapped with <START> and <END>.
These {{number_of_new_prompts}} new prompts are:
"""

GEN_ARGS: Dict[str, Any] = {
    "max_new_tokens": 8, 
    "temperature": 0.0, 
    "do_sample": False
}

YES_PAT = re.compile(r"\s*(Yes|yes|y|true|1)\b", re.I)
NO_PAT  = re.compile(r"\s*(No|no|n|false|0)\b", re.I)
START_TAG, END_TAG = "<START>", "<END>"
EXTRACT_NEW_PROMPT = re.compile(r"<START>([\s\S]*?)<\/?END>", re.I)

In [None]:
import math, random
import numpy as np
from sklearn.metrics import f1_score, matthews_corrcoef, confusion_matrix

def fill_eval(prompt_str, text):
    return PROMPT_EVAL_TMPL.replace("{{prompt}}", prompt_str).replace("{{text}}", text)

def fill_grad(prompt_str, error_text):
    return PROMPT_GRAD_TMPL.replace("{{prompt}}", prompt_str).replace("{{error_string}}", error_text)

def fill_update(prompt_str, reason, k):
    return (PROMPT_UPDATE_TMPL
            .replace("{{prompt}}", prompt_str)
            .replace("{{reason}}", reason)
            .replace("{{number_of_new_prompts}}", str(k)))

def render_eval_input(prompt_str: str, text: str) -> str:
    return fill_eval(prompt_str, text)

def normalize_yesno(raw: str, inp: str = '') -> str:
    if not raw:
        return "no"
    
    s = raw.splitlines()[0].strip("`'\" ").lower()
    
    if YES_PAT.search(s): return "yes"
    if NO_PAT.search(s):  return "no"
    
    return "no"

def llm_yesno(prompt_str: str, text: str) -> str:
    inp = render_eval_input(prompt_str, text)
    out = chat(inp, **GEN_ARGS)
    return normalize_yesno(out, inp)

def gen_reason(prompt_str: str, error_text: str) -> str:
    grad_inp = fill_grad(prompt_str, error_text)

    return chat(grad_inp, max_new_tokens=256, temperature=0.0, do_sample=False)

def apply_reason_to_prompt(prompt_str: str, reason: str, k: int = 3) -> str:
    upd_inp = fill_update(prompt_str, reason, k)
    upd_out = chat(upd_inp, max_new_tokens=1024, temperature=0.0, do_sample=False)

    m = EXTRACT_NEW_PROMPT.findall(upd_out)
    
    for raw_prompt in m:
        new_prompt = raw_prompt.strip()
        if len(new_prompt) >= 8:
            return raw_prompt

    return (prompt_str + " Only consider code that executes in a web browser context. Answer Yes or No.")


def sample_batch(dataset, n: int):
    if len(dataset) <= n: return dataset
    return random.sample(dataset, n)

def eval_f1_batch(prompt_str: str, batch):
    ys, yh = [], []
    for ex in batch:
        yh.append(1 if llm_yesno(prompt_str, ex["text"]) == "yes" else 0)
        ys.append(1 if ex["label"].lower() == "yes" else 0)
    return f1_score(ys, yh)

def evaluate_full(prompt_str: str, dataset):
    y_true, y_pred = [], []
    for ex in dataset:
        y_pred.append(1 if llm_yesno(prompt_str, ex["text"]) == "yes" else 0)
        y_true.append(1 if ex["label"].lower() == "yes" else 0)
    cm  = confusion_matrix(y_true, y_pred, labels=[1,0])
    tp, fn, fp, tn = int(cm[0,0]), int(cm[0,1]), int(cm[1,0]), int(cm[1,1])
    acc = (tp+tn)/max(1,(tp+tn+fp+fn))
    f1  = f1_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    return {"acc":acc, "f1":f1, "mcc":mcc, "tp":tp, "fp":fp, "tn":tn, "fn":fn}

class UCBR:
    def __init__(self, prompts, c0=C0):
        self.P = prompts[:]         
        self.s = [1.0] * len(self.P)
        self.f = [1.0] * len(self.P)
        self.N = 1                  
        self.c0 = c0
        
    def select_k(self, k: int):
        c = self.c0 / math.sqrt(self.N)
        ucb = []
        for si, fi in zip(self.s, self.f):
            t = si + fi
            score = (si / t) + c * math.sqrt(max(1e-9, math.log(self.N) / t))
            ucb.append(score)
            
        weights = [max(1e-9, x) for x in ucb]
        idxs = random.choices(range(len(self.P)), weights=weights, k=min(k, len(self.P)))
        
        seen = set()
        out = []
        for i in idxs:
            if i not in seen:
                seen.add(i)
                out.append(i)
        return out
        
    def update(self, i: int, score: float, batch_len: int):
        self.s[i] += score * batch_len
        self.f[i] += (1.0 - score) * batch_len
        self.N += 1

In [None]:
import json, pathlib
from copy import deepcopy
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

LOGDIR = pathlib.Path("./artifacts_gemma3")

def preview_tail(path, n=10):
    try:
        lines = pathlib.Path(path).read_text(encoding="utf-8").splitlines()
        for line in lines[-n:]:
            print(line)
    except Exception as e:
        print(f"(preview error) {e}")


def train_ucbr_prompts_with_logging(base_prompt, train_set, test_set,
                                    train_bs=TRAIN_BS, test_bs=TEST_BS,
                                    mismatch_cap=MISMATCH_CAP, new_prompts_per_round=NEW_PROMPTS_PER_ROUND,
                                    k_select=K, epochs=EPOCHS,
                                    log_mismatch_from="test", 
                                    mismatch_log_sample=64):

    EPOCH_PROMPTS_PATH = LOGDIR / "epoch_prompts.jsonl"
    MISMATCH_SAMPLES_PATH = LOGDIR / "mismatches_sampled.jsonl"
    
    P = [base_prompt]
    ucbr = UCBR(P, c0=C0) 
    best = {"prompt": P[0], "f1": -1.0}

    if LOGDIR.exists():
        EPOCH_PROMPTS_PATH.write_text("", encoding="utf-8")
        MISMATCH_SAMPLES_PATH.write_text("", encoding="utf-8")

    for t in range(1, epochs + 1):
        print(f"[epoch {t:02}] start  arms={len(P)}", flush=True)
        Pexp = P[:]

        for i, p in enumerate(P):
            train_batch = sample_batch(train_set, train_bs) 
            
            mismatches = []
            for ex in train_batch:
                pred = llm_yesno(p, ex["text"])
                gold = "yes" if ex["label"].lower() == "yes" else "no"
                if pred != gold:
                    mismatches.append({"id": ex.get("id"), "text": ex["text"], "gold": gold, "pred": pred})
                if len(mismatches) >= mismatch_cap:
                    break

            if mismatches and (len(Pexp) - len(P)) < new_prompts_per_round:
                reason = gen_reason(p, mismatches[0]["text"])
                p_new  = apply_reason_to_prompt(p, reason, k = new_prompts_per_round)
                Pexp.append(p_new)

        idxs = ucbr.select_k(k_select)
        test_batch = sample_batch(test_set, test_bs)
        print(f"[epoch {t:02}] select={idxs} n_test={len(test_batch)}", flush=True)

        f1_by_arm = {}
        for i in idxs:
            p_cur = Pexp[i]
            f1 = eval_f1_batch(p_cur, test_batch)
            f1_by_arm[i] = f1
            ucbr.update(i, f1, len(test_batch))
            if f1 > best["f1"]:
                best = {"prompt": p_cur, "f1": f1}

        with open(EPOCH_PROMPTS_PATH, "a", encoding="utf-8") as fw:
            for i, ptxt in enumerate(Pexp):
                rec = {
                    "epoch": t,
                    "arm_index": i,
                    "prompt": ptxt,
                    "selected": int(i in idxs),
                    "f1_if_selected": f1_by_arm.get(i, None),
                }
                fw.write(json.dumps(rec, ensure_ascii=False) + "\n")

        def _collect_mismatches_for_prompt(prompt_text, pool, cap):
            rows = []
            for ex in pool:
                txt = ex.get("text", "")
                if not txt or not txt.strip(): 
                    continue
                
                pred = llm_yesno(prompt_text, ex["text"])
                gold = "yes" if ex["label"].lower() == "yes" else "no"
                if pred != gold:
                    short = ex["text"][:220].replace("\n"," ")
                    rows.append({
                        "epoch": t,
                        "prompt": prompt_text,
                        "id": ex.get("id"),
                        "label": gold,
                        "pred": pred,
                        "text_preview": short
                    })
                if len(rows) >= cap:
                    break
            return rows

        for i in idxs:
            p_cur = Pexp[i]
            pool = test_batch if log_mismatch_from == "test" else sample_batch(train_set, mismatch_log_sample)
            rows = _collect_mismatches_for_prompt(p_cur, pool, mismatch_log_sample)
            if rows:
                with open(MISMATCH_SAMPLES_PATH, "a", encoding="utf-8") as fw:
                    for r in rows:
                        fw.write(json.dumps(r, ensure_ascii=False) + "\n")

        P = Pexp
        ucbr.P = P
        if len(ucbr.s) < len(P):
            add = len(P) - len(ucbr.s)
            ucbr.s += [1.0] * add
            ucbr.f += [1.0] * add

    return best, str(EPOCH_PROMPTS_PATH), str(MISMATCH_SAMPLES_PATH)

In [None]:
import json, random, pathlib, numpy as np
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed

def set_seed(seed=7):
    random.seed(seed); np.random.seed(seed)
set_seed(777)

def load_jsonl(p):
    with open(p, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

DEV_JSONL  = "/home/user/llm/dataset_xsshield/dev.jsonl" 
EVAL_JSONL = "/home/user/llm/dataset_xsshield/eval.jsonl"

try:
    train = list(load_jsonl(DEV_JSONL))
    test  = list(load_jsonl(EVAL_JSONL))
except FileNotFoundError:
    print(f"[ERROR] Data files not found. Please check paths: {DEV_JSONL} and {EVAL_JSONL}")
    train = [{"text": "test data for error", "label": "no"}] * 10
    test = train

print(f"[data] train={len(train)}  test={len(test)}")
print("[data] test label dist:", Counter([x["label"].lower() for x in test]))

def clean_dataset_in_place(dataset, name="Dataset"):
    original_len = len(dataset)
    
    cleaned_dataset = [
        item for item in dataset 
        if item.get("text", "").strip()
    ]
    
    removed_count = original_len - len(cleaned_dataset)
    if removed_count > 0:
        pass
    return cleaned_dataset

train = clean_dataset_in_place(train, name="Train set")
test = clean_dataset_in_place(test, name="Test set")

SEED_PROMPT = "Is This JavaScript code slice XSS attack payload?"

LOGDIR.mkdir(parents=True, exist_ok=True)

best, ep_path, mm_path = train_ucbr_prompts_with_logging(
    SEED_PROMPT, train, test,
    train_bs=TRAIN_BS, test_bs=TEST_BS,
    mismatch_cap=MISMATCH_CAP, new_prompts_per_round=NEW_PROMPTS_PER_ROUND,
    k_select=K, epochs=EPOCHS,
    log_mismatch_from="test",
    mismatch_log_sample=64,
)
print(f"[best F1 est] {best['f1']}")
print(f"[saved] {ep_path}")
print(f"[saved] {mm_path}")

CACHE = {}
def chat_cached(inp, **kw):
    h = hash(inp)
    if h in CACHE: return CACHE[h]
    out = chat(inp, **kw)
    CACHE[h] = out
    return out

def worker_llm_yesno(prompt_str, text):
    if not text or not text.strip():
        return "no"
    
    inp = render_eval_input(prompt_str, text)
    out = chat_cached(inp, **GEN_ARGS) 
    return normalize_yesno(out)

BEST_PATH = LOGDIR / "best_prompt.txt"
PARTIAL_FILE = LOGDIR / "predictions_eval_partial.jsonl"
FINAL_FILE   = LOGDIR / "predictions_eval.jsonl"

LOGDIR.mkdir(parents=True, exist_ok=True) 

best_prompt_str = best["prompt"]
BEST_PATH.write_text(best_prompt_str, encoding="utf-8")


print(f"[best F1 est] {best['f1']:.4f} (Note: F1 is from final eval or last batch if resuming)")
print(f"[saved] {ep_path}")
print(f"[saved] {mm_path}")

start_idx = 0
existing = {}
if PARTIAL_FILE.exists():
    print("[resume] loading partial preds...", flush=True)
    with open(PARTIAL_FILE, "r", encoding="utf-8") as f:
        for i,line in enumerate(f):
            try:
                obj = json.loads(line)
                existing[int(obj["_idx"])] = obj
            except:
                continue
    start_idx = max(existing.keys()) + 1 if existing else 0
    print(f"[resume] loaded {len(existing)} rows; will start at idx={start_idx}", flush=True)

n = len(test)
best_prompt_str = best["prompt"]
print(f"[eval] total={n}  start_idx={start_idx}  mode=SEQUENTIAL", flush=True)

processed_count = 0

with open(PARTIAL_FILE, "a", encoding="utf-8") as outf:
    for idx in range(start_idx, n):
        ex_obj = test[idx]

        if not ex_obj.get("text", "").strip():
            pred = "no"
        else:
            try:
                pred = worker_llm_yesno(best_prompt_str, ex_obj["text"])
            except Exception as e:
                pred = "no"

        rec = {"_idx": idx, "label": ex_obj["label"].lower(), "pred": pred}
        outf.write(json.dumps(rec, ensure_ascii=False) + "\n")
        
        processed_count += 1
        if processed_count % BATCH_WRITE == 0:
            outf.flush()
            print(f"[INFO] Processed {processed_count} samples (Total {idx+1}/{n})...", flush=True)

print("[eval] finished writing partial file. Now aggregate to final.", flush=True)

rows=[]
with open(PARTIAL_FILE, "r", encoding="utf-8") as f:
    for line in f:
        try:
            rows.append(json.loads(line))
        except:
            pass
rows.sort(key=lambda x: x["_idx"])

y_true = [1 if r["label"]=="yes" else 0 for r in rows]
y_pred = [1 if r["pred"]=="yes" else 0 for r in rows]

if len(y_true) > 0:
    cm = confusion_matrix(y_true, y_pred, labels=[1,0])
    tp, fn, fp, tn = int(cm[0,0]), int(cm[0,1]), int(cm[1,0]), int(cm[1,1])
    acc = (tp+tn)/len(y_true)
    f1  = f1_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)

    with open(FINAL_FILE, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print(f"\n[FINAL] acc={acc:.4f}  f1={f1:.4f}  mcc={mcc:.4f}")
    print(f"[CM] tp={tp} fp={fp} tn={tn} fn={fn}")
    print(f"[saved] {BEST_PATH}")
    print(f"[saved] {FINAL_FILE}")
else:
    print("\n[FINAL] Not enough data processed for final metrics.")