In [1]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from scipy.stats import t
from matplotlib.ticker import MaxNLocator
import re
import numpy as np
# import pandas as pd
# import seaborn as sns
import glob
from time import process_time

from src.BTCS_Stepper import BTCS_Stepper, RandomTruncatedFourierSeries, rollout, dataloader
from src.prdp import should_refine

# plt.rcParams['figure.dpi'] = 200
plt.rcParams["font.family"] = "serif"

# add magic comments for autoreload
%load_ext autoreload
%autoreload 2

In [None]:
jax.devices()

In [3]:
NUM_INITIAL_CONDITIONS = 200    # training set size
N_TEST_SAMPLES = 5              # validation set size

BATCH_SIZE = 25
ic_generator = RandomTruncatedFourierSeries(domain_extent=1.0, num_modes=5)

SOLVER_NAME = "jacobi"
U_INIT_STRING = "zeroinit" # "randinit" or "zeroinit" or "onesinit"

if SOLVER_NAME == "jacobi":
    N_DOF = 30
    btcs_stepper = BTCS_Stepper(num_points=N_DOF)
    LINSOLVER = btcs_stepper.jacobi_dynamic
    N_EPOCHS = 100
elif SOLVER_NAME == "SD":
    N_DOF = 50
    btcs_stepper = BTCS_Stepper(num_points=N_DOF)
    LINSOLVER = btcs_stepper.sd_dynamic
    N_EPOCHS = 200

if U_INIT_STRING == "zeroinit":
    U_INIT = jnp.zeros(N_DOF)
elif U_INIT_STRING == "randinit":
    U_INIT = jax.random.normal(jax.random.PRNGKey(0), shape=(N_DOF,))
elif U_INIT_STRING == "onesinit":
    U_INIT = jnp.ones(N_DOF)

grid = jnp.linspace(0, 1, N_DOF+2)[1:-1]

@eqx.filter_jit
def val_loss(m, val_data):
    """Compute the loss on the test set.

    Args:
        m: the model to evaluate
        val_data: the test data, with shape (n_samples, n_steps, n_dof)
    """
    print("compiling val_loss()")
    val_ic_set = val_data[:,0]
    pred_trajectories = jax.vmap(rollout(m, 2, include_init=True))(val_ic_set)
    pred_1_errors = jnp.linalg.norm(pred_trajectories[:, 1] - val_data[:, 1], axis=1) # normed over n_dof
    pred_2_errors = jnp.linalg.norm(pred_trajectories[:, 2] - val_data[:, 2], axis=1) # normed over n_dof
    
    data_1_norms  = jnp.linalg.norm(val_data[:, 0], axis=1) # norm over n_dof
    data_2_norms = jnp.linalg.norm(val_data[:, 1], axis=1) # norm over n_dof
    
    pred_1_mse_normalized = jnp.mean((pred_1_errors**2 / data_1_norms**2), axis=0) # mean squared for over all samples
    pred_2_mse_normalized = jnp.mean((pred_2_errors**2 / data_2_norms**2), axis=0) # mean squared for over all samples
    
    return jnp.hstack((pred_1_mse_normalized, pred_2_mse_normalized))

### Generate Train and Validation data

In [None]:
# Training Data: X's
key = jax.random.PRNGKey(1337)
ic_keys = jax.random.split(key, NUM_INITIAL_CONDITIONS)
ic_funs = jax.vmap(ic_generator)(ic_keys) # list of functions that generate initial conditions on given grid
ic_set = jax.vmap(lambda f: f(grid))(ic_funs) # vmap the list of functions to generate many initial conditions

# Training Data: Y's
train_set = jax.vmap(rollout(btcs_stepper, 2, include_init=True))(ic_set)

# Validation Data: X's
key = jax.random.PRNGKey(1338)
test_ic_keys = jax.random.split(key, N_TEST_SAMPLES)
test_ic_funs = jax.vmap(ic_generator)(test_ic_keys)
test_ic_set = jax.vmap(lambda f: f(grid))(test_ic_funs)

# Validation Data: Y's
val_data_trjs = jax.vmap(rollout(btcs_stepper, 100, include_init=True))(test_ic_set)

### Linsolve residuum 

In [None]:
def relative_residuum_hist(state, n_inner):
    res_2 = btcs_stepper.residuum_history(state, SOLVER_NAME, n_inner) # (n_iter+1, n_dof)
    rel_residuum_hist = jnp.linalg.norm(res_2, axis=1) / jnp.linalg.norm(state)
    return rel_residuum_hist # (n_iter+1,)

# do for all initial conditions in the training set and plot the residual history mean and std
# res_hist_all = jax.vmap(relative_residuum_hist, in_axes=(0, None))(ic_set, 50) # (n_samples, n_iter+1)
pred_1_set = train_set[:, 1]
res_hist_all = jax.vmap(relative_residuum_hist, in_axes=(0, None))(pred_1_set, 50) # (n_samples, n_iter+1)
res_hist_mean = jnp.mean(res_hist_all, axis=0)
res_hist_std = jnp.std(res_hist_all, axis=0)

fig, ax = plt.subplots(figsize=(3,3))
ax.plot(res_hist_mean, label="mean")
ax.fill_between(range(res_hist_mean.shape[0]), res_hist_mean - res_hist_std, res_hist_mean + res_hist_std, alpha=0.2, label="std")
ax.set_yscale("log")
ax.set_xlabel("# iterations")
ax.set_ylabel("Primal relative residual")
ax.set_title(f"Heat 1D, Residuums, solver={SOLVER_NAME}\n Average of {NUM_INITIAL_CONDITIONS} initial conditions")

ax.grid(which='major', axis='y')
ax.minorticks_on()
ax.grid(which='both', axis='x', linestyle='--', linewidth=0.5)
ax.grid(which='major', axis='both', linestyle='-', linewidth=1.0)

# fig.savefig(f"figures/heat_1d__primal_residuum__{SOLVER_NAME}.pdf", bbox_inches="tight")

### Define custom vjp

In [6]:
@eqx.filter_custom_vjp
def linsolver(state, n_iter, u_init):
    return LINSOLVER(state, n_iter, u_init)

@linsolver.def_fwd
def linsolver_fwd(perturbed, state, n_iter, u_init):
    next_state = linsolver(state, n_iter, u_init)
    res = None
    return next_state, res

@linsolver.def_bwd
def linsolver_bwd(res, g, perturbed, state, n_iter, u_init):
    # A.T v = g => v = A.inv g
    print("using custom vjp")
    v = LINSOLVER(g, n_iter, u_init) # g is gradient of loss wrt u2 (n_dof,)
    return v #, None

In [8]:
SOLVER_JACOBIANS = "implicit" # "unrolled" or "implicit"
if SOLVER_JACOBIANS == "unrolled":
    inner_solve = LINSOLVER
elif SOLVER_JACOBIANS == "implicit":
    inner_solve = linsolver
else:
    raise ValueError(f"Invalid value for SOLVER_JACOBIANS: {SOLVER_JACOBIANS}")

# Outer loop

In [9]:
# Optimizer
optimizer = optax.adam(optax.exponential_decay(1e-3, 100, 0.94))

# Loss Function
@eqx.filter_jit
def loss_fn(model, data, inner_iterations):
    print("compiling update_fn")
    ic = data[:, 0]
    target = data[:, 2]
    prediction_1 = jax.vmap(model)(ic) # batched forward pass # (batch_size, n_dof)
    prediction_2 = jax.vmap(inner_solve,  in_axes=(0, None, None))(prediction_1, inner_iterations, U_INIT)
    return jnp.mean((prediction_2 - target)**2) # MSE over batches as well as space

@eqx.filter_jit
def update_fn(model, state, data, inner_iterations):
    print("compiling update_fn")
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, data, inner_iterations)
    updates, new_state = optimizer.update(grad, state, model)
    new_model= eqx.apply_updates(model, updates)
    return new_model, new_state, loss   


## Loop over seeds_list, n_inner_list

In [None]:
SAVE_RESULTS = False
# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]
SEED_LIST = [1]
N_INNER_LIST = [1,2,3,4,5,6,10,15,20,25]

N_EPOCHS = 100
print(f"SOLVER NAME: {SOLVER_NAME} , SOLVER JACOBIAN = {SOLVER_JACOBIANS}, U_INIT = {U_INIT[:2]}...{U_INIT[-2:]}, N_DOF = {N_DOF}")

for seed_count, seed in enumerate(SEED_LIST):
    
    print(f"Training with seed {seed} ({seed_count+1} of {len(SEED_LIST)})")
    key = jax.random.PRNGKey(seed)
    key, model_init_key = jax.random.split(key)
    
    # init metrics
    losses_all_n = []
    errors_all_n = []
    time_all_n = []
    
    # Loop over n_inner
    for n_inner in N_INNER_LIST:
        print(f"\nTraining with {n_inner} inner iterations\n")
        
        # initialize model
        model_MLP = eqx.nn.MLP(
            in_size=N_DOF, out_size=N_DOF, 
            width_size=64, depth=3, 
            activation=jax.nn.relu, 
            key=model_init_key)
        
        # initialize optimizer
        opt_state = optimizer.init(eqx.filter(model_MLP, eqx.is_array))

        # init metrics
        loss_history = [loss_fn(model_MLP, train_set, n_inner)]
        error_history = [val_loss(model_MLP, val_data_trjs)]

        # Training Loop
        key, shuffle_key = jax.random.split(key)
        
        for epoch in range(N_EPOCHS):
            shuffle_key, subkey = jax.random.split(shuffle_key)
            loss_mini_batch = []
            for batch in dataloader(train_set, key=subkey, batch_size=BATCH_SIZE):
                model_MLP, opt_state, loss = update_fn(model_MLP, opt_state, batch, n_inner)
                loss_mini_batch.append(loss)
            
            loss_history.append(np.mean(loss_mini_batch))
            error_history.append(val_loss(model_MLP, val_data_trjs))
            
            print(f"Epoch {epoch+1}/{N_EPOCHS}, loss: {loss_history[-1]}, rel error: {error_history[-1]}")
        
        losses_all_n.append(loss_history)
        errors_all_n.append(np.array(error_history))
    
    losses_all_n = np.array(losses_all_n)
    errors_all_n = np.array(errors_all_n) # shape (len(N_INNER_LIST), N_EPOCHS, 2)

    # save results
    if SAVE_RESULTS:
        df = pd.DataFrame({
            "max_iter": N_INNER_LIST,
            "losses": list(losses_all_n),
            "1-step errors": list(errors_all_n[:,:,0]),
            "2-step errors": list(errors_all_n[:,:,1]),
            "time": time_all_n,
            "seed": seed,
        })
        file_name = f"results/heat_1d_sep29_{SOLVER_NAME}_{SOLVER_JACOBIANS}_time/maxiter_constant__seed_{seed}.pkl"
        df.to_pickle(file_name)

## Loop over seeds_list, use PRDP

In [None]:
N_MIN, N_STEP = 1,1

# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]
SEED_LIST = [1]
N_EPOCHS = 100
SAVE_RESULTS = False
print(f"PRDP: SOLVER_NAME: {SOLVER_NAME} , SOLVER_JACOBIAN = {SOLVER_JACOBIANS}, U_INIT = {U_INIT[:2]}...{U_INIT[-2:]}")

for seed_count, seed in enumerate(SEED_LIST):
    print(f"Training with seed {seed} ({seed_count+1}/{len(SEED_LIST)})")
    key = jax.random.PRNGKey(seed)
    
    # init model to be trained
    key, model_init_key = jax.random.split(key)
    model_mlp_prdp = eqx.nn.MLP(
        in_size=N_DOF, out_size=N_DOF, 
        width_size=64, depth=3, 
        activation=jax.nn.relu, 
        key=model_init_key
    )
    
    # initialize optimizer
    opt_state = optimizer.init(eqx.filter(model_mlp_prdp, eqx.is_array))

    # initialize metrics
    btcs_stepper = BTCS_Stepper(num_points=N_DOF)
    n_inner_tracker = N_MIN
    loss_hist_prdp = [loss_fn(model_mlp_prdp, train_set, n_inner_tracker)]
    error_hist_prdp = [val_loss(model_mlp_prdp, val_data_trjs)]
    n_inner_hist_prdp = [np.nan] # no value at zeroth epoch, but need same list length as loss_hist

    # initialize PRDP's Nmax checkpoint error
    should_refine.error_checkpoint = 100

    # Training Loop
    key, shuffle_key = jax.random.split(key)
    for epoch in range(N_EPOCHS):
        shuffle_key, subkey = jax.random.split(shuffle_key)
        loss_mini_batch = []
        for batch in dataloader(train_set, key=subkey, batch_size=BATCH_SIZE):
            model_mlp_prdp, opt_state, loss = update_fn(model_mlp_prdp, opt_state, batch, 
                                                        n_inner_tracker)
            loss_mini_batch.append(loss)
        
        loss_hist_prdp.append(np.mean(loss_mini_batch))
        error_hist_prdp.append(val_loss(model_mlp_prdp, val_data_trjs))
        n_inner_hist_prdp.append(n_inner_tracker)
        
        print(f"Epoch {epoch+1}/{N_EPOCHS}, n_inner: {n_inner_tracker}, loss: {loss_hist_prdp[-1]}, error: {error_hist_prdp[-1]}")

        # PRDP
        if should_refine(np.array(error_hist_prdp)[:, 1],  # [:,1] is the two-step error history
                         0.98, 0.9, 8): 
            n_inner_tracker += N_STEP
    
    # SAVE
    loss_hist_prdp = np.array(loss_hist_prdp)
    error_hist_prdp = np.array(error_hist_prdp)

    if SAVE_RESULTS:
        df = pd.DataFrame({
            "losses": [loss_hist_prdp],
            "1-step errors": [error_hist_prdp[:,0]],
            "2-step errors": [error_hist_prdp[:,1]],
            "n_inner": [n_inner_hist_prdp],
            "max_iter": "PRDP",
            "auto_using": "two-step-error",
            "seed": seed,
        })
        file_name = f"results/heat_1d_sep29_{SOLVER_NAME}_{SOLVER_JACOBIANS}_time/maxiter_auto__seed_{seed}.pkl"
        df.to_pickle(file_name)

In [None]:
# Plot error of the last seed
plt.plot(error_hist_prdp[:,1])
plt.yscale('log')
plt.grid()