In [None]:
import jax.numpy as jnp

import matplotlib.pyplot as plt
import tornadox
from matplotlib import animation

import pnmol

from IPython.display import HTML

In [None]:
%load_ext autoreload
%autoreload 2

## Discretize some 1D PDE

In [None]:
# discretized_pde = pnmol.pde_problems.heat_1d()
discretized_pde = pnmol.pde_problems.burgers_1d()

In [None]:

constant_steps = tornadox.step.ConstantSteps(0.01)
adaptive_steps = tornadox.step.AdaptiveSteps(abstol=1e-3, reltol=1e-3)
nu = 3

In [None]:
ek1 = tornadox.ek1.ReferenceEK1(num_derivatives=nu, steprule=adaptive_steps)
sol = ek1.solve(ivp=discretized_pde)

In [None]:
print(sol.mean[3].shape)
print(sol.cov[3].shape)
E0 = ek1.iwp.projection_matrix(0)
print(E0.shape)

In [None]:
plt.rcParams["animation.embed_limit"] = 2 * 10**8  # Set the animation max size to 200MB

grid = discretized_pde.spatial_grid

fig = plt.figure(figsize=(20, 8))
ax = fig.add_subplot(1,1,1)
_im1 = ax.plot(grid.points.squeeze(), E0 @ sol.mean[0])
_im1 = ax.plot(grid.points.squeeze(), discretized_pde.y0.squeeze())

ax1ylim = [-0.2, 1.2]
ax.set_ylim(ax1ylim)

plt.close()


def animate(i):
    
    mean = E0 @ sol.mean[i]
    std = E0 @ jnp.sqrt(jnp.diag(sol.cov[i]))
    
    ax.cla()
    ax.set_title(f"t={sol.t[i]}")
    ax.plot(grid.points.squeeze(), mean, color="C0", label="PN solution")
    ax.fill_between(
        grid.points.squeeze(), 
        mean - 2 * std, 
        mean + 2 * std,
        color="C0",
        alpha=0.2,
    )
    ax.set_ylim(ax1ylim)
    ax.legend()
    
# Animation setup
anim = animation.FuncAnimation(
    fig, func=animate, frames=len(sol.t), interval=100, repeat_delay=4000, blit=False
)
HTML(anim.to_jshtml())