# Surprise Minimization – Experiment 02

This notebook provides several complementary numerical views of the PDE model

$$\partial_t S = -1 + \alpha\,\Delta S, \qquad \partial_t v = -v,$$

under Neumann-like boundary conditions, interpreted as:

* a **continuous surprise-minimizing field** (PDE),
* a **GPU-accelerated JAX implementation**, 
* a **phase diagram over $\alpha$**,
* a **cellular automaton (CA) style discrete analogue**, 
* a **3D extension**, and
* a **finite-element sketch (FEniCSx)**.

Conceptual framing:

* High curvature $\|\nabla S\|$ = epistemic tension / exploration.
* Transient curvature = **simulated danger** (controlled uncertainty).
* Long-time collapse of curvature = **inoculation against surprise**.
* Steady state with vanishing policy $v \to 0$ = **dark-room equilibrium**.


## 0. Imports and Common Utilities

We will use:

* `jax` / `jax.numpy` for GPU-accelerated PDE simulation.
* `matplotlib` for visualization.
* Standard Python/NumPy for CA comparison.
* Optional FEniCSx imports (will only work if you have FEniCSx installed).

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

try:
    import jax
    import jax.numpy as jnp
    from jax import jit, vmap
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    print("JAX not available; PDE sections will fall back to NumPy CPU.")

try:
    from jax.scipy.signal import convolve2d as jax_convolve2d
except Exception:
    jax_convolve2d = None

% Matplotlib defaults
plt.rcParams['figure.figsize'] = (5, 4)
plt.rcParams['image.cmap'] = 'viridis'

---
## 1. JAX / GPU PDE Simulation (2D)

We simulate:

$$\partial_t S = -1 + \alpha\,\Delta S$$

with a 5-point Laplacian stencil and simple Neumann-like boundary via symmetric padding.

In [ ]:
nx = 120
dx = 1.0
dt = 0.05
alpha = 0.2
steps = 2000
plot_interval = 400

# CFL-ish warning
if alpha > 0 and dt > dx**2 / (4 * alpha):
    print("[Warning] dt may be too large for stable explicit diffusion.")

# 2D Laplacian kernel
lap_kernel_np = np.array([[0., 1., 0.],
                          [1., -4., 1.],
                          [0., 1., 0.]]) / dx**2

if JAX_AVAILABLE:
    lap_kernel = jnp.array(lap_kernel_np)
else:
    lap_kernel = lap_kernel_np

def laplace_np(S):
    from scipy.signal import convolve2d
    return convolve2d(S, lap_kernel_np, mode='same', boundary='symm')

if JAX_AVAILABLE and jax_convolve2d is not None:
    @jit
    def laplace_jax(S):
        return jax_convolve2d(S, lap_kernel, mode='same', boundary='symm')
else:
    laplace_jax = None

def step_S_np(S, alpha, dt):
    return S + dt * (-1.0 + alpha * laplace_np(S))

if JAX_AVAILABLE and laplace_jax is not None:
    @jit
    def step_S_jax(S, alpha, dt):
        return S + dt * (-1.0 + alpha * laplace_jax(S))
else:
    step_S_jax = None

# Initialize S with some spatial structure
x = np.linspace(0, 1, nx)
X, Y = np.meshgrid(x, x)
S0 = 0.5 * np.sin(2 * np.pi * X) * np.sin(2 * np.pi * Y)

if JAX_AVAILABLE and step_S_jax is not None:
    S = jnp.array(S0)
else:
    S = S0.copy()

In [ ]:
curvature_history = []

for t in range(steps):
    if JAX_AVAILABLE and step_S_jax is not None:
        S = step_S_jax(S, alpha, dt)
        curv = jnp.sqrt(jnp.mean(laplace_jax(S)**2))
        curv_val = float(curv)
    else:
        S = step_S_np(S, alpha, dt)
        curv = np.sqrt(np.mean(laplace_np(S)**2))
        curv_val = float(curv)

    curvature_history.append(curv_val)

    if t % plot_interval == 0:
        plt.figure()
        plt.imshow(np.array(S))
        plt.title(f"S(x,y) at step {t}")
        plt.colorbar(label='S')
        plt.show()

plt.figure()
plt.plot(curvature_history)
plt.xlabel('Time step')
plt.ylabel('Epistemic curvature ||ΔS||')
plt.title('Curvature decay: exploration → dark-room collapse')
plt.grid(True)
plt.show()

Interpretation:

* Early steps: curvature is high → exploration, **simulated danger**.
* Over time: curvature decays and flattens → **inoculation against surprise**.
* Long-time limit: nearly uniform curvature, near-constant $S$ → **dark-room equilibrium**.

---
## 2. Phase Diagram over $\alpha$ (Batch Runs)

We sweep several values of $\alpha$ and compare how fast curvature collapses.

Idea:

* Small $\alpha$ → slow diffusion, curvature decays slowly (longer exploration).
* Large $\alpha$ → fast diffusion, quick collapse (fast inoculation).

In [ ]:
alphas = [0.05, 0.1, 0.2, 0.4]
steps_phase = 600

phase_curv = {}

for a in alphas:
    if JAX_AVAILABLE and step_S_jax is not None:
        S_alpha = jnp.array(S0)
    else:
        S_alpha = S0.copy()

    curv_hist = []
    for t in range(steps_phase):
        if JAX_AVAILABLE and step_S_jax is not None:
            S_alpha = step_S_jax(S_alpha, a, dt)
            curv = jnp.sqrt(jnp.mean(laplace_jax(S_alpha)**2))
        else:
            S_alpha = step_S_np(S_alpha, a, dt)
            curv = np.sqrt(np.mean(laplace_np(S_alpha)**2))
        curv_hist.append(float(curv))
    phase_curv[a] = curv_hist

plt.figure(figsize=(6,4))
for a in alphas:
    plt.plot(phase_curv[a], label=f"alpha={a}")
plt.xlabel('Time step')
plt.ylabel('||ΔS||')
plt.title('Phase diagram over diffusion strength alpha')
plt.legend()
plt.grid(True)
plt.show()

You can interpret persistent curvature (high, slowly decaying curves) as a regime with longer-lasting **exploration**, and rapid decay as a regime with fast **dark-room convergence**.

---
## 3. PDE vs Cellular Automaton (CA) Comparison

We build a CA that approximates diffusion with a local rule and compare to the PDE evolution.

* PDE update: $S^{t+1} = S^t + dt(-1 + \alpha \Delta S^t)$
* CA update: $s_i^{t+1} = s_i^t + c \sum_{j \in N(i)} (s_j^t - s_i^t) - c_0$ with discrete constants.

In [ ]:
# CA discretization on a coarse grid
nx_ca = 60
dx_ca = 1.0
dt_ca = 0.1
alpha_ca = 0.2

x_ca = np.linspace(0,1,nx_ca)
X_ca, Y_ca = np.meshgrid(x_ca, x_ca)
S_ca = 0.5 * np.sin(2 * np.pi * X_ca) * np.sin(2 * np.pi * Y_ca)

def laplace_ca(S):
    from scipy.signal import convolve2d
    return convolve2d(S, lap_kernel_np, mode='same', boundary='symm')

def step_ca(S, alpha, dt):
    return S + dt * (-1.0 + alpha * laplace_ca(S))

curv_ca = []
steps_ca = 600
for t in range(steps_ca):
    S_ca = step_ca(S_ca, alpha_ca, dt_ca)
    curv = np.sqrt(np.mean(laplace_ca(S_ca)**2))
    curv_ca.append(curv)
    if t % 200 == 0:
        plt.figure()
        plt.imshow(S_ca)
        plt.title(f"CA-like S at step {t}")
        plt.colorbar()
        plt.show()

plt.figure()
plt.plot(curv_ca)
plt.xlabel('Time step')
plt.ylabel('||ΔS|| (CA analogue)')
plt.title('CA-style curvature decay')
plt.grid(True)
plt.show()

This CA-like scheme is essentially the same local rule as the PDE discretization, just framed in CA language. It exhibits the same structure:

* initial heterogeneity → local exploration;
* diffusion of structure → flattening;
* long-time dark-room collapse.


---
## 4. 3D Domain Variant

We extend to $S(x,y,z,t)$ with a 3D Laplacian. This is more computationally expensive, but the structure is the same.

We use NumPy for clarity (you can replace with JAX easily).

In [ ]:
nx3 = 40
dx3 = 1.0
dt3 = 0.02
alpha3 = 0.2
steps3 = 400

x3 = np.linspace(0,1,nx3)
X3, Y3, Z3 = np.meshgrid(x3, x3, x3, indexing='ij')
S3 = 0.5 * np.sin(2*np.pi*X3) * np.sin(2*np.pi*Y3) * np.sin(2*np.pi*Z3)

def laplace_3d(S):
    # 3D 6-point Laplacian with Neumann-like boundaries via np.pad
    Sp = np.pad(S, 1, mode='edge')
    lap = (
        Sp[2:,1:-1,1:-1] + Sp[:-2,1:-1,1:-1] +
        Sp[1:-1,2:,1:-1] + Sp[1:-1,:-2,1:-1] +
        Sp[1:-1,1:-1,2:] + Sp[1:-1,1:-1,:-2]
        - 6*Sp[1:-1,1:-1,1:-1]
    ) / dx3**2
    return lap

curv3 = []

for t in range(steps3):
    S3 = S3 + dt3 * (-1.0 + alpha3 * laplace_3d(S3))
    lap3 = laplace_3d(S3)
    curv3.append(np.sqrt(np.mean(lap3**2)))
    if t % 100 == 0:
        mid_slice = S3[:,:,nx3//2]
        plt.figure()
        plt.imshow(mid_slice)
        plt.title(f"3D S slice z=mid at step {t}")
        plt.colorbar()
        plt.show()

plt.figure()
plt.plot(curv3)
plt.xlabel('Time step')
plt.ylabel('3D ||ΔS||')
plt.title('3D curvature decay')
plt.grid(True)
plt.show()

Even in 3D, the qualitative story is unchanged: transient curvature, then flattening, then dark-room equilibrium.

---
## 5. Finite-Element Sketch (FEniCSx / dolfinx)

Below is **illustrative code** for running the same PDE using FEniCSx.
It will only run if you have a FEniCSx environment (e.g. Docker image).

We do not execute it here, but it shows how to cast:

$$\partial_t S = -1 + \alpha \Delta S$$

into a finite-element variational form.

In [ ]:
fenics_available = False
try:
    import dolfinx
    from dolfinx import fem, mesh
    import ufl
    from mpi4py import MPI
    fenics_available = True
    print("FEniCSx detected.")
except ImportError:
    print("FEniCSx (dolfinx) not available in this environment.")

if fenics_available:
    # Simple 2D unit square mesh
    domain = mesh.create_unit_square(MPI.COMM_WORLD, 40, 40)
    V = fem.FunctionSpace(domain, ("Lagrange", 1))

    S = fem.Function(V)
    S_old = fem.Function(V)
    S_old.interpolate(lambda x: 0.5*np.sin(2*np.pi*x[0])*np.sin(2*np.pi*x[1]))
    S.x.array[:] = S_old.x.array

    alpha_fe = 0.2
    dt_fe = 0.05

    v = ufl.TestFunction(V)
    S_trial = ufl.TrialFunction(V)

    # Variational form for implicit Euler:
    # (S^{n+1} - S^n)/dt = -1 + alpha Δ S^{n+1}
    F = (S_trial - S_old)*v*ufl.dx + dt_fe*(1.0*v*ufl.dx - alpha_fe*ufl.dot(ufl.grad(S_trial), ufl.grad(v))*ufl.dx)
    a = ufl.lhs(F)
    L = ufl.rhs(F)

    problem = fem.petsc.LinearProblem(a, L, u=S)

    nsteps_fe = 50
    for n in range(nsteps_fe):
        S_old.x.array[:] = S.x.array
        S = problem.solve()
        # You can export S via XDMF for visualization in Paraview.
        
    print("FEniCSx simulation finished (no visualization in this notebook).")