In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.animation import FuncAnimation
# import pandas as pd
import equinox as eqx
import optax

import src.implicax.implicax as implicax
from src.implicax.implicax.utilities import rollout
import pdequinox as pdeqx
from src.BTCS_Stepper import dataloader
from src.prdp import should_refine, numpy_ewma
import time

In [None]:
jax.devices()

In [4]:
N_REF = 97

# Generate data

### Generate random initial conditions - smoothened, normalized, divergence free

In [None]:
NUM_SAMPLES = 205
ic_x_set = []
ic_y_set = []

for i in range(NUM_SAMPLES):
    key_x, key_y = jax.random.split(jax.random.PRNGKey(i))
    u_0_x = jax.random.normal(key_x, (1, N_REF, N_REF))
    u_0_y = jax.random.normal(key_y, (1, N_REF, N_REF))
    ic_x_set.append(u_0_x)
    ic_y_set.append(u_0_y)

ic_x_set = jnp.stack(ic_x_set)
ic_y_set = jnp.stack(ic_y_set)

# Smoothen the initial conditions by pushing them through a heat solver (= low-pass filter)
heat_stepper = implicax.Heat2d(1.0, N_REF, 1.0, nu=3e-3)
ic_x_set = jax.vmap(heat_stepper)(ic_x_set)
ic_y_set = jax.vmap(heat_stepper)(ic_y_set)

# Remove the mean in each sample
ic_x_set -= jnp.mean(ic_x_set, axis=(-1,-2), keepdims=True)
ic_y_set -= jnp.mean(ic_y_set, axis=(-1,-2), keepdims=True)
# Make sure the magnitude is around 1
ic_x_set /= jnp.std(ic_x_set, axis=(-1,-2), keepdims=True)
ic_y_set /= jnp.std(ic_y_set, axis=(-1,-2), keepdims=True)

# The state for the NS simulator has three channels (velocity-x, velocity-y,
# pressure). We will initialize pressure to zero.
ic_pressure_set = jnp.zeros_like(ic_x_set)
ic_set = jnp.concatenate((ic_x_set, ic_y_set, ic_pressure_set), axis=-3)


# Instantiate the NS simulator with Re=1000 and use it to make the ic divergence-free
ns_simulator = implicax.NavierStokes(1.0, N_REF, 0.1, nu=1e-4, maxiter_picard=1)
ic_set = jax.vmap(ns_simulator.make_incompressible)(ic_set)

assert ic_set.shape == (NUM_SAMPLES, 3, N_REF, N_REF)

### Generate trajectories

In [9]:
NUM_TIMESTEPS = 5
trj_set = jax.vmap(rollout(ns_simulator, NUM_TIMESTEPS, include_init=True))(ic_set)
assert trj_set.shape == (NUM_SAMPLES, NUM_TIMESTEPS+1, 3, N_REF, N_REF) # sanity check


In [None]:
# downsample for the source 
trj_set_source = trj_set[:,:,:,1::2,1::2]
assert trj_set_source.shape == (NUM_SAMPLES, NUM_TIMESTEPS+1, 3, (N_REF-1)/2, (N_REF-1)/2)
NDOF_SOURCE = trj_set_source.shape[3]

### Train:Validation split

In [11]:
train_set = trj_set_source[:200]
val_set = trj_set_source[200:]

# Learning NS

### Loss and test_loss for correction learning

$$
u_0 -\mathcal{P}_s-> u_{1,s} -\mathbb{C}_{\theta}-> u_{1,c} -\mathbb{S}-> u_{2,s} -\mathbb{C}_{\theta}-> u_{2,c}
$$

$\mathcal{P}_s$ is coarser than $\mathcal{P}_r$ as it uses half the number of points => corrector $\mathbb{C}_{\theta}$ learns to resolve fine physics.


In [None]:
ns_sim_halfspace = implicax.NavierStokes(1.0, NDOF_SOURCE, 0.1, nu=1e-4, maxiter_picard=1, restart=8, maxiter_linsolve=120)

In [14]:
# Loss fn for correction learning
@eqx.filter_jit
def loss_fn(model, data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):
    """
    Correction learning loss function.
    """
    print("compiling loss_fn")
    ic       = data[:,0]
    target_1 = data[:,1]
    target_2 = data[:,2]

    pred_1_coarse = jax.vmap(coarse_sim_1)(ic)
    correction_1 = jax.vmap(model)(pred_1_coarse)       # here the model learns the correction which then has to be manually added to the solution
    pred_1_corrected = pred_1_coarse + correction_1

    pred_2_coarse = jax.vmap(coarse_sim_2)(pred_1_corrected)
    correction_2 = jax.vmap(model)(pred_2_coarse)
    pred_2_corrected = pred_2_coarse + correction_2

    return jnp.mean((target_2 - pred_2_corrected)**2) + jnp.mean((target_1 - pred_1_corrected)**2)

In [15]:
@eqx.filter_jit
def val_loss(model, test_data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):
    """Compute the loss on the test set.
    
    Args:
        model: the model to evaluate
        test_data: the test data, with shape (n_samples, n_steps, n_channels, N, N)
    """
    print("compiling val_loss function")
    ic = test_data[:, 0]
    target_1 = test_data[:, 1]
    target_2 = test_data[:, 2]
    target_5 = test_data[:, 5]

    pred_1_coarse    = jax.vmap(coarse_sim_1)(ic)
    pred_1_corrected = jax.vmap(model)(pred_1_coarse) + pred_1_coarse
    pred_2_coarse    = jax.vmap(coarse_sim_2)(pred_1_corrected)
    pred_2_corrected = jax.vmap(model)(pred_2_coarse) + pred_2_coarse
    pred_3_coarse    = jax.vmap(coarse_sim_1)(pred_2_corrected)
    pred_3_corrected = jax.vmap(model)(pred_3_coarse) + pred_3_coarse
    pred_4_coarse    = jax.vmap(coarse_sim_2)(pred_3_corrected)
    pred_4_corrected = jax.vmap(model)(pred_4_coarse) + pred_4_coarse
    pred_5_coarse    = jax.vmap(coarse_sim_1)(pred_4_corrected)
    pred_5_corrected = jax.vmap(model)(pred_5_coarse) + pred_5_coarse

    pred_1_errors = jnp.linalg.norm(pred_1_corrected - target_1, axis=(-2, -1)) # normed over space
    pred_2_errors = jnp.linalg.norm(pred_2_corrected - target_2, axis=(-2, -1)) # normed over space
    pred_5_errors = jnp.linalg.norm(pred_5_corrected - target_5, axis=(-2, -1)) # normed over space

    data_1_norms = jnp.linalg.norm(target_1, axis=(-2, -1))
    data_2_norms = jnp.linalg.norm(target_2, axis=(-2, -1))
    data_5_norms = jnp.linalg.norm(target_5, axis=(-2, -1))
    pred_1_mse_normed = jnp.mean((pred_1_errors**2 / data_1_norms**2), axis=0) # shape (3,): velocity-x, velocity-y, pressure
    pred_2_mse_normed = jnp.mean((pred_2_errors**2 / data_2_norms**2), axis=0)
    pred_5_mse_normed = jnp.mean((pred_5_errors**2 / data_5_norms**2), axis=0)

    return jnp.vstack((pred_1_mse_normed, pred_2_mse_normed, pred_5_mse_normed)) # shape (3, 3)

In [16]:
N_EPOCHS = 100
# optimizer = optax.adam(1e-3)
# optimizer = optax.adam(optax.exponential_decay(1e-3, 100, 0.85))
optimizer = optax.adam(optax.cosine_decay_schedule(1e-3, N_EPOCHS*8, 0.1))

@eqx.filter_jit
def update_fn(model, opt_state, data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):
    print("compiling update_fn")
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, data, coarse_sim_1, coarse_sim_2)
    updates, new_state = optimizer.update(grad, opt_state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss

### Learning (test)

In [None]:
# initialize model
model_resnet = pdeqx.arch.ClassicResNet(
    num_spatial_dims=2,
    in_channels=3,
    out_channels=3,
    key=jax.random.PRNGKey(92),
    hidden_channels=64,
    num_blocks=3
)

# initialize optimizer
opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))

# initialize metrics
loss_hist = [loss_fn(model_resnet, train_set)]
rel_error_hist = [val_loss(model_resnet, val_set)]

# training loop
BATCH_SIZE = 25
shuffle_key = jax.random.PRNGKey(42)

for epoch in range(N_EPOCHS):
    shuffle_key, subkey = jax.random.split(shuffle_key)
    loss_mini_batch = []
    for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):
        model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch)
        loss_mini_batch.append(loss)
    loss_hist.append(np.mean(np.array(loss_mini_batch)))
    rel_error_hist.append(val_loss(model_resnet, val_set))
    print(f"Epoch {epoch+1}/{N_EPOCHS}: Loss = {loss_hist[-1]}, rel. error 1-step = {rel_error_hist[-1][0]}")

In [None]:
len(rel_error_hist)

In [27]:
# Save model
# eqx.tree_serialise_leaves(f"results/saved_models/ns_spatial_correction_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.eqx", model_resnet)

In [None]:
rel_error_hist = np.array(rel_error_hist)
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
axs[0].set_title("Training loss (minibatch avg.)")
axs[0].plot(loss_hist)
axs[0].set_yscale("log"); axs[0].set_ylim(1e-3, 1)

axs[1].set_title("1-step nMSE on val. set")
axs[1].plot(rel_error_hist[:,0,0], label="velocity-x")
axs[1].plot(rel_error_hist[:,0,1], label="velocity-y")
axs[1].legend(); axs[1].set_ylim(1e-3, 1)

axs[2].set_title("2-step nMSE on val. set")
axs[2].plot(rel_error_hist[:,1,0], label="velocity-x")
axs[2].plot(rel_error_hist[:,1,1], label="velocity-y")
axs[2].legend(); axs[2].set_ylim(1e-3, 1)

axs[3].set_title("5-step nMSE on val. set")
axs[3].plot(rel_error_hist[:,2,0], label="velocity-x")
axs[3].plot(rel_error_hist[:,2,1], label="velocity-y")
axs[3].legend(); axs[3].set_ylim(1e-3, 1)

for ax in axs:
    ax.set_yscale("log")
    ax.grid(True)

## Different constant inner iterations

In [24]:
SEED_LIST = [1]
# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]

In [None]:
MAX_ITER_LIST    = [1,2,3,4,5,6,8,10,15,20,40,50,60, 80, 120]
RESTART = 8
SAVE_RESULTS = False
N_EPOCHS = 100

for seed_count, seed in enumerate(SEED_LIST):
    key = jax.random.PRNGKey(seed)
    key, model_init_key = jax.random.split(key)
    
    # initialize metrics
    losses_all_n = []
    errors_all_n = []
    time_all_n = []
    # res_all_n = []
    
    for maxiter_count, max_iter in enumerate(MAX_ITER_LIST):
        
        print(f"\nTRAINING WITH: seed={seed} ({seed_count+1} of {len(SEED_LIST)}), max_iter={max_iter} ({maxiter_count+1} of {len(MAX_ITER_LIST)}), restart={RESTART}\n")
        
        # initialize the incompletely converged solver
        ns_sim_incomplete_halfspace = implicax.NavierStokes(1.0, NDOF_SOURCE, 0.1, nu=1e-4, maxiter_picard=1,
                                                            maxiter_linsolve=max_iter, restart=RESTART)

        # initialize model
        model_resnet = pdeqx.arch.ClassicResNet(
            2, 3, 3,
            hidden_channels=64, 
            num_blocks=3,
            key=model_init_key,
        )

        # initialize optimizer
        opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))

        # initialize metrics
        loss_hist = [loss_fn(model_resnet, train_set, coarse_sim_2=ns_sim_incomplete_halfspace, coarse_sim_1=ns_sim_incomplete_halfspace)]
        rel_error_hist = [val_loss(model_resnet, val_set)]
        # res_hist = [residuum_fn(model_resnet, train_set, ns_sim_incomplete_halfspace, ns_sim_incomplete_halfspace)]

        # training loop
        BATCH_SIZE = 25
        key, shuffle_key = jax.random.split(key)
        for epoch in range(N_EPOCHS):
            shuffle_key, subkey = jax.random.split(shuffle_key)
            loss_mini_batch = []
            for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):
                model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch, 
                                                          coarse_sim_2=ns_sim_incomplete_halfspace, 
                                                          coarse_sim_1=ns_sim_incomplete_halfspace)
                loss_mini_batch.append(loss)
            
            loss_hist.append(np.mean(np.array(loss_mini_batch)))
            rel_error_hist.append(val_loss(model_resnet, val_set))
            print(f"Epoch {epoch+1}/{N_EPOCHS}: Loss = {loss_hist[-1]}, 1-step rel. error = {rel_error_hist[-1][0]}")

        losses_all_n.append(loss_hist)
        errors_all_n.append(np.array(rel_error_hist))


    # save results
    if SAVE_RESULTS:
        losses_all_n = np.array(losses_all_n)
        errors_all_n = np.array(errors_all_n) # shape (len(MAX_RESTART_LIST), N_EPOCHS, 3, 3)
        # # res_all_n = np.array(res_all_n)
        df = pd.DataFrame({
            # "max_iter": [max_iter] * len(MAX_RESTART_LIST),
            # "max_restart": MAX_RESTART_LIST,
            "max_restart": [RESTART] * len(MAX_ITER_LIST),
            "max_iter": MAX_ITER_LIST,
            "losses": list(losses_all_n),
            "1-step errors": list(errors_all_n[:,:,0]),
            "2-step errors": list(errors_all_n[:,:,1]),
            "5-step errors": list(errors_all_n[:,:,2]),
            "seed": seed
        })
        file_name = f"results/navier_stokes/maxiter_constant__seed_{seed}.pkl"
        df.to_pickle(file_name)
        

### Plot

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(16, 8))

axs[0,0].set_title("Training loss (minibatch avg.)")
for i, max_iter in enumerate(MAX_ITER_LIST):
    axs[0,0].plot(losses_all_n[i], label=f"{max_iter}")
    axs[0,1].plot(errors_all_n[i][:,0,0], label=f"{max_iter}")
    axs[0,2].plot(errors_all_n[i][:,1,0], label=f"{max_iter}")
    axs[1,1].plot(errors_all_n[i][:,0,1], label=f"{max_iter}")
    axs[1,2].plot(errors_all_n[i][:,1,1], label=f"{max_iter}")
axs[0,0].set_ylim(1e-5, 1)

for ax in axs.flatten():
    ax.set_yscale("log")
    ax.grid(True)

axs[0,1].set_title("x-velocity 1-step nMSE"); axs[0,1].set_ylim(1e-3, 1)
axs[0,2].set_title("x-velocity 2-step nMSE"); axs[0,2].set_ylim(1e-3, 1)
axs[1,1].set_title("y-velocity 1-step nMSE"); axs[1,1].set_ylim(1e-3, 1)
axs[1,2].set_title("y-velocity 2-step nMSE"); axs[1,2].set_ylim(1e-3, 1)

axs[0,-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', title="inner iters")

## PRDP

In [15]:
N_MIN, N_STEP = 1, 1

In [None]:
RESTART = 8  # GMRES
SAVE_RESULTS = False

# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]
SEED_LIST = [1]

for seed_count, seed in enumerate(SEED_LIST):
    print(f"Training with seed={seed} ({seed_count+1} of {len(SEED_LIST)})")
    key = jax.random.PRNGKey(seed)
    
    # init model to be trained
    key, model_init_key = jax.random.split(key)
    model_resnet = pdeqx.arch.ClassicResNet(
        2, 3, 3,
        hidden_channels=64, 
        num_blocks=3,
        key=model_init_key,
    )

    # init coarse-convergence solver
    n_inner = N_MIN
    ns_sim_incomplete_halfspace = implicax.NavierStokes(
        1.0, NDOF_SOURCE, 0.1, nu=1e-4,
        maxiter_picard=1,
        maxiter_linsolve=n_inner,
        restart=RESTART
    )

    # initialize optimizer
    opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))

    # initialize metrics
    loss_hist_prdp = [loss_fn(model_resnet, train_set, coarse_sim_2=ns_sim_incomplete_halfspace, coarse_sim_1=ns_sim_incomplete_halfspace)]
    error_hist_prdp = [val_loss(model_resnet, val_set)]
    n_inner_hist_prdp = [np.nan] # no meaning of n_inner at zeroth epoc

    # initialize PRDP's Nmax checkpoint error
    should_refine.error_checkpoint = 100

    # training loop
    BATCH_SIZE = 25
    key, shuffle_key = jax.random.split(key)
    start_time = time.process_time()
    
    for epoch in range(N_EPOCHS):
        shuffle_key, subkey = jax.random.split(shuffle_key)
        loss_mini_batch = []
        for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):
            model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch, 
                                                      coarse_sim_2=ns_sim_incomplete_halfspace, 
                                                      coarse_sim_1=ns_sim_incomplete_halfspace)
            loss_mini_batch.append(loss)
        
        loss_hist_prdp.append(jnp.mean(jnp.array(loss_mini_batch)))
        error_hist_prdp.append(val_loss(model_resnet, val_set))
        print(f"Epoch {epoch+1}/{N_EPOCHS}, n_inner = {n_inner}, Loss = {loss_hist_prdp[-1]}, 1-step rel. error 'x' = {error_hist_prdp[-1][0][0]}")
        n_inner_hist_prdp.append(n_inner)

        # PRDP
        if should_refine(
            np.array(error_hist_prdp)[:,2,0], # [all epochs][which-step][which-channel]
            stepping_threshold=0.95,
        ): 
            n_inner += N_STEP
            ns_sim_incomplete_halfspace = implicax.NavierStokes(
                1.0, NDOF_SOURCE, 0.1, nu=1e-4, 
                maxiter_picard=1,
                maxiter_linsolve=n_inner,
                restart=RESTART
            )

    # SAVE
    loss_hist_prdp = np.array(loss_hist_prdp)
    error_hist_prdp = np.array(error_hist_prdp)

    if SAVE_RESULTS:
        df = pd.DataFrame({
            "losses": [loss_hist_prdp],
            "1-step errors": [error_hist_prdp[:,0]],
            "2-step errors": [error_hist_prdp[:,1]],
            "5-step errors": [error_hist_prdp[:,2]],
            "n_inner": [n_inner_hist_prdp],
            "max_restart": RESTART,
            "seed": seed,
            "max_iter": "PRDP",
            "auto_using": "fivesteperror"
        })
        file_name = f"results/ns_spatial_sep27_time/maxiter_auto_using_fivesteperror__seed_{seed}.pkl"
        df.to_pickle(file_name)


In [None]:
error_hist_prdp = np.array(error_hist_prdp)

fig, axs = plt.subplots(1, 3, figsize=(20, 5))

axs[0].set_title("Training loss (minibatch avg.)")
axs[0].plot(loss_hist_prdp)
axs[0].set_yscale("log"); axs[0].grid()
axs[0].set_ylim(1e-3, 1)

axs[1].set_title("X-velocity nMSE (val. set)")
axs[1].plot(error_hist_prdp[:,0,0], label="1-step nMSE | x-vel")
axs[1].plot(error_hist_prdp[:,1,0], label="2-step nMSE | x-vel", color="orange")
axs[1].plot(error_hist_prdp[:,2,0], label="5-step nMSE | x-vel", color="green")
axs[1].legend(); axs[1].set_yscale("log"); axs[1].grid()
# axs[1].set_ylim(5e-3, 1)

EMA_WINDOW = 6
axs[1].plot(numpy_ewma(error_hist_prdp[:,2,0], EMA_WINDOW), label="EMA")
axs[1].axhline(y = error_hist_prdp[0,2,0], color='grey', linestyle=':')

axs[2].set_title("GMRES maxiter")
axs[2].plot(n_inner_hist_prdp); plt.grid()
# axs[2].set_ylim(0, 17)
# axs[1].set_xlim(96, 100)
# axs[1].set_ylim(1e-3, 1e-1)

# from matplotlib.ticker import MaxNLocator
# axs[2].yaxis.set_major_locator(MaxNLocator(integer=True))
