# Setup

In [None]:
# Standard library imports
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

# Third-party library imports
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.preprocessing import minmax_scale

# PyTorch imports
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset, random_split

The code imports standard library modules: `math` for mathematical operations, `random` for generating random numbers, `dataclass` for defining data classes, and typing modules (`Dict`, `List`, `Optional`, `Tuple`) for type hints.

It imports third-party libraries: `matplotlib.pyplot` for plotting graphs, `numpy` for numerical computations, `sklearn.manifold.TSNE` for dimensionality reduction using t-SNE, `sklearn.metrics.f1_score` and `roc_auc_score` for evaluating classification models, and `sklearn.preprocessing.minmax_scale` for scaling features to a range.

It imports PyTorch components: `torch` as the main package, `torch.nn` for neural network layers and functions, utilities for handling variable-length sequences (`pack_padded_sequence`, `pad_packed_sequence`, `pad_sequence`) from `torch.nn.utils.rnn`, and data loading tools (`DataLoader`, `Dataset`, `random_split`) from `torch.utils.data`.

TODO

In [None]:
# general settings
class CFG:
    img_dim1 = 20
    img_dim2 = 10
    SEED = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

# display style 
plt.style.use("seaborn-v0_8")
plt.rcParams["figure.figsize"] = (CFG.img_dim1, CFG.img_dim2)

print(f"Using device: {CFG.device}")
print(f"PyTorch Version: {torch.__version__}")

# fix randomness (insofar as possible ;-)
SEED = 42
random.seed(CFG.SEED)
np.random.seed(CFG.SEED)
torch.manual_seed(CFG.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CFG.SEED)


The code defines configuration constants and sets up reproducible random number generation for a machine learning project.

A class named `CFG` is created with four attributes:  
- `img_dim1` set to 20, representing the first dimension of images (likely width),  
- `img_dim2` set to 10, representing the second dimension (likely height),  
- `SEED` set to 42 for reproducibility,  
- `device` assigned `"cuda"` if CUDA is available (indicating GPU support), otherwise `"cpu"`.  

Matplotlib's display settings are then configured: the style is set to `seaborn-v0_8`, and the figure size is updated using values from `CFG`.  

The script prints which computing device it will use and the installed PyTorch version.  

Random number generation is seeded consistently across Python's built-in `random` module, NumPy, and PyTorch (both CPU and GPU when CUDA is available), using the value 42 to ensure reproducible results.

# Utils

## Config

In [None]:
@dataclass
class HALTConfig:
    # Feature dimensions (Sec 3.2 + Fig.1)
    top_k: int = 20
    n_stat_features: int = 5  # [avg_logp, rank_proxy, h_overall, h_alts, delta_h_dec]
    input_dim: int = 25       # top_k + n_stat_features

    # Projection MLP (Sec 3.2, Appendix B)
    proj_dim: int = 128

    # Bidirectional GRU (Appendix B)
    hidden_dim: int = 256
    num_layers: int = 5
    dropout: float = 0.4

    # Top-q salient pooling (Eq Sec 3.2)
    top_q: float = 0.15

    # Training (Appendix B)
    batch_size: int = 512
    max_epochs: int = 5 # 100
    lr: float = 4.41e-4
    weight_decay: float = 2.34e-6
    lr_patience: int = 3
    lr_factor: float = 0.5
    early_stop_patience: int = 15
    max_grad_norm: float = 1.0

The `HALTConfig` class defines hyperparameters and architectural settings for a HALT model.  

Feature dimensions are configured as follows:  
- `top_k` is set to 20, indicating the number of top items considered for features.  
- `n_stat_features` is set to 5, representing five additional statistical features: average log-probability, rank proxy, overall entropy (`h_overall`), alternative entropy (`h_alts`), and decrease in entropy (`delta_h_dec`).  
- `input_dim` is computed as the sum of `top_k` and `n_stat_features`, resulting in 25.  

A projection MLP has `proj_dim` set to 128, mapping input features into a shared embedding space.  

A bidirectional GRU is configured with `hidden_dim` of 256, `num_layers` of 5, and `dropout` of 0.4 for regularization.  

The `top_q` parameter is set to 0.15, indicating that the top 15% of salient elements are selected during pooling as described in Section 3.2.  

Training settings include:  
- `batch_size` of 512,  
- `max_epochs` set to 5 (for fast iteration purpose - and synthetic data; obviously has to be higher for real dataset),  
- `lr` (learning rate) at $4.41 \times 10^{-4}$,  
- `weight_decay` (L2 regularization coefficient) at $2.34 \times 10^{-6}$,  
- `lr_patience` of 3 epochs before reducing the learning rate,  
- `lr_factor` of 0.5 to scale down the learning rate when plateauing,  
- `early_stop_patience` of 15 epochs before halting training if no improvement occurs,  
- `max_grad_norm` of 1.0 for gradient clipping to stabilize training.

## Feature Extraction

 (Sec 3.2)

In [None]:

EPS = 1e-9

def compute_engineered_features(logprobs_k: torch.Tensor) -> torch.Tensor:

    T, K = logprobs_k.shape

    # Stable truncated distribution over top-k (Eq 4)
    mt = logprobs_k.max(dim=1, keepdim=True).values
    exp_s = torch.exp(logprobs_k - mt)
    p_tilde = exp_s / (exp_s.sum(dim=1, keepdim=True) + EPS)

    # 1. AvgLogP (Eq 5)
    avg_logp = logprobs_k.mean(dim=1)  # [T]

    # 2. RankProxy (Eq 6)
    sel = logprobs_k[:, :1]            # [T, 1]
    rank_proxy = 1.0 + (logprobs_k[:, 1:] > sel).float().sum(dim=1)  # [T]

    # 3. Overall entropy over truncated distribution (Eq 7)
    h_overall = -(p_tilde * torch.log(p_tilde + EPS)).sum(dim=1)  # [T]

    # 4. Alternatives-only entropy (Eq 8‚Äì9)
    p_alts = p_tilde[:, 1:]                            # [T, K-1]
    p_alts_n = p_alts / (p_alts.sum(dim=1, keepdim=True) + EPS)
    h_alts = -(p_alts_n * torch.log(p_alts_n + EPS)).sum(dim=1)   # [T]

    # 5. Temporal delta of binary decision entropy (Eq 10‚Äì13)
    best_alt = logprobs_k[:, 1:].max(dim=1).values     # [T]
    sel_lp = logprobs_k[:, 0]                          # [T]

    pc_num = torch.exp(sel_lp)
    pc_den = torch.exp(sel_lp) + torch.exp(best_alt)   # [T]
    pc = pc_num / (pc_den + EPS)
    pc = pc.clamp(EPS, 1.0 - EPS)                      # avoid log(0)
    h_dec = -(pc * torch.log(pc) + (1 - pc) * torch.log(1 - pc))  # [T]

    delta_h_dec = h_dec.clone()
    delta_h_dec[1:] = h_dec[1:] - h_dec[:-1]

    return torch.stack([avg_logp, rank_proxy, h_overall, h_alts, delta_h_dec], dim=1)


The `compute_engineered_features` function computes five engineered features from a sequence of log-probabilities (`logprobs_k`) for each time step. The input tensor has shape `[T, K]`, where `T` is the number of time steps and `K` is the number of top candidates.

First, a numerically stable truncated distribution is computed:  
- The maximum log-probability per time step (`mt`) is subtracted before exponentiation to avoid overflow, yielding `exp_s`.  
- Normalized probabilities (`p_tilde`) are obtained by dividing `exp_s` by its sum, with a small epsilon (`EPS = 1e-9`) added for numerical stability.

Five features are then computed:

1. **AvgLogP**: The mean of the original log-probabilities across the `K` candidates at each time step.

2. **RankProxy**: A rank-based metric: for each time step, it compares the top log-probability (first element) to the remaining `K-1` values and counts how many are strictly greater, then adds 1.

3. **Overall entropy (`h_overall`)**: Shannon entropy computed over the truncated distribution `p_tilde`, measuring uncertainty across all top-`K` candidates.

4. **Alternatives-only entropy (`h_alts`)**: Shannon entropy computed over the distribution restricted to alternatives (excluding the top candidate). The probabilities of alternatives are renormalized before entropy computation.

5. **Temporal delta of binary decision entropy (`delta_h_dec`)**:  
   - The probability of selecting the top candidate versus the best alternative is computed via a binary softmax: `pc = exp(sel_lp) / (exp(sel_lp) + exp(best_alt))`, where `sel_lp` is the top candidate's log-probability and `best_alt` is the maximum among alternatives.  
   - This `pc` is clamped to avoid log(0), and binary entropy `h_dec` is computed.  
   - The temporal difference (`delta_h_dec`) is obtained by differencing consecutive entropy values.

The function returns a tensor of shape `[T, 5]`, stacking the five features along the feature dimension.

In [None]:
def build_halt_input(logprobs_k: torch.Tensor) -> torch.Tensor:
    """
    Concatenate engineered stats with raw log-probs ‚Üí [T, 25].

    Order: [stats (5) || raw_logprobs (20)] -> Fig. 1 / Sec 3.2.
    """
    stats = compute_engineered_features(logprobs_k)     # [T, 5]
    return torch.cat([stats, logprobs_k], dim=1)        # [T, 25]


The `build_halt_input` function constructs the model input by concatenating engineered statistical features with raw log-probabilities along the feature dimension.

It takes `logprobs_k`, a tensor of shape `[T, K]` (where `K = 20`), representing log-probabilities over top-`K` candidates at each of `T` time steps.

First, it computes the five engineered features (AvgLogP, RankProxy, Overall Entropy, Alternatives-Only Entropy, and Temporal Delta of Decision Entropy) by calling `compute_engineered_features`, resulting in a tensor of shape `[T, 5]`.

Then, it concatenates these features with the original raw log-probabilities (`logprobs_k`) along dimension 1, yielding a final input tensor of shape `[T, 25]` ‚Äî matching the `input_dim` specified in `HALTConfig`.

The concatenation order is explicitly: five engineered statistics followed by twenty raw log-probabilities.

## Synthetic Dataset Generator 

(Sec 4.1)

In [None]:
def _synthetic_logprobs(
    seq_len: int,
    top_k: int,
    hallucinated: bool,
    rng: np.random.Generator
) -> np.ndarray:
    """
    Synthesize "realistic" top-k log-prob matrix for hallucinated vs correct tokens.

    Hallucinated: flatter distribution (small gaps, high noise)
    Correct: peaked distribution (large gaps, low noise)

    Returns:
        [seq_len, top_k] float32 log-probs, sorted descending per row.
    """
    result = np.zeros((seq_len, top_k), dtype=np.float32)
    for t in range(seq_len):
        if hallucinated:
            base = rng.uniform(-4.0, -0.3)      # lower mean (more uncertainty)
            gap  = rng.uniform(0.05, 0.3)       # narrow gaps (indistinguishable)
        else:
            base = rng.uniform(-1.5, -0.1)      # higher mean (confident)
            gap  = rng.uniform(0.5, 2.0)        # large gaps (confident ranking)
        # Alternatives: base - gap*(1..K-1) + small noise
        alts = base - gap * np.arange(1, top_k)
        alts += rng.uniform(-0.5, 0.5, top_k - 1)
        result[t] = np.concatenate([[base], alts]).astype(np.float32)
    return result

The `_synthetic_logprobs` function generates synthetic log-probability matrices for evaluation or training purposes, simulating two distinct scenarios: hallucinated and correct token generations.

Inputs:
- `seq_len`: number of time steps (sequence length),
- `top_k`: number of top candidates per time step,
- `hallucinated`: boolean flag indicating whether to generate hallucinated (uncertain) or correct (confident) log-probabilities,
- `rng`: NumPy random number generator for reproducibility.

For each time step `t` from 0 to `seq_len-1`, the function creates a row of `top_k` log-probabilities:

- When `hallucinated=True`, it uses a lower baseline log-probability (`base` sampled uniformly from [-4.0, -0.3]) and small gaps (`gap` in [0.05, 0.3]), producing flatter distributions with indistinguishable candidates and higher uncertainty.
- When `hallucinated=False`, it uses a higher baseline (`base` in [-1.5, -0.1]) and larger gaps (`gap` in [0.5, 2.0]), yielding peaked distributions where the top candidate is clearly favored.

For alternatives (indices 1 to `top_k-1`), log-probabilities are constructed as `base - gap * i` for `i = 1..(K-1)`, with small uniform noise added (`[-0.5, 0.5]`). The top log-probability is set to `base`, and all values are concatenated and cast to float32.

The output is a `[seq_len, top_k]` array with rows sorted in descending order (since `base` is the largest value by construction and gaps are positive), representing realistic synthetic log-probability distributions for both correct and hallucinated cases.

In [None]:
class SyntheticLogProbDataset(Dataset):
    """
    Synthetic dataset: (feature_sequence, label) pairs.
    One sample = one LLM response as top-k log-prob time series.
    """
    def __init__(
        self,
        n_samples: int = 4000,
        top_k: int = 20,
        min_len: int = 10,
        max_len: int = 150,
        hallucination_rate: float = 0.5,
        seed: int = SEED,
    ):
        rng = np.random.default_rng(seed)
        self.samples: List[Tuple[torch.Tensor, int]] = []
        for _ in range(n_samples):
            label   = int(rng.random() < hallucination_rate)
            seq_len = int(rng.integers(min_len, max_len + 1))
            raw     = _synthetic_logprobs(seq_len, top_k, bool(label), rng)
            x       = build_halt_input(torch.from_numpy(raw))
            self.samples.append((x, label))

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]:
        x, label = self.samples[idx]
        return x, label, x.shape[0]   # tensor, label, length




The `SyntheticLogProbDataset` class implements a PyTorch Dataset for generating synthetic data used to train or evaluate the HALT model.

Each sample corresponds to one simulated LLM response, represented as a time series of top-`K` log-probabilities with an associated binary label: 1 for hallucinated responses, 0 for correct ones.

In the constructor:
- A NumPy random number generator is initialized with the given seed for reproducibility.
- For each of `n_samples` (default 4000) iterations:
  - A binary label is sampled: hallucinated (`label=1`) with probability `hallucination_rate` (default 0.5), otherwise correct (`label=0`).
  - A random sequence length is drawn uniformly from `[min_len, max_len]`, i.e., between 10 and 150 tokens.
  - Synthetic log-probabilities are generated using `_synthetic_logprobs`, with the `hallucinated` flag set to the sampled label.
  - The raw log-prob matrix (shape `[seq_len, top_k]`) is converted to a PyTorch tensor and passed through `build_halt_input` to form the full input features (shape `[seq_len, 25]`).
  - The resulting tensor `x`, its label, and the sequence length are stored.

The `__len__` method returns the total number of samples.

The `__getitem__` method retrieves a single sample as a tuple:  
- `x`: the full feature tensor of shape `[seq_len, 25]`,  
- `label`: integer (0 or 1),  
- `x.shape[0]`: the actual sequence length for that sample (needed e.g., for padding-aware batching).

In [None]:
def collate_fn(batch):
    xs, labels, lengths = zip(*batch)
    
    # Convert list of tensors to padded tensor (requires_grad kept)
    lengths_t = torch.tensor(lengths, dtype=torch.long)
    idx       = lengths_t.argsort(descending=True)
    
    xs_sorted   = [xs[i] for i in idx]
    lengths_s   = lengths_t[idx]
    
    # üîëÂÖ≥ÈîÆÔºöpad_sequence preserves requires_grad if inputs have it
    padded = pad_sequence(xs_sorted, batch_first=True, padding_value=0.0)
    
    # If needed, ensure grad flag is preserved:
    # padded.retain_grad()  # optional, not needed here
    
    labels_t = torch.tensor([labels[i] for i in idx], dtype=torch.float)
    
    return padded, labels_t, lengths_s

The `collate_fn` function processes a batch of samples from the `SyntheticLogProbDataset` to prepare it for model training or evaluation.

Each input sample is a tuple `(x, label, length)`, where:
- `x` is a 2D tensor of shape `[seq_len, 25]`,
- `label` is an integer (0 or 1),
- `length` is the original sequence length.

The function groups inputs by unpacking the batch into three lists: `xs`, `labels`, and `lengths`.

It then sorts the samples in descending order of sequence length using `argsort`, which is essential for efficient padded batching with most RNN implementations (longest sequences first). The indices `idx` indicate this sorting order.

The sorted samples are reordered: `xs_sorted` contains tensors ordered by decreasing length, and `lengths_s` stores the corresponding sorted lengths.

Padding is applied via `pad_sequence`, which pads shorter sequences to match the length of the longest sequence in the batch. Padding is done with zeros (`padding_value=0.0`), and `batch_first=True` ensures the output tensor has shape `[batch_size, max_seq_len, 25]`.

Labels are extracted in the same sorted order and converted to a float tensor.

The function returns a tuple:
- `padded`: padded input tensor,
- `labels_t`: label tensor (same order as `padded`),
- `lengths_s`: sorted sequence lengths tensor (descending), used for unpadded length handling in models like GRUs.

This collate function ensures variable-length sequences are handled efficiently while preserving gradient flow.

## Model Blocks 
(Sec 3.2 + Appendix B)

In [None]:
class InputProjection(nn.Module):
    """
    LayerNorm ‚Üí Linear ‚Üí GELU ‚Üí Linear  (Sec 3.2)
    [B,T,25] ‚Üí [B,T,128]
    """
    def __init__(self, input_dim: int, proj_dim: int):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.fc1  = nn.Linear(input_dim, proj_dim)
        self.act  = nn.GELU()
        self.fc2  = nn.Linear(proj_dim, proj_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(self.norm(x))))



The `InputProjection` module performs a feedforward transformation on the input sequence to map features from dimensionality 25 to 128 (as defined in `HALTConfig`), following the architecture described in Section 3.2 and Appendix B.

It consists of four sequential operations:

1. **Layer Normalization (`nn.LayerNorm`)**: Applied across the feature dimension (last axis), normalizing each time step independently to stabilize training.
2. **First Linear Layer (`fc1`)**: A linear transformation mapping the input features (dim=25) to a projection dimension of 128.
3. **GELU Activation**: A smooth nonlinearity applied element-wise to introduce nonlinearity after the first linear layer.
4. **Second Linear Layer (`fc2`)**: Maps the projected features back to the same dimension (128), completing a residual-free projection block.

The forward pass takes an input tensor `x` of shape `[B, T, 25]` (batch √ó time steps √ó features) and returns a tensor of shape `[B, T, 128]`. LayerNorm is applied first to normalize inputs per time step before projection, which can improve convergence and generalization.

In [None]:
class TopQPooling(nn.Module):
    """
    Average top-q salient timesteps (Sec 3.2).
    [B,T,D], lengths ‚Üí [B,D]
    """
    def __init__(self, q: float = 0.15):
        super().__init__()
        self.q = q

    def forward(self, H: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        B, T, D = H.shape
        mask   = torch.arange(T, device=H.device).unsqueeze(0) < lengths.unsqueeze(1)
        scores = H.norm(dim=2).masked_fill(~mask, -1e9)  # L2 norm + padding mask
        k_vals = (lengths.float() * self.q).ceil().long().clamp(min=1)

        pooled = torch.zeros(B, D, device=H.device)
        for b in range(B):
            top_idx = scores[b].topk(k_vals[b].item()).indices
            pooled[b] = H[b, top_idx].mean(dim=0)
        return pooled

The `TopQPooling` module implements top-`q` salient temporal pooling as described in Section 3.2.

Given an input tensor `H` of shape `[B, T, D]` (batch √ó time steps √ó hidden dimension) and a tensor `lengths` indicating the actual sequence lengths (excluding padding), the module selects a subset of the most salient timesteps and computes their average representation.

Operations:

1. **Masking**: A boolean mask identifies valid (non-padded) positions using the lengths.

2. **Scoring**: For each timestep, salience is measured as its L2 norm across the feature dimension (`H.norm(dim=2)`), resulting in a scalar salience score per timestep. Padded positions are assigned a very negative value (`-1e9`) so they are never selected.

3. **Top-k selection**: For each sequence `b`, the number of top timesteps to select, `k_vals[b]`, is computed as `ceil(q * length[b])`, where `q` (default 0.15) is the fraction of timesteps to retain. This value is clamped to at least 1.

4. **Pooling**: For each batch element, the indices of the top-`k_vals[b]` timesteps (by salience) are selected using `topk`. The corresponding hidden representations are averaged along the timestep dimension to produce a single pooled vector of size `D`.

The output tensor has shape `[B, D]`, where each row is the mean of the top-`q` most salient timesteps according to their L2 norms.

This operation focuses the representation on the most informative parts of each sequence, discarding padding and low-salience steps.

In [None]:
class HALT(nn.Module):
    """
    HALT: Hallucination Assessment via Log-probs as Time series
    (Sec 3.2 + Appendix B)

    Input:
        x:       [B, T, D=25] padded feature sequences
        lengths: [B]          original sequence lengths
    Output:
        logits:  [B]          raw hallucination scores (apply sigmoid)
    """
    def __init__(self, config: HALTConfig):
        super().__init__()
        self.cfg = config
        self.projection = InputProjection(config.input_dim, config.proj_dim)

        # Bidirectional GRU encoder (Appendix B)
        self.gru = nn.GRU(
            input_size=config.proj_dim,
            hidden_size=config.hidden_dim,
            num_layers=config.num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=config.dropout if config.num_layers > 1 else 0.0,
        )
        gru_out_dim = config.hidden_dim * 2   # 512
        self.pooler = TopQPooling(q=config.top_q)
        self.classifier = nn.Linear(gru_out_dim, 1)

        # Proper weight initialization (Appendix B)
        self._init_weights()

    def _init_weights(self):
        # GRU weights: xavier for input-hidden, orthogonal for hidden-hidden
        for name, p in self.gru.named_parameters():
            if "weight_ih" in name:
                nn.init.xavier_uniform_(p)
            elif "weight_hh" in name:
                nn.init.orthogonal_(p)
            elif "bias" in name:
                nn.init.zeros_(p)

        # Classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        proj    = self.projection(x)                          # [B,T,128]
        lengths_cpu = lengths.cpu()  # pack_padded_sequence requires CPU

        # üîë Sort lengths descending (required by pack_padded_sequence)
        idx_sorted = torch.sort(lengths_cpu, descending=True).indices
        lengths_sorted = lengths_cpu[idx_sorted]
        proj_sorted    = proj[idx_sorted]

        packed  = pack_padded_sequence(proj_sorted, lengths_sorted,
                                       batch_first=True, enforce_sorted=True)
        out_pk, _ = self.gru(packed)
        H, _    = pad_packed_sequence(out_pk, batch_first=True,
                                      total_length=T)        # [B,T,512]

        # Restore original order for pooling & loss
        inv_idx = idx_sorted.argsort()
        H_orig  = H[inv_idx]

        pooled  = self.pooler(H_orig, lengths)               # [B,512]
        logits  = self.classifier(pooled).squeeze(1)         # [B]
        return logits


The `HALT` class implements the full hallucination assessment model, as described in Section 3.2 and Appendix B.

**Architecture overview**:  
- Input: padded sequence `x` of shape `[B, T, 25]` and corresponding `lengths` tensor (actual unpadded lengths per sequence).  
- Output: logits of shape `[B]`, representing unnormalized hallucination scores (sigmoid applied externally during inference).

**Components**:  

1. **Input Projection (`InputProjection`)**:  
   Maps the 25-dimensional engineered features to a 128-dimensional space via LayerNorm ‚Üí Linear ‚Üí GELU ‚Üí Linear, preserving temporal structure.

2. **Bidirectional GRU Encoder**:  
   Processes sequences with a 5-layer bidirectional GRU (hidden size 256, dropout 0.4). Since `pack_padded_sequence` requires sorted sequences and the original order must be preserved after processing, the forward pass:  
   - Sorts inputs by descending length on CPU (required for `pack_padded_sequence`),  
   - Packs and passes through the GRU,  
   - Unpacks back to full length (`[B, T, 512]` ‚Äî 2√ó256 for bidirectional output),  
   - Restores original batch order via inverse indexing.

3. **Top-q Pooling (`TopQPooling`)**:  
   Selects the top 15% of timesteps (by L2 norm) per sequence and averages their GRU hidden states, producing a fixed-size `[B, 512]` representation.

4. **Classifier**:  
   A single linear layer maps the 512-dimensional pooled representation to a scalar logit.

**Weight Initialization**:  
- Input-to-hidden weights in the GRU use Xavier initialization, hidden-to-hidden weights use orthogonal initialization (to mitigate vanishing gradients), and biases are zero-initialized.  
- The classifier weights use Xavier initialization; bias is zero if present.

The model outputs raw logits, where higher values indicate stronger predicted hallucination likelihood.

## Training & Evaluation Utilities 
(Appendix B)

In [None]:
def train_epoch(model, loader, optimizer, criterion, max_grad_norm: float):
    model.train()
    total_loss, n = 0.0, 0
    for x_pad, labels, lengths in loader:
        x_pad = x_pad.to(CFG.device)
        labels = labels.to(CFG.device)
        lengths = lengths.to(CFG.device)

        optimizer.zero_grad()
        logits = model(x_pad, lengths)
        loss   = criterion(logits, labels)

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        total_loss += loss.item()
        n += 1
    return total_loss / max(n, 1)

The `train_epoch` function performs one full training pass over the dataset.

- Sets the model to training mode (`model.train()`), enabling operations like dropout.
- Initializes cumulative loss and sample counter.

For each batch `(x_pad, labels, lengths)` from the data loader:
- Moves all tensors to the configured device (GPU or CPU).
- Clears previous gradients with `optimizer.zero_grad()`.
- Computes model predictions (`logits`) by calling the model with padded inputs and lengths.
- Calculates loss between logits and labels using the provided criterion (e.g., BCEWithLogitsLoss).
- Performs backpropagation via `loss.backward()`.
- Clips gradients to the specified maximum norm (`max_grad_norm = 1.0`) to prevent exploding gradients.
- Updates model parameters via `optimizer.step()`.

Accumulates the per-batch loss and increments the counter. Returns the average training loss over all batches.

Note: Loss averaging uses `max(n, 1)` to avoid division by zero in edge cases where the loader yields no batches.

In [None]:
@torch.no_grad()
def evaluate(model, loader, criterion, device, threshold=0.5):
    model.eval()
    all_logits, all_labels = [], []
    total_loss, n = 0.0, 0

    for x_pad, labels, lengths in loader:
        x_pad = x_pad.to(device)
        labels = labels.to(device)
        lengths = lengths.to(device)

        logits = model(x_pad, lengths)
        loss   = criterion(logits, labels)

        total_loss += loss.item()
        n += 1
        all_logits.extend(logits.cpu().tolist())
        all_labels.extend(labels.cpu().int().tolist())

    probs    = torch.sigmoid(torch.tensor(all_logits)).numpy()
    preds    = (probs >= threshold).astype(int)
    labels_np = np.array(all_labels)

    macro_f1  = f1_score(labels_np, preds, average="macro", zero_division=0)
    accuracy  = (preds == labels_np).mean()
    try:
        auroc   = roc_auc_score(labels_np, probs)
    except ValueError:  # e.g., single class
        auroc   = 0.5

    return {
        "loss": total_loss / max(n, 1),
        "macro_f1": macro_f1,
        "auroc": auroc,
        "accuracy": accuracy,
    }



The `evaluate` function computes model performance metrics on a validation or test dataset in evaluation mode.

- Disables gradient computation (`@torch.no_grad()`) and sets the model to evaluation mode (`model.eval()`), turning off dropout and other training-specific behaviors.

- Iterates through the data loader, moving inputs, labels, and lengths to the specified device.
- Computes logits and loss (without updating model weights).
- Stores all logits and true labels on CPU as Python lists.

After collecting all predictions:
- Applies the sigmoid function to logits to obtain probabilities.
- Converts probabilities to binary predictions using a threshold (default 0.5).
- Computes four metrics:
  - **Average loss** over the dataset,
  - **Macro F1 score**, averaging per-class F1 and handling zero-division cases by returning 0,
  - **AUROC** (Area Under the ROC Curve); defaults to 0.5 if evaluation is impossible (e.g., labels are all one class),
  - **Accuracy**, the proportion of correct binary predictions.

Returns a dictionary containing these metrics.

In [None]:

def train_halt(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    criterion,
    cfg: HALTConfig,
) -> float:
    """Training loop with early stopping & LR scheduling."""
    best_f1, best_state, patience_ctr = -float("inf"), None, 0

    for epoch in range(1, cfg.max_epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion,
                                 cfg.max_grad_norm)

        val_metrics = evaluate(model, val_loader, criterion, CFG.device)
        scheduler.step(val_metrics["macro_f1"])

        if epoch % 5 == 0 or epoch == 1:
            print(
                f"Epoch {epoch:3d}/{cfg.max_epochs} | "
                f"train_loss={train_loss:.4f}, val_loss={val_metrics['loss']:.4f} | "
                f"macro-F1={val_metrics['macro_f1']:.4f}, AUROC={val_metrics['auroc']:.4f}"
            )

        # Early stopping based on macro-F1
        if val_metrics["macro_f1"] > best_f1:
            best_f1 = val_metrics["macro_f1"]
            best_state = {k: v.clone().cpu() for k, v in model.state_dict().items()}
            patience_ctr = 0
        else:
            patience_ctr += 1
            if patience_ctr >= cfg.early_stop_patience:
                print(f"Early stopping at epoch {epoch}. Best val macro-F1: {best_f1:.4f}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return best_f1


The `train_halt` function executes the complete training loop for the HALT model, incorporating learning rate scheduling and early stopping.

- Initializes tracking variables:  
  - `best_f1` (best macro-F1 score seen, initialized to negative infinity),  
  - `best_state` (to store the model weights corresponding to the best macro-F1),  
  - `patience_ctr` (counts consecutive epochs without improvement).

- For each epoch from 1 to `cfg.max_epochs`:  
  - Trains for one epoch using `train_epoch`, passing the gradient norm limit from config.  
  - Evaluates on the validation set using `evaluate`, returning metrics including loss and macro-F1.  
  - Updates the learning rate scheduler based on validation macro-F1 (assumes a metric-aware scheduler like `ReduceLROnPlateau`).  
  - Every epoch or every 5 epochs (and always at epoch 1), prints training progress: epoch number, training loss, validation loss, macro-F1, and AUROC.  
  - Checks for improvement in macro-F1:  
    - If improved, updates `best_f1`, saves a CPU copy of the current model state dictionary, and resets patience counter.  
    - If not improved, increments patience; when it reaches `cfg.early_stop_patience` (15 by default), prints a message and breaks out of training loop.

- After training, loads the best-performing model weights (from epoch with highest macro-F1) back into the model.  
- Returns the best validation macro-F1 score observed.

This ensures robust training with protection against overfitting via early stopping and adaptive learning rate control.

## Inference & Interpretability 
(Appendix C.2)

In [None]:
FEATURE_NAMES = (
    ["avg_logp", "rank_proxy", "h_overall", "h_alts", "delta_h_dec"]
    + [f"logprob_{i+1}" for i in range(20)]
)



def predict_hallucination(logprobs_matrix: np.ndarray, model: nn.Module,
                          device: str = CFG.device, threshold: float = 0.5) -> dict:
    """
    Predict hallucination for a single LLM response.

    Args:
        logprobs_matrix: [T, top_k] ‚Äî top-k log-probs per token.
                         Column 0 = selected token, rest = alternatives sorted desc.

    Returns:
        dict with prob_hallucinated, prediction, logit
    """
    model.eval()
    x = build_halt_input(torch.from_numpy(logprobs_matrix.astype(np.float32)))
    x = x.unsqueeze(0).to(device)                      # [1,T,25]
    lengths = torch.tensor([logprobs_matrix.shape[0]], dtype=torch.long).to(device)

    with torch.no_grad():
        logit = model(x, lengths).item()
    prob = torch.sigmoid(torch.tensor(logit)).item()

    return {"prob_hallucinated": prob, "prediction": int(prob >= threshold), "logit": logit}

The `predict_hallucination` function runs inference on a single LLM response to assess whether it contains hallucinations.

**Input**:  
- `logprobs_matrix`: a 2D NumPy array of shape `[T, top_k]`, where each row contains the top-`K` log-probabilities for a token, with column 0 holding the selected (sampled) token and columns 1 onward holding alternatives sorted in descending order.  
- `model`: the trained HALT model.  
- Optional arguments: computation device (`device`) and decision threshold (default 0.5).

**Process**:  
- Sets the model to evaluation mode.  
- Converts the NumPy input to a PyTorch tensor of type float32 and passes it through `build_halt_input` to construct the full 25-dimensional input (5 engineered features + 20 log-probabilities).  
- Adds a batch dimension to form shape `[1, T, 25]`.  
- Constructs a length tensor with value `[T]` (the actual sequence length).  
- Moves both tensors to the specified device.  
- Disables gradient computation, runs forward pass through the model to obtain logits, and converts the scalar logit to a Python float.  
- Applies sigmoid to convert the logit into a probability of hallucination.

**Output**:  
A dictionary with three keys:  
- `"prob_hallucinated"`: estimated probability (float between 0 and 1),  
- `"prediction"`: binary decision (0 for no hallucination, 1 for hallucination),  
- `"logit"`: raw model output before sigmoid.  

This enables straightforward interpretation and use of the model‚Äôs confidence and decision for individual samples.

In [None]:
def compute_feature_importance(model, loader, device, n_features=25) -> Dict[str, float]:
    """
    Feature importance via Gradient √ó Input (Appendix C.2).
    Safely handles model grad state.
    """
    # üîë Ensure model is in train mode AND parameters require grad
    model.train()
    
    for p in model.parameters():
        if not p.requires_grad:
            print("‚ö†Ô∏è  Enabling grad on parameter:", p.shape)
        p.requires_grad_(True)

    feat_imp = torch.zeros(n_features, device=device)

    for x_pad, labels, lengths in loader:
        # Ensure x_pad requires grad
        x_pad = torch.nn.Parameter(x_pad.to(device), requires_grad=True)
        labels = labels.to(device)

        logits = model(x_pad, lengths)  # [B]
        
        # Compute gradient of sum (scalar) w.r.t. inputs
        loss = logits.sum()
        grads = torch.autograd.grad(
            outputs=loss,
            inputs=x_pad,
            retain_graph=False,
            create_graph=False
        )[0]

        B, T, D = grads.shape
        mask    = (torch.arange(T, device=device).unsqueeze(0) < lengths.unsqueeze(1)).unsqueeze(2)
        grad_input = (grads.abs() * x_pad.abs()) * mask

        feat_imp += grad_input.sum(dim=(0, 1))
    
    model.eval()
    feat_imp /= feat_imp.sum() + EPS
    importance_dict = dict(zip(FEATURE_NAMES, feat_imp.tolist()))
    return dict(sorted(importance_dict.items(), key=lambda kv: kv[1], reverse=True))




The `compute_feature_importance` function estimates feature importance using the *Gradient √ó Input* method, as described in Appendix C.2.

- Sets the model to training mode and ensures all parameters have gradient computation enabled (`requires_grad=True`), as importance is derived from gradients w.r.t. inputs.

- Initializes a tensor `feat_imp` of zeros (length 25) on the specified device to accumulate importance scores across batches.

For each batch in the data loader:
- Wraps the padded input tensor `x_pad` as a `torch.nn.Parameter` with `requires_grad=True`, enabling gradient computation.
- Moves labels to the device and computes model logits.
- Defines a scalar loss as the sum of logits, then computes gradients of this loss w.r.t. the input `x_pad` using `torch.autograd.grad`, producing a gradient tensor of shape `[B, T, D]`.

- Constructs a boolean mask to identify valid (non-padded) timesteps using `lengths`.  
- Computes element-wise absolute gradient √ó absolute input (`grad_input`), masked to exclude padding. This aligns with the Gradient √ó Input salience measure, prioritizing features whose gradient and activation magnitude jointly indicate contribution.

- Aggregates `grad_input` across batch and time dimensions (sum over axes 0 and 1), accumulating per-feature importance.

After processing all batches:
- Sets the model back to evaluation mode.  
- Normalizes accumulated importance scores so they sum to 1 (using `EPS = 1e-9` for numerical stability).  
- Zips the normalized scores with `FEATURE_NAMES` (the ordered list of feature names: 5 engineered features followed by `logprob_1` through `logprob_20`) to form a dictionary.  
- Sorts the dictionary in descending order by importance value and returns it.

The result is a mapping from feature name to normalized importance score, allowing interpretation of which inputs most strongly influenced hallucination predictions.

## Shape sanity & backward pass check (debug/coverage)

In [None]:
def _shape_sanity_check(model, cfg: HALTConfig):
    B_t, T_t = 4, 80
    x   = torch.randn(B_t, T_t, cfg.input_dim).to(CFG.device)
    lens = torch.tensor([80, 65, 40, 20], dtype=torch.long).to(CFG.device)
    lbls = torch.zeros(B_t).to(CFG.device)

    out = model(x, lens)
    loss = nn.BCEWithLogitsLoss()(out, lbls)
    loss.backward()

    has_grads = all(p.grad is not None for p in model.parameters() if p.requires_grad)
    print(f"  Input shape:   {list(x.shape)}")
    print(f"  Output shape:  {list(out.shape)}")
    print(f"  Loss:          {loss.item():.4f}")
    print(f"  Gradients OK? {has_grads} ‚úì")



The `_shape_sanity_check` function verifies that the model, data pipeline, and loss computation are correctly integrated and support gradient flow.

- Constructs a small dummy batch:  
  - Input `x` of shape `[4, 80, input_dim]`, with random values,  
  - Sequence `lengths` set to `[80, 65, 40, 20]` (decreasing to match the sorting requirement of `pack_padded_sequence`),  
  - Binary labels initialized to zeros (`[4]`).

- Runs the model on this batch to obtain outputs, computes binary cross-entropy loss between outputs and labels, then calls `loss.backward()` to propagate gradients.

- Checks whether all trainable parameters (those with `requires_grad=True`) have non-`None` gradients after backpropagation.

- Prints:  
  - Input tensor shape,  
  - Output tensor shape (should be `[4]`),  
  - Computed loss value (rounded to four decimals),  
  - A verification status indicating whether gradients were computed successfully.

This check ensures the model runs end-to-end without shape mismatches, loss computation errors, or broken gradient paths, and confirms that the training setup is ready for actual use.

## Visualisations

In [None]:
plt.switch_backend('agg')  # Non-interactive backend

def plot_feature_importance(
    importance_dict: dict,
    output_path: str = "feature_importance.png",
    top_k: int = 10,
    title: str = "Feature Importance (Gradient √ó Input)",
):
    """Bar plot of feature importance."""
    features = list(importance_dict.keys())
    scores   = list(importance_dict.values())

    # Sort descending
    indices = np.argsort(scores)[::-1][:top_k]
    features_top = [features[i] for i in indices]
    scores_top   = [scores[i] for i in indices]

    # Normalize
    scores_top = minmax_scale(scores_top, feature_range=(0.2, 1.0))

    fig, ax = plt.subplots(figsize=(9, 5))
    bars = ax.barh(features_top[::-1], scores_top[::-1],
                   color="#2C7BB6", edgecolor="black", linewidth=0.5)

    ax.set_xlabel("Normalized Importance Score")
    ax.set_title(title, fontsize=12, pad=10)
    ax.grid(axis="x", alpha=0.3)
    ax.tick_params(axis='y', labelsize=10)

    # Annotate
    for i, v in enumerate(scores_top[::-1]):
        ax.text(v + 0.02, i, f"{v:.3f}", va="center", fontsize=9)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"‚úÖ Saved feature importance plot ‚Üí {output_path}")
    plt.close()

The `plot_feature_importance` function creates a horizontal bar chart of the top-k features by importance.

- Inputs:  
  - `importance_dict`: dictionary mapping feature names to normalized scores,  
  - `output_path`: filename for saving the plot (default `"feature_importance.png"`),  
  - `top_k`: number of top features to display (default 10),  
  - `title`: plot title.

- It extracts keys and values, sorts features in descending order of importance, then selects the top `top_k`.  
- Scores are rescaled to the range [0.2, 1.0] using min-max scaling for visual clarity‚Äîavoiding overly flat bars due to extreme concentration or dispersion.  
- A horizontal bar chart is drawn: features ordered from highest (top) to lowest (bottom), with bars colored in blue (`#2C7BB6`) and black edges.  
- Labels, title, and grid (horizontal) are added for readability; y-axis tick labels use font size 10.  
- Each bar is annotated with its exact normalized score (three decimals), positioned slightly to the right of the bar.  
- Layout is tightly adjusted, figure saved at 300 DPI with tight bounding box for quality and correctness of labeling.  
- The figure is closed to free memory, and a success message prints.

This visualization highlights the relative contribution of each input feature to hallucination prediction.

In [None]:

def plot_feature_correlation(
    model,
    loader,
    device,
    output_path: str = "feature_correlation.png"
):
    """
    Correlation heatmap between engineered features and model logits.
    Uses gradient-safe input path via compute_feature_importance pattern.
    """
    # Temporarily enable grad & eval to get logits
    model.train()
    X, Y = [], []

    with torch.enable_grad():
        for x_pad, labels, lengths in loader:
            # Wrap input in Parameter (like compute_feature_importance)
            x_pad = torch.nn.Parameter(x_pad.to(device), requires_grad=True)
            logits = model(x_pad, lengths)

            # Use only first 5 engineered features (avg over time)
            feat_means = x_pad[:, :, :5].mean(dim=1).detach().cpu().numpy()
            X.append(feat_means)

            Y.extend(logits.detach().cpu().numpy())

    model.eval()

    X = np.vstack(X)
    Y = np.array(Y)

    corr_matrix = []
    for i in range(5):
        feat_i = X[:, i]
        corr = np.corrcoef(feat_i, Y)[0, 1]
        corr_matrix.append(corr)

    fig, ax = plt.subplots(figsize=(5.5, 3.2))
    im = ax.imshow([corr_matrix], cmap="RdBu_r", vmin=-1, vmax=1)

    ax.set_yticks([0])
    ax.set_yticklabels(["Logits"], fontsize=10)
    ax.set_xticks(range(5))
    ax.set_xticklabels([
        "AvgLogP", "RankProxy",
        "Hoverall", "Halts", "ŒîHdec"
    ], fontsize=10)
    ax.set_title("Feature ‚Üî Hallucination Logits Correlation", pad=8)

    cbar = plt.colorbar(im, shrink=0.75)
    cbar.set_label("Pearson Correlation", fontsize=9)

    # Annotations
    for i in range(5):
        color = "black" if abs(corr_matrix[i]) < 0.7 else "white"
        ax.text(i, 0, f"{corr_matrix[i]:+.2f}", ha="center", va="center",
                color=color, fontsize=10)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print(f"‚úÖ Saved correlation heatmap ‚Üí {output_path}")
    plt.close()



The `plot_feature_correlation` function generates a heatmap showing Pearson correlations between the five engineered features and model logits (hallucination scores).

- In training mode with gradient enabled, it iterates through the data loader, wrapping inputs in `torch.nn.Parameter` to preserve gradient compatibility.
- For each batch, it extracts the first five columns of the input tensor (engineered features: AvgLogP, RankProxy, h_overall, h_alts, delta_h_dec), averages them across time steps to obtain per-sample feature values, and collects logits.
- After gathering all data, the model is set back to evaluation mode.

A correlation matrix is computed: for each of the five engineered features, Pearson correlation with logits is calculated and stored.

A 1√ó5 heatmap is plotted:
- Uses diverging colormap (`RdBu_r`) with range [‚àí1, 1] for intuitive interpretation of positive/negative correlations.
- Y-axis labels the single row as ‚ÄúLogits‚Äù.
- X-axis labels each feature with abbreviated names (e.g., ‚ÄúHoverall‚Äù, ‚ÄúHalts‚Äù, ‚ÄúŒîHdec‚Äù).
- A colorbar indicates correlation strength.
- Correlation values are annotated on the heatmap: black text for low absolute correlation (|r| < 0.7), white for higher magnitude to ensure readability.

The figure is saved at 300 DPI, and a confirmation message prints upon completion.

In [None]:

def extract_embeddings_for_t_sne(model, loader, device, n_samples=2000):
    """
    Extract embeddings ‚Äî only if they are detachable.
    Uses same forward as compute_feature_importance to ensure consistency.
    """
    model.train()  # Ensure grad enabled
    all_embs, all_labels = [], []

    with torch.no_grad():  # Safe: embeddings don‚Äôt need gradients
        for x_pad, labels, lengths in loader:
            # Use same input as feature importance: wrap in Parameter but detach
            x_pad = torch.nn.Parameter(x_pad.to(device), requires_grad=False)  # no grad needed here
            logits = model(x_pad, lengths)

            # Extract pooled embeddings via same pipeline:
            proj = model.projection(x_pad)
            lengths_cpu = lengths.cpu()
            idx_sorted  = torch.sort(lengths_cpu, descending=True).indices
            proj_sorted = proj[idx_sorted]
            lengths_sorted = lengths_cpu[idx_sorted]

            packed  = torch.nn.utils.rnn.pack_padded_sequence(
                proj_sorted, lengths_sorted, batch_first=True, enforce_sorted=True
            )
            out_pk, _ = model.gru(packed)
            H, _    = torch.nn.utils.rnn.pad_packed_sequence(out_pk, batch_first=True)
            inv_idx = idx_sorted.argsort()
            H_orig  = H[inv_idx]

            pooled = model.pooler(H_orig, lengths)
            all_embs.append(pooled.detach().cpu().numpy())
            # Use true labels from dataset (not prediction)
            all_labels.extend(labels.int().cpu().tolist())

            if len(all_embs) * loader.batch_size >= n_samples:
                break

    model.eval()
    embeddings = np.vstack(all_embs)
    labels_arr   = np.array(all_labels[:len(embeddings)])
    return embeddings, labels_arr



The `extract_embeddings_for_t_sne` function extracts contextual embeddings from the model for downstream dimensionality reduction (e.g., t-SNE), ensuring consistency with the model‚Äôs internal processing pipeline.

- Sets the model to training mode (to match conditions used during gradient-based analyses, though `torch.no_grad()` is later applied since embeddings are not needed for backprop).
- Iterates through the data loader, wrapping inputs as `Parameter` with gradient disabled (`requires_grad=False`) for safety.
- Performs the same forward path used in `HALT.forward`:  
  - Input projection (`projection` layer),  
  - Sorting by sequence length and packing,  
  - Bidirectional GRU encoding,  
  - Unpacking and restoring original order,  
  - Top-`q` pooling (via `model.pooler`) to obtain pooled representations.
- Embeddings and true labels are collected on CPU in NumPy format.  
- Loop breaks once at least `n_samples` (default 2000) samples are collected, in case the loader yields more.

- Returns two NumPy arrays:  
  - `embeddings`: shape `[N, D]` (D = 512),  
  - `labels_arr`: binary labels matching the embedded samples.

This ensures embeddings reflect exactly how they are used in the model, without modification or approximation.

In [None]:

def plot_t_sne_embeddings(
    model,
    loader,
    device,
    output_path: str = "tsne_embeddings.png",
    n_samples: int = 1500,
    perplexity: float = 30.0
):
    """
    T-SNE visualization of GRU embeddings, colored by **true class** (hallucinated vs correct).
    Uses same data loading and embedding extraction logic as feature importance.
    """
    print(f"üß† Extracting {n_samples} embeddings for T-SNE...")
    embeddings, labels = extract_embeddings_for_t_sne(model, loader, device, n_samples=n_samples)

    # Subsample if needed (T-SNE O(N^2))
    if len(embeddings) > n_samples:
        idx = np.random.choice(len(embeddings), size=n_samples, replace=False)
        embeddings = embeddings[idx]
        labels     = labels[idx]

    # T-SNE
    tsne = TSNE(n_components=2, perplexity=perplexity,
                random_state=42, max_iter=500, metric="euclidean")
    embs_2d = tsne.fit_transform(embeddings)

    # Plot
    fig, ax = plt.subplots(figsize=(6.5, 5))

    class_names = ["Hallucinated", "Correct"]
    colors      = ["#D7191C", "#2B83BA"]

    for cls in [0, 1]:
        mask = labels == cls
        ax.scatter(
            embs_2d[mask, 0], embs_2d[mask, 1],
            c=colors[cls], label=class_names[cls],
            alpha=0.6, s=25, edgecolor="none"
        )

    ax.set_xlabel("T-SNE Component 1", fontsize=10)
    ax.set_ylabel("T-SNE Component 2", fontsize=10)
    ax.set_title("GRU Embedding Space (T-SNE Projection)", fontsize=12, pad=10)
    ax.legend(loc="best", fontsize=9)

    # Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"‚úÖ Saved T-SNE plot ‚Üí {output_path}")
    plt.close()



The `plot_t_sne_embeddings` function visualizes the model‚Äôs GRU-derived embeddings in 2D using t-SNE, colored by ground-truth class (hallucinated vs. correct).

- Calls `extract_embeddings_for_t_sne` to obtain embeddings and true labels for up to `n_samples` (default 1500). If more are collected, random subsampling is applied to match `n_samples`, ensuring computational feasibility for t-SNE (quadratic complexity).

- Applies t-SNE with:  
  - 2 components,  
  - Perplexity (default 30.0),  
  - Fixed random seed (42) for reproducibility,  
  - Maximum 500 iterations,  
  - Euclidean distance metric.

- Produces a scatter plot:  
  - Hallucinated samples (label=1) in red (`#D7191C`),  
  - Correct samples (label=0) in blue (`#2B83BA`).  
  - Points are semi-transparent (alpha=0.6) with size 25 and no edge color.

- Axis labels and title describe the visualization; legend is placed automatically. Ticks are removed for clarity.

- The figure is saved at 300 DPI with tight bounding box, and a success message prints.

This plot provides intuition about whether hallucinated and correct samples separate in the embedding space, revealing model behavior at the representation level.

In [None]:

def plot_sensitivity_analysis(
    model,
    loader,
    device,
    n_samples: int = 50,
    output_path: str = "sensitivity_analysis.png"
):
    """SHAP-like perturbation sensitivity on top-5 engineered features."""
    model.train()

    results = {f: [] for f in range(25)}

    with torch.enable_grad():
        for x_pad, _, lengths in loader:
            if len(results[0]) >= n_samples * 5:
                break

            # Detach & clone to avoid interfering with training
            x_clean = torch.nn.Parameter(x_pad.to(device).detach().clone(), requires_grad=False)

            # Compute baseline logits (without perturbation)
            with torch.no_grad():
                base_logits = model(x_pad.to(device), lengths.to(device))

            for feat_idx in range(5):
                # Create perturbed copy
                x_pert = x_pad.clone().detach()
                mean_feat = x_pert[:, :, feat_idx].mean(dim=1, keepdim=True)
                std_feat  = x_pert[:, :, feat_idx].std(dim=1, keepdim=True) + 1e-8
                noise = torch.randn_like(x_pert[:, :, feat_idx]) * std_feat * 0.5
                x_pert[:, :, feat_idx] += noise

                # Perturbed logits (no grad needed for sensitivity)
                with torch.no_grad():
                    pert_logits = model(x_pert.to(device), lengths.to(device))

                delta_logits = torch.abs(pert_logits - base_logits).cpu().numpy()
                results[feat_idx].extend(delta_logits.tolist())

    model.eval()

    mean_delta = [np.mean(results[i]) for i in range(5)]
    std_delta  = [np.std(results[i]) for i in range(5)]

    features_names = ["AvgLogP", "RankProxy", "Hoverall", "Halts", "ŒîHdec"]
    fig, ax = plt.subplots(figsize=(8, 4))

    x_pos = np.arange(len(features_names))
    ax.bar(x_pos, mean_delta, yerr=std_delta,
           color="#1B9E77", capsize=5, alpha=0.8, edgecolor="black")

    ax.set_xticks(x_pos)
    ax.set_xticklabels(features_names, fontsize=10)
    ax.set_ylabel("Logit Œî (Perturbation Sensitivity)", fontsize=10)
    ax.set_title("Sensitivity Analysis: Engineered Features", pad=8)
    ax.grid(axis="y", alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print(f"‚úÖ Saved sensitivity plot ‚Üí {output_path}")
    plt.close()


# Run Forrest run

## Prep work 

In [None]:

cfg = HALTConfig()
print(f"\n‚öôÔ∏è Configuration:\n{cfg}")
print(f"üíæ Device: {CFG.device}")

# ‚îÄ‚îÄ Dataset
full_ds = SyntheticLogProbDataset(n_samples=4000, top_k=cfg.top_k, seed=SEED)
n_total = len(full_ds)
n_train = int(0.70 * n_total)
n_val   = int(0.15 * n_total)
n_test  = n_total - n_train - n_val

train_set, val_set, test_set = random_split(
    full_ds,
    [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(SEED),
)

loader_kw = dict(collate_fn=collate_fn, num_workers=0)
train_loader = DataLoader(train_set, batch_size=cfg.batch_size,
                            shuffle=True,  **loader_kw)
val_loader   = DataLoader(val_set,   batch_size=cfg.batch_size,
                            shuffle=False, **loader_kw)
test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size,
                            shuffle=False, **loader_kw)

print(f"\nüìä Dataset split:")
print(f"  Train: {len(train_set):>5} | Val: {len(val_set):>5} | Test: {len(test_set):>5}")


Configuration is instantiated from `HALTConfig`, and key settings‚Äîincluding computing device‚Äîare printed.

A synthetic dataset of 4000 samples is generated using `SyntheticLogProbDataset`. The dataset is split into training (70%), validation (15%), and test (15%) subsets using `random_split`, with shuffling controlled by a deterministic random generator seeded at 42 to ensure reproducibility.

Data loaders are created for each subset:
- Training loader uses `shuffle=True` to randomize batches.
- Validation and test loaders use `shuffle=False` to maintain consistent ordering.
- All loaders use the previously defined `collate_fn` and disable multiprocessing (`num_workers=0`).

Finally, the sizes of the train/validation/test splits are printed.

In [None]:
model = HALT(cfg).to(CFG.device)

for p in model.parameters():
    p.requires_grad_(True)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nüß† Model: HALT ({n_params:,} params ‚âà {n_params/1e6:.1f}M)")

# ‚îÄ‚îÄ Optimizer & scheduler
criterion = nn.BCEWithLogitsLoss().to(CFG.device)
optimizer  = torch.optim.Adam(model.parameters(), lr=cfg.lr,
                                weight_decay=cfg.weight_decay)
scheduler  = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=cfg.lr_factor,
    patience=cfg.lr_patience,
)


The HALT model is instantiated with the configuration and moved to the configured device (`CFG.device`). All parameters are explicitly set to require gradients (though this is typically redundant after initialization, it ensures no gradient-disabled layers remain).

The total number of trainable parameters is counted and printed in both raw count and millions (e.g., "1.2M").

The loss function is set to `BCEWithLogitsLoss`, appropriate for binary classification with raw logits. It is moved to the device.

The optimizer is Adam, configured with:
- Learning rate from config (`lr = 4.41e-4`),
- L2 weight decay (`weight_decay = 2.34e-6`).

The learning rate scheduler is `ReduceLROnPlateau`, which:
- Monitors validation macro-F1 (`mode="max"`),
- Reduces learning rate by a factor of `lr_factor` (0.5) when no improvement occurs for `lr_patience` (3) epochs.

This setup prepares the model and training infrastructure for the main training loop.

## Model

In [None]:
print("\n" + "=" * 72)
print("üöÄ Training HALT")
print("=" * 72)

best_val_f1 = train_halt(model, train_loader, val_loader,
                            optimizer, scheduler, criterion, cfg)

print(f"\n Best validation macro-F1: {best_val_f1:.4f}")


The `train_halt` function is called with the model, data loaders, optimizer, learning rate scheduler, loss function, and configuration. It runs the full training loop‚Äîincluding early stopping based on validation macro-F1‚Äîand returns the highest macro-F1 score observed on the validation set.

After training completes, the best validation macro-F1 is printed with four decimal places.

In [None]:
# Final test evaluation
print("\n" + "=" * 72)
print("üß™ Final Test Evaluation")
print("=" * 72)

test_metrics = evaluate(model, test_loader, criterion, CFG.device)
print(f"  Macro-F1 : {test_metrics['macro_f1']:.4f}")
print(f"  AUROC    : {test_metrics['auroc']:.4f}")
print(f"  Accuracy : {test_metrics['accuracy']:.4f}")
print(f"  Loss     : {test_metrics['loss']:.4f}")


After training completes, a final evaluation is performed on the test set.

A header signals the start of final test evaluation. The `evaluate` function is called with the trained model (which now holds the weights from its best validation macro-F1 checkpoint), the test data loader, loss function, and device.

The resulting test metrics are printed: macro-F1 score, AUROC, accuracy, and average loss‚Äîeach formatted to four decimal places for precise reporting.

In [None]:

# ‚îÄ‚îÄ Example inference
print("\n" + "=" * 72)
print("üîç Example Inference")
print("=" * 72)

rng_demo = np.random.default_rng(123)
for label_str, is_hall in [("Hallucinated", True), ("Correct", False)]:
    demo = _synthetic_logprobs(60, cfg.top_k, is_hall, rng_demo)
    res  = predict_hallucination(demo, model, CFG.device)
    pred_str = "HALLUCINATED" if res["prediction"] else "NOT HALLUCINATED"
    print(f"[{label_str:>13}]  P(hallucinated)={res['prob_hallucinated']:.4f} "
            f"‚Üí {pred_str}")

A demo inference section is run to illustrate model behavior on synthetic examples.

A NumPy random generator seeded at 123 is created for reproducibility. For two cases‚Äîmarked as "Hallucinated" and "Correct"‚Äîa 60-step synthetic log-probability sequence is generated using `_synthetic_logprobs`, with the `hallucinated` flag set appropriately.

Each sequence is passed to `predict_hallucination`, which returns the hallucination probability, binary prediction, and logit.

The output is formatted to show:
- The ground-truth label ("Hallucinated" or "Correct"),
- The predicted probability of hallucination (four decimals),
- The resulting prediction ("HALLUCINATED" or "NOT HALLUCINATED").

This provides intuitive verification that the model assigns high hallucination probability to synthetic hallucinated sequences and low probability to correct ones.

In [None]:

# ‚îÄ‚îÄ Feature importance (Appendix C.2)
print("\n" + "=" * 72)
print("üìä Feature Importance (Gradient√óInput, top-10)")
print("=" * 72)

print("üîç Testing gradient flow...")
dummy_x = torch.randn(2, 5, cfg.input_dim, device=CFG.device, requires_grad=True)
dummy_lens = torch.tensor([5, 3], device=CFG.device)

model.eval()
logits = model(dummy_x, dummy_lens)
print(f" logits.requires_grad? {logits.requires_grad}")
loss = nn.BCEWithLogitsLoss()(logits, torch.zeros(2, device=CFG.device))
loss.backward()
print(f"dummy x.grad exists? {dummy_x.grad is not None}")
assert dummy_x.grad is not None, "Break in gradient flow!"
print("‚úÖ Gradient test PASSED")



A feature importance analysis is prepared, preceded by a gradient flow validation step.

- A small dummy input tensor (`dummy_x`) of shape `[2, 5, 25]` is created with `requires_grad=True`, and corresponding lengths `[5, 3]`.
- The model is switched to evaluation mode.
- Forward pass computes logits; binary cross-entropy loss is computed against zero labels.
- Backward propagation is performed.

The script checks that:
- Output logits has `requires_grad=True`, confirming the forward path preserves gradients,
- Input tensor `dummy_x` has a non-null gradient after backpropagation.

An assertion ensures gradients are computed; if not, training would be impossible. A confirmation message is printed upon success.

This pre-check verifies that the model supports gradient-based feature importance computation before proceeding with `compute_feature_importance`.

In [None]:

importance = compute_feature_importance(model, val_loader, CFG.device,
                                        n_features=cfg.input_dim)
print(f"{'Feature':<22}  {'Importance':>10}")
print("-" * 36)
for name, score in list(importance.items())[:10]:
    print(f"{name:<22}  {score:>10.4f}")

Feature importance is computed on the validation set using gradient-based analysis (Gradient √ó Input).

The `compute_feature_importance` function is called with the trained model, validation loader, device, and feature count (25). The function aggregates per-feature gradient magnitudes weighted by input magnitude across all non-padded timesteps and normalizes them to sum to 1.

The top 10 most important features are printed in descending order of importance:
- Each line shows the feature name (left-aligned, up to 22 characters) and its normalized importance score (right-aligned, formatted to four decimal places).  
- A separator line precedes the output for readability.

This highlights which engineered or raw log-probability features contribute most to hallucination detection.

In [None]:

# ‚îÄ‚îÄ Shape & backward pass check
print("\n" + "=" * 72)
print("üîç Shape Sanity Check & Backward Pass")
print("=" * 72)
_shape_sanity_check(model, cfg)

# ‚îÄ‚îÄ Save model (optional)
torch.save(model.state_dict(), "halt_model.pt")
print("\nüíæ Model saved to halt_model.pt")

A final shape and gradient flow verification is performed by calling `_shape_sanity_check`, which tests that the model processes a small batch correctly, produces valid outputs, computes loss without errors, and successfully propagates gradients to all parameters.

After confirming correctness, the trained model‚Äôs state dictionary is saved to disk as `"halt_model.pt"` using `torch.save`.

A confirmation message indicates successful model persistence.

## Visualisations

In [None]:
importance = compute_feature_importance(model, val_loader, CFG.device)
plot_feature_importance(importance, output_path="feature_importance.png")

plot_feature_correlation(model, val_loader, CFG.device)


plot_t_sne_embeddings(
    model, val_loader, CFG.device,
    output_path="tsne_embeddings.png",
    n_samples=1500
)

plot_sensitivity_analysis(model, val_loader, CFG.device)

Additional diagnostic visualizations are generated to analyze model behavior:

- **Feature importance plot**: The `plot_feature_importance` function is called with the computed importance dictionary and saves a figure as `"feature_importance.png"`.

- **Feature correlation analysis**: The `plot_feature_correlation` function is invoked to visualize pairwise correlations among engineered features and raw log-probabilities.

- **t-SNE embedding visualization**: The `plot_t_sne_embeddings` function generates a 2D t-SNE plot of model representations for validation samples (up to 1500 samples), saving the result as `"tsne_embeddings.png"`. This helps assess whether hallucinated and correct samples separate in the embedding space.

- **Sensitivity analysis**: The `plot_sensitivity_analysis` function is called to evaluate how input perturbations affect predictions‚Äîlikely assessing robustness or identifying vulnerable features.

Each function produces a PNG file, enabling visual interpretation of model characteristics beyond scalar metrics.