# ðŸ§  Tweet Topic Classification â€” Robust v3 (Colab Ready)

This notebook trains a robust tweet topic classifier with:
- Test-Time Augmentation (TTA)
- Trimmed/bagged centroids or Mahalanobis
- Calibrated Logistic Regression
- kNN blending
- Per-class open-set thresholds

Artifacts are saved under `./artifacts_v3`. Run the setup cell once, then run all cells.


In [None]:
# Setup: install dependencies (Colab-safe)
import sys, subprocess, importlib

def ensure(pkg):
    try:
        importlib.import_module(pkg.split("==")[0].split(">=")[0])
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

for p in [
    "datasets>=2.14.0",
    "pandas>=2.0.0",
    "numpy>=1.23.0",
    "scikit-learn>=1.3.0",
    "sentence-transformers>=2.2.2",
    "emoji",
    "wordsegment",
    "matplotlib>=3.7.0",
    "scipy>=1.10.0",
    "joblib>=1.3.0",
]:
    ensure(p)

print("Dependencies ensured.")


In [None]:
# Imports & reproducibility
import os, re, json, math, random
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from datasets import load_dataset, DatasetDict
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.preprocessing import normalize
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from sklearn.covariance import LedoitWolf

from sentence_transformers import SentenceTransformer

# Reproducibility
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
np.random.seed(SEED); random.seed(SEED)

def set_single_thread():
    for var in ["MKL_NUM_THREADS","NUMEXPR_NUM_THREADS","OMP_NUM_THREADS","OPENBLAS_NUM_THREADS"]:
        os.environ[var] = "1"
set_single_thread()

print("Environment ready.")


In [None]:
# Configuration
DATASET_NAME = "cardiffnlp/tweet_topic_single"   # 6-label single-label task

EMBEDDER_CANDIDATES = [
    "intfloat/multilingual-e5-base",
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
]
BATCH_SIZE = 64
SHOW_PROGRESS = True

# TTA
ENABLE_TTA = True
TTA_VARIANTS = ["light", "heavy", "nohashtag"]

# Prototypes
USE_MAHALANOBIS = True
TRIM_FRAC = 0.10
BOOTSTRAP_CENTROIDS = 3

# Logistic Regression
LR_MAX_ITER = 1500
LR_C = 1.0
LR_CLASS_WEIGHT = "balanced"

# kNN
KNN_K = 5

# Calibration & Open-set
CALIBRATION_METHOD = "sigmoid"
ENABLE_OPEN_SET = True
OPEN_SET_STRATEGY = "per_class"
TARGET_MIN_PRECISION = 0.80

# Ensemble grid
ENSEMBLE_GRID = [0.0, 0.25, 0.5, 0.75, 1.0]

# Domain priors (optional)
DOMAIN_PRIORS = {}

# Artifacts
ARTIFACT_DIR = "./artifacts_v3"
os.makedirs(ARTIFACT_DIR, exist_ok=True)
print("Config ready.")


In [None]:
# Dataset loading

def discover_splits(ds: DatasetDict):
    keys = list(ds.keys())
    prefer = [
        ("train_coling2022", "validation_coling2022", "test_coling2022"),
        ("train_coling2022_random", "validation_coling2022_random", "test_coling2022_random"),
        ("train", "validation", "test"),
    ]
    for tr, va, te in prefer:
        if tr in keys and te in keys:
            return tr, (va if va in keys else None), te
    tr = "train" if "train" in keys else keys[0]
    te = "test" if "test" in keys else keys[-1]
    va = "validation" if "validation" in keys else None
    return tr, va, te

print("Loading dataset:", DATASET_NAME)
ds = load_dataset(DATASET_NAME)
train_key, val_key, test_key = discover_splits(ds)
train_ds, test_ds = ds[train_key], ds[test_key]
val_ds = ds[val_key] if val_key else None

label_field = "label" if "label" in train_ds.features else ("labels" if "labels" in train_ds.features else None)
text_field = "text"
label_names = train_ds.features[label_field].names if label_field == "label" and hasattr(train_ds.features[label_field], "names") else None

print(f"Splits: train={train_key}, val={val_key}, test={test_key}")
print("Labels:", label_names if label_names else "multi-label or not provided")


In [None]:
# Preprocessing + TTA cleaners
import emoji
from wordsegment import load as ws_load, segment as ws_segment
ws_load()

def split_hashtag(ht: str) -> str:
    if not ht or ht[0] != "#": return ht
    w = ht[1:]
    try: return " ".join(ws_segment(w))
    except: return w

def clean_light(t: str) -> str:
    t = str(t)
    t = re.sub(r"http\S+|www\.\S+"," ", t)
    t = re.sub(r"@\w+"," ", t)
    t = re.sub(r"#\w+", lambda m: " " + split_hashtag(m.group()), t)
    t = emoji.replace_emoji(t, replace=" ")
    t = re.sub(r"(.)\1{2,}", r"\1\1", t)
    t = re.sub(r"[^0-9A-Za-z\u0900-\u097F\s']", " ", t)
    return re.sub(r"\s+"," ", t).strip()

def clean_heavy(t: str) -> str:
    t = clean_light(t)
    t = t.lower()
    t = re.sub(r"\b(rt)\b"," ", t)
    return re.sub(r"\s+"," ", t).strip()

def clean_nohashtag(t: str) -> str:
    t = str(t)
    t = re.sub(r"http\S+|www\.\S+"," ", t)
    t = re.sub(r"@\w+"," ", t)
    t = re.sub(r"#\w+"," ", t)
    t = emoji.replace_emoji(t, replace=" ")
    t = re.sub(r"(.)\1{2,}", r"\1\1", t)
    t = re.sub(r"[^0-9A-Za-z\u0900-\u097F\s']", " ", t)
    return re.sub(r"\s+"," ", t).strip()

CLEANERS = {"light": clean_light, "heavy": clean_heavy, "nohashtag": clean_nohashtag}


In [None]:
# Convert to DataFrame

def ds_to_df(d):
    return d.to_pandas()

df_train, df_test = ds_to_df(train_ds), ds_to_df(test_ds)
df_val = ds_to_df(val_ds) if val_ds is not None else None

is_multi_label = (label_field == "labels")
if is_multi_label:
    def indices_from_multihot(vec): return [i for i, v in enumerate(vec) if int(v) == 1]
    df_train["labels_idx"] = df_train[label_field].apply(indices_from_multihot)
    df_test["labels_idx"] = df_test[label_field].apply(indices_from_multihot)
    if df_val is not None:
        df_val["labels_idx"] = df_val[label_field].apply(indices_from_multihot)
else:
    if "label_name" not in df_train.columns and label_names is not None:
        df_train["label_name"] = df_train["label"].apply(lambda i: label_names[i])
        df_test["label_name"]  = df_test["label"].apply(lambda i: label_names[i])
        if df_val is not None:
            df_val["label_name"] = df_val["label"].apply(lambda i: label_names[i])

print("Train size:", len(df_train), "Test size:", len(df_test), "Val size:", (0 if df_val is None else len(df_val)))


In [None]:
# Embedding utilities

def get_embedder(name: str):
    return SentenceTransformer(name)

def encode_texts(model, texts: List[str], tta_variants=None, batch_size=64, show_progress=True) -> np.ndarray:
    if not tta_variants:
        tta_variants = ["light"]
    embs = None
    for var in tta_variants:
        cleaner = CLEANERS.get(var, clean_light)
        cleaned = [cleaner(t) for t in texts]
        E = model.encode(cleaned, batch_size=batch_size, show_progress_bar=show_progress,
                         convert_to_numpy=True, normalize_embeddings=True)
        embs = E if embs is None else (embs + E)
    embs = embs / float(len(tta_variants))
    return embs

def single_label_vectors(df, model, tta=True):
    texts = df[text_field].astype(str).tolist()
    X = encode_texts(model, texts, tta_variants=(TTA_VARIANTS if (tta and ENABLE_TTA) else ["light"]),
                     batch_size=BATCH_SIZE, show_progress=SHOW_PROGRESS)
    if "label_name" in df.columns:
        y = np.array(df["label_name"].tolist())
    else:
        y = np.array(df["label"].tolist())
        if not isinstance(y[0], str) and label_names is not None:
            y = np.array([label_names[i] for i in y])
    return X, y


In [None]:
# Optional label descriptions (for prompt blending)
DEFAULT_LABEL_PROMPTS = {
    "arts & culture": "Arts culture literature museums theater painting design creativity art exhibitions books festivals",
    "business & entrepreneurs": "Business startups markets finance economy entrepreneurship investing commerce companies profits",
    "pop culture": "Celebrities movies TV music fandom memes entertainment viral trends pop culture gossip",
    "daily life": "Everyday life personal updates family friends work routine feelings opinions misc general chatter",
    "sports & gaming": "Sports cricket football basketball olympics esports gaming matches tournaments teams players",
    "science & technology": "Technology science research engineering gadgets software AI data innovation programming internet"
}


In [None]:
# Prototypes & utilities
from scipy.optimize import minimize_scalar


def trimmed_centroid(X: np.ndarray, frac=0.10):
    if frac <= 0 or frac >= 0.5:
        return normalize(X.mean(axis=0, keepdims=True))[0]
    mu = X.mean(axis=0, keepdims=True)
    d = 1 - (normalize(X) @ normalize(mu).T).ravel()  # cosine distance
    keep = int(max(1, math.ceil((1-frac)*len(X))))
    idx = np.argsort(d)[:keep]
    c = normalize(X[idx].mean(axis=0, keepdims=True))[0]
    return c


def compute_centroids_trimmed(X: np.ndarray, y: np.ndarray, labels: List[str], trim_frac=0.1) -> np.ndarray:
    cents = []
    for lab in labels:
        Xi = X[y == lab]
        cents.append(trimmed_centroid(Xi, frac=trim_frac))
    cents = np.vstack(cents)
    return normalize(cents)


def compute_mahalanobis_params(X: np.ndarray, y: np.ndarray, labels: List[str]):
    means, precisions = [], []
    for lab in labels:
        Xi = X[y == lab]
        means.append(Xi.mean(axis=0))
        if Xi.shape[0] > 2:
            lw = LedoitWolf().fit(Xi)
            precisions.append(lw.precision_)
        else:
            precisions.append(np.eye(X.shape[1]))
    return np.vstack(means), precisions


def cosine_logits(X, centroids):
    return X @ centroids.T


def mahalanobis_logits(X, means, precisions):
    sims = np.zeros((X.shape[0], means.shape[0]))
    for k in range(means.shape[0]):
        diff = X - means[k]
        md2 = np.sum(diff @ precisions[k] * diff, axis=1)
        sims[:, k] = -md2
    return sims


def temperature_scale_logits(logits: np.ndarray, T: float) -> np.ndarray:
    return logits / max(T, 1e-6)


def softmax(L):
    S = L - L.max(axis=1, keepdims=True)
    e = np.exp(S)
    return e / e.sum(axis=1, keepdims=True)


In [None]:
# Training for one embedder

def run_for_embedder(embedder_name: str, save_prefix: str = "run"):
    print(f"\n=== Embedder: {embedder_name} ===")
    model = get_embedder(embedder_name)

    Xtr, ytr = single_label_vectors(df_train, model, tta=True)
    Xte, yte = single_label_vectors(df_test,  model, tta=True)
    if df_val is not None:
        Xva, yva = single_label_vectors(df_val, model, tta=True)
    else:
        Xtr, Xva, ytr, yva = train_test_split(Xtr, ytr, test_size=0.15, stratify=ytr, random_state=42)

    labels = sorted(list(set(ytr)))
    lab_to_idx = {l:i for i,l in enumerate(labels)}

    # Centroids (trimmed + bagging)
    if BOOTSTRAP_CENTROIDS > 1:
        cents_list = []
        for b in range(BOOTSTRAP_CENTROIDS):
            idx = np.random.RandomState(42+b).choice(len(Xtr), len(Xtr), replace=True)
            cents_list.append(compute_centroids_trimmed(Xtr[idx], ytr[idx], labels, trim_frac=TRIM_FRAC))
        centroids = normalize(np.mean(np.stack(cents_list, axis=0), axis=0))
    else:
        centroids = compute_centroids_trimmed(Xtr, ytr, labels, trim_frac=TRIM_FRAC)

    # Optional: blend prompts
    prompt_texts = [DEFAULT_LABEL_PROMPTS.get(l, l) for l in labels]
    P_emb = encode_texts(model, prompt_texts, tta_variants=["heavy"], batch_size=64, show_progress=False)
    P_emb = normalize(P_emb)
    centroids = normalize((1-0.2)*centroids + 0.2*P_emb)

    # Choose logits
    if USE_MAHALANOBIS:
        means, precisions = compute_mahalanobis_params(Xtr, ytr, labels)
        logits_tr = mahalanobis_logits(Xtr, means, precisions)
        logits_va = mahalanobis_logits(Xva, means, precisions)
        logits_te = mahalanobis_logits(Xte, means, precisions)
    else:
        means, precisions = None, None
        logits_tr = cosine_logits(Xtr, centroids)
        logits_va = cosine_logits(Xva, centroids)
        logits_te = cosine_logits(Xte, centroids)

    # Temperature scale (single T)
    def nll(T):
        P = softmax(temperature_scale_logits(logits_va, T))
        y_idx = np.array([lab_to_idx[v] for v in yva])
        return -np.log(P[np.arange(len(y_idx)), y_idx] + 1e-9).mean()
    res = minimize_scalar(nll, bounds=(0.2, 5.0), method="bounded")
    T_cent = float(res.x)

    P_tr_cent = softmax(temperature_scale_logits(logits_tr, T_cent))
    P_va_cent = softmax(temperature_scale_logits(logits_va, T_cent))
    P_te_cent = softmax(temperature_scale_logits(logits_te, T_cent))

    # Logistic Regression + calibration
    lr = LogisticRegression(max_iter=LR_MAX_ITER, solver="saga", C=LR_C, penalty="l2",
                            class_weight=LR_CLASS_WEIGHT, random_state=42)
    cal = CalibratedClassifierCV(lr, method=CALIBRATION_METHOD, cv=5)
    cal.fit(Xtr, ytr)
    P_va_lr = cal.predict_proba(Xva)
    P_te_lr = cal.predict_proba(Xte)
    lr_classes = list(cal.classes_)
    cols = [lr_classes.index(l) for l in labels]
    P_va_lr, P_te_lr = P_va_lr[:, cols], P_te_lr[:, cols]

    # kNN (cosine)
    nbrs = NearestNeighbors(n_neighbors=min(KNN_K, len(Xtr)), metric="cosine")
    nbrs.fit(Xtr)
    dist_va, idx_va = nbrs.kneighbors(Xva, return_distance=True)
    dist_te, idx_te = nbrs.kneighbors(Xte, return_distance=True)

    def knn_probs(idx_mat, dist_mat):
        sim = 1.0 - dist_mat
        probs = np.zeros((idx_mat.shape[0], len(labels)), dtype=float)
        for i in range(idx_mat.shape[0]):
            labs = ytr[idx_mat[i]]
            weights = sim[i] / (sim[i].sum() + 1e-9)
            for lab, w in zip(labs, weights):
                probs[i, lab_to_idx[lab]] += w
        return probs

    P_va_knn = knn_probs(idx_va, dist_va)
    P_te_knn = knn_probs(idx_te, dist_te)

    # Ensemble search
    best = None
    for a in ENSEMBLE_GRID:
        for b in ENSEMBLE_GRID:
            c = 1.0 - a - b
            if c < 0 or c > 1: continue
            P_va = a*P_va_cent + b*P_va_lr + c*P_va_knn
            pred = [labels[i] for i in np.argmax(P_va, axis=1)]
            f1m = f1_score(yva, pred, average="macro")
            if best is None or f1m > best[0]:
                best = (f1m, (a,b,c))
    ens_weights = best[1] if best else (1.0, 0.0, 0.0)

    a,b,c = ens_weights
    P_te = a*P_te_cent + b*P_te_lr + c*P_te_knn
    P_va = a*P_va_cent + b*P_va_lr + c*P_va_knn

    # Domain priors (optional)
    if DOMAIN_PRIORS:
        prior_vec = np.array([DOMAIN_PRIORS.get(l, 1.0) for l in labels], dtype=float)
        prior_vec = prior_vec / prior_vec.sum()
        P_va = (P_va * prior_vec) / (P_va * prior_vec).sum(axis=1, keepdims=True)
        P_te = (P_te * prior_vec) / (P_te * prior_vec).sum(axis=1, keepdims=True)

    # Per-class thresholds on validation
    thresholds = {}
    pred_va = np.argmax(P_va, axis=1)
    pred_lab_va = np.array([labels[i] for i in pred_va])
    maxP_va = P_va.max(axis=1)
    for l in labels:
        m = pred_lab_va == l
        if m.sum() == 0:
            thresholds[l] = 0.0
            continue
        idx = np.argsort(-maxP_va[m])
        probs = maxP_va[m][idx]
        correct = (yva[m][idx] == l).astype(int)
        tp = np.cumsum(correct)
        fp = np.cumsum(1-correct)
        precision = tp / np.maximum(tp+fp, 1)
        meet = np.where(precision >= TARGET_MIN_PRECISION)[0]
        thr = probs[meet[-1]] if len(meet) else probs.max()
        thresholds[l] = float(thr)

    # Apply thresholds on test
    pred_te_idx = np.argmax(P_te, axis=1)
    pred_te = np.array([labels[i] for i in pred_te_idx])
    maxP_te = P_te.max(axis=1)
    for i in range(len(pred_te)):
        l = pred_te[i]
        if maxP_te[i] < thresholds.get(l, 0.0):
            pred_te[i] = "Other"

    acc = accuracy_score(yte, pred_te)
    macro = f1_score(yte, pred_te, average="macro", labels=[l for l in labels if l != "Other"])
    micro = f1_score(yte, pred_te, average="micro", labels=[l for l in labels if l != "Other"])

    print(f"Ensemble weights (Centroid, LR, kNN): {ens_weights}")
    print(f"Accuracy: {acc:.4f} | F1-macro: {macro:.4f} | F1-micro: {micro:.4f}")
    print(classification_report(yte, pred_te, digits=4))

    # Save artifacts
    run_dir = os.path.join(ARTIFACT_DIR, save_prefix)
    os.makedirs(run_dir, exist_ok=True)
    meta = {
        "embedder": embedder_name,
        "labels": labels,
        "ensemble_weights": ens_weights,
        "use_mahalanobis": bool(USE_MAHALANOBIS),
        "trim_frac": TRIM_FRAC,
        "bootstrap_centroids": BOOTSTRAP_CENTROIDS,
        "tta": bool(ENABLE_TTA),
        "tta_variants": TTA_VARIANTS,
        "calibration_method": CALIBRATION_METHOD,
        "open_set": bool(ENABLE_OPEN_SET),
        "open_set_strategy": OPEN_SET_STRATEGY,
        "per_class_thresholds": thresholds,
        "target_min_precision": TARGET_MIN_PRECISION,
        "label_prompt_blend": 0.2,
        "domain_priors": DOMAIN_PRIORS,
    }
    with open(os.path.join(run_dir, "meta.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    if USE_MAHALANOBIS:
        np.savez(os.path.join(run_dir, "prototypes_mahalanobis.npz"), means=means, precisions=np.array(precisions, dtype=object))
    else:
        np.savez(os.path.join(run_dir, "prototypes_cosine.npz"), centroids=centroids)

    try:
        import joblib
        joblib.dump(cal, os.path.join(run_dir, "calibrated_lr.joblib"))
    except Exception as e:
        print("Note: could not save calibrated LR:", e)

    np.save(os.path.join(run_dir, "Xtr.npy"), Xtr)
    np.save(os.path.join(run_dir, "ytr.npy"), ytr)
    np.save(os.path.join(run_dir, "P_val.npy"), P_va)
    np.save(os.path.join(run_dir, "y_val.npy"), yva)

    return {"embedder": embedder_name, "acc": acc, "f1_macro": macro, "f1_micro": micro, "run_dir": run_dir}


In [None]:
# Run across embedders and pick the best
results = []
for i, emb in enumerate(EMBEDDER_CANDIDATES):
    res = run_for_embedder(emb, save_prefix=f"run_{i}")
    results.append(res)

res_df = pd.DataFrame(results).sort_values(["f1_macro","acc"], ascending=[False, False])
print(res_df)
best_dir = res_df.iloc[0]["run_dir"]
print("Best artifacts at:", best_dir)


In [None]:
# Inference utility (with exemplars)

def predict_texts(texts: List[str], run_dir: str, topk=3, k_exemplar=0):
    with open(os.path.join(run_dir, "meta.json"), "r", encoding="utf-8") as f:
        meta = json.load(f)
    labels = meta["labels"]
    embedder = SentenceTransformer(meta["embedder"])

    # Load prototypes
    cos_path = os.path.join(run_dir, "prototypes_cosine.npz")
    maha_path = os.path.join(run_dir, "prototypes_mahalanobis.npz")
    use_maha = os.path.exists(maha_path)
    if use_maha:
        dat = np.load(maha_path, allow_pickle=True); means = dat["means"]; precisions = list(dat["precisions"])
    else:
        dat = np.load(cos_path); centroids = dat["centroids"]

    # Load LR
    import joblib
    cal = None
    try:
        cal = joblib.load(os.path.join(run_dir, "calibrated_lr.joblib"))
    except Exception:
        pass

    # Encode with same TTA
    tta_vars = meta.get("tta_variants", ["light"]) if meta.get("tta", False) else ["light"]
    def _encode(model, texts, tta_variants):
        import emoji, re
        from wordsegment import load as ws_load, segment as ws_segment
        ws_load()
        def split_hashtag(ht: str) -> str:
            if not ht or ht[0] != "#": return ht
            w = ht[1:]
            try: return " ".join(ws_segment(w))
            except: return w
        def clean_light(t: str) -> str:
            t = str(t)
            t = re.sub(r"http\S+|www\.\S+"," ", t)
            t = re.sub(r"@\w+"," ", t)
            t = re.sub(r"#\w+", lambda m: " " + split_hashtag(m.group()), t)
            t = emoji.replace_emoji(t, replace=" ")
            t = re.sub(r"(.)\1{2,}", r"\1\1", t)
            t = re.sub(r"[^0-9A-Za-z\u0900-\u097F\s']", " ", t)
            return re.sub(r"\s+"," ", t).strip()
        def clean_heavy(t: str) -> str:
            t = clean_light(t)
            t = t.lower()
            t = re.sub(r"\b(rt)\b"," ", t)
            return re.sub(r"\s+"," ", t).strip()
        def clean_nohashtag(t: str) -> str:
            t = str(t)
            t = re.sub(r"http\S+|www\.\S+"," ", t)
            t = re.sub(r"@\w+"," ", t)
            t = re.sub(r"#\w+"," ", t)
            t = emoji.replace_emoji(t, replace=" ")
            t = re.sub(r"(.)\1{2,}", r"\1\1", t)
            t = re.sub(r"[^0-9A-Za-z\u0900-\u097F\s']", " ", t)
            return re.sub(r"\s+"," ", t).strip()
        cleaners = {"light": clean_light, "heavy": clean_heavy, "nohashtag": clean_nohashtag}
        embs = None
        for var in tta_variants:
            cleaner = cleaners.get(var, clean_light)
            cleaned = [cleaner(t) for t in texts]
            E = model.encode(cleaned, batch_size=64, show_progress_bar=False,
                             convert_to_numpy=True, normalize_embeddings=True)
            embs = E if embs is None else (embs + E)
        embs = embs / float(len(tta_variants))
        return embs

    X = _encode(embedder, texts, tta_vars)

    # Probs from centroid/maha
    if use_maha:
        sims = np.zeros((X.shape[0], len(labels)))
        for k in range(len(labels)):
            diff = X - means[k]
            md2 = np.sum(diff @ precisions[k] * diff, axis=1)
            sims[:, k] = -md2
        P_cent = np.exp(sims - sims.max(axis=1, keepdims=True))
        P_cent = P_cent / P_cent.sum(axis=1, keepdims=True)
    else:
        sims = X @ centroids.T
        P_cent = np.exp(sims - sims.max(axis=1, keepdims=True))
        P_cent = P_cent / P_cent.sum(axis=1, keepdims=True)

    # LR probs
    P_lr = None
    if cal is not None:
        P_lr = cal.predict_proba(X)
        lr_classes = list(cal.classes_)
        cols = [lr_classes.index(l) for l in labels]
        P_lr = P_lr[:, cols]

    # kNN probs (using saved Xtr,ytr)
    P_knn = None
    try:
        Xtr = np.load(os.path.join(run_dir, "Xtr.npy"))
        ytr = np.load(os.path.join(run_dir, "ytr.npy"), allow_pickle=True)
        nbrs = NearestNeighbors(n_neighbors=min(5, len(Xtr)), metric="cosine").fit(Xtr)
        dist, idx = nbrs.kneighbors(X, return_distance=True)
        sim = 1.0 - dist
        lab_to_idx = {l:i for i,l in enumerate(labels)}
        P_knn = np.zeros((len(X), len(labels)))
        for i in range(len(X)):
            labs = ytr[idx[i]]
            w = sim[i] / (sim[i].sum() + 1e-9)
            for lab, ww in zip(labs, w):
                P_knn[i, lab_to_idx[lab]] += ww
    except Exception:
        pass

    a,b,c = meta.get("ensemble_weights", [1.0,0.0,0.0])
    P = a*P_cent + (b*P_lr if P_lr is not None else 0) + (c*P_knn if P_knn is not None else 0)

    # Domain priors
    pri = meta.get("domain_priors", {})
    if pri:
        pv = np.array([pri.get(l,1.0) for l in labels])
        pv = pv / pv.sum()
        P = (P * pv) / (P * pv).sum(axis=1, keepdims=True)

    thresholds = meta.get("per_class_thresholds", None)

    idx = np.argsort(-P, axis=1)[:, :topk]
    topk_labels = [[labels[j] for j in row] for row in idx]
    topk_scores = [[float(P[i, j]) for j in row] for i, row in enumerate(idx)]
    pred = [labels[row[0]] for row in idx]

    if thresholds:
        maxp = P.max(axis=1)
        for i in range(len(pred)):
            if maxp[i] < thresholds.get(pred[i], 0.0):
                pred[i] = "Other"

    # Exemplars
    exemplars = None
    if k_exemplar and os.path.exists(os.path.join(run_dir, "Xtr.npy")):
        Xtr = np.load(os.path.join(run_dir, "Xtr.npy"))
        ytr = np.load(os.path.join(run_dir, "ytr.npy"), allow_pickle=True)
        sim = X @ Xtr.T
        exemplars = []
        for i in range(len(texts)):
            idk = np.argsort(-sim[i])[:k_exemplar]
            exemplars.append([(float(sim[i, j]), str(ytr[j])) for j in idk])

    return {"pred": pred, "topk_labels": topk_labels, "topk_scores": topk_scores, "exemplars": exemplars}


In [None]:
# CSV inference helper

def predict_on_csv(csv_path: str, run_dir: str, text_col="text", sample=None, topk=3, k_exemplar=0):
    if not os.path.exists(csv_path):
        print("CSV not found:", csv_path); return None
    dfu = pd.read_csv(csv_path)
    if text_col not in dfu.columns:
        for c in dfu.columns:
            if c.lower() == "text": text_col = c; break
    if text_col not in dfu.columns:
        print("Column not found. Available:", dfu.columns.tolist()); return None
    dfx = dfu[[text_col]].dropna().rename(columns={text_col:"text"})
    if sample and len(dfx) > sample:
        dfx = dfx.sample(sample, random_state=42)
    out = predict_texts(dfx["text"].tolist(), run_dir, topk=topk, k_exemplar=k_exemplar)
    df_out = dfx.copy()
    df_out["pred_topic"] = out["pred"]
    df_out["topk_topics"] = ["; ".join(x) for x in out["topk_labels"]]
    df_out["topk_scores"] = [", ".join([f"{s:.3f}" for s in xs]) for xs in out["topk_scores"]]
    out_path = os.path.join(ARTIFACT_DIR, "predictions_user_csv_v3.csv")
    df_out.to_csv(out_path, index=False)
    print("Saved:", out_path)
    return df_out

# Example usage (uncomment and set your CSV path):
# USER_CSV = "/content/your.csv"
# _ = predict_on_csv(USER_CSV, best_dir, text_col="text", sample=500, topk=3, k_exemplar=3)
