# DeRAG Prompt Attack Tutorial  
**Part 1 – Retrieval-Augmented Generation (RAG)**  
Jerry Wang  
August 2025

This notebook will demonstrate how to perform black-box prompt injection attacks on Retrieval-Augmented Generation (RAG) systems using Differential Evolution (DE). Inspired by the one-pixel attack in vision, our goal is to craft a short adversarial suffix that re-ranks retrieval results—causing a target incorrect passage to be retrieved instead of the correct one.

DeRAG treats the entire RAG pipeline as a black box: we do not require gradient access, model internals, or retriever weights. Instead, we evolve a population of suffix candidates and measure their ranking success through similarity metrics, aiming to minimize the number of tokens needed for a successful attack.

In theory, robust RAG models should be resistant to such small manipulations. However, our results reveal that even a few appended tokens (often ≤ 5) can reliably mislead state-of-the-art retrievers. This demonstrates a fundamental vulnerability in many LLM-based QA systems.

To learn more, please refer to our paper:
**“DeRAG: Black-box Adversarial Attacks on Retrieval-Augmented Generation Applications via Prompt Injection”**  
Presented at KDD Workshop on Prompt Optimization (2025).  
[GitHub Repo](https://github.com/pen9rum/Rag_attack_DeRag)

Let's get started.

---

##  Imports  
Ensure that you have the following packages installed:
```bash
!pip install -q datasets transformers scipy  ir-datasets tqdm --upgrade 


In [2]:
!pip -q install datasets  ir-datasets transformers scipy tqdm --upgrade


## DeRAG Attack Across Four Datasets

This notebook implements and benchmarks DeRAG—a black-box prompt injection attack using Differential Evolution (DE)—across four benchmark datasets: **MS MARCO**, **FiQA**, **FEVER**, and **SciFact**. Each dataset is drawn from the [BEIR benchmark suite](https://arxiv.org/abs/2104.08663), covering domains from financial QA to fact verification.

For each dataset, we randomly sample 1,000 documents and 100 queries. Then, for every query, we select a target "incorrect" document from the non-relevant set and attempt to promote it into the Top-1, Top-10, or Top-20 retrieved results using several attack methods.

The evaluated attack methods include:
- `none`: Baseline retrieval without perturbation.
- `random`: Uniformly sampling suffix tokens.
- `greedy`: HotFlip-style greedy search.
- `ggpp`: Gradient-Guided Prompt Perturbation (white-box).
- `DE_fixed`: Fixed-length DE without early stopping.
- `DE_seq`: Sequential DE with incremental suffix lengths.
- `DE_fixed_stop`: Fixed-length DE with early stopping.
- `DE_seq_stop`: Sequential DE with early stopping (default method).

Each method is evaluated using multiple metrics:
- `Success@K`: Whether the target document appears in the top-K results.
- `Token Used`: Number of tokens in the adversarial suffix.
- `Iteration Count`: Optimization steps required.
- `ΔMRR`: Change in Mean Reciprocal Rank.
- `ΔnDCG@20`: Change in ranking quality at cutoff 20.
- `ΔCosine`: Change in semantic similarity between the query and the target document.

To begin running the experiments, simply execute the `run_all()` function for your desired ranking threshold:

```python
for K in (1, 10, 20):
    run_all(topk_loss=K, save_dir=f"results_top{K}")


In [6]:
import os, random, math, tqdm, warnings, numpy as np, pandas as pd, torch
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS     = 1_000
N_Q        = 100
TAIL_L     = 5         
BUDGET     = 150       
PATIENCE   = 20       
BATCH_CLS  = 16     

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank): return 1 / math.log2(rank + 1)

def load_subset():
    c = load_dataset("BeIR/scifact", "corpus",  split="corpus")
    q = load_dataset("BeIR/scifact", "queries", split="queries")
    docs = random.sample(list(c), N_DOCS)
    qs   = random.sample(list(q), N_Q)
    return [d["text"] for d in docs], [x["text"] for x in qs]

def ggpp_full(tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
              cap_L=5, max_epoch=150, topk_loss=20):
    body = tok(tgt_txt, add_special_tokens=False,
               truncation=True, max_length=510)["input_ids"]
    body += [tok.unk_token_id] * max(0, cap_L - len(body))

    base = bert(**tok(tgt_txt, return_tensors="pt",
                      truncation=True, max_length=512).to(DEVICE)
               ).last_hidden_state[0, 0]
    imp = []
    for i in range(cap_L):
        tmp = body.copy(); tmp[i] = tok.mask_token_id
        tens = torch.tensor([tok.cls_token_id]+tmp+[tok.sep_token_id]
                           ).unsqueeze(0).to(DEVICE)
        emb = bert(input_ids=tens,
                   attention_mask=torch.ones_like(tens)).last_hidden_state[0, 0]
        imp.append(1 - torch.nn.functional.cosine_similarity(base, emb, dim=0).item())
    prefix = [body[i] for i in np.argsort(imp)[-cap_L:]]

    pos   = list(range(502, 502+cap_L))
    patch = ids.clone()
    for p, v in zip(pos, prefix): patch[p] = v
    am = msk.clone(); am[pos] = 1
    W  = bert.embeddings.word_embeddings.weight

    def loss_fn(cls_vec):
        kth=torch.topk(cos_row(cls_vec,CP),topk_loss).values[-1]
        sim=torch.nn.functional.cosine_similarity(cls_vec,tgt_cls)[0]
        return torch.relu(kth-sim)

    best_cls=bert(input_ids=patch.unsqueeze(0),
                  attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
    best_loss=loss_fn(best_cls).item(); used_iter=0

    for _ in range(max_epoch):
        used_iter+=1
        emb=bert.embeddings.word_embeddings(
            patch.unsqueeze(0)).detach().clone().requires_grad_(True)
        cls=bert(inputs_embeds=emb,
                 attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
        loss=loss_fn(cls); loss.backward()
        if loss.item()==0: best_cls=cls; break
        grad=emb.grad[0,pos]; score=torch.matmul(grad,W.t())
        cand_ids=score.topk(5,largest=False,dim=1).indices.cpu()
        improved=False
        for mask in range(1,1<<cap_L):
            cand=patch.clone()
            for i in range(cap_L):
                if mask&(1<<i): cand[pos[i]]=random.choice(cand_ids[i]).item()
            cls2=bert(input_ids=cand.unsqueeze(0),
                      attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
            l2=loss_fn(cls2).item()
            if l2<best_loss:
                best_loss=l2; patch=cand; best_cls=cls2; improved=True
            if l2==0: break
        if not improved: break
    return best_cls,cap_L,used_iter


def run_all(topk_loss, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    tok  = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased").eval().to(DEVICE)
    VOC  = tok.vocab_size
    def enc(t): return tok(t, padding="max_length", truncation=True,
                           max_length=512, return_tensors="pt").to(DEVICE)

    docs, qs = load_subset()
    CLS=[]
    with torch.inference_mode():
        for i in tqdm.tqdm(range(0, len(docs), BATCH_CLS), desc="CLS"):
            bt = enc(docs[i:i+BATCH_CLS])
            cls= bert(**bt).last_hidden_state[:, 0, :].cpu()
            CLS.append(cls)
            del bt; torch.cuda.empty_cache()
    C_CLS = torch.cat(CLS)

    METHODS = ("none", "random", "greedy", "ggpp",
               "DE_fixed", "DE_seq", "DE_fixed_stop", "DE_seq_stop")
    rec=[]

    for qtxt in tqdm.tqdm(qs, desc=f"Top-{topk_loss}"):
        tgt = random.randrange(len(docs))
        tgt_txt = docs[tgt]
        if len(tok(tgt_txt)["input_ids"]) > 510: continue
        tgt_cls = C_CLS[tgt:tgt+1].to(DEVICE)
        CP = C_CLS[[i for i in range(len(docs)) if i != tgt]]

        qenc = enc(qtxt); ids = qenc["input_ids"][0]; msk = qenc["attention_mask"][0]
        with torch.no_grad():
            qcls = bert(**qenc).last_hidden_state[:, 0, :]

        base = torch.cat([
            cos_row(qcls, CP).cpu(),    
            torch.nn.functional.cosine_similarity(qcls.cpu(), tgt_cls.cpu())
        ])
        rank_b = (base > base[-1]).sum().item() + 1

        for mtd in METHODS:
            success=False; used_L=0; used_iter=0; adv_cls=qcls
            if mtd=="random":
                best_loss=1e9
                for _ in range(BUDGET):
                    patch=ids.clone()
                    for p in range(502, 502+TAIL_L):
                        patch[p] = random.randrange(VOC)
                    am = msk.clone(); am[502:502+TAIL_L] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    loss= max(0., (kth - torch.nn.functional.cosine_similarity(
                                        cls, tgt_cls)[0]).item())
                    used_iter += 1
                    if loss < best_loss:
                        best_loss=loss; adv_cls=cls; success=(loss==0)
                    if success: break
                used_L=TAIL_L
            elif mtd=="greedy":
                pos=list(range(502, 502+TAIL_L)); patch=ids.clone()
                best_loss=1e9
                for _ in range(BUDGET):
                    used_iter += 1; improved=False
                    for p in pos:
                        best_id=patch[p].item()
                        for cand in random.sample(range(VOC),512):
                            patch[p]=cand
                            am = msk.clone(); am[pos]=1
                            cls=bert(input_ids=patch.unsqueeze(0),
                                     attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                            kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                            loss=max(0.,(kth-torch.nn.functional.cosine_similarity(
                                              cls,tgt_cls)[0]).item())
                            if loss<best_loss:
                                best_loss=loss; best_id=cand; adv_cls=cls; improved=True
                            if loss==0: success=True; break
                        patch[p]=best_id
                        if success: break
                    if success or not improved: break
                used_L=TAIL_L
            elif mtd=="ggpp":
                adv_cls, used_L, used_iter = ggpp_full(
                    tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
                    cap_L=TAIL_L, max_epoch=BUDGET, topk_loss=topk_loss)
                kth = torch.topk(cos_row(adv_cls, CP), topk_loss).values[-1]
                success = (kth <= torch.nn.functional.cosine_similarity(
                                      adv_cls, tgt_cls)[0])
            def de_run(L, max_iter, plateau):
                pos=list(range(502, 502+L)); bounds=[(0, VOC-1)]*L
                gens = max_iter//20 if max_iter else 1000
                stop = [0]; best=[1e9]
                def obj(v):
                    v=[int(round(x)) for x in v]
                    patch=ids.clone()
                    for p,t in zip(pos,v): patch[p]=t
                    am=msk.clone(); am[pos]=1
                    cls=bert(input_ids=patch.unsqueeze(0),
                             attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                    sim=torch.nn.functional.cosine_similarity(cls,tgt_cls)[0]
                    return max(0.,(kth-sim).item())
                def cb(xk,_):
                    if not plateau: return False
                    cur=obj(xk)
                    stop[0] = 0 if cur < best[0] else stop[0]+1
                    best[0] = min(best[0], cur)
                    return stop[0] >= PATIENCE or cur == 0
                res = differential_evolution(obj, bounds, popsize=20,
                                             maxiter=gens, tol=0,
                                             polish=False, seed=SEED,
                                             callback=cb)
                v=[int(round(x)) for x in res.x]
                patch=ids.clone()
                for p,t in zip(pos,v): patch[p]=t
                am=msk.clone(); am[pos]=1
                cls=bert(input_ids=patch.unsqueeze(0),
                         attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                return cls, res.nfev, res.fun==0

            if mtd=="DE_fixed":
                adv_cls, used_iter, success = de_run(TAIL_L, BUDGET, plateau=False); used_L=TAIL_L
            if mtd=="DE_fixed_stop":
                adv_cls, used_iter, success = de_run(TAIL_L, None, plateau=True); used_L=TAIL_L
            if mtd=="DE_seq":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, BUDGET, plateau=False)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L
            if mtd=="DE_seq_stop":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, None, plateau=True)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L-
            sims = torch.cat([
                cos_row(adv_cls.cpu(), CP),      
                torch.nn.functional.cosine_similarity(
                    adv_cls.cpu(), tgt_cls.cpu())
            ])
            rank_a = (sims > sims[-1]).sum().item() + 1
            d_mrr  = 1/rank_a - 1/rank_b
            d_ndcg = ((dcg(rank_a) if rank_a<=20 else 0) -
                      (dcg(rank_b) if rank_b<=20 else 0))
            d_cos  = (torch.nn.functional.cosine_similarity(
                        adv_cls, tgt_cls)[0] -
                      torch.nn.functional.cosine_similarity(
                        qcls, tgt_cls)[0]).item()

            rec.append(dict(
                top_k      = topk_loss,
                method     = mtd,
                success    = int(rank_a <= topk_loss),
                token_used = used_L,
                iter_used  = used_iter,
                delta_mrr  = d_mrr,
                delta_ndcg = d_ndcg,
                delta_cos  = d_cos
            ))
        torch.cuda.empty_cache()  

    pd.DataFrame(rec).to_csv(f"{save_dir}/records.csv", index=False)
    print(f"✓ {save_dir}  rows = {len(rec)}")
for K in (1, 10, 20):
    run_all(topk_loss=K, save_dir=f"results_top{K}")
    torch.cuda.empty_cache()


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.58it/s]
Top-1:  53%|█████████████████████████████████████                                 | 53/100 [7:23:39<4:48:51, 368.76s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Top-1: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [13:45:01<00:00, 495.02s/it]


✓ results_top1  rows = 768


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:16<00:00,  3.78it/s]
Top-10:  10%|███████                                                                | 10/100 [48:14<7:12:15, 288.18s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (652 > 512). Running this sequence through the model will result in indexing errors
Top-10: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [8:25:42<00:00, 303.42s/it]


✓ results_top10  rows = 768


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:17<00:00,  3.69it/s]
Top-20:  36%|████████████████████████▊                                            | 36/100 [2:43:09<2:36:58, 147.16s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (899 > 512). Running this sequence through the model will result in indexing errors
Top-20: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [6:48:30<00:00, 245.10s/it]

✓ results_top20  rows = 736





In [9]:

import os, random, math, tqdm, warnings, numpy as np, pandas as pd, torch
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS     = 1_000
N_Q        = 100
TAIL_L     = 5          
BUDGET     = 150        
PATIENCE   = 20         
BATCH_CLS  = 16         

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank): return 1 / math.log2(rank + 1)

def load_subset():
    c = load_dataset("BeIR/fiqa", "corpus",  split="corpus")
    q = load_dataset("BeIR/fiqa", "queries", split="queries")
    docs = random.sample(list(c), N_DOCS)
    qs   = random.sample(list(q), N_Q)
    return [d["text"] for d in docs], [x["text"] for x in qs]

def ggpp_full(tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
              cap_L=5, max_epoch=150, topk_loss=20):
    body = tok(tgt_txt, add_special_tokens=False,
               truncation=True, max_length=510)["input_ids"]
    body += [tok.unk_token_id] * max(0, cap_L - len(body))

    base = bert(**tok(tgt_txt, return_tensors="pt",
                      truncation=True, max_length=512).to(DEVICE)
               ).last_hidden_state[0, 0]
    imp = []
    for i in range(cap_L):
        tmp = body.copy(); tmp[i] = tok.mask_token_id
        tens = torch.tensor([tok.cls_token_id]+tmp+[tok.sep_token_id]
                           ).unsqueeze(0).to(DEVICE)
        emb = bert(input_ids=tens,
                   attention_mask=torch.ones_like(tens)).last_hidden_state[0, 0]
        imp.append(1 - torch.nn.functional.cosine_similarity(base, emb, dim=0).item())
    prefix = [body[i] for i in np.argsort(imp)[-cap_L:]]

    pos   = list(range(502, 502+cap_L))
    patch = ids.clone()
    for p, v in zip(pos, prefix): patch[p] = v
    am = msk.clone(); am[pos] = 1
    W  = bert.embeddings.word_embeddings.weight

    def loss_fn(cls_vec):
        kth=torch.topk(cos_row(cls_vec,CP),topk_loss).values[-1]
        sim=torch.nn.functional.cosine_similarity(cls_vec,tgt_cls)[0]
        return torch.relu(kth-sim)

    best_cls=bert(input_ids=patch.unsqueeze(0),
                  attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
    best_loss=loss_fn(best_cls).item(); used_iter=0

    for _ in range(max_epoch):
        used_iter+=1
        emb=bert.embeddings.word_embeddings(
            patch.unsqueeze(0)).detach().clone().requires_grad_(True)
        cls=bert(inputs_embeds=emb,
                 attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
        loss=loss_fn(cls); loss.backward()
        if loss.item()==0: best_cls=cls; break
        grad=emb.grad[0,pos]; score=torch.matmul(grad,W.t())
        cand_ids=score.topk(5,largest=False,dim=1).indices.cpu()
        improved=False
        for mask in range(1,1<<cap_L):
            cand=patch.clone()
            for i in range(cap_L):
                if mask&(1<<i): cand[pos[i]]=random.choice(cand_ids[i]).item()
            cls2=bert(input_ids=cand.unsqueeze(0),
                      attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
            l2=loss_fn(cls2).item()
            if l2<best_loss:
                best_loss=l2; patch=cand; best_cls=cls2; improved=True
            if l2==0: break
        if not improved: break
    return best_cls,cap_L,used_iter

def run_all(topk_loss, save_dir):
    os.makedirs(f"{save_dir}/fiqa", exist_ok=True)

    tok  = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased").eval().to(DEVICE)
    VOC  = tok.vocab_size
    def enc(t): return tok(t, padding="max_length", truncation=True,
                           max_length=512, return_tensors="pt").to(DEVICE)

    docs, qs = load_subset()
    CLS=[]
    with torch.inference_mode():
        for i in tqdm.tqdm(range(0, len(docs), BATCH_CLS), desc="CLS"):
            bt = enc(docs[i:i+BATCH_CLS])
            cls= bert(**bt).last_hidden_state[:, 0, :].cpu()
            CLS.append(cls)
            del bt; torch.cuda.empty_cache()
    C_CLS = torch.cat(CLS)

    METHODS = ("none", "random", "ggpp",
               "DE_fixed", "DE_seq", "DE_fixed_stop", "DE_seq_stop") 
    rec=[]
    for qtxt in tqdm.tqdm(qs, desc=f"Top-{topk_loss}"):
        tgt = random.randrange(len(docs))
        tgt_txt = docs[tgt]
        if len(tok(tgt_txt)["input_ids"]) > 510: continue
        tgt_cls = C_CLS[tgt:tgt+1].to(DEVICE)
        CP = C_CLS[[i for i in range(len(docs)) if i != tgt]]

        qenc = enc(qtxt); ids = qenc["input_ids"][0]; msk = qenc["attention_mask"][0]
        with torch.no_grad():
            qcls = bert(**qenc).last_hidden_state[:, 0, :]
        base = torch.cat([
            cos_row(qcls, CP).cpu(),     # ★ 移到 CPU
            torch.nn.functional.cosine_similarity(qcls.cpu(), tgt_cls.cpu())
        ])
        rank_b = (base > base[-1]).sum().item() + 1

        for mtd in METHODS:
            success=False; used_L=0; used_iter=0; adv_cls=qcls
            if mtd=="random":
                best_loss=1e9
                for _ in range(BUDGET):
                    patch=ids.clone()
                    for p in range(502, 502+TAIL_L):
                        patch[p] = random.randrange(VOC)
                    am = msk.clone(); am[502:502+TAIL_L] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    loss= max(0., (kth - torch.nn.functional.cosine_similarity(
                                        cls, tgt_cls)[0]).item())
                    used_iter += 1
                    if loss < best_loss:
                        best_loss=loss; adv_cls=cls; success=(loss==0)
                    if success: break
                used_L=TAIL_L
            elif mtd=="greedy":
                pos=list(range(502, 502+TAIL_L)); patch=ids.clone()
                best_loss=1e9
                for _ in range(BUDGET):
                    used_iter += 1; improved=False
                    for p in pos:
                        best_id=patch[p].item()
                        for cand in random.sample(range(VOC),512):
                            patch[p]=cand
                            am = msk.clone(); am[pos]=1
                            cls=bert(input_ids=patch.unsqueeze(0),
                                     attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                            kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                            loss=max(0.,(kth-torch.nn.functional.cosine_similarity(
                                              cls,tgt_cls)[0]).item())
                            if loss<best_loss:
                                best_loss=loss; best_id=cand; adv_cls=cls; improved=True
                            if loss==0: success=True; break
                        patch[p]=best_id
                        if success: break
                    if success or not improved: break
                used_L=TAIL_L
            elif mtd=="ggpp":
                adv_cls, used_L, used_iter = ggpp_full(
                    tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
                    cap_L=TAIL_L, max_epoch=BUDGET, topk_loss=topk_loss)
                kth = torch.topk(cos_row(adv_cls, CP), topk_loss).values[-1]
                success = (kth <= torch.nn.functional.cosine_similarity(
                                      adv_cls, tgt_cls)[0])
            def de_run(L, max_iter, plateau):
                pos=list(range(502, 502+L)); bounds=[(0, VOC-1)]*L
                gens = max_iter//20 if max_iter else 1000
                stop = [0]; best=[1e9]
                def obj(v):
                    v=[int(round(x)) for x in v]
                    patch=ids.clone()
                    for p,t in zip(pos,v): patch[p]=t
                    am=msk.clone(); am[pos]=1
                    cls=bert(input_ids=patch.unsqueeze(0),
                             attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                    sim=torch.nn.functional.cosine_similarity(cls,tgt_cls)[0]
                    return max(0.,(kth-sim).item())
                def cb(xk,_):
                    if not plateau: return False
                    cur=obj(xk)
                    stop[0] = 0 if cur < best[0] else stop[0]+1
                    best[0] = min(best[0], cur)
                    return stop[0] >= PATIENCE or cur == 0
                res = differential_evolution(obj, bounds, popsize=20,
                                             maxiter=gens, tol=0,
                                             polish=False, seed=SEED,
                                             callback=cb)
                v=[int(round(x)) for x in res.x]
                patch=ids.clone()
                for p,t in zip(pos,v): patch[p]=t
                am=msk.clone(); am[pos]=1
                cls=bert(input_ids=patch.unsqueeze(0),
                         attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                return cls, res.nfev, res.fun==0

            if mtd=="DE_fixed":
                adv_cls, used_iter, success = de_run(TAIL_L, BUDGET, plateau=False); used_L=TAIL_L
            if mtd=="DE_fixed_stop":
                adv_cls, used_iter, success = de_run(TAIL_L, None, plateau=True); used_L=TAIL_L
            if mtd=="DE_seq":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, BUDGET, plateau=False)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L
            if mtd=="DE_seq_stop":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, None, plateau=True)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L
            sims = torch.cat([
                cos_row(adv_cls.cpu(), CP),       
                torch.nn.functional.cosine_similarity(
                    adv_cls.cpu(), tgt_cls.cpu())
            ])
            rank_a = (sims > sims[-1]).sum().item() + 1
            d_mrr  = 1/rank_a - 1/rank_b
            d_ndcg = ((dcg(rank_a) if rank_a<=20 else 0) -
                      (dcg(rank_b) if rank_b<=20 else 0))
            d_cos  = (torch.nn.functional.cosine_similarity(
                        adv_cls, tgt_cls)[0] -
                      torch.nn.functional.cosine_similarity(
                        qcls, tgt_cls)[0]).item()

            rec.append(dict(
                top_k      = topk_loss,
                method     = mtd,
                success    = int(rank_a <= topk_loss),
                token_used = used_L,
                iter_used  = used_iter,
                delta_mrr  = d_mrr,
                delta_ndcg = d_ndcg,
                delta_cos  = d_cos
            ))
        torch.cuda.empty_cache()   
    pd.DataFrame(rec).to_csv(f"{save_dir}/fiqa/records.csv", index=False)
    print(f"✓ {save_dir}  rows = {len(rec)}")

for K in (1, 10, 20):
    run_all(topk_loss=K, save_dir=f"results_top{K}")
    torch.cuda.empty_cache()


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:12<00:00,  4.89it/s]
Top-1:   0%|                                                                                    | 0/100 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1585 > 512). Running this sequence through the model will result in indexing errors
Top-1: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [10:01:57<00:00, 361.18s/it]


✓ results_top1  rows = 665


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:12<00:00,  4.87it/s]
Top-10:  73%|██████████████████████████████████████████████████▎                  | 73/100 [5:35:33<2:01:09, 269.25s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (664 > 512). Running this sequence through the model will result in indexing errors
Top-10: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [7:19:14<00:00, 263.55s/it]


✓ results_top10  rows = 686


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.79it/s]
Top-20:  62%|██████████████████████████████████████████▊                          | 62/100 [4:15:20<2:21:34, 223.55s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (560 > 512). Running this sequence through the model will result in indexing errors
Top-20: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [7:03:33<00:00, 254.14s/it]

✓ results_top20  rows = 693





In [32]:

import os, random, math, tqdm, warnings, numpy as np, pandas as pd, torch
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 41
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS     = 1_000
N_Q        = 100
TAIL_L     = 5         
BUDGET     = 150      
PATIENCE   = 20        
BATCH_CLS  = 16       

def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank): return 1 / math.log2(rank + 1)
def load_subset():
    ds = load_dataset("microsoft/ms_marco", "v1.1", split="validation")  # ~10 k 条
    all_passages, all_queries = [], []
    for ex in ds:
        all_queries.append(ex["query"])
        all_passages.extend(ex["passages"])

    docs = random.sample(all_passages, N_DOCS)
    qs   = random.sample(all_queries, N_Q)
    return docs, qs
def ggpp_full(tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
              cap_L=5, max_epoch=150, topk_loss=20):
    body = tok(tgt_txt, add_special_tokens=False,
               truncation=True, max_length=510)["input_ids"]
    body += [tok.unk_token_id] * max(0, cap_L - len(body))

    base = bert(**tok(tgt_txt, return_tensors="pt",
                      truncation=True, max_length=512).to(DEVICE)
               ).last_hidden_state[0, 0]
    imp = []
    for i in range(cap_L):
        tmp = body.copy(); tmp[i] = tok.mask_token_id
        tens = torch.tensor([tok.cls_token_id]+tmp+[tok.sep_token_id]
                           ).unsqueeze(0).to(DEVICE)
        emb = bert(input_ids=tens,
                   attention_mask=torch.ones_like(tens)).last_hidden_state[0, 0]
        imp.append(1 - torch.nn.functional.cosine_similarity(base, emb, dim=0).item())
    prefix = [body[i] for i in np.argsort(imp)[-cap_L:]]

    pos   = list(range(502, 502+cap_L))
    patch = ids.clone()
    for p, v in zip(pos, prefix): patch[p] = v
    am = msk.clone(); am[pos] = 1
    W  = bert.embeddings.word_embeddings.weight
    def loss_fn(cls_vec):
        kth=torch.topk(cos_row(cls_vec,CP),topk_loss).values[-1]
        sim=torch.nn.functional.cosine_similarity(cls_vec,tgt_cls)[0]
        return torch.relu(kth-sim)

    best_cls=bert(input_ids=patch.unsqueeze(0),
                  attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
    best_loss=loss_fn(best_cls).item(); used_iter=0

    for _ in range(max_epoch):
        used_iter+=1
        emb=bert.embeddings.word_embeddings(
            patch.unsqueeze(0)).detach().clone().requires_grad_(True)
        cls=bert(inputs_embeds=emb,
                 attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
        loss=loss_fn(cls); loss.backward()
        if loss.item()==0: best_cls=cls; break
        grad=emb.grad[0,pos]; score=torch.matmul(grad,W.t())
        cand_ids=score.topk(5,largest=False,dim=1).indices.cpu()
        improved=False
        for mask in range(1,1<<cap_L):
            cand=patch.clone()
            for i in range(cap_L):
                if mask&(1<<i): cand[pos[i]]=random.choice(cand_ids[i]).item()
            cls2=bert(input_ids=cand.unsqueeze(0),
                      attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
            l2=loss_fn(cls2).item()
            if l2<best_loss:
                best_loss=l2; patch=cand; best_cls=cls2; improved=True
            if l2==0: break
        if not improved: break
    return best_cls,cap_L,used_iter
def run_all(topk_loss, save_dir):
    os.makedirs(f"{save_dir}/msmarco", exist_ok=True)

    tok  = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased").eval().to(DEVICE)
    VOC  = tok.vocab_size
    def enc(t): return tok(t, padding="max_length", truncation=True,
                           max_length=512, return_tensors="pt").to(DEVICE)
    docs, qs = load_subset()
    CLS=[]
    with torch.inference_mode():
        for i in tqdm.tqdm(range(0, len(docs), BATCH_CLS), desc="CLS"):
            bt = enc(docs[i:i+BATCH_CLS])
            cls= bert(**bt).last_hidden_state[:, 0, :].cpu()
            CLS.append(cls)
            del bt; torch.cuda.empty_cache()
    C_CLS = torch.cat(CLS)

    METHODS = ("none", "random", "ggpp",
               "DE_fixed", "DE_seq", "DE_fixed_stop", "DE_seq_stop") #"greedy"
    rec=[]
    for qtxt in tqdm.tqdm(qs, desc=f"Top-{topk_loss}"):
        tgt = random.randrange(len(docs))
        tgt_txt = docs[tgt]
        if len(tok(tgt_txt)["input_ids"]) > 510: continue
        tgt_cls = C_CLS[tgt:tgt+1].to(DEVICE)
        CP = C_CLS[[i for i in range(len(docs)) if i != tgt]]

        qenc = enc(qtxt); ids = qenc["input_ids"][0]; msk = qenc["attention_mask"][0]
        with torch.no_grad():
            qcls = bert(**qenc).last_hidden_state[:, 0, :]-
        base = torch.cat([
            cos_row(qcls, CP).cpu(),     # ★ 移到 CPU
            torch.nn.functional.cosine_similarity(qcls.cpu(), tgt_cls.cpu())
        ])
        rank_b = (base > base[-1]).sum().item() + 1

        for mtd in METHODS:
            success=False; used_L=0; used_iter=0; adv_cls=qcls
            if mtd=="random":
                best_loss=1e9
                for _ in range(BUDGET):
                    patch=ids.clone()
                    for p in range(502, 502+TAIL_L):
                        patch[p] = random.randrange(VOC)
                    am = msk.clone(); am[502:502+TAIL_L] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    loss= max(0., (kth - torch.nn.functional.cosine_similarity(
                                        cls, tgt_cls)[0]).item())
                    used_iter += 1
                    if loss < best_loss:
                        best_loss=loss; adv_cls=cls; success=(loss==0)
                    if success: break
                used_L=TAIL_L
            elif mtd=="greedy":
                pos=list(range(502, 502+TAIL_L)); patch=ids.clone()
                best_loss=1e9
                for _ in range(BUDGET):
                    used_iter += 1; improved=False
                    for p in pos:
                        best_id=patch[p].item()
                        for cand in random.sample(range(VOC),512):
                            patch[p]=cand
                            am = msk.clone(); am[pos]=1
                            cls=bert(input_ids=patch.unsqueeze(0),
                                     attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                            kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                            loss=max(0.,(kth-torch.nn.functional.cosine_similarity(
                                              cls,tgt_cls)[0]).item())
                            if loss<best_loss:
                                best_loss=loss; best_id=cand; adv_cls=cls; improved=True
                            if loss==0: success=True; break
                        patch[p]=best_id
                        if success: break
                    if success or not improved: break
                used_L=TAIL_L
            elif mtd=="ggpp":
                adv_cls, used_L, used_iter = ggpp_full(
                    tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
                    cap_L=TAIL_L, max_epoch=BUDGET, topk_loss=topk_loss)
                kth = torch.topk(cos_row(adv_cls, CP), topk_loss).values[-1]
                success = (kth <= torch.nn.functional.cosine_similarity(
                                      adv_cls, tgt_cls)[0])
            def de_run(L, max_iter, plateau):
                pos=list(range(502, 502+L)); bounds=[(0, VOC-1)]*L
                gens = max_iter//20 if max_iter else 1000
                stop = [0]; best=[1e9]
                def obj(v):
                    v=[int(round(x)) for x in v]
                    patch=ids.clone()
                    for p,t in zip(pos,v): patch[p]=t
                    am=msk.clone(); am[pos]=1
                    cls=bert(input_ids=patch.unsqueeze(0),
                             attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                    kth=torch.topk(cos_row(cls,CP),topk_loss).values[-1]
                    sim=torch.nn.functional.cosine_similarity(cls,tgt_cls)[0]
                    return max(0.,(kth-sim).item())
                def cb(xk,_):
                    if not plateau: return False
                    cur=obj(xk)
                    stop[0] = 0 if cur < best[0] else stop[0]+1
                    best[0] = min(best[0], cur)
                    return stop[0] >= PATIENCE or cur == 0
                res = differential_evolution(obj, bounds, popsize=20,
                                             maxiter=gens, tol=0,
                                             polish=False, seed=SEED,
                                             callback=cb)
                v=[int(round(x)) for x in res.x]
                patch=ids.clone()
                for p,t in zip(pos,v): patch[p]=t
                am=msk.clone(); am[pos]=1
                cls=bert(input_ids=patch.unsqueeze(0),
                         attention_mask=am.unsqueeze(0)).last_hidden_state[:,0,:]
                return cls, res.nfev, res.fun==0

            if mtd=="DE_fixed":
                adv_cls, used_iter, success = de_run(TAIL_L, BUDGET, plateau=False); used_L=TAIL_L
            if mtd=="DE_fixed_stop":
                adv_cls, used_iter, success = de_run(TAIL_L, None, plateau=True); used_L=TAIL_L
            if mtd=="DE_seq":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, BUDGET, plateau=False)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L
            if mtd=="DE_seq_stop":
                for L in range(1, TAIL_L+1):
                    cls,iters,ok = de_run(L, None, plateau=True)
                    used_iter += iters
                    if ok: adv_cls=cls; used_L=L; success=True; break
                else: used_L=TAIL_L
            sims = torch.cat([
                cos_row(adv_cls.cpu(), CP),       # ★ 移到 CPU
                torch.nn.functional.cosine_similarity(
                    adv_cls.cpu(), tgt_cls.cpu())
            ])
            rank_a = (sims > sims[-1]).sum().item() + 1
            d_mrr  = 1/rank_a - 1/rank_b
            d_ndcg = ((dcg(rank_a) if rank_a<=20 else 0) -
                      (dcg(rank_b) if rank_b<=20 else 0))
            d_cos  = (torch.nn.functional.cosine_similarity(
                        adv_cls, tgt_cls)[0] -
                      torch.nn.functional.cosine_similarity(
                        qcls, tgt_cls)[0]).item()

            rec.append(dict(
                top_k      = topk_loss,
                method     = mtd,
                success    = int(rank_a <= topk_loss),
                token_used = used_L,
                iter_used  = used_iter,
                delta_mrr  = d_mrr,
                delta_ndcg = d_ndcg,
                delta_cos  = d_cos
            ))
        torch.cuda.empty_cache()   # 每 query 釋放

    pd.DataFrame(rec).to_csv(f"{save_dir}/msmarco/records.csv", index=False)
    print(f"✓ {save_dir}  rows = {len(rec)}")
for K in (1, 10, 20):
    run_all(topk_loss=K, save_dir=f"results_top{K}")
    torch.cuda.empty_cache()


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:09<00:00,  6.38it/s]
Top-1: 100%|████████████████████████████████████████████████████████████████████████| 100/100 [1:31:07<00:00, 54.67s/it]


✓ results_top1  rows = 700


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:12<00:00,  4.89it/s]
Top-10: 100%|███████████████████████████████████████████████████████████████████████| 100/100 [1:03:15<00:00, 37.95s/it]


✓ results_top10  rows = 700


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:09<00:00,  6.66it/s]
Top-20: 100%|███████████████████████████████████████████████████████████████████████| 100/100 [1:09:40<00:00, 41.80s/it]

✓ results_top20  rows = 700





In [3]:

import os, random, math, warnings, tqdm, numpy as np, pandas as pd, torch
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS     = 1_000
N_Q        = 100
TAIL_L     = 5
BUDGET     = 150
PATIENCE   = 20
BATCH_CLS  = 16
def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank): 
    return 1 / math.log2(rank + 1)
def load_subset():
    corpus  = load_dataset("BeIR/fever",   "corpus",  split="corpus")
    queries = load_dataset("BeIR/fever",   "queries", split="queries")

    docs = random.sample(list(corpus),  N_DOCS)
    qs   = random.sample(list(queries), N_Q)

    return [d["text"] for d in docs], [q["text"] for q in qs]

def ggpp_full(tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
              cap_L=5, max_epoch=150, topk_loss=20):

    body = tok(tgt_txt, add_special_tokens=False,
               truncation=True, max_length=510)["input_ids"]
    body += [tok.unk_token_id] * max(0, cap_L - len(body))

    base = bert(**tok(tgt_txt, return_tensors="pt",
                      truncation=True, max_length=512).to(DEVICE)
               ).last_hidden_state[0, 0]
    imp = []
    for i in range(cap_L):
        tmp = body.copy(); tmp[i] = tok.mask_token_id
        tens = torch.tensor([tok.cls_token_id] + tmp + [tok.sep_token_id]
                           ).unsqueeze(0).to(DEVICE)
        emb = bert(input_ids=tens,
                   attention_mask=torch.ones_like(tens)).last_hidden_state[0, 0]
        imp.append(1 - torch.nn.functional.cosine_similarity(base, emb, dim=0).item())

    prefix = [body[i] for i in np.argsort(imp)[-cap_L:]]

    pos   = list(range(502, 502 + cap_L))
    patch = ids.clone()
    for p, v in zip(pos, prefix):
        patch[p] = v
    am = msk.clone(); am[pos] = 1
    W  = bert.embeddings.word_embeddings.weight

    def loss_fn(cls_vec):
        kth = torch.topk(cos_row(cls_vec, CP), topk_loss).values[-1]
        sim = torch.nn.functional.cosine_similarity(cls_vec, tgt_cls)[0]
        return torch.relu(kth - sim)

    best_cls  = bert(input_ids=patch.unsqueeze(0),
                     attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
    best_loss = loss_fn(best_cls).item()
    used_iter = 0

    for _ in range(max_epoch):
        used_iter += 1
        emb = bert.embeddings.word_embeddings(
            patch.unsqueeze(0)).detach().clone().requires_grad_(True)
        cls = bert(inputs_embeds=emb,
                   attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
        loss = loss_fn(cls)
        loss.backward()
        if loss.item() == 0:
            best_cls = cls
            break
        grad  = emb.grad[0, pos]
        score = torch.matmul(grad, W.t())
        cand_ids = score.topk(5, largest=False, dim=1).indices.cpu()
        improved = False
        for mask in range(1, 1 << cap_L):
            cand = patch.clone()
            for i in range(cap_L):
                if mask & (1 << i):
                    cand[pos[i]] = random.choice(cand_ids[i]).item()
            cls2 = bert(input_ids=cand.unsqueeze(0),
                        attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
            l2 = loss_fn(cls2).item()
            if l2 < best_loss:
                best_loss = l2
                patch     = cand
                best_cls  = cls2
                improved  = True
            if l2 == 0:
                break
        if not improved:
            break
    return best_cls, cap_L, used_iter
def run_all(topk_loss: int, save_dir: str):

    os.makedirs(f"{save_dir}/fever", exist_ok=True)

    tok  = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased").eval().to(DEVICE)
    VOC  = tok.vocab_size
    enc  = lambda t: tok(t, padding="max_length", truncation=True,
                         max_length=512, return_tensors="pt").to(DEVICE)

    docs, qs = load_subset()
    CLS = []
    with torch.inference_mode():
        for i in tqdm.tqdm(range(0, len(docs), BATCH_CLS), desc="CLS"):
            bt  = enc(docs[i:i + BATCH_CLS])
            cls = bert(**bt).last_hidden_state[:, 0, :].cpu()
            CLS.append(cls)
            del bt
            torch.cuda.empty_cache()
    C_CLS = torch.cat(CLS)

    METHODS = (
        "none", "random", "ggpp",
        "DE_fixed", "DE_seq", "DE_fixed_stop", "DE_seq_stop"
    )

    rec = []
    for qtxt in tqdm.tqdm(qs, desc=f"Top-{topk_loss}"):
        tgt      = random.randrange(len(docs))
        tgt_txt  = docs[tgt]
        if len(tok(tgt_txt)["input_ids"]) > 510:
            continue
        tgt_cls  = C_CLS[tgt:tgt + 1].to(DEVICE)
        CP       = C_CLS[[i for i in range(len(docs)) if i != tgt]]

        qenc = enc(qtxt)
        ids, msk = qenc["input_ids"][0], qenc["attention_mask"][0]
        with torch.no_grad():
            qcls = bert(**qenc).last_hidden_state[:, 0, :]

        base = torch.cat([
            cos_row(qcls, CP).cpu(),
            torch.nn.functional.cosine_similarity(qcls.cpu(), tgt_cls.cpu())
        ])
        rank_b = (base > base[-1]).sum().item() + 1
        for mtd in METHODS:
            success   = False
            used_L    = 0
            used_iter = 0
            adv_cls   = qcls

            # ----- none -----
            if mtd == "none":
                used_L  = 0
                success = rank_b <= topk_loss

            # ----- random -----
            elif mtd == "random":
                best_loss = 1e9
                for _ in range(BUDGET):
                    patch = ids.clone()
                    for p in range(502, 502 + TAIL_L):
                        patch[p] = random.randrange(VOC)
                    am  = msk.clone(); am[502:502 + TAIL_L] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
                    kth  = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    loss = max(0., (kth - torch.nn.functional.cosine_similarity(
                        cls, tgt_cls)[0]).item())
                    used_iter += 1
                    if loss < best_loss:
                        best_loss = loss
                        adv_cls   = cls
                        success   = (loss == 0)
                    if success:
                        break
                used_L = TAIL_L

            elif mtd == "ggpp":
                adv_cls, used_L, used_iter = ggpp_full(
                    tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
                    cap_L=TAIL_L, max_epoch=BUDGET, topk_loss=topk_loss
                )
                kth     = torch.topk(cos_row(adv_cls, CP), topk_loss).values[-1]
                success = (kth <= torch.nn.functional.cosine_similarity(
                    adv_cls, tgt_cls)[0])-
            def de_run(L, max_iter, plateau):
                pos    = list(range(502, 502 + L))
                bounds = [(0, VOC - 1)] * L
                gens   = max_iter // 20 if max_iter else 1000
                stop   = [0]; best = [1e9]

                def obj(v):
                    v = [int(round(x)) for x in v]
                    patch = ids.clone()
                    for p, t in zip(pos, v):
                        patch[p] = t
                    am  = msk.clone(); am[pos] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    sim = torch.nn.functional.cosine_similarity(cls, tgt_cls)[0]
                    return max(0., (kth - sim).item())

                def cb(xk, _):
                    if not plateau:
                        return False
                    cur = obj(xk)
                    stop[0] = 0 if cur < best[0] else stop[0] + 1
                    best[0] = min(best[0], cur)
                    return stop[0] >= PATIENCE or cur == 0

                res = differential_evolution(
                    obj, bounds, popsize=20, maxiter=gens,
                    tol=0, polish=False, seed=SEED, callback=cb
                )

                v = [int(round(x)) for x in res.x]
                patch = ids.clone()
                for p, t in zip(pos, v):
                    patch[p] = t
                am  = msk.clone(); am[pos] = 1
                cls = bert(input_ids=patch.unsqueeze(0),
                           attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
                return cls, res.nfev, res.fun == 0
            if mtd == "DE_fixed":
                adv_cls, used_iter, success = de_run(TAIL_L, BUDGET, plateau=False)
                used_L = TAIL_L
            elif mtd == "DE_fixed_stop":
                adv_cls, used_iter, success = de_run(TAIL_L, None, plateau=True)
                used_L = TAIL_L
            elif mtd == "DE_seq":
                for L in range(1, TAIL_L + 1):
                    cls, iters, ok = de_run(L, BUDGET, plateau=False)
                    used_iter += iters
                    if ok:
                        adv_cls = cls; used_L = L; success = True; break
                else:
                    used_L = TAIL_L
            elif mtd == "DE_seq_stop":
                for L in range(1, TAIL_L + 1):
                    cls, iters, ok = de_run(L, None, plateau=True)
                    used_iter += iters
                    if ok:
                        adv_cls = cls; used_L = L; success = True; break
                else:
                    used_L = TAIL_L
            sims = torch.cat([
                cos_row(adv_cls.cpu(), CP),
                torch.nn.functional.cosine_similarity(
                    adv_cls.cpu(), tgt_cls.cpu())
            ])
            rank_a = (sims > sims[-1]).sum().item() + 1
            d_mrr  = 1 / rank_a - 1 / rank_b
            d_ndcg = ((dcg(rank_a) if rank_a <= 20 else 0) -
                      (dcg(rank_b) if rank_b <= 20 else 0))
            d_cos  = (torch.nn.functional.cosine_similarity(
                adv_cls, tgt_cls)[0] -
                      torch.nn.functional.cosine_similarity(
                qcls, tgt_cls)[0]).item()

            rec.append(dict(
                top_k      = topk_loss,
                method     = mtd,
                success    = int(rank_a <= topk_loss),
                token_used = used_L,
                iter_used  = used_iter,
                delta_mrr  = d_mrr,
                delta_ndcg = d_ndcg,
                delta_cos  = d_cos
            ))
        torch.cuda.empty_cache()

    pd.DataFrame(rec).to_csv(f"{save_dir}/fever/records.csv", index=False)
    print(f"✓ {save_dir}  rows={len(rec)}")

if __name__ == "__main__":
    for K in (1, 10, 20):
        run_all(topk_loss=K, save_dir=f"results_top{K}")
        torch.cuda.empty_cache()


  from .autonotebook import tqdm as notebook_tqdm
CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:15<00:00,  3.97it/s]
Top-1:  14%|█████████▋                                                           | 14/100 [1:35:10<10:07:33, 423.88s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (674 > 512). Running this sequence through the model will result in indexing errors
Top-1: 100%|███████████████████████████████████████████████████████████████████████| 100/100 [9:43:08<00:00, 349.89s/it]


✓ results_top1  rows=686


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:16<00:00,  3.90it/s]
Top-10:  21%|██████████████▍                                                      | 21/100 [1:31:07<6:17:37, 286.81s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (631 > 512). Running this sequence through the model will result in indexing errors
Top-10: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [7:30:15<00:00, 270.15s/it]


✓ results_top10  rows=693


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:15<00:00,  3.98it/s]
Top-20:  37%|█████████████████████████▌                                           | 37/100 [2:07:32<2:00:56, 115.18s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (583 > 512). Running this sequence through the model will result in indexing errors
Top-20: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [6:13:48<00:00, 224.28s/it]

✓ results_top20  rows=686



