In [1]:
!pip -q install datasets transformers scipy tqdm --upgrade


**MS MARCO & SciFact & NQ**

This notebook presents an ablation study on DeRAG attacks using Differential Evolution (DE) under a controlled setting. Specifically, we evaluate how suffix length (`L ∈ {1..10}`) and retrieval depth (`Top-K ∈ {1, 10}`) affect the effectiveness of adversarial prompt suffixes.

###  Experiment Setup

- **Corpus**: Random sample of 1,000 documents per dataset
- **Queries**: 120 filtered questions per dataset (tokens ∈ [20, 500], must contain `?`)
- **Target Document**: Document ranked **800th** under baseline retrieval
- **Encoder**: `bert-base-uncased` (evaluated in `fp16` on GPU)
- **Suffix**: Appended at the end of the query, of length L ∈ [1, 10]
 

Each query is paired with one incorrect target passage from the 800th position. The goal is to use a DE-optimized suffix to **promote** that document into the **Top-K** positions. Suffixes are optimized using **hinge loss** over cosine similarity between query and document embeddings.

###  Evaluation Metrics
For each `(query, L, K)` setting, we report:
- **Suffix tokens** (`suffix`, `suffix_len`)
- **Final rank of target document** (`final_rank`)
- **Change in cosine similarity** (`Δcos`)
- **Change in nDCG** (`ΔnDCG@K`)
- **Success rate** (whether `rank ≤ K`)
- **Iteration count** (used by DE optimizer)

###  Output
Results are saved to:
```bash
exp_results/{dataset}/ablation/ablation_results.csv


In [8]:


import os, random, math, json, warnings, numpy as np, pandas as pd, torch, tqdm
from datasets     import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution
warnings.filterwarnings("ignore")

SEED             = 42
DEVICE           = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS, N_Q      = 1_000, 120
TAIL_MAX         = 10
RANK_TARGET      = 800
POP, MAXITER     = 20, 2_000
PATIENCE         = 20
BATCH_CLS        = 16          
DATASETS         = ["msmarco","fiqa","NQ"]   
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

tok  = BertTokenizer.from_pretrained("bert-base-uncased")
VOC  = tok.vocab_size


def enc(txts, dev):
    return tok(txts, padding="max_length", truncation=True,
               max_length=512, return_tensors="pt").to(dev)

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank:int)->float:
    return 1 / math.log2(rank + 1)

def is_good_q(t:str)->bool:
    return "?" in t and 20 < len(tok(t)["input_ids"])-2 < 500

def suffix_str(ids):            # pretty print suffix
    return " ".join(tok.convert_ids_to_tokens(ids))

def de_opt(ids, msk, CP, tgt_cls, L, topk):
    pos, bounds = list(range(512-L,512)), [(0,VOC-1)]*L
    def loss_from_cls(c):
        kth = torch.topk(cos_row(c,CP),topk).values[-1]
        sim = torch.nn.functional.cosine_similarity(c,tgt_cls)[0]
        return max(0., (kth-sim).item())
    def obj(v):
        v=[int(round(x)) for x in v]
        p,m=ids.clone(),msk.clone(); m[pos]=1
        for i,t in zip(pos,v): p[i]=t
        cls=bert(input_ids=p.unsqueeze(0),
                 attention_mask=m.unsqueeze(0)).last_hidden_state[:,0,:]
        return loss_from_cls(cls)
    best,stale=1e9,0
    def cb(xk,_):
        nonlocal best,stale
        cur=obj(xk)
        if cur<best: best,stale=cur,0
        else: stale+=1
        return cur==0 or stale>=PATIENCE
    res=differential_evolution(obj,bounds,popsize=POP,maxiter=MAXITER,
                               tol=0,polish=False,seed=SEED,callback=cb)
    suf=[int(round(x)) for x in res.x]
    p,m=ids.clone(),msk.clone(); m[pos]=1
    for i,t in zip(pos,suf): p[i]=t
    cls=bert(input_ids=p.unsqueeze(0),
             attention_mask=m.unsqueeze(0)).last_hidden_state[:,0,:]
    return suf,cls,res.nfev

def prepare(ds:str):
    corpus  = load_dataset(f"BeIR/{ds}", "corpus",  split="corpus")
    queries = load_dataset(f"BeIR/{ds}", "queries", split="queries")
    docs = random.sample(list(corpus), N_DOCS)

    pool = [q for q in queries if is_good_q(q["text"])]
    if len(pool) < N_Q:
        raise ValueError(f"{ds}: less than 120")
    qs = random.sample(pool, N_Q)

    cpu_bert = BertModel.from_pretrained("bert-base-uncased").eval().to("cpu")
    CLS=[]
    with torch.no_grad():
        for i in tqdm.tqdm(range(0,N_DOCS,BATCH_CLS), desc=f"{ds}-CLS"):
            batch = [d["text"] for d in docs[i:i+BATCH_CLS]]
            CLS.append(cpu_bert(**enc(batch,"cpu")).last_hidden_state[:,0,:])
    C_CLS = torch.cat(CLS)
    del cpu_bert; torch.cuda.empty_cache()

    return docs, qs, C_CLS

def run_ablation(ds:str):
    out_dir = f"exp_results/{ds}/ablation"
    os.makedirs(out_dir, exist_ok=True)

    docs, qs, C_CLS = prepare(ds)

    global bert
    bert = BertModel.from_pretrained("bert-base-uncased",
                                     torch_dtype=torch.float16
                                    ).to(DEVICE).eval()

    rec=[]
    for K in (1,10):
        for L in range(1, TAIL_MAX+1):
            for q in tqdm.tqdm(qs, desc=f"{ds}  K={K} L={L}"):
                qtxt = q["text"]
                ids,msk = enc([qtxt],DEVICE)["input_ids"][0], enc([qtxt],DEVICE)["attention_mask"][0]
                with torch.no_grad():
                    qcls = bert(**enc([qtxt],DEVICE)).last_hidden_state[:,0,:]

                sims = cos_row(qcls.cpu(), C_CLS)
                order = torch.argsort(sims, descending=True)
                tgt   = int(order[RANK_TARGET-1])
                tgt_txt = docs[tgt]["text"]
                if len(tok(tgt_txt)["input_ids"]) > 510: 
                    continue

                baseline_sim = sims[tgt].item()
                orig_top1    = docs[int(order[0])]["text"]
                tgt_cls      = C_CLS[tgt:tgt+1].to(DEVICE)
                CP           = C_CLS[[i for i in range(N_DOCS) if i != tgt]]

                suf, cls, it = de_opt(ids, msk, CP, tgt_cls, L, K)
                new_sims = cos_row(cls.cpu(), C_CLS)
                fr = (new_sims > new_sims[tgt]).sum().item() + 1

                rec.append(dict(dataset=ds, topK=K, L=L,
                                query=qtxt,
                                target_excerpt=tgt_txt[:120].replace("\n"," "),
                                orig_top1_excerpt=orig_top1[:120].replace("\n"," "),
                                suffix=suffix_str(suf),
                                suffix_len=len(suf),
                                suffix_token_ids=json.dumps(suf),
                                baseline_rank=RANK_TARGET,
                                final_rank=fr,
                                delta_rank=RANK_TARGET-fr,
                                delta_cos=new_sims[tgt].item()-baseline_sim,
                                delta_ndcg=dcg(fr)-dcg(RANK_TARGET),
                                success=int(fr<=K),
                                iter_used=it))
                torch.cuda.empty_cache()

    pd.DataFrame(rec).to_csv(f"{out_dir}/ablation_results.csv",
                             index=False, encoding="utf-8")
    print(f"✓ {ds} ablation FINISHED — rows = {len(rec)}")

if __name__ == "__main__":
    os.makedirs("exp_results", exist_ok=True)
    for ds in DATASETS:
        run_ablation(ds)
        torch.cuda.empty_cache()


msmarco-CLS: 100%|██████████████████████████████████████████████████████████████████████| 63/63 [03:13<00:00,  3.08s/it]
msmarco  K=1 L=1: 100%|███████████████████████████████████████████████████████████████| 120/120 [12:51<00:00,  6.43s/it]
msmarco  K=1 L=2: 100%|███████████████████████████████████████████████████████████████| 120/120 [31:14<00:00, 15.62s/it]
msmarco  K=1 L=3: 100%|███████████████████████████████████████████████████████████████| 120/120 [45:53<00:00, 22.94s/it]
msmarco  K=1 L=4: 100%|███████████████████████████████████████████████████████████████| 120/120 [57:14<00:00, 28.62s/it]
msmarco  K=1 L=5: 100%|█████████████████████████████████████████████████████████████| 120/120 [1:12:20<00:00, 36.17s/it]
msmarco  K=1 L=6: 100%|█████████████████████████████████████████████████████████████| 120/120 [1:31:35<00:00, 45.79s/it]
msmarco  K=1 L=7: 100%|█████████████████████████████████████████████████████████████| 120/120 [1:50:54<00:00, 55.45s/it]
msmarco  K=1 L=8: 100%|█████████

✓ msmarco ablation 完成 — rows = 2400





#  DeRAG Loss Function Comparison

This repository evaluates the effectiveness of **Differential Evolution (DE)** optimized tail-patch attacks under different **loss objectives** (`hinge` vs. `cosine`) across four BEIR datasets.

##  Datasets

- **MS MARCO** (Open-domain QA)
- **FiQA** (Financial QA)
- **SciFact** (Scientific Fact Verification)
- **FEVER** (Fact Extraction and Verification)

## Experiment Setup

- **Corpus**: 1,000 documents per dataset
- **Queries**: 100 filtered questions per dataset  
  - Must contain a `?`
  - Token length ∈ [20, 500]
- **Target Document**: The **100th-ranked** document under baseline similarity (cosine).
- **Model**: `bert-base-uncased`
- **Device**: GPU with `float16` for inference
- **Optimization**: Differential Evolution (DE)
  - Population: 20
  - Max Iterations: 2,000
  - Suffix Length: `L = 5` tokens

##  Objective

For each `(query, target document)` pair, we use DE to generate a **5-token suffix** that is appended to the query to **promote the target document's rank**.

We compare two loss modes:

- `hinge`: Max-margin loss between target and top distractor
- `cos`: Cosine similarity loss with target document only

##  Evaluation Metrics

For each query and loss mode:

- `suffix` (token sequence)
- `suffix_len`
- `suffix_token_ids`
- `baseline_rank` (always 100)
- `final_rank`
- `delta_rank`
- `delta_cos`
- `delta_nDCG`
- `success` (whether final rank == 1)
- `iter_used` (optimizer steps)


In [46]:

import os
import random
import math
import json
import warnings

import numpy as np
import pandas as pd
import torch
import tqdm
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS, N_Q  = 1_000, 100         
TAIL_L       = 5                 
RANK_TARGET  = 100              
POP, MAXITER = 20, 2_000        
PATIENCE     = 20                
BATCH_CLS    = 16               

tok = BertTokenizer.from_pretrained("bert-base-uncased")
VOC = tok.vocab_size 

def enc(txts, dev):
    return tok(
        txts,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(dev)

def get_ids_msk(text, dev):
    e = enc([text], dev)
    return e["input_ids"][0], e["attention_mask"][0]

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1
    )

def dcg(rank: int) -> float:
    return 1.0 / math.log2(rank + 1)

def is_good_q(text: str) -> bool:
   
    n = len(tok(text)["input_ids"]) - 2
    return (10 < n < 500)

def suffix_str(ids):
    return " ".join(tok.convert_ids_to_tokens(ids))

def de_opt(ids, msk, CP, tgt_cls, loss_mode="hinge"):

    L = TAIL_L
    pos = list(range(512 - L, 512))
    bounds = [(0, VOC - 1)] * L

    def loss_from_cls(c):
        # c: shape [1, dim]
        if loss_mode == "cos":
            # cos loss = - cosine_similarity(c, tgt_cls)
            return -(torch.nn.functional.cosine_similarity(c, tgt_cls)[0]).item()
        kth = torch.topk(cos_row(c, CP), 1).values[-1]
        sim = torch.nn.functional.cosine_similarity(c, tgt_cls)[0]
        return max(0.0, (kth - sim).item())

    def obj(v):
        v = [int(round(x)) for x in v]
        p, m = ids.clone(), msk.clone()
        for i, t in zip(pos, v):
            p[i] = t
        m[pos] = 1
        cls = bert(input_ids=p.unsqueeze(0),
                   attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
        return loss_from_cls(cls)

    best, stale = 1e9, 0
    def cb(xk, _):
        nonlocal best, stale
        cur = obj(xk)
        if cur < best:
            best, stale = cur, 0
        else:
            stale += 1
        return (cur == 0) or (stale >= PATIENCE)

    res = differential_evolution(
        obj,
        bounds,
        popsize=POP,
        maxiter=MAXITER,
        tol=0,
        polish=False,
        seed=SEED,
        callback=cb
    )
    suf = [int(round(x)) for x in res.x]
    p, m = ids.clone(), msk.clone()
    m[pos] = 1
    for i, t in zip(pos, suf):
        p[i] = t
    cls = bert(input_ids=p.unsqueeze(0),
               attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
    return suf, cls, res.nfev

def load_qrels_fever():
    qrels_all = load_dataset("BeIR/fever-qrels", split="train")
    sample = qrels_all[0]
    r_qid_key   = next((k for k in sample.keys() if "query" in k.lower()), None)
    r_doc_key   = next((k for k in sample.keys() if "corpus" in k.lower() or "doc" in k.lower()), None)
    r_score_key = next((k for k in sample.keys() if k.lower() in {"score","label","relevance"}), None)
    if not (r_qid_key and r_doc_key and r_score_key):
        raise RuntimeError("fever:No qrels。")
    pos_dict = {}
    for r in qrels_all:
        if int(r[r_score_key]) > 0:
            qid = str(r[r_qid_key])
            did = str(r[r_doc_key])
            pos_dict.setdefault(qid, []).append(did)
    if not pos_dict:
        raise RuntimeError("fever: qrels no ans！")
    return pos_dict

def prepare_fever():
    print("\n### Preparing FEVER ###")

    queries = load_dataset("BeIR/fever", "queries", split="queries")
    corpus  = load_dataset("BeIR/fever", "corpus",  split="corpus")
    pos_dict = load_qrels_fever()
    sample_q = queries[0]
    qid_key  = next((k for k in sample_q.keys() if "id" in k.lower()), None)
    text_key = next((k for k in sample_q.keys() if k.lower() in {"text","query","body"}), None)
    if not (qid_key and text_key):
        raise RuntimeError("fever: 无法检测到 queries 中的字段。")
    sample_d = corpus[0]
    doc_id_key   = next((k for k in sample_d.keys() if "id" in k.lower()), None)
    doc_text_key = next((k for k in sample_d.keys()
                         if "text" in k.lower() or "body" in k.lower() or "passage" in k.lower()), None)
    if not (doc_id_key and doc_text_key):
        raise RuntimeError("fever: 无法检测到 corpus 中的字段。")
    cand = [
        {"id": str(q[qid_key]), "text": q[text_key]}
        for q in queries
        if (str(q[qid_key]) in pos_dict) and isinstance(q.get(text_key), str) and is_good_q(q[text_key])
    ]
    if len(cand) == 0:
        raise RuntimeError("fever: no query！")
    if len(cand) < N_Q:
        qs = random.choices(cand, k=N_Q)
    else:
        qs = random.sample(cand, N_Q)

    pos_ids = { pos_dict[q["id"]][0] for q in qs }
    all_doc_ids = [str(d[doc_id_key]) for d in corpus]
    other_ids = [did for did in all_doc_ids if did not in pos_ids]
    if len(other_ids) < N_DOCS - len(pos_ids):
        raise RuntimeError(f"fever: 填充文档不足（需要 {N_DOCS - len(pos_ids)}, 只有 {len(other_ids)}）。")
    fillers = random.sample(other_ids, N_DOCS - len(pos_ids))
    sel_ids = list(pos_ids) + fillers
    id2text = { str(d[doc_id_key]): d[doc_text_key] for d in corpus }
    docs = [{"id": did, "text": id2text[did]} for did in sel_ids]
    pos_text = { did: id2text[did][:120].replace("\n", " ") for did in pos_ids }
    cpu_bert = BertModel.from_pretrained("bert-base-uncased").eval().to("cpu")
    CLS_list = []
    with torch.no_grad():
        for i in tqdm.tqdm(range(0, N_DOCS, BATCH_CLS), desc="CLS-fever"):
            batch = docs[i : i + BATCH_CLS]
            batch_texts = [item["text"] for item in batch]
            out = cpu_bert(**enc(batch_texts, "cpu")).last_hidden_state[:, 0, :]
            CLS_list.append(out)
    C_CLS = torch.cat(CLS_list, dim=0)
    del cpu_bert
    torch.cuda.empty_cache()

    return docs, qs, pos_dict, pos_text, C_CLS

def run_loss_compare_fever():
    out_dir = "exp_results/loss_compare/fever"
    os.makedirs(out_dir, exist_ok=True)

    docs, qs, pos_dict, pos_text, C_CLS = prepare_fever()

    global bert
    bert = BertModel.from_pretrained("bert-base-uncased",
                                     torch_dtype=torch.float16).to(DEVICE).eval()

    rec = []  
    for mode in ("hinge", "cos"):
        pbar = tqdm.tqdm(qs, desc=f"fever-{mode}")
        for q in pbar:
            qtxt = q["text"]
            qid  = q["id"]
            ids, msk = get_ids_msk(qtxt, DEVICE)
            with torch.no_grad():
                qcls = bert(**enc([qtxt], DEVICE)).last_hidden_state[:, 0, :]

            sims = cos_row(qcls.cpu(), C_CLS)
            order = torch.argsort(sims, descending=True)
            true_pos_id      = pos_dict[qid][0]
            true_pos_excerpt = pos_text[true_pos_id]
            tgt_idx       = int(order[RANK_TARGET - 1].item())
            tgt_id        = docs[tgt_idx]["id"]
            tgt_excerpt   = docs[tgt_idx]["text"][:120].replace("\n", " ")
            baseline_sim  = sims[tgt_idx].item()
            pos_rank    = None
            pos_excerpt = None
            tgt_cls = C_CLS[tgt_idx : tgt_idx + 1].to(DEVICE)
            CP = torch.cat([C_CLS[:tgt_idx], C_CLS[tgt_idx+1:]], dim=0)

            suf, cls_adv, iters = de_opt(ids, msk, CP, tgt_cls, loss_mode=mode)
            new_sims = cos_row(cls_adv.cpu(), C_CLS)
            fr = (new_sims > new_sims[tgt_idx]).sum().item() + 1
            entry = {
                "dataset":             "fever",
                "loss":                mode,
                "query":               qtxt,
                "orig_answer_id":      true_pos_id,
                "orig_answer_excerpt": true_pos_excerpt,
                "tgt_id":              tgt_id,
                "tgt_excerpt":         tgt_excerpt,
                "suffix":              suffix_str(suf),
                "suffix_len":          len(suf),
                "suffix_token_ids":    json.dumps(suf),
                "baseline_rank":       RANK_TARGET,
                "final_rank":          fr,
                "delta_rank":          RANK_TARGET - fr,
                "delta_cos":           new_sims[tgt_idx].item() - baseline_sim,
                "delta_ndcg":          dcg(fr) - dcg(RANK_TARGET),
                "success":             int(fr == 1),
                "iter_used":           iters,
                "pos_rank":            pos_rank,
                "pos_excerpt":         pos_excerpt,
            }
            rec.append(entry)

    df = pd.DataFrame(rec)
    df.to_csv(f"{out_dir}/loss_detail.csv", index=False, encoding="utf-8")

    agg = df.groupby("loss").agg(
        success_rate   = ("success",   "mean"),
        avg_iters      = ("iter_used", "mean"),
        avg_delta_rank = ("delta_rank","mean"),
        avg_delta_cos  = ("delta_cos", "mean"),
    ).reset_index()
    agg.to_csv(f"{out_dir}/loss_aggregate.csv", index=False, encoding="utf-8")

    h = agg.loc[agg.loss == "hinge", "success_rate"].values[0]
    c = agg.loc[agg.loss == "cos",   "success_rate"].values[0]
    print(f"✓ fever done  –  hinge success = {h:.2%},  cos success = {c:.2%}")
    del bert
    torch.cuda.empty_cache()
if __name__ == "__main__":
    os.makedirs("exp_results/loss_compare/fever", exist_ok=True)
    run_loss_compare_fever()



### Preparing FEVER ###


CLS-fever: 100%|████████████████████████████████████████████████████████████████████████| 63/63 [03:01<00:00,  2.89s/it]
fever-hinge: 100%|██████████████████████████████████████████████████████████████████| 100/100 [1:12:16<00:00, 43.37s/it]
fever-cos: 100%|████████████████████████████████████████████████████████████████████| 100/100 [1:22:05<00:00, 49.26s/it]

✓ fever done  –  hinge success = 26.00%,  cos success = 1.00%





In [42]:
import os
import random
import math
import json
import warnings

import numpy as np
import pandas as pd
import torch
import tqdm
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS, N_Q  = 1_000, 100       
TAIL_L       = 5                 
RANK_TARGET  = 100               
POP, MAXITER = 20, 2_000         
PATIENCE     = 20                
BATCH_CLS    = 16               
tok = BertTokenizer.from_pretrained("bert-base-uncased")
VOC = tok.vocab_size 
def enc(txts, dev):

    return tok(
        txts,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(dev)

def get_ids_msk(text, dev):
    e = enc([text], dev)
    return e["input_ids"][0], e["attention_mask"][0]

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1
    )

def dcg(rank: int) -> float:
    return 1.0 / math.log2(rank + 1)

def is_good_q(text: str) -> bool:
    if "?" not in text:
        return False
    n = len(tok(text)["input_ids"]) - 2
    return (15 < n < 500)

def suffix_str(ids):
    return " ".join(tok.convert_ids_to_tokens(ids))

def de_opt(ids, msk, CP, tgt_cls, loss_mode="hinge"):

    L = TAIL_L
    pos = list(range(512 - L, 512))
    bounds = [(0, VOC - 1)] * L

    def loss_from_cls(c):
        # c: shape [1, dim]
        if loss_mode == "cos":
            return -(torch.nn.functional.cosine_similarity(c, tgt_cls)[0]).item()
        kth = torch.topk(cos_row(c, CP), 1).values[-1]
        sim = torch.nn.functional.cosine_similarity(c, tgt_cls)[0]
        return max(0.0, (kth - sim).item())

    def obj(v):
        
        v = [int(round(x)) for x in v]
        p, m = ids.clone(), msk.clone()
        for i, t in zip(pos, v):
            p[i] = t
        m[pos] = 1
        cls = bert(input_ids=p.unsqueeze(0),
                   attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
        return loss_from_cls(cls)

    best, stale = 1e9, 0
    def cb(xk, _):
        nonlocal best, stale
        cur = obj(xk)
        if cur < best:
            best, stale = cur, 0
        else:
            stale += 1
        return (cur == 0) or (stale >= PATIENCE)

    res = differential_evolution(
        obj,
        bounds,
        popsize=POP,
        maxiter=MAXITER,
        tol=0,
        polish=False,
        seed=SEED,
        callback=cb
    )
    suf = [int(round(x)) for x in res.x]
    p, m = ids.clone(), msk.clone()
    m[pos] = 1
    for i, t in zip(pos, suf):
        p[i] = t
    cls = bert(input_ids=p.unsqueeze(0),
               attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
    return suf, cls, res.nfev

def load_qrels_fiqa():
    qrels_all = load_dataset("BeIR/fiqa-qrels")
    qrels = qrels_all["test"]
    sample = qrels[0]
    r_qid_key   = next((k for k in sample.keys() if "query" in k.lower()), None)
    r_doc_key   = next((k for k in sample.keys() if "corpus" in k.lower() or "doc" in k.lower()), None)
    r_score_key = next((k for k in sample.keys() if k.lower() in {"score","label","relevance"}), None)
    if not (r_qid_key and r_doc_key and r_score_key):
        raise RuntimeError("fiqa: NOT ACCEPT。")
    pos_dict = {}
    for r in qrels:
        if int(r[r_score_key]) > 0:
            qid = str(r[r_qid_key])
            did = str(r[r_doc_key])
            pos_dict.setdefault(qid, []).append(did)
    if not pos_dict:
        raise RuntimeError("fiqa: qrels NO！")
    return pos_dict

def prepare_fiqa():
    print("\n### Preparing FiQA-2018 ###")
    queries = load_dataset("BeIR/fiqa", "queries", split="queries")
    corpus  = load_dataset("BeIR/fiqa", "corpus",  split="corpus")
    pos_dict = load_qrels_fiqa()
    sample_q = queries[0]
    qid_key  = next((k for k in sample_q.keys() if "id" in k.lower()), None)
    text_key = next((k for k in sample_q.keys() if k.lower() in {"text","query","body"}), None)
    if not (qid_key and text_key):
        raise RuntimeError("fiqa: NO QUERY。")
    sample_d = corpus[0]
    doc_id_key   = next((k for k in sample_d.keys() if "id" in k.lower()), None)
    doc_text_key = next((k for k in sample_d.keys() 
                         if "text" in k.lower() or "body" in k.lower() or "passage" in k.lower()), None)
    if not (doc_id_key and doc_text_key):
        raise RuntimeError("fiqa: NO CORPUS。")
    cand = [
        {"id": str(q[qid_key]), "text": q[text_key]}
        for q in queries
        if (str(q[qid_key]) in pos_dict) and isinstance(q.get(text_key), str) and is_good_q(q[text_key])
    ]
    if len(cand) == 0:
        raise RuntimeError("fiqa: OUT OF query！")
    if len(cand) < N_Q:
        qs = random.choices(cand, k=N_Q)
    else:
        qs = random.sample(cand, N_Q)
    pos_ids = { pos_dict[q["id"]][0] for q in qs }
    all_doc_ids = [str(d[doc_id_key]) for d in corpus]
    other_ids = [did for did in all_doc_ids if did not in pos_ids]
    if len(other_ids) < N_DOCS - len(pos_ids):
        raise RuntimeError(f"fiqa: NO（,NEED {N_DOCS - len(pos_ids)}, ONLY {len(other_ids)}）。")
    fillers = random.sample(other_ids, N_DOCS - len(pos_ids))
    sel_ids = list(pos_ids) + fillers
    id2text = { str(d[doc_id_key]): d[doc_text_key] for d in corpus }
    docs = [{"id": did, "text": id2text[did]} for did in sel_ids]
    pos_text = { did: id2text[did][:120].replace("\n", " ") for did in pos_ids }
    cpu_bert = BertModel.from_pretrained("bert-base-uncased").eval().to("cpu")
    CLS_list = []
    with torch.no_grad():
        for i in tqdm.tqdm(range(0, N_DOCS, BATCH_CLS), desc="CLS-fiqa"):
            batch = docs[i : i + BATCH_CLS]
            batch_texts = [item["text"] for item in batch]
            out = cpu_bert(**enc(batch_texts, "cpu")).last_hidden_state[:, 0, :]
            CLS_list.append(out)
    C_CLS = torch.cat(CLS_list, dim=0)
    del cpu_bert
    torch.cuda.empty_cache()

    return docs, qs, pos_dict, pos_text, C_CLS
def run_loss_compare_fiqa():
    out_dir = "exp_results/loss_compare/fiqa"
    os.makedirs(out_dir, exist_ok=True)

 
    docs, qs, pos_dict, pos_text, C_CLS = prepare_fiqa()

    global bert
    bert = BertModel.from_pretrained("bert-base-uncased",
                                     torch_dtype=torch.float16).to(DEVICE).eval()

    rec = [] 

    for mode in ("hinge", "cos"):
        pbar = tqdm.tqdm(qs, desc=f"fiqa-{mode}")
        for q in pbar:
            qtxt = q["text"]
            qid  = q["id"]

            ids, msk = get_ids_msk(qtxt, DEVICE)
            with torch.no_grad():
                qcls = bert(**enc([qtxt], DEVICE)).last_hidden_state[:, 0, :]

            sims = cos_row(qcls.cpu(), C_CLS)
            order = torch.argsort(sims, descending=True)

            true_pos_id      = pos_dict[qid][0]
            true_pos_excerpt = pos_text[true_pos_id]

            tgt_idx       = int(order[RANK_TARGET - 1].item())
            tgt_id        = docs[tgt_idx]["id"]
            tgt_excerpt   = docs[tgt_idx]["text"][:120].replace("\n", " ")
            baseline_sim  = sims[tgt_idx].item()
            pos_rank    = None
            pos_excerpt = None
            tgt_cls = C_CLS[tgt_idx : tgt_idx + 1].to(DEVICE)
            CP = torch.cat([C_CLS[:tgt_idx], C_CLS[tgt_idx+1:]], dim=0)

            suf, cls_adv, iters = de_opt(ids, msk, CP, tgt_cls, loss_mode=mode)
            new_sims = cos_row(cls_adv.cpu(), C_CLS)
            fr = (new_sims > new_sims[tgt_idx]).sum().item() + 1
            entry = {
                "dataset":            "fiqa",
                "loss":               mode,
                "query":              qtxt,

                "orig_answer_id":      true_pos_id,
                "orig_answer_excerpt": true_pos_excerpt,
                "tgt_id":             tgt_id,
                "tgt_excerpt":        tgt_excerpt,

                "suffix":             suffix_str(suf),
                "suffix_len":         len(suf),
                "suffix_token_ids":   json.dumps(suf),
                "baseline_rank":      RANK_TARGET,
                "final_rank":         fr,
                "delta_rank":         RANK_TARGET - fr,
                "delta_cos":          new_sims[tgt_idx].item() - baseline_sim,
                "delta_ndcg":         dcg(fr) - dcg(RANK_TARGET),
                "success":            int(fr == 1),
                "iter_used":          iters,
                "pos_rank":           pos_rank,
                "pos_excerpt":        pos_excerpt,
            }

            rec.append(entry)

    df = pd.DataFrame(rec)
    df.to_csv(f"{out_dir}/loss_detail.csv", index=False, encoding="utf-8")
    agg = df.groupby("loss").agg(
        success_rate   = ("success",   "mean"),
        avg_iters      = ("iter_used", "mean"),
        avg_delta_rank = ("delta_rank","mean"),
        avg_delta_cos  = ("delta_cos", "mean"),
    ).reset_index()
    agg.to_csv(f"{out_dir}/loss_aggregate.csv", index=False, encoding="utf-8")

    h = agg.loc[agg.loss == "hinge", "success_rate"].values[0]
    c = agg.loc[agg.loss == "cos",   "success_rate"].values[0]
    print(f"✓ fiqa done  –  hinge success = {h:.2%},  cos success = {c:.2%}")
    del bert
    torch.cuda.empty_cache()
if __name__ == "__main__":
    os.makedirs("exp_results/loss_compare/fiqa", exist_ok=True)
    run_loss_compare_fiqa()



### Preparing FiQA-2018 ###


CLS-fiqa: 100%|█████████████████████████████████████████████████████████████████████████| 63/63 [03:24<00:00,  3.25s/it]
fiqa-hinge: 100%|███████████████████████████████████████████████████████████████████| 100/100 [1:17:11<00:00, 46.31s/it]
fiqa-cos: 100%|█████████████████████████████████████████████████████████████████████| 100/100 [1:09:47<00:00, 41.88s/it]

✓ fiqa done  –  hinge success = 23.00%,  cos success = 2.00%





In [None]:

plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        sub = summary_df[
            (summary_df['Dataset'] == ds_name) &
            (summary_df['TopK'] == topk)
        ].sort_values('suffix_len')
        plt.plot(
            sub['suffix_len'],
            sub['mean_delta_rank'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.axvline(x=4.5, color='gray', linestyle='--', linewidth=1)
ymin, ymax = plt.ylim()
plt.fill_betweenx([ymin, ymax], 4.5, 10, color='gray', alpha=0.15)
plt.text(5.2, ymax * 0.9, 'Marginal gains ≈ 0 for L ≥ 5', fontsize=9, color='gray')

plt.xlabel('Suffix Length (L)')
plt.ylabel('Mean ΔRank')
plt.title('Mean ΔRank vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='upper left', fontsize=8)
plt.tight_layout()

plt.savefig('delta_rank.png', dpi=300)
plt.show()

plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        sub = summary_df[
            (summary_df['Dataset'] == ds_name) &
            (summary_df['TopK'] == topk)
        ].sort_values('suffix_len')
        plt.plot(
            sub['suffix_len'],
            sub['success_rate'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.xlabel('Suffix Length (L)')
plt.ylabel('Success Rate')
plt.title('Success Rate vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='lower right', fontsize=8)
plt.tight_layout()
plt.savefig('success_rate.png', dpi=300)
plt.show()
plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        mg_df = marginal_dict[(ds_name, topk)]
        plt.plot(
            mg_df['suffix_len'],
            mg_df['marginal_gain'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.axhline(y=0, color='gray', linestyle='-', linewidth=1)
plt.axvline(x=5, color='red', linestyle='--', linewidth=1)
plt.text(5.2, plt.ylim()[1] * 0.8, 'L = 5 cutoff', color='red', fontsize=9)

plt.xlabel('Suffix Length (L)')
plt.ylabel('Marginal Gain (ΔRank difference)')
plt.title('Marginal Gain vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='upper right', fontsize=8)
plt.tight_layout()

plt.savefig('marginal_gain.png', dpi=300)
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pathlib
FILES = {
    'MSMARCO': 'exp_results/msmarco/ablation/ablation_results.csv',
    'Fiqa':    'exp_results/ablation/ablation_results.csv',
    'nq':      'exp_results/nq/ablation/ablation_results.csv'
}
COLORS = {
    'MSMARCO': 'tab:blue',
    'Fiqa':    'tab:orange',
    'nq':      'tab:green'
}
def load_and_truncate(file_path, dataset_name, topk_focus, keep_per_group=100):
    if not pathlib.Path(file_path).exists():
        raise FileNotFoundError(f"File not found: {file_path}")

    df = pd.read_csv(file_path)
    if 'topK' in df.columns:
        topk_col = 'topK'
    elif 'top_k' in df.columns:
        topk_col = 'top_k'
    else:
        raise KeyError(f"Cannot find 'topK' or 'top_k' column in {file_path}")

    df['Dataset'] = dataset_name
    df_trunc = (
        df
        .groupby([topk_col, 'suffix_len'], group_keys=False)
        .apply(lambda grp: grp.iloc[:keep_per_group])
        .reset_index(drop=True)
    )

    df_focus = df_trunc[df_trunc[topk_col] == topk_focus].copy()
    df_focus['TopK'] = topk_focus
    return df_focus
data_frames = []
for ds_name, path in FILES.items():
    for topk in [1, 10]:
        df_focus = load_and_truncate(path, ds_name, topk_focus=topk, keep_per_group=100)
        data_frames.append(df_focus)
combined = pd.concat(data_frames, ignore_index=True)

def compute_summary(df):
    summary = (
        df
        .groupby(['Dataset', 'TopK', 'suffix_len'], as_index=False)
        .agg(
            mean_delta_rank=('delta_rank', 'mean'),
            success_rate   =('success',    'mean'),
            mean_delta_cos =('delta_cos',  'mean'),
            mean_delta_ndcg=('delta_ndcg', 'mean')
        )
    )
    return summary

summary_df = compute_summary(combined)

def compute_marginal_gain(summary_df, metric_key):
    df = summary_df.copy().sort_values('suffix_len').reset_index(drop=True)
    gains = [np.nan] 
    for i in range(1, len(df)):
        gains.append(df.loc[i, metric_key] - df.loc[i-1, metric_key])
    df['marginal_gain'] = gains
    return df
marginal_dict = {}
for ds_name in FILES.keys():
    for topk in [1, 10]:
        sub = summary_df[
            (summary_df['Dataset'] == ds_name) &
            (summary_df['TopK'] == topk)
        ][['suffix_len', 'mean_delta_rank']].copy()
        mg = compute_marginal_gain(sub, 'mean_delta_rank')
        marginal_dict[(ds_name, topk)] = mg

plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        sub = summary_df[
            (summary_df['Dataset'] == ds_name) &
            (summary_df['TopK'] == topk)
        ].sort_values('suffix_len')
        plt.plot(
            sub['suffix_len'],
            sub['mean_delta_rank'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.axvline(x=4.5, color='gray', linestyle='--', linewidth=1)
ymin, ymax = plt.ylim()
plt.fill_betweenx([ymin, ymax], 4.5, 10, color='gray', alpha=0.15)
plt.text(5.2, ymax * 0.9, 'Marginal gains ≈ 0 for L ≥ 5', fontsize=9, color='gray')

plt.xlabel('Suffix Length (L)')
plt.ylabel('Mean ΔRank')
plt.title('Mean ΔRank vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        sub = summary_df[
            (summary_df['Dataset'] == ds_name) &
            (summary_df['TopK'] == topk)
        ].sort_values('suffix_len')
        plt.plot(
            sub['suffix_len'],
            sub['success_rate'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.xlabel('Suffix Length (L)')
plt.ylabel('Success Rate')
plt.title('Success Rate vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='lower right', fontsize=8)
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 5))
for ds_name in FILES.keys():
    for topk, linestyle, marker in [(1, '-', 'o'), (10, '--', 's')]:
        mg_df = marginal_dict[(ds_name, topk)]
        plt.plot(
            mg_df['suffix_len'],
            mg_df['marginal_gain'],
            label=f"{ds_name} Top-{topk}",
            color=COLORS[ds_name],
            linestyle=linestyle,
            marker=marker,
            linewidth=2
        )

plt.axhline(y=0, color='gray', linestyle='-', linewidth=1)
plt.axvline(x=5, color='red', linestyle='--', linewidth=1)
plt.text(5.2, plt.ylim()[1] * 0.8, 'L = 5 cutoff', color='red', fontsize=9)

plt.xlabel('Suffix Length (L)')
plt.ylabel('Marginal Gain (ΔRank difference)')
plt.title('Marginal Gain vs. Suffix Length – MSMARCO/Fiqa/nq (Top-1 & Top-10)')
plt.xticks(range(1, 11))
plt.grid(alpha=0.3)
plt.legend(loc='upper right', fontsize=8)
plt.tight_layout()
plt.show()


In [None]:


import os
import random
import math
import json
import warnings

import numpy as np
import pandas as pd
import torch
import tqdm
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")

# ─────────────── 全局配置 ───────────────
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS, N_Q  = 1_000, 100        
TAIL_L       = 5                
RANK_TARGET  = 100              
POP, MAXITER = 20, 2_000         
PATIENCE     = 20                
BATCH_CLS    = 16               

tok = BertTokenizer.from_pretrained("bert-base-uncased")
VOC = tok.vocab_size  
def enc(txts, dev):
    return tok(
        txts,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(dev)

def get_ids_msk(text, dev):
    e = enc([text], dev)
    return e["input_ids"][0], e["attention_mask"][0]

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1
    )

def dcg(rank: int) -> float:
    return 1.0 / math.log2(rank + 1)

def is_good_q(text: str) -> bool:

    n = len(tok(text)["input_ids"]) - 2
    return (10 < n < 500)

def suffix_str(ids):

    return " ".join(tok.convert_ids_to_tokens(ids))

def de_opt(ids, msk, CP, tgt_cls, loss_mode="hinge"):
 
    L = TAIL_L
    pos = list(range(512 - L, 512))
    bounds = [(0, VOC - 1)] * L

    def loss_from_cls(c):
        # c: shape [1, dim]
        if loss_mode == "cos":
            # cos loss = - cosine_similarity(c, tgt_cls)
            return -(torch.nn.functional.cosine_similarity(c, tgt_cls)[0]).item()
    
        kth = torch.topk(cos_row(c, CP), 1).values[-1]
        sim = torch.nn.functional.cosine_similarity(c, tgt_cls)[0]
        return max(0.0, (kth - sim).item())

    def obj(v):
      
        v = [int(round(x)) for x in v]
        p, m = ids.clone(), msk.clone()
        for i, t in zip(pos, v):
            p[i] = t
        m[pos] = 1
        cls = bert(input_ids=p.unsqueeze(0),
                   attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
        return loss_from_cls(cls)

    best, stale = 1e9, 0
    def cb(xk, _):
        nonlocal best, stale
        cur = obj(xk)
        if cur < best:
            best, stale = cur, 0
        else:
            stale += 1
        return (cur == 0) or (stale >= PATIENCE)

    res = differential_evolution(
        obj,
        bounds,
        popsize=POP,
        maxiter=MAXITER,
        tol=0,
        polish=False,
        seed=SEED,
        callback=cb
    )
    suf = [int(round(x)) for x in res.x]

 
    p, m = ids.clone(), msk.clone()
    m[pos] = 1
    for i, t in zip(pos, suf):
        p[i] = t
    cls = bert(input_ids=p.unsqueeze(0),
               attention_mask=m.unsqueeze(0)).last_hidden_state[:, 0, :]
    return suf, cls, res.nfev

def load_qrels_scifact():
   
    qrels_all = load_dataset("BeIR/scifact-qrels", split="train")
    sample = qrels_all[0]
    r_qid_key   = next((k for k in sample.keys() if "query" in k.lower()), None)
    r_doc_key   = next((k for k in sample.keys() if "corpus" in k.lower() or "doc" in k.lower()), None)
    r_score_key = next((k for k in sample.keys() if k.lower() in {"score","label","relevance"}), None)
    if not (r_qid_key and r_doc_key and r_score_key):
        raise RuntimeError("scifact: 无法检测到 qrels 字段。")
    pos_dict = {}
    for r in qrels_all:
        if int(r[r_score_key]) > 0:
            qid = str(r[r_qid_key])
            did = str(r[r_doc_key])
            pos_dict.setdefault(qid, []).append(did)
    if not pos_dict:
        raise RuntimeError("scifact: qrels 中没有任何正例！")
    return pos_dict

def prepare_scifact():
    queries = load_dataset("BeIR/scifact", "queries", split="queries")
    corpus  = load_dataset("BeIR/scifact", "corpus",  split="corpus")
    pos_dict = load_qrels_scifact()
    sample_q = queries[0]
    qid_key  = next((k for k in sample_q.keys() if "id" in k.lower()), None)
    text_key = next((k for k in sample_q.keys() if k.lower() in {"text","query","body"}), None)
    if not (qid_key and text_key):
        raise RuntimeError("scifact: not able to detect。")

    sample_d = corpus[0]
    doc_id_key   = next((k for k in sample_d.keys() if "id" in k.lower()), None)
    doc_text_key = next((k for k in sample_d.keys()
                         if "text" in k.lower() or "body" in k.lower() or "passage" in k.lower()), None)
    if not (doc_id_key and doc_text_key):
        raise RuntimeError("scifact: corpus error。")

    cand = [
        {"id": str(q[qid_key]), "text": q[text_key]}
        for q in queries
        if (str(q[qid_key]) in pos_dict) and isinstance(q.get(text_key), str) and is_good_q(q[text_key])
    ]
    if len(cand) == 0:
        raise RuntimeError("scifact: no query！")
    if len(cand) < N_Q:
        
        qs = random.choices(cand, k=N_Q)
    else:
        qs = random.sample(cand, N_Q)
    pos_ids = { pos_dict[q["id"]][0] for q in qs }
    all_doc_ids = [str(d[doc_id_key]) for d in corpus]
    other_ids = [did for did in all_doc_ids if did not in pos_ids]
    if len(other_ids) < N_DOCS - len(pos_ids):
        raise RuntimeError(f"scifact: 填充文档不足（需要 {N_DOCS - len(pos_ids)}, 只有 {len(other_ids)}）。")
    fillers = random.sample(other_ids, N_DOCS - len(pos_ids))
    sel_ids = list(pos_ids) + fillers
    id2text = { str(d[doc_id_key]): d[doc_text_key] for d in corpus }
    docs = [{"id": did, "text": id2text[did]} for did in sel_ids]
    pos_text = { did: id2text[did][:120].replace("\n", " ") for did in pos_ids }
    cpu_bert = BertModel.from_pretrained("bert-base-uncased").eval().to("cpu")
    CLS_list = []
    with torch.no_grad():
        for i in tqdm.tqdm(range(0, N_DOCS, BATCH_CLS), desc="CLS-scifact"):
            batch = docs[i : i + BATCH_CLS]
            batch_texts = [item["text"] for item in batch]
            out = cpu_bert(**enc(batch_texts, "cpu")).last_hidden_state[:, 0, :]
            CLS_list.append(out)
    C_CLS = torch.cat(CLS_list, dim=0)
    del cpu_bert
    torch.cuda.empty_cache()

    return docs, qs, pos_dict, pos_text, C_CLS
def run_loss_compare_scifact():
    out_dir = "exp_results/loss_compare/scifact"
    os.makedirs(out_dir, exist_ok=True)
    docs, qs, pos_dict, pos_text, C_CLS = prepare_scifact()
    global bert
    bert = BertModel.from_pretrained("bert-base-uncased",
                                     torch_dtype=torch.float16).to(DEVICE).eval()

    rec = []  
    for mode in ("hinge", "cos"):
        pbar = tqdm.tqdm(qs, desc=f"scifact-{mode}")
        for q in pbar:
            qtxt = q["text"]
            qid  = q["id"]
            ids, msk = get_ids_msk(qtxt, DEVICE)
            with torch.no_grad():
                qcls = bert(**enc([qtxt], DEVICE)).last_hidden_state[:, 0, :]

            sims = cos_row(qcls.cpu(), C_CLS)
            order = torch.argsort(sims, descending=True)
            true_pos_id      = pos_dict[qid][0]
            true_pos_excerpt = pos_text[true_pos_id]

            tgt_idx       = int(order[RANK_TARGET - 1].item())
            tgt_id        = docs[tgt_idx]["id"]
            tgt_excerpt   = docs[tgt_idx]["text"][:120].replace("\n", " ")
            baseline_sim  = sims[tgt_idx].item()
            pos_rank    = None
            pos_excerpt = None
            tgt_cls = C_CLS[tgt_idx : tgt_idx + 1].to(DEVICE)
            CP = torch.cat([C_CLS[:tgt_idx], C_CLS[tgt_idx+1:]], dim=0)

            suf, cls_adv, iters = de_opt(ids, msk, CP, tgt_cls, loss_mode=mode)
            new_sims = cos_row(cls_adv.cpu(), C_CLS)
            fr = (new_sims > new_sims[tgt_idx]).sum().item() + 1
            entry = {
                "dataset":             "scifact",
                "loss":                mode,
                "query":               qtxt,
                "orig_answer_id":      true_pos_id,
                "orig_answer_excerpt": true_pos_excerpt,
                "tgt_id":              tgt_id,
                "tgt_excerpt":         tgt_excerpt,
                "suffix":              suffix_str(suf),
                "suffix_len":          len(suf),
                "suffix_token_ids":    json.dumps(suf),
                "baseline_rank":       RANK_TARGET,
                "final_rank":          fr,
                "delta_rank":          RANK_TARGET - fr,
                "delta_cos":           new_sims[tgt_idx].item() - baseline_sim,
                "delta_ndcg":          dcg(fr) - dcg(RANK_TARGET),
                "success":             int(fr == 1),
                "iter_used":           iters,
                "pos_rank":            pos_rank,
                "pos_excerpt":         pos_excerpt,
            }
            rec.append(entry)
    df = pd.DataFrame(rec)
    df.to_csv(f"{out_dir}/loss_detail.csv", index=False, encoding="utf-8")
    agg = df.groupby("loss").agg(
        success_rate   = ("success",   "mean"),
        avg_iters      = ("iter_used", "mean"),
        avg_delta_rank = ("delta_rank","mean"),
        avg_delta_cos  = ("delta_cos", "mean"),
    ).reset_index()
    agg.to_csv(f"{out_dir}/loss_aggregate.csv", index=False, encoding="utf-8")

    h = agg.loc[agg.loss == "hinge", "success_rate"].values[0]
    c = agg.loc[agg.loss == "cos",   "success_rate"].values[0]
    print(f"✓ scifact done  –  hinge success = {h:.2%},  cos success = {c:.2%}")
    del bert
    torch.cuda.empty_cache()
if __name__ == "__main__":
    os.makedirs("exp_results/loss_compare/scifact", exist_ok=True)
    run_loss_compare_scifact()
