# LLM Reranker for Recommenders

Inference-time steps that improve relevance without destabilizing the system:

- Query rewriting: normalize, expand intent
- Item rewriting: better item text for matching
- Reranking: LLM-style cross scoring over a small candidate set
- Hybrid scoring: combine model score + LLM score
- Structured outputs + strict parsing
- Caching, latency budgets, cost controls
- Fallback logic that preserves availability


## 0) Setup


In [1]:
import json
import time
import re
from collections import OrderedDict

import numpy as np
import pandas as pd
from numpy.random import default_rng

rng = default_rng(7)

def l2_normalize(x, axis=-1, eps=1e-12):
    n = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(n, eps)

def ndcg_at_k(relevances, k=10):
    rel = np.asarray(relevances)[:k]
    if rel.size == 0:
        return 0.0
    discounts = 1.0 / np.log2(np.arange(2, rel.size + 2))
    dcg = (rel * discounts).sum()
    ideal = np.sort(rel)[::-1]
    idcg = (ideal * discounts).sum()
    return float(dcg / idcg) if idcg > 0 else 0.0

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def tokenize(text):
    return re.findall(r"[a-z0-9]+", (text or "").lower())


## 1) Synthetic catalog + baseline ranker


In [2]:
# Catalog with text and latent factors (baseline ranker sees embeddings, LLM sees text)
n_items = 2500
d_latent = 32

topics = [
    "wireless mouse", "gaming keyboard", "ergonomic chair", "standing desk",
    "noise cancelling headphones", "usb c hub", "4k monitor", "webcam",
    "notebook", "pen", "coffee machine", "air purifier", "yoga mat", "dumbbells",
]

def make_title():
    t = rng.choice(topics)
    brand = rng.choice(["Aster", "Norda", "Kite", "Vento", "Orion", "Mori", "Delta"])
    attr = rng.choice(["pro", "mini", "max", "plus", "lite", "ultra"])
    return f"{brand} {t} {attr}"

titles = [make_title() for _ in range(n_items)]

synonyms = {
    "wireless": ["cordless", "bt", "bluetooth"],
    "headphones": ["cans", "headset"],
    "standing": ["sit-stand", "adjustable"],
    "monitor": ["display", "screen"],
    "keyboard": ["keys", "mechanical keyboard"],
}

def augment(text):
    toks = text.split()
    out = []
    for w in toks:
        out.append(w)
        if w in synonyms and rng.random() < 0.35:
            out.append(rng.choice(synonyms[w]))
    if rng.random() < 0.25:
        out.append(rng.choice(["for work", "for gaming", "for home office", "travel friendly"]))
    return " ".join(out)

descriptions = [augment(t) for t in titles]
items = pd.DataFrame({"item_id": np.arange(n_items), "title": titles, "text": descriptions})

# Latent item vectors used by baseline ranker
V = l2_normalize(rng.normal(size=(n_items, d_latent)).astype(np.float32))

def sample_user_intent():
    v = rng.normal(size=(d_latent,)).astype(np.float32)
    return l2_normalize(v)

def retrieve_topk(user_vec, k=200):
    scores = (V @ user_vec).astype(np.float32)
    idx = np.argpartition(scores, -k)[-k:]
    idx = idx[np.argsort(scores[idx])[::-1]]
    return idx, scores[idx]

items.head()


Unnamed: 0,item_id,title,text
0,0,Orion dumbbells lite,Orion dumbbells lite
1,1,Orion yoga mat lite,Orion yoga mat lite for gaming
2,2,Norda air purifier pro,Norda air purifier pro
3,3,Norda noise cancelling headphones ultra,Norda noise cancelling headphones ultra
4,4,Aster yoga mat max,Aster yoga mat max


## 2) Mock LLM service with structured outputs, caching, and latency


In [3]:
class LRUCache:
    def __init__(self, max_size=2048):
        self.max_size = int(max_size)
        self._d = OrderedDict()
        self.hits = 0
        self.misses = 0

    def get(self, key):
        if key in self._d:
            self._d.move_to_end(key)
            self.hits += 1
            return self._d[key]
        self.misses += 1
        return None

    def put(self, key, value):
        self._d[key] = value
        self._d.move_to_end(key)
        if len(self._d) > self.max_size:
            self._d.popitem(last=False)

    def stats(self):
        total = self.hits + self.misses
        return {
            "size": len(self._d),
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": (self.hits / total) if total else 0.0,
        }

cache = LRUCache(max_size=4096)

def estimate_tokens(text):
    return max(1, len(tokenize(text)) * 3)

def mock_llm_call(prompt, temperature=0.0):
    # latency grows with prompt size
    tok = estimate_tokens(prompt)
    base_ms = 12 + 0.06 * tok
    jitter_ms = rng.normal(0, 3)
    time.sleep(max(0.0, (base_ms + jitter_ms) / 1000.0))

    # occasional formatting failure
    fail = (rng.random() < 0.04) if temperature <= 0.2 else (rng.random() < 0.08)
    if fail:
        return "RESPONSE: score=0.83 rationale=looks good"  # invalid JSON

    # deterministic-ish scoring derived from prompt hash
    h = (abs(hash(prompt)) % 10_000) / 10_000.0
    score = 0.55 + 0.35 * h
    out = {
        "relevance": float(min(1.0, max(0.0, score))),
        "rationale": "Matches intent keywords; penalizes vague or mismatched terms.",
        "policy": {"allowed": True, "reason": ""},
    }
    return json.dumps(out)


## 3) Query rewriting: normalize intent, expand synonyms


In [4]:
STOP = set(["the","a","an","and","or","to","for","with","of","in","on","at","is","are"])
QUERY_SYNONYMS = {
    "bt": "bluetooth",
    "cordless": "wireless",
    "cans": "headphones",
    "sit-stand": "standing",
    "screen": "monitor",
    "display": "monitor",
    "keys": "keyboard",
    "headset": "headphones",
}

def rewrite_query(q):
    toks = [t for t in tokenize(q) if t not in STOP]
    toks = [QUERY_SYNONYMS.get(t, t) for t in toks]
    if "monitor" in toks and "4k" not in toks and rng.random() < 0.5:
        toks.append("4k")
    if "headphones" in toks and "noise" not in toks and rng.random() < 0.5:
        toks.extend(["noise", "cancelling"])
    seen = set()
    out = []
    for t in toks:
        if t not in seen:
            out.append(t); seen.add(t)
    return " ".join(out)

q_raw = "Need a cordless headset for work calls"
q_rewritten = rewrite_query(q_raw)
q_raw, q_rewritten


('Need a cordless headset for work calls',
 'need wireless headphones work calls')

## 4) Item rewriting


In [5]:
def rewrite_item_text(title, text):
    toks = [t for t in tokenize(text) if t not in STOP]
    toks = [QUERY_SYNONYMS.get(t, t) for t in toks]
    toks = toks[:24]  # compress for cost
    return f"{title}. " + " ".join(toks)

items["rewritten_text"] = [rewrite_item_text(t, x) for t, x in zip(items["title"], items["text"])]
items[["title","text","rewritten_text"]].head(3)


Unnamed: 0,title,text,rewritten_text
0,Orion dumbbells lite,Orion dumbbells lite,Orion dumbbells lite. orion dumbbells lite
1,Orion yoga mat lite,Orion yoga mat lite for gaming,Orion yoga mat lite. orion yoga mat lite gaming
2,Norda air purifier pro,Norda air purifier pro,Norda air purifier pro. norda air purifier pro


## 5) Reranking with structured outputs + strict parsing

Pattern:
1. ranker produces candidates (K=100..500)
2. LLM reranks a smaller set (k=20..80)
3. combine scores (hybrid)
4. strict parser: if failure → fallback


In [6]:
def strict_parse_json(s):
    obj = json.loads(s)
    if not isinstance(obj, dict):
        raise ValueError("not a dict")
    if "relevance" not in obj:
        raise ValueError("missing relevance")
    rel = float(obj["relevance"])
    if not (0.0 <= rel <= 1.0):
        raise ValueError("relevance out of range")
    if "policy" in obj and isinstance(obj["policy"], dict):
        if "allowed" in obj["policy"] and not isinstance(obj["policy"]["allowed"], bool):
            raise ValueError("policy.allowed not bool")
    return obj

def llm_relevance(query, item_text, latency_budget_ms=80, use_cache=True):
    key = (query, item_text)
    if use_cache:
        cached = cache.get(key)
        if cached is not None:
            return cached, 0.0, True

    start = time.perf_counter()
    prompt = (
        "You are a relevance scorer for a recommender reranker. "
        "Return JSON only with keys: relevance (0..1), rationale (string), policy.allowed (bool).\n"
        f"Query: {query}\n"
        f"Item: {item_text}\n"
    )
    raw = mock_llm_call(prompt, temperature=0.0)
    latency_ms = (time.perf_counter() - start) * 1000.0

    if latency_ms > latency_budget_ms:
        raise TimeoutError(f"LLM budget exceeded: {latency_ms:.1f}ms > {latency_budget_ms}ms")

    obj = strict_parse_json(raw)
    score = float(obj["relevance"])
    if use_cache:
        cache.put(key, score)
    return score, latency_ms, False

def rerank_with_llm(query, cand_ids, base_scores, top_m=50, alpha=0.65, budget_ms=120):
    t0 = time.perf_counter()

    take = cand_ids[:top_m]
    llm_scores = np.zeros(len(take), dtype=np.float32)
    latencies = []
    cached_flags = []

    for i, item_id in enumerate(take):
        elapsed = (time.perf_counter() - t0) * 1000.0
        if elapsed > budget_ms:
            raise TimeoutError(f"Request budget exceeded during scoring: {elapsed:.1f}ms")
        s, lat, cached = llm_relevance(query, items.loc[item_id, "rewritten_text"], latency_budget_ms=80, use_cache=True)
        llm_scores[i] = s
        latencies.append(lat)
        cached_flags.append(cached)

    bs = base_scores[:top_m]
    bs_norm = (bs - bs.min()) / (bs.max() - bs.min() + 1e-9)
    hybrid = alpha * bs_norm + (1 - alpha) * llm_scores

    order = np.argsort(hybrid)[::-1]
    ranked = take[order]
    diagnostics = {
        "request_ms": (time.perf_counter() - t0) * 1000.0,
        "llm_calls": int(len(take)),
        "avg_llm_ms": float(np.mean([x for x in latencies if x > 0]) if latencies else 0.0),
        "cache_hit_rate_in_request": float(np.mean(cached_flags)) if cached_flags else 0.0,
    }
    return ranked, hybrid[order], diagnostics


## 6) Fallback logic


In [7]:
def serve_request(raw_query, user_vec, k_final=10):
    cand_ids, cand_scores = retrieve_topk(user_vec, k=140)
    base_top = cand_ids[:k_final]
    q = rewrite_query(raw_query)

    try:
        ranked, hybrid_scores, diag = rerank_with_llm(q, cand_ids, cand_scores, top_m=60, alpha=0.7, budget_ms=140)
        final = ranked[:k_final]
        mode = "hybrid_rerank"
    except (TimeoutError, json.JSONDecodeError, ValueError) as e:
        final = base_top
        diag = {"error": str(e), "request_ms": None}
        mode = "fallback_baseline"

    return {"mode": mode, "query_raw": raw_query, "query_used": q, "final_item_ids": final, "diagnostics": diag}

resp = serve_request("cordless headset for zoom calls", sample_user_intent(), k_final=10)
resp["mode"], resp["diagnostics"], items.loc[resp["final_item_ids"], ["title"]].head()


('fallback_baseline',
 {'error': 'Expecting value: line 1 column 1 (char 0)', 'request_ms': None},
                            title
 1895            Kite webcam plus
 2176  Norda ergonomic chair plus
 2460        Delta usb c hub mini
 1130           Aster webcam mini
 198        Mori air purifier pro)

## 7) Offline evaluation: does reranking help?


In [8]:
INTENTS = [
    "wireless mouse", "noise cancelling headphones", "standing desk", "4k monitor",
    "usb c hub", "webcam", "ergonomic chair", "gaming keyboard"
]

def sample_request():
    intent = rng.choice(INTENTS)
    noisy = intent.replace("wireless", rng.choice(["wireless","cordless","bt"]))
    noisy = noisy.replace("headphones", rng.choice(["headphones","cans","headset"]))
    if rng.random() < 0.4:
        noisy += " for work"
    user_vec = sample_user_intent()
    return noisy, intent, user_vec

def click_probability(intent, item_text, baseline_score_):
    itoks = set(tokenize(intent))
    jtoks = set(tokenize(item_text))
    overlap = len(itoks & jtoks) / max(1, len(itoks))
    z = 1.7 * overlap + 0.25 * baseline_score_
    return float(sigmoid(3.0 * (z - 0.6)))

def evaluate(n=500):
    rows = []
    for _ in range(n):
        q_raw, true_intent, uvec = sample_request()
        cand_ids, cand_scores = retrieve_topk(uvec, k=140)

        base10 = cand_ids[:10]
        base_rel = []
        for item_id in base10:
            p = click_probability(true_intent, items.loc[item_id, "rewritten_text"], float(cand_scores[np.where(cand_ids==item_id)[0][0]]))
            base_rel.append(1.0 if p > 0.5 else 0.0)

        resp = serve_request(q_raw, uvec, k_final=10)
        hyb10 = resp["final_item_ids"]
        hyb_rel = []
        for item_id in hyb10:
            p = click_probability(true_intent, items.loc[item_id, "rewritten_text"], float((V[item_id] @ uvec)))
            hyb_rel.append(1.0 if p > 0.5 else 0.0)

        rows.append({
            "baseline_ndcg@10": ndcg_at_k(base_rel, k=10),
            "hybrid_ndcg@10": ndcg_at_k(hyb_rel, k=10),
            "mode": resp["mode"],
        })
    return pd.DataFrame(rows)

df_eval = evaluate(n=700)
df_eval.describe()


Unnamed: 0,baseline_ndcg@10,hybrid_ndcg@10
count,700.0,700.0
mean,0.265505,0.265505
std,0.282183,0.282183
min,0.0,0.0
25%,0.0,0.0
50%,0.30103,0.30103
75%,0.445221,0.445221
max,1.0,1.0


In [9]:
impr = (df_eval["hybrid_ndcg@10"] - df_eval["baseline_ndcg@10"])
pd.Series({
    "avg_baseline_ndcg@10": df_eval["baseline_ndcg@10"].mean(),
    "avg_hybrid_ndcg@10": df_eval["hybrid_ndcg@10"].mean(),
    "avg_delta": impr.mean(),
    "p50_delta": float(impr.median()),
    "p90_delta": float(impr.quantile(0.9)),
    "fallback_rate": float((df_eval["mode"]=="fallback_baseline").mean()),
    "cache_stats": cache.stats(),
})


Unnamed: 0,0
avg_baseline_ndcg@10,0.265505
avg_hybrid_ndcg@10,0.265505
avg_delta,0.0
p50_delta,0.0
p90_delta,0.0
fallback_rate,1.0
cache_stats,"{'size': 4096, 'hits': 933, 'misses': 4850, 'h..."


## 8) Real-world constraints to carry over

- Only rerank small sets (top 20–100). Anything bigger is a cost/latency trap.
- Strict schema. No free-form interfaces across service boundaries.
- Budgets at two layers: request-level and LLM-call-level.
- Deterministic fallback is mandatory to baseline ranking.
- Instrument: p95/p99 latency, cache hit rate, parse failure rate, timeout rate, rerank applied rate.
- Roll out behind a flag. Canary + rollback must be cheap.
