In [None]:
import apebench
import exponax as ex

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import CenteredNorm
import jax

x = np.linspace(0,1,160).reshape(1,-1)

sine_ic_gen = ex.ic.RandomSineWaves1d(num_spatial_dims=1, domain_extent=1,
                                      cutoff=5, amplitude_range=(-1,1),
                                      max_one=True).gen_ic_fun(key=jax.random.PRNGKey(2))
IC = sine_ic_gen(x)
plt.plot(IC.T)
plt.show()

def make_ic(ic_key, cutoff=5):
    sine_ic_gen = ex.ic.RandomSineWaves1d(num_spatial_dims=1, domain_extent=1,
                                      cutoff=cutoff, amplitude_range=(-1,1),
                                      max_one=True).gen_ic_fun(key=jax.random.PRNGKey(ic_key))
    return sine_ic_gen(x)

def make_trajectory(ic, g_2=-2, g_4=-18, d_1=-1, n=100):
    ks = apebench.scenarios.difficulty.KuramotoSivashinskyConservative(
        diffusion_gamma=g_2,
        hyp_diffusion_gamma=g_4,
        convection_delta=d_1
    )
    ks_stepper = ks.get_ref_stepper()
    traj = ex.rollout(ks_stepper, n=n)(ic)
    return traj

def plot_traj_ic(traj, ic, suptitle="", maxv=3):
    fig, (ax1, ax2) = plt.subplots(1, 2, width_ratios=[3,2], figsize=(13,5))
    fig.suptitle(suptitle)
    im = ax1.imshow(traj.squeeze(1).T, cmap='coolwarm', norm=CenteredNorm(halfrange=maxv), aspect='auto')
    fig.colorbar(im, ax=ax1)
    ax2.plot(ic[0], label='IC')
    ax2.plot(traj[-1].T, label='T')
    ax2.set_ylim(-maxv, maxv)
    ax2.legend(loc='upper right', fontsize=6)
    fig.tight_layout()
    return fig

N = 20
stable_table = np.full((N,N), False)
std_table = np.full((N,N), 0.0)
interval_mins = [-15, -18*49]
g2_s = np.linspace(interval_mins[0], 0, N, False)
g4_s = -np.logspace(np.log2(-interval_mins[1]), 0, N, False,base=2)
IC = make_ic(40)
for i, g_2 in enumerate(g2_s):
    for j, g_4 in enumerate(g4_s):
        print(i,j,end="|" if j!=N-1 else '\n')
        traj = make_trajectory(IC, g_2, g_4)
        stable_table[i,j] = jax.numpy.isfinite(traj).all()
        if stable_table[i,j]:
            std_table[i,j] = jax.numpy.std(traj, axis=0).mean()

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(11,5), sharex=True, sharey=True)
ax1.imshow(stable_table, cmap='Greys')
im = ax2.imshow(std_table, cmap='OrRd')
fig.colorbar(im, ax=ax2,shrink=0.8,label='avg_std')

ax1.set_yticks(np.arange(0,N+1)[::N//5]-0.5, [*g2_s[::N//5]]+[g2_s[-1]], rotation=45)
ax1.set_xticks(np.arange(0,N+1)[::N//5]-0.5, map(lambda x: x.round(2),[*g4_s[::N//5]]+[g4_s[-1]]), rotation=45)
ax1.set_ylabel('g_2')
ax1.set_xlabel('g_4')

plt.show()

def animate(g_2=-2, g_4=-18, d_1=-1, ic_key=1, maxv=3):
    IC = make_ic(ic_key)
    traj = make_trajectory(IC, g_2, g_4, d_1, n=100)
    title = f"g_2={g_2}, g_4={g_4}, d_1={d_1}"
    fig = plot_traj_ic(traj, IC, suptitle=title, maxv=maxv)
    return fig

fig = animate(g_2=-12, g_4=-320, d_1=-1, ic_key=1,maxv=6)
display(fig)

# prompt: i want animated slider of showing image with three sliders, i already defined function to return figure with 3 arguments

import ipywidgets as widgets
from IPython.display import display, clear_output

# Assuming animate function is defined as in the previous response.

g2_slider = widgets.FloatSlider(
    value=-2,
    min=interval_mins[0],
    max=0,
    step=0.5,
    description='g_2:'
)

g4_slider = widgets.FloatSlider(
    value=-18,
    min=interval_mins[1],
    max=0,
    step=0.5,
    description='g_4:'
)

d1_slider = widgets.FloatSlider(
    value=-1,
    min=-10,
    max=0,
    step=0.1,
    description='d_1:'
)

ic_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=100,
    step=1,
    description='ic_key:'
)

output = widgets.Output()

def update_plot(change):
    with output:
        clear_output(wait=True)  # Clear the previous plot
        fig = animate(g_2=g2_slider.value, g_4=g4_slider.value, d_1=d1_slider.value, ic_key=ic_slider.value)
        display(fig)


g2_slider.observe(update_plot, 'value')
g4_slider.observe(update_plot, 'value')
d1_slider.observe(update_plot, 'value')
ic_slider.observe(update_plot, 'value')

display(g2_slider, g4_slider, d1_slider, ic_slider, output)

# Initial plot
with output:
    fig = animate(g_2=g2_slider.value, g_4=g4_slider.value, d_1=d1_slider.value, ic_key=ic_slider.value)
    display(fig)


