# HALT: Hallucination Assessment via Log-probs as Time Series

**Paper**: Shapiro, Taneja & Goel (2026) — arXiv:2602.02888

## Overview

HALT is a lightweight hallucination detector that treats the top-k token log-probabilities produced by an LLM as a **multivariate time series** and classifies each response as hallucinated or faithful using a compact **Bidirectional GRU**.

### Key design decisions
| Component | Detail |
|---|---|
| Input features | Top-20 log-probs + 5 engineered uncertainty statistics = **25-dim** vector per token |
| Input projection | LayerNorm → 2-layer MLP → 128-dim |
| Encoder | BiGRU, hidden=256, layers=5, dropout=0.4 → output 512-dim per step |
| Pooling | **Top-q** (q=0.15): average the 15% of timesteps with largest ℓ₂ norm |
| Classifier | Linear(512 → 1), BCEWithLogitsLoss |
| Optimizer | Adam, lr=4.41e-4, weight_decay=2.34e-6 |

### Tensor shape table
```
raw_logprobs  : (B, T, 20)
eng_features  : (B, T, 5)   [entropy_overall, entropy_alts, avg_logprob, rank_proxy, dec_entropy_delta]
x             : (B, T, 25)  concatenated input
x_proj        : (B, T, 128) after MLP projection
gru_out       : (B, T, 512) bidirectional hidden states
pooled        : (B, 512)    top-q pooling
logit         : (B, 1)      classifier output
```

In [None]:
# ── Environment / installs ────────────────────────────────────────────────────
# Uncomment if running in a fresh environment
# !pip install torch torchvision --quiet

In [None]:
# ── Imports ───────────────────────────────────────────────────────────────────
import math
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import Dataset, DataLoader

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# ── Config ────────────────────────────────────────────────────────────────────
@dataclass
class HALTConfig:
    # Feature dimensions
    top_k: int = 20                  # number of top log-probs per token
    n_eng_features: int = 5          # engineered uncertainty features
    input_dim: int = 25              # top_k + n_eng_features

    # Projection MLP
    proj_dim: int = 128

    # BiGRU encoder
    hidden_dim: int = 256            # per-direction hidden size
    n_gru_layers: int = 5
    gru_dropout: float = 0.4
    bidirectional: bool = True

    # Top-q pooling
    top_q: float = 0.15              # fraction of timesteps to pool

    # Classifier
    out_norm: bool = False           # LayerNorm before linear (disabled in best setting)

    # Training
    batch_size: int = 512
    lr: float = 4.41e-4
    weight_decay: float = 2.34e-6
    max_epochs: int = 100
    early_stop_patience: int = 15
    lr_scheduler_patience: int = 3
    lr_scheduler_factor: float = 0.5
    grad_clip_max_norm: float = 1.0

    @property
    def gru_output_dim(self) -> int:
        """BiGRU concatenates forward + backward: 256 * 2 = 512."""
        return self.hidden_dim * (2 if self.bidirectional else 1)


cfg = HALTConfig()
print(cfg)

## Feature Engineering

Given the top-k log-probability matrix `logprobs` of shape `(T, k)` for a single response, we derive **5 scalar features per timestep**.

| # | Name | Equation |
|---|---|---|
| 1 | `avg_logprob` | mean of top-k log-probs |
| 2 | `rank_proxy` | rank of selected token within top-k window |
| 3 | `entropy_overall` | H(p̃) over renormalised top-k |
| 4 | `entropy_alts` | H(p̃_alts) over renormalised alternatives |
| 5 | `dec_entropy_delta` | Δ of binary decision entropy between selected and best alternative |

In [None]:
# ── Feature Extraction ────────────────────────────────────────────────────────

def safe_entropy(probs: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """Shannon entropy of a probability vector. Shape: (...,) → scalar along last dim."""
    return -(probs * (probs + eps).log()).sum(dim=-1)


def renorm_top_k(logprobs: torch.Tensor) -> torch.Tensor:
    """
    Numerically stable softmax over top-k log-probs (Eq. 4).
    logprobs: (T, k)  →  p_tilde: (T, k)
    """
    m = logprobs.max(dim=-1, keepdim=True).values          # (T, 1)
    exp_l = (logprobs - m).exp()                            # (T, k)
    return exp_l / exp_l.sum(dim=-1, keepdim=True)          # (T, k)


def extract_features(logprobs: torch.Tensor) -> torch.Tensor:
    """
    Extract 5 engineered uncertainty features for each token.

    Args:
        logprobs: (T, k)  — top-k log-probs; column 0 is the selected token.

    Returns:
        features: (T, 5)  — [avg_logprob, rank_proxy, entropy_overall,
                                  entropy_alts, dec_entropy_delta]
    """
    T, k = logprobs.shape

    # ── 1. Average log-probability (Eq. 5) ───────────────────────────────────
    avg_logprob = logprobs.mean(dim=-1, keepdim=True)              # (T, 1)

    # ── 2. Rank proxy (Eq. 6) ────────────────────────────────────────────────
    selected_lp = logprobs[:, 0:1]                                  # (T, 1)
    alts_lp     = logprobs[:, 1:]                                   # (T, k-1)
    rank_proxy  = 1.0 + (alts_lp > selected_lp).float().sum(dim=-1, keepdim=True)  # (T, 1)

    # ── Renormalised distribution (Eq. 4) ───────────────────────────────────
    p_tilde = renorm_top_k(logprobs)                                # (T, k)

    # ── 3. Overall entropy (Eq. 7) ──────────────────────────────────────────
    entropy_overall = safe_entropy(p_tilde).unsqueeze(-1)           # (T, 1)

    # ── 4. Alternatives-only entropy (Eq. 8-9) ──────────────────────────────
    p_alts       = p_tilde[:, 1:]                                   # (T, k-1)
    p_alts_norm  = p_alts / (p_alts.sum(dim=-1, keepdim=True) + 1e-9)  # (T, k-1)
    entropy_alts = safe_entropy(p_alts_norm).unsqueeze(-1)          # (T, 1)

    # ── 5. Decision entropy delta (Eq. 10-13) ───────────────────────────────
    best_alt_lp  = alts_lp.max(dim=-1).values                      # (T,)
    # Binary probability of selected vs best alternative
    denom        = (selected_lp.squeeze() + best_alt_lp).exp()      # avoid log-sum-exp issues
    pc           = selected_lp.squeeze().exp() / (
                       selected_lp.squeeze().exp() + best_alt_lp.exp() + 1e-9
                   )                                                  # (T,)
    h_dec        = -(pc * (pc + 1e-9).log()
                     + (1 - pc) * (1 - pc + 1e-9).log())            # (T,)
    # Temporal delta (Eq. 13); pad t=0 with 0
    delta_h_dec  = torch.cat([
        torch.zeros(1, device=logprobs.device),
        h_dec[1:] - h_dec[:-1]
    ]).unsqueeze(-1)                                                  # (T, 1)

    # ── Concatenate all 5 features ───────────────────────────────────────────
    features = torch.cat([
        avg_logprob,       # (T, 1)
        rank_proxy,        # (T, 1)
        entropy_overall,   # (T, 1)
        entropy_alts,      # (T, 1)
        delta_h_dec        # (T, 1)
    ], dim=-1)             # (T, 5)

    return features


def build_input_sequence(logprobs: torch.Tensor) -> torch.Tensor:
    """
    Build the full enriched feature sequence l̃_{1:T} (Section 3.2).

    Args:
        logprobs: (T, k)  — raw top-k log-probabilities

    Returns:
        x_tilde: (T, k+5) — [engineered_features ∥ raw_logprobs]
    """
    eng = extract_features(logprobs)                # (T, 5)
    return torch.cat([eng, logprobs], dim=-1)        # (T, 25)


# ── Quick shape test ─────────────────────────────────────────────────────────
dummy_logprobs = -torch.rand(37, 20) * 10           # T=37 tokens, k=20
dummy_x        = build_input_sequence(dummy_logprobs)
print(f"logprobs shape : {dummy_logprobs.shape}")   # (37, 20)
print(f"x_tilde shape  : {dummy_x.shape}")          # (37, 25)

## Dataset

The `HALTDataset` accepts a list of `(logprobs_tensor, label)` pairs where:
- `logprobs_tensor` : `(T_i, 20)` — variable-length top-20 log-prob sequence for response *i*
- `label` : `int` — 1 = hallucinated, 0 = faithful

The collate function pads sequences to the batch-maximum length and returns a boolean mask.

In [None]:
# ── Dataset & DataLoader ──────────────────────────────────────────────────────

class HALTDataset(Dataset):
    """
    Each sample is a tuple:
        logprobs : Tensor (T_i, top_k)  — raw log-probabilities
        label    : int                  — 1 hallucinated, 0 faithful
    """
    def __init__(self, samples: List[Tuple[torch.Tensor, int]]):
        self.samples = samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        logprobs, label = self.samples[idx]         # (T_i, k), int
        x_tilde = build_input_sequence(logprobs)    # (T_i, 25)
        length  = x_tilde.shape[0]
        return x_tilde, torch.tensor(label, dtype=torch.float32), length


def halt_collate_fn(
    batch: List[Tuple[torch.Tensor, torch.Tensor, int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Pads variable-length sequences to batch-max length.

    Returns:
        x_padded : (B, T_max, 25)  padded feature sequences
        labels   : (B,)            binary labels
        lengths  : (B,)            original sequence lengths (for packing)
        mask     : (B, T_max)      True at valid (non-padded) positions
    """
    x_list, labels, lengths = zip(*batch)
    lengths  = torch.tensor(lengths, dtype=torch.long)
    labels   = torch.stack(labels)                             # (B,)
    x_padded = pad_sequence(x_list, batch_first=True)          # (B, T_max, 25)

    B, T_max, _ = x_padded.shape
    mask = torch.arange(T_max).unsqueeze(0) < lengths.unsqueeze(1)  # (B, T_max)

    return x_padded, labels, lengths, mask


def make_dataloader(samples, batch_size: int, shuffle: bool) -> DataLoader:
    ds = HALTDataset(samples)
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=halt_collate_fn,
        drop_last=False,
    )


# ── Synthetic data for demonstration ─────────────────────────────────────────
def make_synthetic_dataset(n: int = 1000, top_k: int = 20,
                            min_len: int = 10, max_len: int = 150) -> List[Tuple[torch.Tensor, int]]:
    """
    Creates synthetic log-prob sequences with a simple hallucination signal:
    hallucinated responses have injected entropy spikes at random positions.
    """
    samples = []
    for _ in range(n):
        T     = random.randint(min_len, max_len)
        label = random.randint(0, 1)
        # Base: top token gets high prob, alternatives decay
        lp = torch.zeros(T, top_k)
        for i in range(top_k):
            lp[:, i] = -float(i) * (0.5 + torch.rand(T) * 0.5)
        if label == 1:
            # Inject flatter distributions at random spike positions
            n_spikes = random.randint(1, max(1, T // 10))
            spike_pos = random.sample(range(T), min(n_spikes, T))
            for pos in spike_pos:
                lp[pos] = -torch.rand(top_k) * 2.0     # flatter = more uniform
        samples.append((lp, label))
    return samples


train_samples = make_synthetic_dataset(n=800)
val_samples   = make_synthetic_dataset(n=100)
test_samples  = make_synthetic_dataset(n=100)

train_loader = make_dataloader(train_samples, cfg.batch_size, shuffle=True)
val_loader   = make_dataloader(val_samples,   cfg.batch_size, shuffle=False)
test_loader  = make_dataloader(test_samples,  cfg.batch_size, shuffle=False)

# Shape sanity check
xb, lb, lensb, maskb = next(iter(train_loader))
print(f"x_padded : {xb.shape}")
print(f"labels   : {lb.shape}")
print(f"lengths  : {lensb.shape}")
print(f"mask     : {maskb.shape}")

## HALT Model

Architecture (Figure 1 + Appendix B):
```
x̃_{1:T}  (B, T, 25)
    │
    ▼  LayerNorm + 2-layer MLP  (GELU)
x_proj   (B, T, 128)
    │
    ▼  pack_padded_sequence
    ▼  BiGRU  (hidden=256, layers=5, dropout=0.4)
    ▼  pad_packed_sequence
H        (B, T, 512)
    │
    ▼  Top-q pooling  (q=0.15, score=||h_t||₂)
pooled   (B, 512)
    │
    ▼  [optional LayerNorm]
    ▼  Linear(512 → 1)
logit    (B, 1)
```

In [None]:
# ── Input Projection (LayerNorm + MLP) ───────────────────────────────────────

class InputProjection(nn.Module):
    """
    Projects enriched feature vectors from input_dim → proj_dim.
    Applied independently to each timestep.

    Input  : (B, T, input_dim=25)
    Output : (B, T, proj_dim=128)
    """
    def __init__(self, input_dim: int, proj_dim: int):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.mlp  = nn.Sequential(
            nn.Linear(input_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, input_dim)
        x = self.norm(x)        # (B, T, input_dim)
        x = self.mlp(x)         # (B, T, proj_dim)
        return x


# ── Top-q Pooling ─────────────────────────────────────────────────────────────

class TopQPooling(nn.Module):
    """
    Averages the top-q fraction of timesteps scored by their ℓ₂ norm.
    Padded positions are excluded via a boolean mask.

    Input  : H    (B, T, D),  mask (B, T)  [True = valid]
    Output : pooled (B, D)
    """
    def __init__(self, q: float = 0.15):
        super().__init__()
        self.q = q

    def forward(self, H: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # H    : (B, T, D)
        # mask : (B, T)  — True at valid positions
        B, T, D = H.shape

        scores = H.norm(dim=-1)                         # (B, T)  ℓ₂ norm
        # Zero out padding
        scores = scores.masked_fill(~mask, float("-inf"))  # (B, T)

        # Compute K = ceil(q * valid_length) per sequence
        valid_lengths = mask.sum(dim=1).float()          # (B,)
        K_per_seq     = (self.q * valid_lengths).ceil().clamp(min=1).long()  # (B,)
        K             = int(K_per_seq.max().item())       # scalar, batch max K

        # Take global top-K indices (simpler; matches paper's description)
        _, top_idx = scores.topk(K, dim=1)               # (B, K)

        # Gather hidden states at top-K timesteps
        idx_exp    = top_idx.unsqueeze(-1).expand(-1, -1, D)  # (B, K, D)
        top_states = H.gather(1, idx_exp)                # (B, K, D)

        # Average
        pooled = top_states.mean(dim=1)                  # (B, D)
        return pooled


# ── Full HALT Model ───────────────────────────────────────────────────────────

class HALT(nn.Module):
    """
    HALT: Hallucination Assessment via Log-probs as Time series.

    Forward pass:
        x       : (B, T, 25)  padded enriched feature sequences
        lengths : (B,)        actual sequence lengths
        mask    : (B, T)      True at valid positions

    Returns:
        logit   : (B,)        unnormalised hallucination score
    """
    def __init__(self, config: HALTConfig):
        super().__init__()
        self.config = config

        # Input projection
        self.projection = InputProjection(
            input_dim=config.input_dim,
            proj_dim=config.proj_dim,
        )

        # Bidirectional GRU
        self.gru = nn.GRU(
            input_size=config.proj_dim,
            hidden_size=config.hidden_dim,
            num_layers=config.n_gru_layers,
            batch_first=True,
            bidirectional=config.bidirectional,
            dropout=config.gru_dropout if config.n_gru_layers > 1 else 0.0,
        )

        # Top-q pooling
        self.pool = TopQPooling(q=config.top_q)

        # Optional output LayerNorm
        self.out_norm = (
            nn.LayerNorm(config.gru_output_dim) if config.out_norm else nn.Identity()
        )

        # Classification head
        self.classifier = nn.Linear(config.gru_output_dim, 1)

    def forward(
        self,
        x: torch.Tensor,           # (B, T, 25)
        lengths: torch.Tensor,     # (B,)
        mask: torch.Tensor         # (B, T)  bool
    ) -> torch.Tensor:             # (B,)

        # ── 1. Project input ─────────────────────────────────────────────────
        x_proj = self.projection(x)                    # (B, T, 128)

        # ── 2. Pack → BiGRU → Unpack ─────────────────────────────────────────
        packed      = pack_padded_sequence(
            x_proj, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        gru_packed, _ = self.gru(packed)
        gru_out, _    = pad_packed_sequence(gru_packed, batch_first=True)  # (B, T, 512)

        # Pad to original T if pad_packed_sequence truncates
        T_padded = x.shape[1]
        if gru_out.shape[1] < T_padded:
            pad_size = T_padded - gru_out.shape[1]
            gru_out  = F.pad(gru_out, (0, 0, 0, pad_size))   # (B, T, 512)

        # ── 3. Top-q pooling ─────────────────────────────────────────────────
        pooled = self.pool(gru_out, mask)               # (B, 512)

        # ── 4. Classifier ────────────────────────────────────────────────────
        pooled = self.out_norm(pooled)                  # (B, 512)
        logit  = self.classifier(pooled).squeeze(-1)   # (B,)
        return logit


# ── Instantiate and verify ────────────────────────────────────────────────────
model = HALT(cfg).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params:,}  ({n_params/1e6:.2f}M)")
print(model)

In [None]:
# ── Forward pass shape verification ──────────────────────────────────────────
with torch.no_grad():
    xb_d    = xb.to(device)
    lensb_d = lensb.to(device)
    maskb_d = maskb.to(device)
    logit_test = model(xb_d, lensb_d, maskb_d)
    print(f"Input  x     : {xb_d.shape}")
    print(f"Output logit : {logit_test.shape}")
    assert logit_test.shape == (xb_d.shape[0],), "Shape mismatch!"
    print("✓ Forward pass OK")

## Training Utilities

Loss, optimiser, LR scheduler, and evaluation metrics.

In [None]:
# ── Loss & Optimiser ──────────────────────────────────────────────────────────

criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=cfg.lr,
    weight_decay=cfg.weight_decay,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",          # maximise macro-F1
    factor=cfg.lr_scheduler_factor,
    patience=cfg.lr_scheduler_patience,
)

print("Criterion  :", criterion)
print("Optimiser  :", optimizer)
print("Scheduler  :", scheduler)

In [None]:
# ── Metrics ───────────────────────────────────────────────────────────────────

def macro_f1(
    logits: torch.Tensor,
    labels: torch.Tensor,
    threshold: float = 0.5
) -> float:
    """
    Macro-averaged F1 across two classes (halluicnated / faithful).
    Primary metric in the paper.
    """
    preds = (logits.sigmoid() >= threshold).long()
    y     = labels.long()

    f1_scores = []
    for cls in [0, 1]:
        tp = ((preds == cls) & (y == cls)).sum().float()
        fp = ((preds == cls) & (y != cls)).sum().float()
        fn = ((preds != cls) & (y == cls)).sum().float()
        prec   = tp / (tp + fp + 1e-9)
        recall = tp / (tp + fn + 1e-9)
        f1     = 2 * prec * recall / (prec + recall + 1e-9)
        f1_scores.append(f1.item())

    return float(np.mean(f1_scores))


def accuracy(logits: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5) -> float:
    preds = (logits.sigmoid() >= threshold).long()
    return (preds == labels.long()).float().mean().item()

In [None]:
# ── Train Step ────────────────────────────────────────────────────────────────

def train_step(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    grad_clip: float,
    dev: torch.device,
) -> Tuple[float, float]:
    """
    One epoch of training.

    Returns:
        avg_loss : float
        macro_f1 : float
    """
    model.train()
    total_loss = 0.0
    all_logits, all_labels = [], []

    for x, labels, lengths, mask in loader:
        x, labels = x.to(dev), labels.to(dev)
        lengths, mask = lengths.to(dev), mask.to(dev)

        optimizer.zero_grad()
        logits = model(x, lengths, mask)               # (B,)
        loss   = criterion(logits, labels)             # scalar
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        all_logits.append(logits.detach().cpu())
        all_labels.append(labels.cpu())

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    avg_loss   = total_loss / len(all_labels)
    mf1        = macro_f1(all_logits, all_labels)
    return avg_loss, mf1


# ── Eval Step ─────────────────────────────────────────────────────────────────

@torch.no_grad()
def eval_step(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    dev: torch.device,
) -> Tuple[float, float, float]:
    """
    Evaluation on a given DataLoader.

    Returns:
        avg_loss : float
        macro_f1 : float
        acc      : float
    """
    model.eval()
    total_loss = 0.0
    all_logits, all_labels = [], []

    for x, labels, lengths, mask in loader:
        x, labels = x.to(dev), labels.to(dev)
        lengths, mask = lengths.to(dev), mask.to(dev)

        logits = model(x, lengths, mask)               # (B,)
        loss   = criterion(logits, labels)

        total_loss += loss.item() * labels.size(0)
        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    avg_loss   = total_loss / len(all_labels)
    mf1        = macro_f1(all_logits, all_labels)
    acc        = accuracy(all_logits, all_labels)
    return avg_loss, mf1, acc

## Training Loop

With early stopping (patience = 15) and ReduceLROnPlateau (patience = 3), both monitoring **macro-F1** on the validation set.

In [None]:
# ── Main Training Loop ────────────────────────────────────────────────────────

def train(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    criterion: nn.Module,
    config: HALTConfig,
    dev: torch.device,
) -> dict:
    """
    Full training run with early stopping.

    Returns:
        history : dict with train/val loss and F1 per epoch
    """
    best_val_f1     = -1.0
    epochs_no_imprv = 0
    best_state      = None

    history = {
        "train_loss": [], "train_f1": [],
        "val_loss": [],   "val_f1": [],
    }

    for epoch in range(1, config.max_epochs + 1):
        tr_loss, tr_f1 = train_step(
            model, train_loader, optimizer, criterion, config.grad_clip_max_norm, dev
        )
        va_loss, va_f1, va_acc = eval_step(model, val_loader, criterion, dev)

        scheduler.step(va_f1)

        history["train_loss"].append(tr_loss)
        history["train_f1"].append(tr_f1)
        history["val_loss"].append(va_loss)
        history["val_f1"].append(va_f1)

        print(
            f"Epoch {epoch:3d}/{config.max_epochs} | "
            f"Train loss={tr_loss:.4f} F1={tr_f1:.4f} | "
            f"Val loss={va_loss:.4f} F1={va_f1:.4f} Acc={va_acc:.4f}"
        )

        # Early stopping
        if va_f1 > best_val_f1:
            best_val_f1     = va_f1
            epochs_no_imprv = 0
            best_state      = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            epochs_no_imprv += 1
            if epochs_no_imprv >= config.early_stop_patience:
                print(f"Early stopping at epoch {epoch} (best val macro-F1={best_val_f1:.4f})")
                break

    # Restore best checkpoint
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"\nRestored best model  (val macro-F1={best_val_f1:.4f})")

    return history

In [None]:
# ── Run Training ──────────────────────────────────────────────────────────────
history = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    config=cfg,
    dev=device,
)

In [None]:
# ── Test Evaluation ───────────────────────────────────────────────────────────
test_loss, test_f1, test_acc = eval_step(model, test_loader, criterion, device)
print(f"\n=== Test Results ===")
print(f"  Loss      : {test_loss:.4f}")
print(f"  Macro-F1  : {test_f1:.4f}")
print(f"  Accuracy  : {test_acc:.4f}")

In [None]:
# ── Training Curves ───────────────────────────────────────────────────────────
try:
    import matplotlib.pyplot as plt

    epochs = range(1, len(history["train_loss"]) + 1)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(epochs, history["train_loss"], label="Train")
    axes[0].plot(epochs, history["val_loss"],   label="Val")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("BCE Loss")
    axes[0].set_title("Loss")
    axes[0].legend()

    axes[1].plot(epochs, history["train_f1"], label="Train")
    axes[1].plot(epochs, history["val_f1"],   label="Val")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Macro-F1")
    axes[1].set_title("Macro-F1")
    axes[1].legend()

    plt.suptitle("HALT Training Curves", fontsize=14)
    plt.tight_layout()
    plt.savefig("halt_training_curves.png", dpi=120)
    plt.show()
    print("Saved: halt_training_curves.png")
except ImportError:
    print("matplotlib not installed — skipping plots.")

In [None]:
# ── Inference on a single response ───────────────────────────────────────────

@torch.no_grad()
def predict_hallucination(
    logprobs: torch.Tensor,
    model: nn.Module,
    dev: torch.device,
    threshold: float = 0.5,
) -> dict:
    """
    Predict hallucination probability for a single response.

    Args:
        logprobs : (T, 20)  — raw top-20 log-probs from the LLM
        model    : trained HALT model
        dev      : device
        threshold: decision threshold

    Returns:
        dict with probability and binary prediction
    """
    model.eval()

    x_tilde  = build_input_sequence(logprobs)           # (T, 25)
    x_batch  = x_tilde.unsqueeze(0).to(dev)             # (1, T, 25)
    lengths  = torch.tensor([x_tilde.shape[0]]).to(dev) # (1,)
    mask     = torch.ones(1, x_tilde.shape[0], dtype=torch.bool, device=dev)  # (1, T)

    logit = model(x_batch, lengths, mask)               # (1,)
    prob  = logit.sigmoid().item()

    return {
        "hallucination_probability": prob,
        "is_hallucinated": prob >= threshold,
        "threshold": threshold,
    }


# Demo: predict on a single synthetic sample
sample_logprobs, sample_label = test_samples[0]
result = predict_hallucination(sample_logprobs, model, device)

print(f"True label              : {'hallucinated' if sample_label == 1 else 'faithful'}")
print(f"Hallucination probability: {result['hallucination_probability']:.4f}")
print(f"Predicted               : {'hallucinated' if result['is_hallucinated'] else 'faithful'}")

In [None]:
# ── Model Checkpoint Save/Load ────────────────────────────────────────────────

checkpoint_path = "halt_checkpoint.pt"

torch.save({
    "model_state_dict"    : model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "config"              : cfg,
    "history"             : history,
    "test_macro_f1"       : test_f1,
}, checkpoint_path)

print(f"Checkpoint saved to: {checkpoint_path}")

# Reload demo
ckpt        = torch.load(checkpoint_path, map_location=device)
model_reloaded = HALT(ckpt["config"]).to(device)
model_reloaded.load_state_dict(ckpt["model_state_dict"])
print("Checkpoint loaded successfully.")

## Summary

| Component | Implementation |
|---|---|
| **Feature extractor** | `extract_features()` — 5 uncertainty signals from top-k log-probs |
| **Input builder** | `build_input_sequence()` — (T, 25) per response |
| **Dataset** | `HALTDataset` + `halt_collate_fn` — handles variable lengths |
| **Projection** | `InputProjection` — LayerNorm + 2-layer GELU MLP (25→128) |
| **Encoder** | BiGRU (hidden=256, layers=5, dropout=0.4) → (B,T,512) |
| **Pooling** | `TopQPooling` (q=0.15, ℓ₂-score) → (B,512) |
| **Classifier** | `Linear(512→1)` + BCEWithLogitsLoss |
| **Optimiser** | Adam, lr=4.41e-4, wd=2.34e-6 |
| **Scheduler** | ReduceLROnPlateau (factor=0.5, patience=3, mode=max) |
| **Early stopping** | patience=15 on val macro-F1 |
| **Metric** | Macro-F1 (primary), accuracy (secondary) |

To use with a **real LLM API**, replace `make_synthetic_dataset` with a function that:
1. Sends prompts to the LLM (e.g., via OpenAI or vLLM)
2. Extracts `logprobs` (shape `(T, 20)`) from the API response metadata
3. Pairs each response with a hallucination label from your annotation pipeline

HALT then trains a **model-specific** detector for that LLM's calibration bias.