# HGF Basics: Understanding the Hierarchical Gaussian Filter

This notebook introduces the Hierarchical Gaussian Filter (HGF), a powerful Bayesian model for learning under uncertainty.

## What is the HGF?

The HGF is a hierarchical Bayesian model where:
- **Level 1**: Observations (what we see)
- **Level 2**: Hidden states (what we believe about the world)
- **Level 3**: Volatility (how quickly we think the world is changing)

The key insight is that beliefs at higher levels modulate learning at lower levels.

In [None]:
# Setup
import sys
sys.path.insert(0, '../..')

import numpy as np
import matplotlib.pyplot as plt

from ara.hgf import (
    HGFAgent,
    HGFParams,
    VolatilitySwitchingTask,
    plot_beliefs,
    plot_prediction_errors,
    plot_hgf_dashboard,
)

# Style
plt.style.use('dark_background')
%matplotlib inline

## 1. Create a Task Environment

The **Volatility Switching Task** alternates between stable and volatile phases.
In stable phases, the true probability is constant.
In volatile phases, it changes rapidly.

In [None]:
# Create a volatility switching task
task = VolatilitySwitchingTask(
    n_trials=200,
    stable_prob=0.8,
    phase_length_mean=30,
)

# Generate task data
task_data = task.generate(seed=42)

print(f"Generated {task_data.n_trials} trials")
print(f"Task type: {task_data.task_type}")

In [None]:
# Visualize the task
fig, ax = plt.subplots(figsize=(12, 4))

trials = np.arange(task_data.n_trials)
ax.scatter(trials, task_data.observations, c='cyan', s=10, alpha=0.5, label='Observations')
ax.plot(trials, task_data.true_probabilities, 'orange', lw=2, label='True P')

# Mark volatility phases
for i, state in enumerate(task_data.volatility_states):
    if state == 'volatile':
        ax.axvspan(i-0.5, i+0.5, alpha=0.1, color='red')

ax.set_xlabel('Trial')
ax.set_ylabel('Probability')
ax.set_title('Volatility Switching Task')
ax.legend()
plt.tight_layout()

## 2. Create an HGF Agent

The agent has three key parameters:
- **ω₂** (omega_2): Tonic log-volatility. Lower = more stable beliefs.
- **κ₁** (kappa_1): Coupling strength. Higher = more sensitive to volatility changes.
- **θ** (theta): Response temperature. Higher = more deterministic choices.

In [None]:
# Create a healthy baseline agent
agent = HGFAgent(
    omega_2=-4.0,  # Moderate volatility
    kappa_1=1.0,   # Normal coupling
    theta=1.0,     # Balanced
    n_levels=3,    # Use 3-level HGF
)

print(f"Agent parameters:")
print(f"  ω₂ = {agent.params.omega_2}")
print(f"  κ₁ = {agent.params.kappa_1}")
print(f"  θ  = {agent.params.theta}")

## 3. Run the Agent Through the Task

In [None]:
# Run agent and get trajectory
trajectory = agent.run(task_data, generate_actions=True)

print(f"Trajectory: {trajectory.n_trials} trials")
print(f"Accuracy: {trajectory.compute_accuracy():.3f}")
print(f"Log-likelihood: {trajectory.compute_log_likelihood():.2f}")

## 4. Visualize Belief Dynamics

In [None]:
# Plot beliefs at each level
fig = plot_beliefs(trajectory, task_data=task_data)
plt.show()

In [None]:
# Plot prediction errors
fig = plot_prediction_errors(trajectory, levels=[1, 2])
plt.show()

In [None]:
# Full dashboard view
fig = plot_hgf_dashboard(trajectory, task_data=task_data)
plt.show()

## 5. Understanding the Key Equations

### Prediction Error (δ₁)
The sensory prediction error is simply:
$$\delta_1 = u - \hat{\mu}_1$$

where $u$ is the observation and $\hat{\mu}_1$ is the prediction.

### Belief Update (μ₂)
The hidden state is updated according to:
$$\mu_2^{(k)} = \mu_2^{(k-1)} + \hat{\sigma}_2 \cdot \delta_1$$

where $\hat{\sigma}_2$ is the prior variance (how uncertain we are).

### The κ₁ Coupling (Critical for Pathology)
The prior variance depends on volatility:
$$\hat{\sigma}_2 = \sigma_2 + \exp(\kappa_1 \mu_3 + \omega_2)$$

This is where **κ₁** becomes crucial. Higher κ₁ means volatility has a stronger effect on learning rate.

In [None]:
# Examine the precision dynamics
precisions = trajectory.get_precisions()

fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

trials = np.arange(trajectory.n_trials)

# Prior precision
axes[0].plot(trials, precisions['pi_hat_2'], 'purple', lw=2)
axes[0].set_ylabel('Prior Precision (π̂₂)')
axes[0].set_title('Prior Precision - How confident in our beliefs')

# Sensory precision
axes[1].plot(trials, precisions['pi_1'], 'cyan', lw=2)
axes[1].set_ylabel('Sensory Precision (π₁)')
axes[1].set_xlabel('Trial')
axes[1].set_title('Sensory Precision - How reliable is the input')

plt.tight_layout()

## 6. The Precision Ratio: Who Wins?

The key insight of predictive coding is that belief updates are a **precision-weighted average** of prior and sensory evidence.

$$\text{Update} \propto \frac{\pi_1}{\pi_1 + \hat{\pi}_2}$$

When sensory precision dominates (ratio → 1), we trust our senses.
When prior precision dominates (ratio → 0), we trust our beliefs.

In [None]:
# Compute precision ratio
pi_ratio = precisions['pi_1'] / (precisions['pi_1'] + precisions['pi_hat_2'] + 1e-10)

fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(trials, pi_ratio, 'orange', lw=2)
ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5)
ax.fill_between(trials, 0, pi_ratio, alpha=0.3, color='orange')

ax.set_xlabel('Trial')
ax.set_ylabel('π₁ / (π₁ + π̂₂)')
ax.set_ylim(0, 1)
ax.set_title('Precision Ratio: Sensory vs Prior Dominance')

# Annotate
ax.text(10, 0.9, '← Trust senses', fontsize=10, color='cyan')
ax.text(10, 0.1, '← Trust beliefs', fontsize=10, color='purple')

plt.tight_layout()

## Next Steps

1. **02_pathological_regimes.ipynb**: See how different parameter settings model psychiatric disorders
2. **03_parameter_fitting.ipynb**: Learn to recover parameters from behavioral data
3. **04_neural_correlates.ipynb**: Simulate EEG/fMRI correlates of prediction errors