In [2]:
# ============================================================
# Tail-Patch 8 方法 × Top-{1,10,20}
#  - 包含 plateau-stop 版本 (DE_fixed_stop, DE_seq_stop)
# ============================================================

!pip -q install datasets transformers scipy tqdm --upgrade


In [4]:
# ================================================================
#  Tail-Patch Attack Experiments
#  - 8 methods × {Top-1, Top-10, Top-20}
#  - corpus 1 000 docs , 100 queries  (BeIR/scifact)
#  - BERT-base-uncased encoder
# ================================================================


import os, random, math, tqdm, numpy as np, pandas as pd, torch, warnings
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from scipy.optimize import differential_evolution

warnings.filterwarnings("ignore")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
N_DOCS     = 1_000
N_Q        = 100
TAIL_L     = 5         # 固定補丁長度
BUDGET     = 150       # 固定預算 (= f-eval 或 epoch 次數)
PATIENCE   = 20        # plateau 20 代早停
BATCH_CLS  = 16        # 計算 CLS 時的批次，大幅影響顯存

# ---------------------------------------------------------------
def cos_row(x, Y):
    return torch.nn.functional.cosine_similarity(
        x.expand_as(Y.to(x.device)), Y.to(x.device), dim=1)

def dcg(rank): return 1 / math.log2(rank + 1)

def load_subset():
    corpus  = load_dataset("BeIR/scifact", "corpus",  split="corpus")
    queries = load_dataset("BeIR/scifact", "queries", split="queries")
    docs = random.sample(list(corpus),  N_DOCS)
    qs   = random.sample(list(queries), N_Q)
    return [d["text"] for d in docs], [q["text"] for q in qs]

# ------------------------- GGPP ---------------------------------
def ggpp_full(tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
              cap_L=5, max_epoch=150, topk_loss=20):
    # --- Algorithm-1: token importance for prefix init ---
    body = tok(tgt_txt, add_special_tokens=False,
               truncation=True, max_length=510)["input_ids"]
    body += [tok.unk_token_id] * max(0, cap_L - len(body))
    base = bert(**tok(tgt_txt, return_tensors="pt",
                      truncation=True, max_length=512).to(DEVICE)
                ).last_hidden_state[0, 0]
    imp = []
    for i in range(cap_L):
        tmp = body.copy(); tmp[i] = tok.mask_token_id
        tens = torch.tensor([tok.cls_token_id]+tmp+[tok.sep_token_id]
                            ).unsqueeze(0).to(DEVICE)
        emb = bert(input_ids=tens,
                   attention_mask=torch.ones_like(tens)).last_hidden_state[0, 0]
        imp.append(1 - torch.nn.functional.cosine_similarity(base, emb, dim=0).item())
    prefix = [body[i] for i in np.argsort(imp)[-cap_L:]]

    # --- prepare patch & helper ---
    pos = list(range(502, 502+cap_L))
    patch = ids.clone()
    for p, v in zip(pos, prefix): patch[p] = v
    am = msk.clone(); am[pos] = 1
    W = bert.embeddings.word_embeddings.weight

    def loss_fn(cls_vec):
        kth = torch.topk(cos_row(cls_vec, CP), topk_loss).values[-1]
        sim = torch.nn.functional.cosine_similarity(cls_vec, tgt_cls)[0]
        return max(0., (kth - sim).item())

    best_cls = bert(input_ids=patch.unsqueeze(0),
                    attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
    best_loss = loss_fn(best_cls); used_iter = 0

    # --- gradient-guided local search ---
    for _ in range(max_epoch):
        used_iter += 1
        emb = bert.embeddings.word_embeddings(
            patch.unsqueeze(0)).detach().clone().requires_grad_(True)
        cls = bert(inputs_embeds=emb,
                   attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
        loss = loss_fn(cls); loss.backward()
        if loss == 0: best_cls = cls; break

        grad = emb.grad[0, pos]                     # (L, 768)
        score = torch.matmul(grad, W.t())           # (L, V)
        cand_ids = score.topk(5, largest=False, dim=1).indices.cpu()

        improved = False
        for mask in range(1, 1 << cap_L):           # 子集窮舉
            cand = patch.clone()
            for i in range(cap_L):
                if mask & (1 << i):
                    cand[pos[i]] = random.choice(cand_ids[i]).item()
            cls2 = bert(input_ids=cand.unsqueeze(0),
                        attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
            l2 = loss_fn(cls2)
            if l2 < best_loss:
                best_loss = l2; patch = cand; best_cls = cls2; improved = True
            if l2 == 0: break
        if not improved: break
    return best_cls, cap_L, used_iter
# ---------------------------------------------------------------

# ---------------------- 主流程 -----------------------
def run_all(topk_loss, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    tok = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased").eval().to(DEVICE)
    VOC  = tok.vocab_size
    def enc(txt):
        return tok(txt, padding="max_length", truncation=True,
                   max_length=512, return_tensors="pt").to(DEVICE)

    # ------- prepare corpus CLS vectors (on CPU) -------
    docs, qs = load_subset()
    CLS = []
    with torch.inference_mode():
        for i in tqdm.tqdm(range(0, len(docs), BATCH_CLS), desc="CLS"):
            bt = enc(docs[i:i+BATCH_CLS])
            cls = bert(**bt).last_hidden_state[:, 0, :].cpu()
            CLS.append(cls)
            del bt; torch.cuda.empty_cache()
    C_CLS = torch.cat(CLS)

    METHODS = (
        "none", "random", "greedy", "ggpp",
        "DE_fixed", "DE_seq",
        "DE_fixed_stop", "DE_seq_stop"
    )
    records = []

    # ------- iterate queries -------
    for qtxt in tqdm.tqdm(qs, desc=f"Top-{topk_loss}"):
        tgt = random.randrange(len(docs))
        tgt_txt = docs[tgt]
        if len(tok(tgt_txt)["input_ids"]) > 510: continue
        tgt_cls = C_CLS[tgt:tgt+1].to(DEVICE)
        CP = C_CLS[[i for i in range(len(docs)) if i != tgt]]

        qenc = enc(qtxt); ids = qenc["input_ids"][0]; msk = qenc["attention_mask"][0]
        with torch.no_grad():
            qcls = bert(**qenc).last_hidden_state[:, 0, :]
        base = torch.cat([cos_row(qcls, CP),
                          torch.nn.functional.cosine_similarity(
                              qcls.cpu(), tgt_cls.cpu())])
        rank_b = (base > base[-1]).sum().item() + 1

        for mtd in METHODS:
            success = False; used_L = 0; used_iter = 0; adv_cls = qcls

            # === none ===
            if mtd == "none":
                pass

            # === random ===
            elif mtd == "random":
                best_loss = 1e9
                for _ in range(BUDGET):
                    patch = ids.clone()
                    for p in range(502, 502+TAIL_L):
                        patch[p] = random.randrange(VOC)
                    am = msk.clone(); am[502:502+TAIL_L] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)
                               ).last_hidden_state[:, 0, :]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    loss = max(0., (kth - torch.nn.functional.cosine_similarity(
                                         cls, tgt_cls)[0]).item())
                    used_iter += 1
                    if loss < best_loss:
                        best_loss = loss; adv_cls = cls; success = (loss == 0)
                    if success: break
                used_L = TAIL_L

            # === greedy (HotFlip-like) ===
            elif mtd == "greedy":
                pos = list(range(502, 502+TAIL_L)); patch = ids.clone()
                best_loss = 1e9
                for _ in range(BUDGET):
                    used_iter += 1; improved = False
                    for p in pos:
                        best_id = patch[p].item()
                        for cand in random.sample(range(VOC), 512):
                            patch[p] = cand
                            am = msk.clone(); am[pos] = 1
                            cls = bert(input_ids=patch.unsqueeze(0),
                                       attention_mask=am.unsqueeze(0)).last_hidden_state[:, 0, :]
                            kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                            loss = max(0., (kth - torch.nn.functional.cosine_similarity(
                                                 cls, tgt_cls)[0]).item())
                            if loss < best_loss:
                                best_loss = loss; best_id = cand; adv_cls = cls; improved = True
                            if loss == 0: success = True; break
                        patch[p] = best_id
                        if success: break
                    if success or not improved: break
                used_L = TAIL_L

            # === ggpp ===
            elif mtd == "ggpp":
                adv_cls, used_L, used_iter = ggpp_full(
                    tok, bert, ids, msk, CP, tgt_cls, tgt_txt,
                    cap_L=TAIL_L, max_epoch=BUDGET, topk_loss=topk_loss
                )
                kth = torch.topk(cos_row(adv_cls, CP), topk_loss).values[-1]
                success = (kth <= torch.nn.functional.cosine_similarity(
                                      adv_cls, tgt_cls)[0])

            # === 共用 DE 執行器 ===
            def de_run(L, max_iter, plateau):
                pos = list(range(502, 502+L)); bounds = [(0, VOC-1)] * L
                gens = max_iter // 20 if max_iter else 1000
                stop = [0]; best = [1e9]

                def obj(v):
                    v = [int(round(x)) for x in v]
                    patch = ids.clone()
                    for p, t in zip(pos, v): patch[p] = t
                    am = msk.clone(); am[pos] = 1
                    cls = bert(input_ids=patch.unsqueeze(0),
                               attention_mask=am.unsqueeze(0)
                               ).last_hidden_state[:, 0, :]
                    kth = torch.topk(cos_row(cls, CP), topk_loss).values[-1]
                    sim = torch.nn.functional.cosine_similarity(cls, tgt_cls)[0]
                    return max(0., (kth - sim).item())

                def cb(xk, _):
                    if not plateau: return False
                    cur = obj(xk)
                    stop[0] = 0 if cur < best[0] else stop[0] + 1
                    best[0] = min(best[0], cur)
                    return stop[0] >= PATIENCE or cur == 0

                res = differential_evolution(obj, bounds, popsize=20,
                                             maxiter=gens, tol=0,
                                             polish=False, seed=SEED,
                                             callback=cb)
                v = [int(round(x)) for x in res.x]
                patch = ids.clone()
                for p, t in zip(pos, v): patch[p] = t
                am = msk.clone(); am[pos] = 1
                cls = bert(input_ids=patch.unsqueeze(0),
                           attention_mask=am.unsqueeze(0)
                           ).last_hidden_state[:, 0, :]
                return cls, res.nfev, res.fun == 0

            # === DE_fixed ===
            if mtd == "DE_fixed":
                adv_cls, used_iter, success = de_run(TAIL_L, BUDGET, plateau=False)
                used_L = TAIL_L

            # === DE_fixed_stop ===
            if mtd == "DE_fixed_stop":
                adv_cls, used_iter, success = de_run(TAIL_L, None, plateau=True)
                used_L = TAIL_L

            # === DE_seq ===
            if mtd == "DE_seq":
                for L in range(1, TAIL_L+1):
                    cls, iters, ok = de_run(L, BUDGET, plateau=False)
                    used_iter += iters
                    if ok:
                        adv_cls = cls; used_L = L; success = True; break
                else: used_L = TAIL_L

            # === DE_seq_stop ===
            if mtd == "DE_seq_stop":
                for L in range(1, TAIL_L+1):
                    cls, iters, ok = de_run(L, None, plateau=True)
                    used_iter += iters
                    if ok:
                        adv_cls = cls; used_L = L; success = True; break
                else: used_L = TAIL_L

            # ========== 指標 ==========
            sims = torch.cat([cos_row(adv_cls.cpu(), CP),
                              torch.nn.functional.cosine_similarity(
                                  adv_cls.cpu(), tgt_cls.cpu())])
            rank_a = (sims > sims[-1]).sum().item() + 1
            d_mrr  = 1/rank_a - 1/rank_b
            d_ndcg = ((dcg(rank_a) if rank_a <= 20 else 0) -
                      (dcg(rank_b) if rank_b <= 20 else 0))
            d_cos  = (torch.nn.functional.cosine_similarity(
                        adv_cls, tgt_cls)[0] -
                      torch.nn.functional.cosine_similarity(
                        qcls, tgt_cls)[0]).item()

            records.append(dict(
                top_k       = topk_loss,
                method      = mtd,
                success     = int(rank_a <= topk_loss),
                token_used  = used_L,
                iter_used   = used_iter,
                delta_mrr   = d_mrr,
                delta_ndcg  = d_ndcg,
                delta_cos   = d_cos
            ))
        torch.cuda.empty_cache()   # 每 query 釋放暫存

    pd.DataFrame(records).to_csv(f"{save_dir}/records.csv", index=False)
    print(f"✓ {save_dir}  rows = {len(records)}")

# ------------------ 執行 3 組 k -------------------
for K in (1, 10, 20):
    run_all(topk_loss=K, save_dir=f"results_top{K}")
    torch.cuda.empty_cache()          # 下一輪前清顯存


CLS: 100%|██████████████████████████████████████████████████████████████████████████████| 63/63 [00:28<00:00,  2.23it/s]
Top-1:   0%|                                                                                    | 0/100 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)