**Laddu's doing**

In [None]:
# IADE final stable — train + eval (Colab-ready)
# Uncomment and run once if required:
# !pip install -q sentence-transformers faiss-cpu tqdm scikit-learn

import os, json, random, math, sys
from tqdm import tqdm
from typing import List, Dict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
from sentence_transformers import util

# -------------------------- Config --------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

BACKBONE = "sentence-transformers/all-MiniLM-L6-v2"  # change if you have larger GPU
PROJ_DIM = 384
BATCH_SIZE = 12
EPOCHS = 3
LR = 2e-5
WEIGHT_DECAY = 0.01
TEMPERATURE = 0.05
MAX_LENGTH = 256
MAX_HARD_NEGS = 2
DEV_FRAC = 0.1
SAVE_PATH = "iade_best_model.pt"
PRINT_EVERY = 200

DATA_PATH = "query-doc.json"
QUERIES_PATH = "final_sorted.jsonl"

# -------------------------- Data helpers ---------------------------
def load_qdoc(path: str):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    examples = []
    for e in data:
        qid = e.get("query_id")
        docs = e.get("documents", [])
        positives = [d for d in docs if d.get("type", "").lower() == "positive"]
        hard_negs = [d for d in docs if d.get("type", "").lower() == "hard_negative"]
        other_negs = [d for d in docs if d.get("type", "").lower() not in ("positive", "hard_negative")]
        if len(positives) == 0:
            continue
        examples.append({"query_id": qid, "positives": positives, "hard_negs": hard_negs, "other_negs": other_negs})
    return examples

def load_queries(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

def build_qdoc_map(qdoc_list):
    d = {}
    for e in qdoc_list:
        qid = e.get("query_id")
        positives = [{"text": doc.get("text",""), "doc_id": doc.get("doc_id")} for doc in e.get("documents", []) if doc.get("type","").lower()=="positive"]
        hard_negs = [{"text": doc.get("text",""), "doc_id": doc.get("doc_id")} for doc in e.get("documents", []) if doc.get("type","").lower()=="hard_negative"]
        other_negs = [{"text": doc.get("text",""), "doc_id": doc.get("doc_id")} for doc in e.get("documents", []) if doc.get("type","").lower() not in ("positive","hard_negative")]
        d[qid] = {"positives": positives, "hard_negs": hard_negs, "other_negs": other_negs}
    return d

# -------------------------- Dataset --------------------------------
class IADataset(Dataset):
    def __init__(self, queries_records: List[Dict], qdoc_map: Dict, max_hard_negs: int = 2):
        self.queries = queries_records
        self.qdoc = qdoc_map
        self.max_hard_negs = max_hard_negs

    def __len__(self): return len(self.queries)

    def _extract_positive_text(self, r, qid):
        pos_field = r.get("positive_doc", "")
        pos_text = ""
        if isinstance(pos_field, dict):
            pos_text = pos_field.get("text","")
        elif isinstance(pos_field, list) and len(pos_field)>0:
            first = pos_field[0]
            if isinstance(first, dict):
                pos_text = first.get("text","")
            else:
                pos_text = str(first)
        elif isinstance(pos_field, str):
            pos_text = pos_field
        pos_text = pos_text.strip()
        if not pos_text:
            candidates = self.qdoc.get(qid, {}).get("positives", [])
            if candidates:
                pos_text = candidates[0].get("text","").strip()
        return pos_text

    def __getitem__(self, idx):
        r = self.queries[idx]
        qid = r["query_id"]
        q_text = r.get("query","")
        q_ins = r.get("instructed_query", q_text)
        q_rev = r.get("reversed_query", q_text + " not")
        pos_text = self._extract_positive_text(r, qid)

        # hard negatives:
        hns = []
        for d in self.qdoc.get(qid, {}).get("hard_negs", []):
            if isinstance(d, dict) and d.get("text"):
                hns.append(d["text"])
            elif isinstance(d, str) and d.strip():
                hns.append(d.strip())
        hns = hns[:self.max_hard_negs]

        return {
            "query_id": qid,
            "query": q_text,
            "instructed_query": q_ins,
            "reversed_query": q_rev,
            "positive": pos_text,
            "hard_negs": hns,
            "attributes": r.get("attributes", {})
        }

# -------------------------- Model ----------------------------------
class IADE(nn.Module):
    def __init__(self, backbone_name=BACKBONE, proj_dim=PROJ_DIM, device=DEVICE):
        super().__init__()
        self.device = device
        self.backbone = AutoModel.from_pretrained(backbone_name)
        self.tokenizer = AutoTokenizer.from_pretrained(backbone_name)
        hidden = self.backbone.config.hidden_size
        self.proj = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, proj_dim))
        self.dropout = nn.Dropout(0.1)
        # expose output dim for safe empty-tensor creation:
        self.out_dim = proj_dim

    def forward_encode(self, texts: List[str], max_length=MAX_LENGTH):
        # return tensor on self.device, shape (N, out_dim)
        if not texts:
            return torch.zeros((0, self.out_dim), device=self.device)
        tok = self.tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
        tok = {k: v.to(self.device) for k,v in tok.items()}
        out = self.backbone(**tok, return_dict=True)
        mask = tok["attention_mask"].unsqueeze(-1).float()
        mean_pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        proj = self.proj(self.dropout(mean_pooled))
        return F.normalize(proj, dim=-1)

    def encode(self, texts: List[str], batch_size=32):
        # batched; returns tensor on device
        if not texts:
            return torch.zeros((0, self.out_dim), device=self.device)
        parts = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            if not batch:
                continue
            with torch.no_grad():
                emb = self.forward_encode(batch)
            parts.append(emb)
        return torch.cat(parts, dim=0).to(self.device)

# -------------------------- Loss + Collate -------------------------
def info_nce_loss(q_emb, p_emb, temperature=TEMPERATURE):
    sims = torch.matmul(q_emb, p_emb.t()) / temperature
    labels = torch.arange(q_emb.size(0), device=sims.device, dtype=torch.long)
    return F.cross_entropy(sims, labels)

def collate_train(samples: List[Dict]):
    B = len(samples)
    queries_orig = [s["query"] for s in samples]
    queries_ins = [s["instructed_query"] for s in samples]
    queries_rev = [s["reversed_query"] for s in samples]
    positives = [s["positive"] if s["positive"] else " " for s in samples]
    hard_negs = [hn for s in samples for hn in s["hard_negs"]]
    passages = positives + hard_negs
    return {"B": B, "queries_orig": queries_orig, "queries_ins": queries_ins, "queries_rev": queries_rev, "passages": passages}

# -------------------------- Metrics (same formulas) ------------------------------
def find_rank_in_ranking(docid, ranking):
    for i, (d, _) in enumerate(ranking):
        if d == docid:
            return i + 1
    return len(ranking) + 1

def compute_mSICR(Rori_rank, Rins_rank, Rrev_rank, Sori, Sins, Srev):
    return int((Rins_rank < Rori_rank and Sins > Sori) and (Rrev_rank > Rori_rank and Srev < Sori))

def compute_mWISE(Rori_rank, Rins_rank, Rrev_rank, m, sat_count, vio_count, K=10):
    delta_ins = Rori_rank - Rins_rank
    delta_rev = Rrev_rank - Rori_rank
    reward = (sat_count / m) * (1 - math.sqrt(max(delta_ins, 0) / K)) * (1.0 / math.sqrt(max(Rins_rank, 1)))
    penalty = - (vio_count / m)
    return reward + penalty

def compute_MDCR(attrs: Dict, pos_doc_text: str, model_for_mdcr: IADE):
    if not attrs:
        return 0.0, 0
    pos_emb = model_for_mdcr.encode([pos_doc_text])  # on device
    sim_scores = []
    for attr_name, attr_value in attrs.items():
        attr_desc = f"The document should reflect {attr_name} = {attr_value}."
        attr_emb = model_for_mdcr.encode([attr_desc])
        sim = float(util.cos_sim(attr_emb, pos_emb)[0,0])
        sim_scores.append(sim)
    mdcr_soft = float(np.mean(sim_scores))
    threshold = max(0.45, mdcr_soft - 0.05)
    mdcr_strict = int(all(s >= threshold for s in sim_scores))
    return mdcr_soft, mdcr_strict

# -------------------------- Evaluation (robust) -------------------------------
def evaluate_model_on_queries(model: IADE, queries_list: List[Dict], qdoc_map: Dict, top_k=10, debug_sample_n=3):
    model.eval()

    # build corpus (unique) and mapping
    corpus_texts = []
    doc_ids = []
    qid_to_text2docid = {}
    for qid, groups in qdoc_map.items():
        qid_to_text2docid[qid] = {}
        for doc in groups["positives"] + groups["hard_negs"] + groups["other_negs"]:
            text = doc.get("text","").strip()
            if not text: continue
            key = " ".join(text.lower().split())
            if key in qid_to_text2docid[qid]:
                continue
            docid = f"{qid}_{doc.get('doc_id', len(doc_ids))}"
            qid_to_text2docid[qid][key] = docid
            doc_ids.append(docid)
            corpus_texts.append(key)

    if not corpus_texts:
        return {"mSICR":0,"mWISE":0,"MDCR_soft":0,"MDCR_strict":0}

    corpus_embeddings = model.encode(corpus_texts, batch_size=64)  # on device
    results = []
    skipped = 0

    # sample debug queries to display retrievals
    debug_samples = random.sample(queries_list, min(debug_sample_n, len(queries_list)))

    for q in tqdm(queries_list, desc="Evaluating"):
        try:
            qid = q["query_id"]
            pos_field = q.get("positive_doc", "")
            if isinstance(pos_field, dict):
                pos_text_raw = pos_field.get("text","")
            elif isinstance(pos_field, list) and len(pos_field)>0:
                first = pos_field[0]
                pos_text_raw = first.get("text","") if isinstance(first, dict) else str(first)
            else:
                pos_text_raw = str(pos_field) if pos_field is not None else ""

            pos_text_key = " ".join(str(pos_text_raw).strip().lower().split())
            if not pos_text_key:
                # try qdoc map first positive
                cand = qdoc_map.get(qid, {}).get("positives", [])
                if cand:
                    pos_text_key = " ".join(cand[0].get("text","").strip().lower().split())

            if not pos_text_key:
                skipped += 1
                continue

            # find doc id: exact normalized match -> substring match -> semantic fallback
            pos_doc = qid_to_text2docid.get(qid, {}).get(pos_text_key)
            if not pos_doc:
                # substring match
                for k, did in qid_to_text2docid.get(qid, {}).items():
                    if pos_text_key in k or k in pos_text_key:
                        pos_doc = did
                        break

            if not pos_doc:
                # semantic fallback: find best corpus doc overall
                pos_emb = model.encode([pos_text_key])  # device
                sims = util.cos_sim(pos_emb, corpus_embeddings)[0]
                best_idx = int(torch.argmax(sims))
                pos_doc = doc_ids[best_idx]

            # compute embeddings for queries
            qori_emb = model.encode([q.get("query","")])
            qins_emb = model.encode([q.get("instructed_query", q.get("query",""))])
            qrev_emb = model.encode([q.get("reversed_query", q.get("query",""))])

            sims_ori = util.cos_sim(qori_emb, corpus_embeddings)[0]
            sims_ins = util.cos_sim(qins_emb, corpus_embeddings)[0]
            sims_rev = util.cos_sim(qrev_emb, corpus_embeddings)[0]

            topk_ori = torch.topk(sims_ori, k=min(top_k, sims_ori.shape[0]))
            topk_ins = torch.topk(sims_ins, k=min(top_k, sims_ins.shape[0]))
            topk_rev = torch.topk(sims_rev, k=min(top_k, sims_rev.shape[0]))

            Rori = [(doc_ids[i], float(sims_ori[i])) for i in topk_ori.indices]
            Rins = [(doc_ids[i], float(sims_ins[i])) for i in topk_ins.indices]
            Rrev = [(doc_ids[i], float(sims_rev[i])) for i in topk_rev.indices]

            Rori_rank = find_rank_in_ranking(pos_doc, Rori)
            Rins_rank = find_rank_in_ranking(pos_doc, Rins)
            Rrev_rank = find_rank_in_ranking(pos_doc, Rrev)

            pos_emb = model.encode([pos_text_key])
            Sori = float(util.cos_sim(qori_emb, pos_emb)[0,0])
            Sins = float(util.cos_sim(qins_emb, pos_emb)[0,0])
            Srev = float(util.cos_sim(qrev_emb, pos_emb)[0,0])

            attrs = q.get("attributes", {})
            m = len(attrs) if attrs else 1
            mdcr_soft, mdcr_strict = compute_MDCR(attrs, pos_text_key, model)
            sat_count = int(round(mdcr_soft * m))
            vio_count = max(0, m - sat_count)

            msicr = compute_mSICR(Rori_rank, Rins_rank, Rrev_rank, Sori, Sins, Srev)
            mwise = compute_mWISE(Rori_rank, Rins_rank, Rrev_rank, m, sat_count, vio_count)

            results.append({"mSICR": msicr, "mWISE": mwise, "MDCR_soft": mdcr_soft, "MDCR_strict": mdcr_strict})

        except Exception as e:
            skipped += 1
            # continue quietly but note
            # print("Eval skip", q.get("query_id"), e)
            continue

    if not results:
        return {"mSICR": 0, "mWISE": 0, "MDCR_soft": 0, "MDCR_strict": 0}

    df = pd.DataFrame(results)
    aggregated = {"mSICR": float(df["mSICR"].mean()), "mWISE": float(df["mWISE"].mean()),
                  "MDCR_soft": float(df["MDCR_soft"].mean()), "MDCR_strict": float(df["MDCR_strict"].mean())}
    if skipped:
        print(f"Evaluation: skipped {skipped} queries due to missing/invalid positives.")
    # show a few debug retrievals
    print("Sample retrievals (debug):")
    for s in random.sample(queries_list, min(3, len(queries_list))):
        qtxt = s.get("query","")[:120]
        q_emb = model.encode([s.get("query","")])
        sims = util.cos_sim(q_emb, corpus_embeddings)[0]
        top = torch.topk(sims, k=3)
        print(" Q:", qtxt)
        for idx in top.indices.tolist():
            print("   ", corpus_texts[idx][:140])
    return aggregated

# -------------------------- Training + run -----------------------------------------
def run_training_and_eval():
    qdoc = load_qdoc(DATA_PATH)
    queries = load_queries(QUERIES_PATH)
    print("Loaded:", len(qdoc), "qdoc entries and", len(queries), "queries")
    qdoc_map = build_qdoc_map(qdoc)
    random.shuffle(queries)
    n_dev = max(1, int(len(queries) * DEV_FRAC))
    dev_qs, train_qs = queries[:n_dev], queries[n_dev:]
    print("Train:", len(train_qs), "Dev:", len(dev_qs))

    train_ds = IADataset(train_qs, qdoc_map, MAX_HARD_NEGS)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: x)

    model = IADE().to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_msicr = -1.0
    for ep in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {ep+1}/{EPOCHS}")):
            coll = collate_train(batch)
            q_ori = model.forward_encode(coll["queries_orig"])
            q_ins = model.forward_encode(coll["queries_ins"])
            q_rev = model.forward_encode(coll["queries_rev"])
            p_emb = model.forward_encode(coll["passages"])

            # Ensure p_emb has at least B rows (if no hns, still positives >= B)
            if p_emb.size(0) < q_ori.size(0):
                # pad with tiny noise vectors to avoid shape mismatch (rare)
                pad = torch.randn((q_ori.size(0)-p_emb.size(0), p_emb.size(1)), device=p_emb.device) * 1e-6
                p_emb = torch.cat([p_emb, pad], dim=0)

            loss = info_nce_loss(q_ori, p_emb) + info_nce_loss(q_ins, p_emb) - 0.3 * info_nce_loss(q_rev, p_emb)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            running_loss += float(loss.item())
            if (step + 1) % PRINT_EVERY == 0:
                print(f"Epoch {ep+1} step {step+1} avg_loss {running_loss/(step+1):.4f}")

        metrics = evaluate_model_on_queries(model, dev_qs, qdoc_map)
        print(f"Epoch {ep+1} metrics:", metrics)
        if metrics["mSICR"] > best_msicr:
            best_msicr = metrics["mSICR"]
            torch.save(model.state_dict(), SAVE_PATH)
            print("Saved best model:", SAVE_PATH)

    # final eval
    model.load_state_dict(torch.load(SAVE_PATH, map_location=DEVICE))
    final_metrics = evaluate_model_on_queries(model, queries, qdoc_map)
    print("Final metrics:", final_metrics)
    return model, final_metrics

# -------------------------- Run -----------------------------------------
if __name__ == "__main__":
    model, metrics = run_training_and_eval()
    print(json.dumps({"model": "IADE_finetuned", **metrics}, indent=2))


Loaded: 515 qdoc entries and 9596 queries
Train: 8637 Dev: 959


Epoch 1/3:  28%|██▊       | 201/720 [00:30<01:20,  6.48it/s]

Epoch 1 step 200 avg_loss -2.1241


Epoch 1/3:  56%|█████▌    | 401/720 [00:59<00:46,  6.83it/s]

Epoch 1 step 400 avg_loss -4.7169


Epoch 1/3:  83%|████████▎ | 601/720 [01:28<00:15,  7.68it/s]

Epoch 1 step 600 avg_loss -5.7139


Epoch 1/3: 100%|██████████| 720/720 [01:45<00:00,  6.84it/s]


Epoch 1 metrics: {'mSICR': 0, 'mWISE': 0, 'MDCR_soft': 0, 'MDCR_strict': 0}
Saved best model: iade_best_model.pt


Epoch 2/3:  28%|██▊       | 201/720 [00:28<01:15,  6.90it/s]

Epoch 2 step 200 avg_loss -7.5195


Epoch 2/3:  56%|█████▌    | 401/720 [00:57<00:49,  6.47it/s]

Epoch 2 step 400 avg_loss -7.4396


Epoch 2/3:  83%|████████▎ | 601/720 [01:26<00:17,  6.91it/s]

Epoch 2 step 600 avg_loss -7.3020


Epoch 2/3: 100%|██████████| 720/720 [01:42<00:00,  7.00it/s]


Epoch 2 metrics: {'mSICR': 0, 'mWISE': 0, 'MDCR_soft': 0, 'MDCR_strict': 0}


Epoch 3/3:  28%|██▊       | 201/720 [00:28<01:21,  6.33it/s]

Epoch 3 step 200 avg_loss -7.5716


Epoch 3/3:  56%|█████▌    | 400/720 [00:57<00:42,  7.46it/s]

Epoch 3 step 400 avg_loss -7.1250


Epoch 3/3:  83%|████████▎ | 601/720 [01:25<00:17,  6.86it/s]

Epoch 3 step 600 avg_loss -7.4320


Epoch 3/3: 100%|██████████| 720/720 [01:42<00:00,  7.01it/s]


Epoch 3 metrics: {'mSICR': 0, 'mWISE': 0, 'MDCR_soft': 0, 'MDCR_strict': 0}
Final metrics: {'mSICR': 0, 'mWISE': 0, 'MDCR_soft': 0, 'MDCR_strict': 0}
{
  "model": "IADE_finetuned",
  "mSICR": 0,
  "mWISE": 0,
  "MDCR_soft": 0,
  "MDCR_strict": 0
}


#Van doing

In [None]:
!pip -q install transformers accelerate peft sentencepiece

import os, json, random, math, time
from dataclasses import dataclass
from typing import List, Dict, Any
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from contextlib import nullcontext

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

try:
    from peft import LoraConfig, get_peft_model, TaskType
    PEFT_AVAILABLE = True
except Exception:
    PEFT_AVAILABLE = False

# ---------------- Config ----------------
SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

FINAL_JSONL = "final_sorted.jsonl"
QDOC_JSON   = "query-doc.json"

BASE_MODEL   = "microsoft/deberta-v3-small"
MAX_LEN      = 256
USE_LORA     = True and PEFT_AVAILABLE
LR_MAIN      = 3e-4 if USE_LORA else 2e-5
BATCH_SIZE   = 12 if torch.cuda.is_available() else 4
GRAD_ACCUM   = 1
EPOCHS       = 2
WARMUP_RATIO = 0.06

MARGIN       = 0.2
W_REVERSE    = 1.8
LAMBDA_ATTR  = 0.35

MAX_TRAIN_PAIRS = 12000
MAX_VAL_PAIRS   = 2000

OUTPUT_DIR   = "reranker-fast-fix"
CKPT_EVERY   = 800
RESUME       = True
ATTR_KEYS    = ["audience","format","language","length","source","keyword"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- IO helpers ----------------
def load_jsonl(path):
    rows=[]
    with open(path,"r",encoding="utf-8") as f:
        for i, line in enumerate(f,1):
            line=line.strip()
            if not line: continue
            try:
                rows.append(json.loads(line))
            except json.JSONDecodeError:
                print(f"Skipping malformed line {i} in {path}")
    return rows

with open(QDOC_JSON,"r",encoding="utf-8") as f:
    qdoc = json.load(f)
entries = load_jsonl(FINAL_JSONL)
print(f"Loaded entries: {len(entries)} | query-doc groups: {len(qdoc)}")

qid_to_docs = {e["query_id"]: [d.get("text","") for d in e.get("documents",[])] for e in qdoc}

# ---------------- Normalization ----------------
def get_doc_text(d):
    if isinstance(d, dict): return str(d.get("text","")).strip()
    if isinstance(d, list) and len(d)>0:
        first=d[0]
        return (first.get("text","").strip() if isinstance(first,dict) else str(first).strip())
    if isinstance(d, str): return d.strip()
    return ""

def weak_attr_vec(text:str, attrs:Dict[str,Any]) -> Dict[str,int]:
    lt = (text or "").lower()
    return {k: int(str(v).lower() in lt) for k,v in (attrs or {}).items()}

def build_pairs(records:List[Dict[str,Any]]):
    pairs=[]
    for e in records:
        orig = e.get("query",""); ins = e.get("instructed_query",""); rev = e.get("reversed_query","")
        pos = get_doc_text(e.get("positive_doc")); neg = get_doc_text(e.get("hard_negative_doc"))
        attrs = e.get("attributes",{}) or {}
        if not (orig and ins and rev and pos and neg): continue
        pos_attr = weak_attr_vec(pos, attrs); neg_attr = weak_attr_vec(neg, attrs)
        pairs.append(dict(query=ins,  pos=pos, neg=neg, mode="instructed", pos_attr=pos_attr))
        pairs.append(dict(query=orig, pos=pos, neg=neg, mode="original",   pos_attr=pos_attr))
        pairs.append(dict(query=rev,  pos=neg, neg=pos, mode="reverse",    pos_attr=neg_attr))
    random.shuffle(pairs)
    return pairs

pairs = build_pairs(entries)
print("Total trainable pairs (raw):", len(pairs))

if MAX_TRAIN_PAIRS:
    pairs = pairs[:MAX_TRAIN_PAIRS + MAX_VAL_PAIRS]
split = int(0.85 * len(pairs))
train_pairs, val_pairs = pairs[:split], pairs[split:]
if MAX_VAL_PAIRS: val_pairs = val_pairs[:MAX_VAL_PAIRS]
print(f"Using: train={len(train_pairs)} | val={len(val_pairs)}")

# ---------------- Dataset & Collator ----------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)

@dataclass
class Item:
    query:str; pos:str; neg:str; mode:str; pos_attr:Dict[str,int]

class PairwiseDataset(Dataset):
    def __init__(self, items, max_len=256):
        self.items=[Item(**x) for x in items]; self.max_len=max_len
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return self.items[i]

def pack(query, doc, max_len=256):
    text = f"[QUERY]\n{query}\n[DOCUMENT]\n{doc}"
    return tokenizer(text, max_length=max_len, truncation=True, padding=False, return_tensors="pt")

class Collator:
    def __init__(self, max_len=256): self.max_len=max_len
    def __call__(self, batch:List[Item]):
        qpos = [pack(b.query, b.pos, self.max_len) for b in batch]
        qneg = [pack(b.query, b.neg, self.max_len) for b in batch]
        def pad(key, seqs):
            return torch.nn.utils.rnn.pad_sequence(
                [s[key].squeeze(0) for s in seqs],
                batch_first=True, padding_value=tokenizer.pad_token_id
            )
        pos = {"input_ids": pad("input_ids", qpos), "attention_mask": pad("attention_mask", qpos)}
        neg = {"input_ids": pad("input_ids", qneg), "attention_mask": pad("attention_mask", qneg)}
        modes=[b.mode for b in batch]
        attr_t = torch.zeros(len(batch), len(ATTR_KEYS), dtype=torch.float32)
        attr_m = torch.zeros_like(attr_t)
        for i,b in enumerate(batch):
            for j,k in enumerate(ATTR_KEYS):
                if k in b.pos_attr:
                    attr_m[i,j]=1.0
                    attr_t[i,j]=float(b.pos_attr[k])
        return pos, neg, modes, attr_t, attr_m

train_ds = PairwiseDataset(train_pairs, MAX_LEN)
val_ds   = PairwiseDataset(val_pairs,   MAX_LEN)
collate  = Collator(MAX_LEN)

# ---------------- Model ----------------
class CrossEncoderWithAttr(nn.Module):
    def __init__(self, base=BASE_MODEL, dropout=0.1, grad_ckpt=True):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base)
        if grad_ckpt and hasattr(self.encoder, "gradient_checkpointing_enable"):
            self.encoder.gradient_checkpointing_enable()
        # VERY IMPORTANT: allow gradient flow from inputs for checkpointing
        if hasattr(self.encoder, "enable_input_require_grads"):
            self.encoder.enable_input_require_grads()
        self.config  = self.encoder.config
        h = self.encoder.config.hidden_size
        self.dp = nn.Dropout(dropout)
        self.rank = nn.Linear(h,1)
        self.attr = nn.Linear(h,len(ATTR_KEYS))
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        if "labels" in kwargs: kwargs.pop("labels")
        out = self.encoder(
            input_ids=input_ids, attention_mask=attention_mask,
            token_type_ids=token_type_ids, **kwargs
        )
        cls = self.dp(out.last_hidden_state[:,0,:])
        return self.rank(cls).squeeze(-1), self.attr(cls)

model = CrossEncoderWithAttr(BASE_MODEL, grad_ckpt=True)

# Attach LoRA; ensure inputs require grads on wrapper too
if USE_LORA:
    try:
        lconf = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=16, lora_alpha=32, lora_dropout=0.05,
            target_modules=["query_proj","key_proj","value_proj","dense","out_proj"]
        )
        model = get_peft_model(model, lconf)
        # For PEFT, also enable input grads on the wrapper (important with ckpt)
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        print("LoRA attached (FEATURE_EXTRACTION).")
    except Exception as e:
        print(f"LoRA attach failed; continuing without LoRA: {e}")
        USE_LORA=False

model = model.to(device)

# ---- Freeze everything except LoRA adapters + heads ----
def mark_trainable_modules(m):
    trainable, total = 0, 0
    for n,p in m.named_parameters():
        total += p.numel()
        # Heads must be trainable; LoRA adapters have "lora_" in their param names
        if ("rank" in n) or ("attr" in n) or ("lora_" in n):
            p.requires_grad_(True); trainable += p.numel()
        else:
            p.requires_grad_(False)
    return trainable, total

trainable, total = mark_trainable_modules(model)
print(f"Trainable params: {trainable/1e6:.2f}M / {total/1e6:.2f}M")

# ---------------- Loss, Optim, Sched ----------------
def bce_with_mask(logits, targets, mask):
    bce = nn.BCEWithLogitsLoss(reduction="none")(logits, targets)
    denom = torch.clamp(mask.sum(), min=1.0)
    return (bce * mask).sum() / denom

def batch_loss(batch):
    pos,neg,modes,attr_t,attr_m = batch
    pos = {k:v.to(device) for k,v in pos.items()}
    neg = {k:v.to(device) for k,v in neg.items()}
    attr_t = attr_t.to(device); attr_m = attr_m.to(device)
    s_pos, a_pos = model(**pos)
    s_neg, _     = model(**neg)
    diff = s_pos - s_neg
    hinge = torch.clamp(MARGIN - diff, min=0.0)
    w = torch.tensor([W_REVERSE if m=="reverse" else 1.0 for m in modes], device=device)
    rank_loss = (hinge * w).mean()
    attr_loss = bce_with_mask(a_pos, attr_t, attr_m)
    loss = rank_loss + LAMBDA_ATTR * attr_loss
    return loss, {"rank_loss": rank_loss.item(), "attr_loss": attr_loss.item()}

# Optim only over trainable params
opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=LR_MAIN, weight_decay=0.01)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate,
    pin_memory=True, num_workers=2, persistent_workers=True
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate,
    pin_memory=True, num_workers=1, persistent_workers=True
)

total_steps = max(1, (len(train_loader) * EPOCHS) // GRAD_ACCUM)
warm_steps  = int(total_steps * WARMUP_RATIO)
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=warm_steps, num_training_steps=total_steps)
scaler = torch.amp.GradScaler(device="cuda") if torch.cuda.is_available() else None

# ---------------- Checkpoint utils ----------------
def save_ckpt(step):
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"model_step{step}.bin"))
    torch.save({"step": step, "opt": opt.state_dict(), "sched": sched.state_dict()},
               os.path.join(OUTPUT_DIR, f"state_step{step}.pt"))

def latest_ckpt():
    if not os.path.isdir(OUTPUT_DIR): return None
    cands = [f for f in os.listdir(OUTPUT_DIR) if f.startswith("state_step")]
    if not cands: return None
    return max(int(f.split("state_step")[1].split(".")[0]) for f in cands)

def load_ckpt(step):
    mpath = os.path.join(OUTPUT_DIR, f"model_step{step}.bin")
    spath = os.path.join(OUTPUT_DIR, f"state_step{step}.pt")
    if os.path.exists(mpath):
        sd = torch.load(mpath, map_location="cpu")
        model.load_state_dict(sd, strict=False)
        model.to(device)
    if os.path.exists(spath):
        s = torch.load(spath, map_location="cpu")
        opt.load_state_dict(s["opt"])
        sched.load_state_dict(s["sched"])
    print(f"Resumed from step {step}")

# ---------------- Eval ----------------
def evaluate():
    model.eval()
    losses=[]
    with torch.no_grad():
        for batch in val_loader:
            ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if torch.cuda.is_available() else nullcontext()
            if torch.cuda.is_available():
                with ctx:
                    loss, _ = batch_loss(batch)
            else:
                loss,_ = batch_loss(batch)
            losses.append(loss.item())
    return float(np.mean(losses)) if losses else 0.0

# ---------------- Train (resume-safe) ----------------
best_val = float("inf")
global_step = 0
if RESUME:
    ck = latest_ckpt()
    if ck is not None:
        load_ckpt(ck)
        global_step = ck

for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_rank, epoch_attr, nb = 0.0, 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for batch in pbar:
        ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if torch.cuda.is_available() else nullcontext()
        if torch.cuda.is_available():
            with ctx:
                loss, logs = batch_loss(batch)
            (scaler.scale(loss/GRAD_ACCUM) if scaler else (loss/GRAD_ACCUM)).backward()
        else:
            loss, logs = batch_loss(batch)
            (loss/GRAD_ACCUM).backward()

        epoch_rank += logs["rank_loss"]; epoch_attr += logs["attr_loss"]; nb += 1

        if (nb % GRAD_ACCUM) == 0:
            if scaler: scaler.step(opt); scaler.update()
            else: opt.step()
            sched.step()
            opt.zero_grad(set_to_none=True)
            global_step += 1
            pbar.set_postfix(loss=float(loss.item()), rank=epoch_rank/nb, attr=epoch_attr/nb)

            if global_step % CKPT_EVERY == 0:
                save_ckpt(global_step)

    val_loss = evaluate()
    print(f"\n>> Val loss: {val_loss:.4f}")
    if val_loss < best_val:
        best_val = val_loss
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "pytorch_model.bin"))
        from transformers import AutoTokenizer as _AT
        _AT.from_pretrained(BASE_MODEL, use_fast=True).save_pretrained(OUTPUT_DIR)
        print(f"Saved BEST to {OUTPUT_DIR} (val={best_val:.4f})")

# ---------------- Inference helper ----------------
@torch.no_grad()
def rerank_scores(query_text: str, docs: List[str], batch_size=32, max_len=256):
    model.eval()
    scores=[]
    for i in range(0, len(docs), batch_size):
        chunk = docs[i:i+batch_size]
        toks = tokenizer(
            [f"[QUERY]\n{query_text}\n[DOCUMENT]\n{d}" for d in chunk],
            max_length=max_len, truncation=True, padding=True, return_tensors="pt"
        ).to(device)
        s,_ = model(**toks)
        scores.extend(s.detach().float().cpu().tolist())
    order = np.argsort(-np.array(scores))
    return [(docs[i], float(scores[i])) for i in order]

# ---------------- Smoke test ----------------
if len(entries) > 0:
    sample = entries[0]
    q = sample.get("instructed_query") or sample.get("query") or ""
    cand = []
    cand.append(get_doc_text(sample.get("positive_doc")))
    cand.append(get_doc_text(sample.get("hard_negative_doc")))
    pool = qid_to_docs.get(sample.get("query_id"), [])[:6]
    for d in pool:
        if d and d not in cand: cand.append(d)
    cand = [d for d in cand if d]
    ranked = rerank_scores(q, cand, batch_size=16, max_len=MAX_LEN)[:5]
    print("\nTop-5 (smoke test):")
    for i,(d,s) in enumerate(ranked,1):
        print(f"{i:>2}. score={s:.3f} | {d[:90].replace('\\n',' ')}...")
else:
    print("No entries to test.")


Loaded entries: 9596 | query-doc groups: 515
Total trainable pairs (raw): 28632
Using: train=11900 | val=2000


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/578 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/286M [00:00<?, ?B/s]

LoRA attached (FEATURE_EXTRACTION).
Trainable params: 1.33M / 142.64M


Epoch 1/2:   0%|          | 0/992 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/286M [00:00<?, ?B/s]

Epoch 1/2: 100%|██████████| 992/992 [17:06<00:00,  1.03s/it, attr=0.351, loss=0.113, rank=0.144]



>> Val loss: 0.1220
Saved BEST to reranker-fast-fix (val=0.1220)


Epoch 2/2: 100%|██████████| 992/992 [17:09<00:00,  1.04s/it, attr=0.318, loss=0.0899, rank=0.0112]



>> Val loss: 0.1113
Saved BEST to reranker-fast-fix (val=0.1113)

Top-5 (smoke test):
 1. score=2.329 | A good coding standard should encompass several key components to ensure code quality and ...
 2. score=2.162 | A comprehensive coding standard is essential for maintaining high-quality code. It should ...
 3. score=2.000 | A comprehensive coding standard should encompass several key elements to ensure that the c...
 4. score=0.172 | When creating a coding standard, it's important to consider various aspects. It should lis...
 5. score=0.044 | A coding standard is a set of guidelines that can include various aspects of coding. It mi...


In [None]:
# ================================================================
# Evaluate trained reranker on your dataset pools (mSICR/mWISE/MDCR)
# Uses: reranker-fast-fix/, /mnt/data/final_sorted.jsonl, /mnt/data/query-doc.json
# ================================================================
!pip -q install transformers peft sentencepiece

import json, numpy as np, torch, torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from contextlib import nullcontext

MODEL_DIR     = "reranker-fast-fix"
BASE_MODEL    = "microsoft/deberta-v3-small"
FINAL_JSONL   = "final_sorted.jsonl"
QDOC_JSON     = "query-doc.json"
ATTR_KEYS     = ["audience","format","language","length","source","keyword"]
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN       = 256
BATCH_SIZE    = 32
MAX_EVAL      = 1000    # set to None for full set

# ---------- IO ----------
def load_jsonl(path):
    rows=[]
    with open(path,"r",encoding="utf-8") as f:
        for i,l in enumerate(f,1):
            l=l.strip()
            if not l: continue
            try: rows.append(json.loads(l))
            except: pass
    return rows

with open(QDOC_JSON,"r",encoding="utf-8") as f:
    qdoc = json.load(f)
entries = load_jsonl(FINAL_JSONL)

# ---------- normalize ----------
def get_doc_text(d):
    if isinstance(d, dict): return str(d.get("text","")).strip()
    if isinstance(d, list) and len(d)>0:
        first=d[0]
        return (first.get("text","").strip() if isinstance(first,dict) else str(first).strip())
    if isinstance(d, str): return d.strip()
    return ""

# Build pools: qid -> list of (doc_id, text)
qid_pools = {}
for group in qdoc:
    qid = group["query_id"]
    pool=[]
    for d in group.get("documents", []):
        pool.append((d.get("doc_id", f"{qid}_unk"), d.get("text","").strip()))
    qid_pools[qid] = pool

# ---------- model (same head as training) ----------
class CrossEncoderWithAttr(nn.Module):
    def __init__(self, base=BASE_MODEL):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base)
        self.config  = self.encoder.config
        h = self.encoder.config.hidden_size
        self.dp = nn.Dropout(0.1)
        self.rank = nn.Linear(h,1)
        self.attr = nn.Linear(h, len(ATTR_KEYS))
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        if "labels" in kwargs: kwargs.pop("labels", None)
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)
        cls = self.dp(out.last_hidden_state[:,0,:])
        return self.rank(cls).squeeze(-1), self.attr(cls)

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
model = CrossEncoderWithAttr(BASE_MODEL)
state = torch.load(f"{MODEL_DIR}/pytorch_model.bin", map_location="cpu")
model.load_state_dict(state, strict=False)
model.to(DEVICE); model.eval()

@torch.no_grad()
def score_batch(query, docs):
    toks = tokenizer(
        [f"[QUERY]\n{query}\n[DOCUMENT]\n{d}" for d in docs],
        max_length=MAX_LEN, truncation=True, padding=True, return_tensors="pt"
    ).to(DEVICE)
    s,_ = model(**toks)
    return s.detach().float().cpu().numpy()

def rank_within_pool(query, pool_texts):
    scores=[]
    for i in range(0, len(pool_texts), BATCH_SIZE):
        chunk = pool_texts[i:i+BATCH_SIZE]
        scores.extend(score_batch(query, chunk))
    order = np.argsort(-np.array(scores))
    return order, np.array(scores)

def find_pos_index_in_pool(pos_text, pool_texts):
    # exact match first
    try:
        return pool_texts.index(pos_text)
    except ValueError:
        # fallback: highest token-overlap
        def tokset(t): return set(t.lower().split())
        pt = tokset(pos_text)
        best_i, best_j = -1, -1.0
        for i, d in enumerate(pool_texts):
            s = len(pt & tokset(d)) / (len(pt | tokset(d)) + 1e-9)
            if s > best_j:
                best_j, best_i = s, i
        return best_i if best_i >= 0 else None

# ---------- evaluation ----------
def evaluate(entries, limit=None):
    msicr, mwise, mdcr_s, mdcr_soft = [], [], [], []
    n = len(entries) if (limit is None) else min(limit, len(entries))
    for e in tqdm(entries[:n], desc="Evaluating reranker"):
        qid = e.get("query_id")
        pool = qid_pools.get(qid, [])
        if not pool: continue
        doc_ids, doc_texts = zip(*pool)
        doc_texts = list(doc_texts)

        pos_text = get_doc_text(e.get("positive_doc"))
        if not pos_text: continue
        pos_idx  = find_pos_index_in_pool(pos_text, doc_texts)
        if pos_idx is None: continue
        pos_docid = doc_ids[pos_idx]

        q_ori = e.get("query","")
        q_ins = e.get("instructed_query","")
        q_rev = e.get("reversed_query","")
        if not (q_ori and q_ins and q_rev): continue

        # Rankings within the pool
        ord_ori, _ = rank_within_pool(q_ori, doc_texts)
        ord_ins, _ = rank_within_pool(q_ins, doc_texts)
        ord_rev, _ = rank_within_pool(q_rev, doc_texts)

        def rank_of(idx, order):
            # order holds indices into doc_texts in ranked order
            return int(np.where(order == idx)[0][0]) + 1 if idx in order else len(order) + 1

        r_ori = rank_of(pos_idx, ord_ori)
        r_ins = rank_of(pos_idx, ord_ins)
        r_rev = rank_of(pos_idx, ord_rev)

        # mSICR
        msicr.append(int((r_ins < r_ori) and (r_rev > r_ori)))

        # mWISE
        m = max(1, len((e.get("attributes") or {}).keys()))
        delta_ins = r_ori - r_ins
        delta_rev = r_rev - r_ori
        mwise.append((delta_ins - delta_rev)/m)

        # MDCR (soft/strict): simple attribute presence in pos_text
        attrs = e.get("attributes",{}) or {}
        if attrs:
            lt = pos_text.lower()
            scores = []
            for k,v in attrs.items():
                scores.append(1.0 if str(v).lower() in lt else 0.0)
            soft = float(np.mean(scores))
            thr = max(0.45, soft - 0.05)  # same heuristic you used
            mdcr_soft.append(soft)
            mdcr_s.append(int(all(s >= thr for s in scores)))
        else:
            mdcr_soft.append(0.0); mdcr_s.append(0)

    def avg(x): return float(np.mean(x)) if x else 0.0
    return {
        "count": len(msicr),
        "mSICR": avg(msicr),
        "mWISE": avg(mwise),
        "MDCR_strict": avg(mdcr_s),
        "MDCR_soft": avg(mdcr_soft),
    }

metrics = evaluate(entries, limit=MAX_EVAL)
print("\n📊 Reranker metrics (subset):")
for k,v in metrics.items(): print(f"{k}: {v:.4f}" if isinstance(v,float) else f"{k}: {v}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Evaluating reranker: 100%|██████████| 1000/1000 [09:13<00:00,  1.81it/s]


📊 Reranker metrics (subset):
count: 1000
mSICR: 0.0400
mWISE: -0.3048
MDCR_strict: 0.0120
MDCR_soft: 0.1610





New one

In [None]:
# ================================================================
# Fast Beyond-Content Reranker++ (T4-safe)
# - List-wise (instructed), pos-only margins (original/reversed)
# - LoRA adapters, small DeBERTa, AMP FP16, no torch.compile
# - Works well on Tesla T4 / Colab
#
# Required files in CWD:
#   - final_sorted.jsonl
#   - query-doc.json
# ================================================================
!pip -q install transformers peft accelerate sentencepiece rank-bm25

import os, json, random, numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
from tqdm import tqdm
from contextlib import nullcontext

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from rank_bm25 import BM25Okapi

# -------- Optional: LoRA (speeds training by tuning fewer params)
try:
    from peft import LoraConfig, get_peft_model, TaskType
    PEFT_AVAILABLE = True
except Exception:
    PEFT_AVAILABLE = False

# ---------------- CONFIG (T4-safe defaults) ----------------
SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

FINAL_JSONL = "final_sorted.jsonl"
QDOC_JSON   = "query-doc.json"

BASE_MODEL   = "microsoft/deberta-v3-xsmall"   # quick backbone; upgrade later to -small/-base once stable
USE_LORA     = True and PEFT_AVAILABLE
MAX_LEN      = 192
BATCH_SIZE   = 12
GRAD_ACCUM   = 1
EPOCHS       = 1
LR_MAIN      = 4e-4 if USE_LORA else 3e-5
WARMUP_RATIO = 0.06
NUM_WORKERS  = 0                   # 0 is simplest/most stable on Colab

K_CAND       = 6                   # pos + hard_neg + (K-2) others
BM25_HARDK   = 4
MAX_TRAIN    = 10000               # cap for speed; set None for full
MAX_VAL      = 1200

# Margins / weights
MARGIN_IO    = 0.6                 # instructed > original
MARGIN_OR    = 0.6                 # original   > reversed
MARGIN_INS   = 0.25                # instructed pos > each neg (intra-list)
W_LIST       = 1.0
W_PAIR       = 1.0
LAMBDA_ATTR  = 0.25                # small aux weight
W_REVERSE_BOOST = 2.0

OUTPUT_DIR   = "reranker_beyond_listwise_fast_t4"
RESUME       = False
ATTR_KEYS    = ["audience","format","language","length","source","keyword"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GPU_NAME = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
AMP_DTYPE = torch.float16 if torch.cuda.is_available() else None

# Speed/stability knobs (good for T4)
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("medium")

# ---------------- IO ----------------
def load_jsonl(path):
    rows=[]
    with open(path,"r",encoding="utf-8") as f:
        for i,l in enumerate(f,1):
            l=l.strip()
            if not l: continue
            try: rows.append(json.loads(l))
            except: pass
    return rows

assert os.path.exists(FINAL_JSONL), f"Missing {FINAL_JSONL} in current directory."
assert os.path.exists(QDOC_JSON),   f"Missing {QDOC_JSON} in current directory."

with open(QDOC_JSON,"r",encoding="utf-8") as f:
    qdoc = json.load(f)
entries = load_jsonl(FINAL_JSONL)
print(f"Loaded entries: {len(entries)} | query-doc groups: {len(qdoc)} | GPU: {GPU_NAME}")

# ---------------- Pools & per-q BM25 ----------------
def norm_text(d):
    if isinstance(d, dict): return str(d.get("text","")).strip()
    if isinstance(d, list) and len(d)>0:
        first=d[0]
        return (first.get("text","").strip() if isinstance(first,dict) else str(first).strip())
    if isinstance(d, str): return d.strip()
    return ""

qid_to_pool = {}
for group in qdoc:
    qid = group["query_id"]
    docs = []
    for d in group.get("documents", []):
        t = d.get("text","").strip()
        if t: docs.append(t)
    ded, seen = [], set()
    for t in docs:
        key = t[:200]
        if key in seen: continue
        ded.append(t); seen.add(key)
    qid_to_pool[qid] = ded

def tokenize(s): return s.lower().split()
bm25_by_qid = {}
for qid, pool in qid_to_pool.items():
    if len(pool) < 2: continue
    tokenized = [tokenize(t) for t in pool]
    bm25_by_qid[qid] = (BM25Okapi(tokenized), pool, tokenized)

# ---------------- Weak attribute labels (string-match heuristic) ----------------
def weak_attr_vec(text:str, attrs:Dict[str,Any]) -> np.ndarray:
    lt = (text or "").lower()
    return np.array([
        1.0 if (k in (attrs or {}) and str(attrs[k]).lower() in lt) else 0.0
        for k in ATTR_KEYS
    ], dtype=np.float32)

# ---------------- Items ----------------
@dataclass
class TrainingItem:
    qid:str; q_ori:str; q_ins:str; q_rev:str
    pos_text:str; neg_text:str; attrs:Dict[str,Any]

def build_items(records:List[Dict[str,Any]])->List[TrainingItem]:
    items=[]
    for e in records:
        qid = e.get("query_id")
        q_ori = e.get("query",""); q_ins = e.get("instructed_query",""); q_rev = e.get("reversed_query","")
        pos  = norm_text(e.get("positive_doc")); neg  = norm_text(e.get("hard_negative_doc"))
        if qid in qid_to_pool and q_ori and q_ins and q_rev and pos and neg:
            items.append(TrainingItem(qid,q_ori,q_ins,q_rev,pos,neg,e.get("attributes",{}) or {}))
    random.shuffle(items)
    return items

items_all = build_items(entries)
if MAX_TRAIN: items_all = items_all[:MAX_TRAIN + (MAX_VAL or 0)]
split = int(0.85 * len(items_all))
train_items, val_items = items_all[:split], items_all[split:]
if MAX_VAL: val_items = val_items[:MAX_VAL]
print(f"Using items: train={len(train_items)} | val={len(val_items)}")

# ---------------- Tokenizer ----------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)

def pack_pairs(queries:List[str], docs:List[str]):
    texts = [f"[QUERY]\n{q}\n[DOCUMENT]\n{d}" for q,d in zip(queries, docs)]
    return tokenizer(texts, max_length=MAX_LEN, truncation=True, padding=True, return_tensors="pt")

# ---------------- Dataset & Collator ----------------
class ListwiseDataset(Dataset):
    def __init__(self, items:List[TrainingItem], k:int=6):
        self.items = items; self.k = k
    def __len__(self): return len(self.items)
    def _bm25_negs(self, qid:str, query:str, avoid:set, want:int)->List[str]:
        pack = bm25_by_qid.get(qid)
        if not pack: return []
        bm25, pool, _ = pack
        scores = bm25.get_scores(tokenize(query))
        order = np.argsort(scores)[::-1]
        out=[]
        for idx in order:
            doc = pool[idx]
            if doc in avoid: continue
            out.append(doc)
            if len(out) >= want: break
        return out
    def __getitem__(self, i):
        x = self.items[i]
        pool = qid_to_pool.get(x.qid, [])
        cand = [x.pos_text]; avoid = set([x.pos_text])
        if x.neg_text and x.neg_text not in avoid:
            cand.append(x.neg_text); avoid.add(x.neg_text)
        need = max(0, K_CAND - len(cand))
        bmnegs = self._bm25_negs(x.qid, x.q_ins, avoid, want=min(need, BM25_HARDK))
        for d in bmnegs:
            if d not in avoid: cand.append(d); avoid.add(d)
        if len(cand) < K_CAND:
            extras = [d for d in pool if d not in avoid]
            random.shuffle(extras); cand.extend(extras[:K_CAND - len(cand)])
        cand = cand[:K_CAND]
        return {
            "qid": x.qid, "q_ori": x.q_ori, "q_ins": x.q_ins, "q_rev": x.q_rev,
            "cands": cand, "pos_index": 0, "pos_attr": weak_attr_vec(x.pos_text, x.attrs)
        }

class FastCollator:
    """INS: full K; ORI/REV: positive-only (for speed)."""
    def __init__(self, k:int=6): self.k=k
    def __call__(self, batch):
        B = len(batch); K = self.k
        docs_ins, queries_ins = [], []
        pos_docs, queries_ori_pos, queries_rev_pos = [], [], []
        pos_indices = []; pos_attr = []
        for b in batch:
            cands = b["cands"]
            if len(cands) < K: cands = cands + [cands[-1]]*(K-len(cands))
            cands = cands[:K]
            docs_ins.extend(cands); queries_ins.extend([b["q_ins"]]*K)
            pos_doc = cands[0]
            pos_docs.append(pos_doc); queries_ori_pos.append(b["q_ori"]); queries_rev_pos.append(b["q_rev"])
            pos_indices.append(0); pos_attr.append(b["pos_attr"])
        tok_ins     = pack_pairs(queries_ins,     docs_ins)
        tok_pos_ori = pack_pairs(queries_ori_pos, pos_docs)
        tok_pos_rev = pack_pairs(queries_rev_pos, pos_docs)
        return {
            "tok_ins": {k:v for k,v in tok_ins.items()},
            "tok_pos_ori": {k:v for k,v in tok_pos_ori.items()},
            "tok_pos_rev": {k:v for k,v in tok_pos_rev.items()},
            "pos_indices": torch.tensor(pos_indices, dtype=torch.long),
            "pos_attr": torch.tensor(np.stack(pos_attr), dtype=torch.float32)
        }

# ---------------- Model ----------------
class CrossEncoderWithAttr(nn.Module):
    def __init__(self, base=BASE_MODEL, dropout=0.1, grad_ckpt=True):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base)
        if grad_ckpt and hasattr(self.encoder, "gradient_checkpointing_enable"):
            self.encoder.gradient_checkpointing_enable()
        if hasattr(self.encoder, "enable_input_require_grads"):
            self.encoder.enable_input_require_grads()
        h = self.encoder.config.hidden_size
        self.dp = nn.Dropout(dropout)
        self.rank = nn.Linear(h,1)
        self.attr = nn.Linear(h,len(ATTR_KEYS))
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)
        cls = self.dp(out.last_hidden_state[:,0,:])
        return self.rank(cls).squeeze(-1), self.attr(cls)

model = CrossEncoderWithAttr(BASE_MODEL, grad_ckpt=True)
if USE_LORA:
    try:
        lconf = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=16, lora_alpha=32, lora_dropout=0.05,
            target_modules=["query_proj","key_proj","value_proj","dense","out_proj"]
        )
        model = get_peft_model(model, lconf)
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        print("LoRA attached (FEATURE_EXTRACTION).")
    except Exception as e:
        print(f"LoRA attach failed; continuing without LoRA: {e}")
        USE_LORA=False

model = model.to(device)

# Freeze backbone; train only LoRA + heads
def mark_trainable(m):
    trainable, total = 0, 0
    for n,p in m.named_parameters():
        total += p.numel()
        if ("lora_" in n) or ("rank" in n) or ("attr" in n):
            p.requires_grad_(True); trainable += p.numel()
        else:
            p.requires_grad_(False)
    print(f"Trainable params: {trainable/1e6:.2f}M / {total/1e6:.2f}M")
mark_trainable(model)

# ---------------- Losses ----------------
def listwise_ce(scores_bk:torch.Tensor, pos_idx:torch.Tensor)->torch.Tensor:
    return nn.CrossEntropyLoss()(scores_bk, pos_idx)

def bce_simple(logits, targets):  # average BCE
    return nn.BCEWithLogitsLoss()(logits, targets)

# ---------------- Dataloaders ----------------
train_ds = ListwiseDataset(train_items, k=K_CAND)
val_ds   = ListwiseDataset(val_items,   k=K_CAND)
collate  = FastCollator(k=K_CAND)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate, num_workers=NUM_WORKERS)

# ---------------- Optim & Sched ----------------
opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=LR_MAIN, weight_decay=0.01)
total_steps = max(1, (len(train_loader) * EPOCHS) // GRAD_ACCUM)
warm_steps  = int(total_steps * WARMUP_RATIO)
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=warm_steps, num_training_steps=total_steps)

# ---------------- Helpers ----------------
def forward_scores(tok_dict):
    tok = {k: v.to(device) for k,v in tok_dict.items()}
    s, a = model(**tok)
    return s, a

def compute_losses(batch, K:int)->Tuple[torch.Tensor, dict]:
    B = batch["pos_indices"].shape[0]
    pos_idx = batch["pos_indices"].to(device)

    # 1) INS list-wise over K
    s_ins_flat, a_ins_flat = forward_scores(batch["tok_ins"])       # [B*K], [B*K, A]
    s_ins = s_ins_flat.view(B, K)
    L_list = listwise_ce(s_ins, pos_idx)

    # Attribute aux on positive
    A = len(ATTR_KEYS)
    a_ins = a_ins_flat.view(B, K, A)
    a_ins_pos = a_ins[torch.arange(B), pos_idx, :]
    attr_t = batch["pos_attr"].to(device)
    L_attr = bce_simple(a_ins_pos, attr_t)

    # 2) POS-only margins for ORI/REV
    s_ori_pos, _ = forward_scores(batch["tok_pos_ori"])             # [B]
    s_rev_pos, _ = forward_scores(batch["tok_pos_rev"])             # [B]
    s_ins_pos = s_ins[torch.arange(B), pos_idx]                     # [B]

    L_io = torch.clamp(MARGIN_IO - (s_ins_pos - s_ori_pos), min=0).mean()
    L_or = torch.clamp(MARGIN_OR - (s_ori_pos - s_rev_pos), min=0).mean() * W_REVERSE_BOOST

    # Intra-instructed margin (pos > each neg)
    neg_mask = torch.ones_like(s_ins, device=device); neg_mask[torch.arange(B), pos_idx] = 0.0
    L_intra = torch.clamp(MARGIN_INS - (s_ins_pos.unsqueeze(1) - s_ins), min=0) * neg_mask
    denom = neg_mask.sum(dim=1).clamp(min=1.0)
    L_intra = (L_intra.sum(dim=1) / denom).mean()

    loss = W_LIST*L_list + W_PAIR*(L_io + L_or + L_intra) + LAMBDA_ATTR*L_attr
    logs = dict(L_list=L_list.item(), L_io=L_io.item(), L_or=L_or.item(), L_intra=L_intra.item(), L_attr=L_attr.item())
    return loss, logs

def evaluate(val_loader, K:int)->float:
    model.eval(); losses=[]
    with torch.no_grad():
        for batch in val_loader:
            if torch.cuda.is_available():
                with torch.autocast(device_type="cuda", dtype=AMP_DTYPE):
                    l,_ = compute_losses(batch, K)
            else:
                l,_ = compute_losses(batch, K)
            losses.append(l.item())
    return float(np.mean(losses)) if losses else 0.0

# ---------------- Train ----------------
best_val = float("inf")
print("Starting training…")
for epoch in range(1, EPOCHS+1):
    model.train()
    pbar=tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    meter={"L_list":0,"L_io":0,"L_or":0,"L_intra":0,"L_attr":0,"n":0}
    for batch in pbar:
        if torch.cuda.is_available():
            with torch.autocast(device_type="cuda", dtype=AMP_DTYPE):
                loss, logs = compute_losses(batch, K_CAND)
        else:
            loss, logs = compute_losses(batch, K_CAND)

        (loss/GRAD_ACCUM).backward()
        torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)

        opt.step(); sched.step(); opt.zero_grad(set_to_none=True)

        for k in ["L_list","L_io","L_or","L_intra","L_attr"]: meter[k]+=logs[k]
        meter["n"]+=1
        show = {k: round(meter[k]/meter["n"],4) for k in ["L_list","L_io","L_or","L_intra","L_attr"]}
        show["loss"]=round(loss.item(),4); pbar.set_postfix(show)

    v = evaluate(val_loader, K_CAND)
    print(f">> Val (combined loss): {v:.4f}")
    if v < best_val:
        best_val = v
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "pytorch_model.bin"))
        tokenizer.save_pretrained(OUTPUT_DIR)
        print(f"Saved BEST to {OUTPUT_DIR} (val={best_val:.4f})")

print("Done.")


Loaded entries: 9596 | query-doc groups: 515 | GPU: Tesla T4
Using items: train=8112 | val=1200
LoRA attached (FEATURE_EXTRACTION).
Trainable params: 1.33M / 72.01M
Starting training…


Epoch 1/1:  25%|██▌       | 169/676 [03:38<09:52,  1.17s/it, L_list=1.51, L_io=0.251, L_or=0.366, L_intra=0.236, L_attr=0.342, loss=1.61]