In [1]:
import jax.numpy as jnp
import jax
from jax import custom_jvp, jacfwd
import jax.numpy.linalg as LA
from jax import jvp
from functools import partial
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import numpy as np

from src.poisson_problem_definition import *
from src.linear_solvers_scan import forward_solve_jacobi, forward_solve_SD, forward_solve, outer_optimisation, outer_optimisation_dynamic_inner
from src.poisson_utilities import plot_outer_history_2 #plot_contours, 
from src.prdp import should_refine
# plt.rcParams['figure.dpi'] = 200
# plt.rcParams['savefig.dpi'] = 200

# import plotly.io as pio
# pio.templates.default = "simple_white"

plotly_config = {
  'toImageButtonOptions': {
    'format': 'png', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': 500,
    'width': 800,
    'scale': 1.5 # Multiply title/legend/axis/canvas sizes by this factor
  }
}

%load_ext autoreload
%autoreload 2

In [2]:
# SETTINGS FOR PLOTTING
plt.rc("text")
plt.rc("font", family="sans-serif")
plt.rc("font", size=9)
plt.rc("axes", labelsize=6)
plt.rc("font", size=9)
plt.rc("legend", fontsize=6)  # Make the legend/label fonts
plt.rc("legend", title_fontsize=7)
plt.rc("xtick", labelsize=6.5)  # a little smaller
plt.rc("ytick", labelsize=6.0)
# Set title font size
plt.rc("axes", titlesize=8)
plt.rc("lines", linewidth=1.0)

WIDTH = 5.5
HEIGHT = 2.0
plt.rc("figure", figsize=(WIDTH, HEIGHT))

In [None]:
# check whether jax is using gpu
print(jax.devices())

## Problem Definition

In [4]:
THETA_REF = 2.0
rhs_fn = rhs_sine_fn
# rhs_fn = rhs_step_fn

In [None]:
N_DOF = 30      # N_DOF from slider
x = jnp.linspace(0, 1, N_DOF+2); x = x[1:-1]; print(f"x shape = {x.shape}")
A =  A_matrix(N_DOF)
u_ref = u_direct(THETA_REF, x, A, rhs_fn) #LA.solve(A, rhs_fn(THETA_REF, x))
assert u_ref.shape == (N_DOF,) # sanity check
# plt.plot(u_ref); plt.title(f"Exact solution, $\\theta$_ref = {THETA_REF}"); plt.grid();

### custome jvp = linsolve(A, $\dot{b}$)

In [6]:
SOLVER = "jacobi" # "jacobi" or "SD"

if SOLVER == "jacobi":
    LINSOLVE = forward_solve_jacobi
elif SOLVER == "SD":
    LINSOLVE = forward_solve_SD

@partial(custom_jvp, nondiff_argnums=(1, 2))
def u_solve(b, n_inner, u_init):
    return LINSOLVE(A, b, n_inner, u_init)

@u_solve.defjvp
def u_solve_jvp(n_inner, u_init, primals, tangents):
    """
    - primals = the input to the function whose JVP we are defining (b)
    - tangents = the autodiff of the inputs to the function whose JVP we are defining (b), 
    w.r.t. to the parameter against which we have requested the jacfwd (theta)
    """
    b, = primals # (n_dof,)
    b_dot, = tangents # (n_dof, d); here (n_dof,)
    primal_out = u_solve(b, n_inner, u_init) # history of <A,b> linsolve (N_INNER, n_dof)
    tangent_out = LINSOLVE(A, b_dot, n_inner, u_init) # history of <A, b_dot> linsolve (N_INNER, n_dof)
    return primal_out, tangent_out

# Inner problem

In [7]:
N_INNER = 1000   # max n_inner from slider

#### u_init = zeros

In [8]:
reference_jacobian = jacfwd(u_direct)(THETA_REF, x, A, rhs_fn)
U_INIT = jnp.zeros(N_DOF)

# unrolled jacobian
unrolled_jacobian = jacfwd(forward_solve)(THETA_REF, A, rhs_fn, x, N_INNER, SOLVER, U_INIT, False )
# implicit jacobian
func = lambda theta: u_solve(rhs_fn(theta, x), N_INNER, U_INIT) # (N_INNER+1, n_dof)
implicit_jacobian = jacfwd(func)(THETA_REF)
# jacobian suboptimalities
implicit_jac_rel_subopt_zerosinit = LA.norm(implicit_jacobian - reference_jacobian, axis=1) / LA.norm(reference_jacobian)
unrolled_jac_rel_subopt_zerosinit = LA.norm(unrolled_jacobian - reference_jacobian, axis=1) / LA.norm(reference_jacobian)

# primal suboptimality
primal = func(THETA_REF) # (N_INNER+1, n_dof)
u_rel_subopt_zerosinit = LA.norm(primal - u_ref, axis=1) / LA.norm(u_ref)

### Plot

In [None]:
print(f"theta_ref = {THETA_REF}, SOLVER = {SOLVER}")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(WIDTH, HEIGHT*1.5))
lastiter = N_INNER

# Plot on the left subplot (ax1)
ax1.plot(unrolled_jac_rel_subopt_zerosinit[:lastiter], "-",  color="tab:blue",   label="unrolled, $\\text{primal}_\\text{init} = \\text{zeros}$")
ax1.plot(implicit_jac_rel_subopt_zerosinit[:lastiter], "--", color="tab:orange", label="implicit, $\\text{tangent}_\\text{init} = \\text{zeros}$")
# ax1.plot( unrolled_jac_rel_subopt_onesinit[:lastiter], "-",  marker='^',  markevery = int(lastiter/4), color="tab:blue",   label="unrolled, $\\text{primal}_\\text{init} =\\text{ones}$")
# ax1.plot( implicit_jac_rel_subopt_onesinit[:lastiter], "--", marker='^',  markevery = int(lastiter/6), color="tab:orange", label="implicit, $\\text{tangent}_\\text{init} =\\text{ones}$")
# ax1.plot( unrolled_jac_rel_subopt_randinit[:lastiter], "-",  marker='x',  markevery = int(lastiter/3), color="tab:blue",   label="unrolled, $\\text{primal}_\\text{init} =\\text{rand}$")
# ax1.plot( implicit_jac_rel_subopt_randinit[:lastiter], "--", marker='x',  markevery = int(lastiter/10), color="tab:orange",label="implicit, $\\text{tangent}_\\text{init} = \\text{rand}$")

ax1.set_ylabel("Jacobian relative suboptimality")
ax1.set_xlabel("# iterations")
ax1.legend(fontsize=8); ax1.set_yscale("log")

# Plot on the right subplot (ax2)
# ax2.set_title(f"Primal relative subopt history\n $\\theta$_ref = {THETA_REF}")
ax2.plot(u_rel_subopt_zerosinit[:lastiter],  label="$\\text{primal}_\\text{init} = \\text{zeros}$")
# ax2.plot(u_rel_subopt_onesinit[:lastiter], "-^", markevery=int(lastiter/5),   label="$\\text{primal}_\\text{init} = \\text{ones}$")
# ax2.plot(u_rel_subopt_randinit[:lastiter], "-x", markevery=int(lastiter/5),   label="$\\text{primal}_\\text{init} = \\text{rand}$")
ax2.set_ylabel("Iterate relative suboptimality")
ax2.set_xlabel("# iterations")
ax2.legend(fontsize=8); ax2.set_yscale("log")

for ax in ax1,ax2:
    ax.grid(which='major', axis='both')
    ax.grid(which='minor', axis='x', linestyle=':', linewidth='0.5')
    ax.minorticks_on()

fig.suptitle(f"Poisson 1 param, solver={SOLVER}")
fig.tight_layout()
# fig.savefig(f"figures/poisson_1_param__{SOLVER}_suboptimalities.pdf", bbox_inches='tight')

In [None]:
## SANITY CHECK: IS U_SOLVE WORKING CORRECTLY
plt.rcParams['figure.dpi'] = 100
plt.plot(LA.solve(A, rhs_fn(THETA_REF, x)), label="u_ref")
plt.plot(u_solve(rhs_fn(THETA_REF,x), 1000, jnp.zeros(N_DOF))[-1], 'o', label="$u_{init} = \\text{zeros}$")
plt.title(f"Sanity check: is u_solve working correctly")
plt.grid()

### Residuum plot

In [None]:
def primal_rel_residuum(primal):
    # primal_hist = u_solve(b, n_iter, solver, u_init)
    Au = A @ primal.T
    res = LA.norm(Au - rhs_sine_fn(THETA_REF, x))
    # res_init = LA.norm(A @ jnp.zeros(N_DOF) - rhs_sine_fn(THETA_REF, x))
    res_init = LA.norm(rhs_sine_fn(THETA_REF, x))
    return res / res_init

U_INIT = jnp.zeros(N_DOF)
primal_hist = forward_solve(THETA_REF, A, rhs_sine_fn, x, 800, SOLVER, U_INIT)
residuum_hist = jax.vmap(primal_rel_residuum)(primal_hist)


plt.figure(figsize=(WIDTH*0.3, HEIGHT))
plt.plot(residuum_hist)#, "-o", markevery=int(N_INNER/5))
plt.yscale("log")
plt.grid(which='major', axis='both')
plt.grid(which='minor', axis='x', linestyle=':', linewidth='0.5')
plt.minorticks_on()
plt.xlabel("# iterations")
plt.ylabel("Primal relative residuum")
plt.title(f"Poisson 1 params | solver = {SOLVER}\nResiduum @ $\\theta$ = {THETA_REF}")
# plt.savefig(f"figures/poisson_1_param__{SOLVER}_primal_residuum.pdf", bbox_inches='tight')

# Outer Problem

In [7]:
from src.linear_solvers_scan import squareloss_fn
THETA_INIT = 5.0        # from slider
N_OUTER = 180           # from slider
LR = 275                # from slider
N_INNER_MAX = 600
U_INIT = jnp.zeros(N_DOF) # jnp.ones(n_dof) # jax.random.normal(jax.random.PRNGKey(0), shape=x.shape)

### Loss function 

In [None]:
DIFF_METHOD = "implicit"

if DIFF_METHOD == "unrolled":
    loss_fn = lambda theta: squareloss_fn(
        forward_solve(theta, A, rhs_fn, x, N_INNER_MAX, SOLVER, U_INIT, False),
        u_ref 
    )
elif DIFF_METHOD == "implicit":
    @jax.jit
    def loss_fn(theta):
        rhs_eval = rhs_fn(theta, x)
        u_history = u_solve(rhs_eval, N_INNER_MAX, U_INIT)
        return squareloss_fn(u_history, u_ref)

# test
print(f"loss at theta_init at refinement level 100 = {loss_fn(THETA_INIT)[101]}")

### Different constant n_inner

In [9]:
theta_hist, loss_hist = outer_optimisation(
    loss_fn,
    n_outer_iterations=N_OUTER,
    outer_lr=LR,
    init=THETA_INIT,
    n_inner_iterations=N_INNER_MAX
)
assert theta_hist.shape == loss_hist.shape == (N_OUTER+1, N_INNER_MAX+1)

theta_rel_error = jnp.abs(theta_hist - THETA_REF) / jnp.abs(THETA_REF)

#### Line plot

In [None]:
_ = plot_outer_history_2(
    loss_hist,
    theta_rel_error,
    "Loss",
    "Parameter relative error",
    log_flag = True,
    iterations_step = 50,
    suptitle=f"Poisson 1 param, Solver: {SOLVER}, Diff method: {DIFF_METHOD}",
    is_prdp_included=False
)

### Manually Scheduled n_inner

In [None]:
N_MIN, N_STEP, N_MAX = 25, None, N_INNER_MAX

hist_dynamic = outer_optimisation_dynamic_inner(
    loss_fn,
    n_outer_iterations=N_OUTER,
    lr = LR,
    theta_init = THETA_INIT,
    n_inner_init = N_MIN,
    n_inner_max = N_MAX,
    n_step = N_STEP, # change inside the function itself
    min_theta_change=0
)
theta_hist_dynamic, loss_hist_dynamic, n_inner_hist_dynamic = hist_dynamic

In [None]:
n_inner_hist_dynamic_clipped = jnp.clip(n_inner_hist_dynamic, 0, N_MAX)
theta_rel_error_dynamic = jnp.abs(theta_hist_dynamic - THETA_REF) / jnp.abs(THETA_REF)

theta_rel_error_all = jnp.hstack([theta_rel_error, theta_rel_error_dynamic.reshape(-1,1)])
loss_hist_all       = jnp.hstack([loss_hist, loss_hist_dynamic.reshape(-1,1)])

fig_outer = plot_outer_history_2(
    loss_hist_all,
    theta_rel_error_all,
    "Loss",
    "Parameter relative error",
    log_flag = True,
    iterations_step = 50,
    suptitle=f"Poisson 1 param, Solver: {SOLVER}, Diff = {DIFF_METHOD}",
)

In [21]:
# fig_outer.savefig(f"figures/poisson_1_param_{SOLVER}_unrolled__training_manual_PRDP.pdf", bbox_inches='tight')

### Plot for PR savings section in paper

In [None]:
# create color map
norm = Normalize(vmin=1, vmax=N_INNER_MAX)
cmap = cm.get_cmap("plasma_r")
PRDP_COLOR = "tab:green"

fig_pr_savings, axs = plt.subplots(1,2, figsize=(WIDTH*0.5, HEIGHT*0.9))

# plot errors for all n_inner
for i in np.arange(0, N_INNER_MAX+1, 50):
    color = cmap(norm(i+1))
    axs[0].plot(theta_rel_error_all[:,i], linewidth=1.0, color=color)
# plot errors for scheduled n_inner
axs[0].plot(theta_rel_error_all[:, -1], linewidth=1.0, color=PRDP_COLOR)
axs[0].set_yscale("log")
axs[0].set_xlabel("Outer iteration")
axs[0].set_ylabel("Parameter relative error")
axs[0].grid()
axs[0].set_xlim(0, None)

# plot n_inner vs outer iteration
K_epsilon = 600
n_outer_for_k_epsilon = 125
converged_n_inner = np.full((n_outer_for_k_epsilon,), N_INNER_MAX)

n_outer_for_prdp = 145
prdp_n_inner = n_inner_hist_dynamic_clipped[:n_outer_for_prdp]

axs[1].plot(converged_n_inner, label="Fully converged", color=cmap(norm(K_epsilon)))
axs[1].plot(prdp_n_inner, label="Scheduled", color=PRDP_COLOR)
axs[1].fill_between(np.arange(n_outer_for_k_epsilon), prdp_n_inner[:n_outer_for_k_epsilon], converged_n_inner[:n_outer_for_k_epsilon], color=PRDP_COLOR, alpha=0.2, label="PR savings")
axs[1].set_xlabel("Outer iteration")
axs[1].set_ylabel("# physics solver iterations")
axs[1].grid()
axs[1].set_xlim(0, None)

# Add colorbar and legend
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig_pr_savings.colorbar(
    sm,
    ax=axs,
    orientation="horizontal",
    location='top',
    shrink=0.4,
    anchor=(0, 2)
)
cbar.set_label("# physics solver iterations")
cbar.set_ticks(ticks=[0, 300, 600])
fig_pr_savings.legend(loc='lower left', bbox_to_anchor=(0.65, 0.73), frameon=False)

fig_pr_savings.set_layout_engine("compressed")
# fig_pr_savings.savefig(f"result_plots/pr_savings_poisson_1_param_{SOLVER}_unrolled.pdf", bbox_inches='tight', pad_inches=0.01)

In [None]:
# Plot n_inner vs. n_outer
N_CONVERGENCE = 600
N_OUTER_TOL = 125
n_convergence_array = jnp.full((N_OUTER_TOL,), N_CONVERGENCE)
n_inner_hist_dynamic_tol = n_inner_hist_dynamic_clipped[:145]

# calculate savings
converged_cost = jnp.sum(n_convergence_array)
dynamic_cost = jnp.sum(n_inner_hist_dynamic_tol)
print(f"converged_cost = {converged_cost}, dynamic_cost = {dynamic_cost}")
print(f"savings = {1 - dynamic_cost / converged_cost}")

plt.figure(figsize=(4, 3))
plt.plot(n_inner_hist_dynamic_tol, label="Scheduled (progressively refined)", color="green")
# plt.axhline(N_CONVERGENCE, color="red", linestyle="--", label="Fully converged")
plt.plot(n_convergence_array, color="red", linestyle="--", label="Fully converged")
plt.fill_between(range(len(n_inner_hist_dynamic_tol)), n_inner_hist_dynamic_tol, N_CONVERGENCE, where=(n_inner_hist_dynamic_tol < N_CONVERGENCE), color='green', alpha=0.2, label = "Progressively refinement (PR) savings")
plt.xlabel("# outer iterations")
plt.ylabel("# inner iterations")
plt.grid()
plt.legend()
plt.xlim(0, 145)
# plt.savefig(f"figures/poisson_1_param_{SOLVER}_unrolled__cost_manual_PRDP.pdf", bbox_inches='tight')

### PRDP

In [None]:
if DIFF_METHOD == "unrolled":
    @partial(jax.jit, static_argnames=['n_inner'])
    def loss_fn_prdp(theta, n_inner):
        u = forward_solve(theta, A, rhs_fn, x, n_inner, SOLVER, U_INIT, False)[-1]
        return squareloss_fn(u, u_ref)
elif DIFF_METHOD == "implicit":
    @partial(jax.jit, static_argnames=['n_inner'])
    def loss_fn_prdp(theta, n_inner):
        rhs_eval = rhs_fn(theta, x)
        u = u_solve(rhs_eval, n_inner, U_INIT)[-1]
        return squareloss_fn(u, u_ref)

# test
print(f"loss at theta_init at refinement level 100 = {loss_fn_prdp(THETA_INIT, 101)}")

def update_fn(theta, n_inner, lr):
    theta_new = theta - lr * jacfwd(loss_fn_prdp)(theta, n_inner)
    loss = loss_fn_prdp(theta_new, n_inner)
    return loss, theta_new

def theta_error_fn(theta):
    return jnp.abs(theta - THETA_REF) / jnp.abs(THETA_REF)

In [None]:
# PRDP params
N_INNER_INIT, N_INNER_STEP = 25, 10
TAU_STEP, TAU_STOP = 0.92, 0.98
ERROR_WINDOW = 2

# initialize model to be trained
theta = THETA_INIT

# initialize physics refinement
n_inner = N_INNER_INIT

# initialize metrics
loss_hist_prdp = [loss_fn_prdp(theta, n_inner)]
theta_error_hist_prdp = [theta_error_fn(theta)]
n_inner_hist_prdp = [N_INNER_INIT] 

# initialize PRDP's checkpoint error
should_refine.error_checkpoint = 100

# training loop
for epoch in range(N_OUTER):
    # update model
    loss, theta = update_fn(theta, int(n_inner), LR)
    loss_hist_prdp.append(loss)
    theta_error_hist_prdp.append(theta_error_fn(theta))
    n_inner_hist_prdp.append(int(n_inner))
    print(f"Epoch {epoch+1}/{N_OUTER}, n_inner  {int(n_inner)}, loss = {loss}, theta rel error = {theta_error_hist_prdp[-1]}")

    # PRDP
    if should_refine(np.array(theta_error_hist_prdp), 
                     stepping_threshold=TAU_STEP,
                     error_window=ERROR_WINDOW,
    ):
        n_inner = min(n_inner + N_INNER_STEP, N_INNER_MAX)

theta_error_hist_prdp = np.array(theta_error_hist_prdp)
loss_hist_prdp = np.array(loss_hist_prdp)

In [None]:
theta_rel_error_all = np.hstack([theta_rel_error, theta_error_hist_prdp.reshape(-1,1)])
loss_history_all = np.hstack([loss_hist, loss_hist_prdp.reshape(-1,1)])

fig_outer_hist = plot_outer_history_2(
    loss_history_all, 
    theta_rel_error_all, 
    "loss", 
    "parameter relative error", 
    log_flag=True, 
    iterations_step=50,
    suptitle=f"Poisson 3 parameters. Solver: {SOLVER}, Diff: {DIFF_METHOD}", 
)


#### PRDP plot for paper

In [None]:

# create color map
norm = Normalize(vmin=1, vmax=N_INNER_MAX)
cmap = cm.get_cmap("plasma_r")
PRDP_COLOR = "tab:green"

fig_pr_savings, axs = plt.subplots(1,2, figsize=(WIDTH*0.5, HEIGHT*0.9))

# plot errors for all constant n_inner
for i in np.arange(0, N_INNER_MAX+1, 50):
    color = cmap(norm(i+1))
    axs[0].plot(theta_rel_error_all[:,i], linewidth=1.0, color=color)
# plot errors for prdp n_inner
axs[0].plot(theta_rel_error_all[:, -1], linewidth=1.0, color=PRDP_COLOR)
axs[0].set_yscale("log")
axs[0].set_xlabel("Outer iteration")
axs[0].set_ylabel("Parameter relative error")
axs[0].grid()
axs[0].set_xlim(0, None)

# plot n_inner vs outer iteration
K_epsilon = 600
n_outer_for_k_epsilon = 125
converged_n_inner = np.full((n_outer_for_k_epsilon,), N_INNER_MAX)
cost_converged = K_epsilon * n_outer_for_k_epsilon
print(f"cost_converged = {cost_converged}")

n_outer_for_prdp = 170
prdp_n_inner = n_inner_hist_prdp[:n_outer_for_prdp]
cost_prdp = np.sum(prdp_n_inner)
print(f"cost_prdp = {cost_prdp}")
print(f"savings = {1 - cost_prdp / cost_converged}")

axs[1].plot(converged_n_inner, label="Fully converged", color=cmap(norm(K_epsilon)))
axs[1].plot(prdp_n_inner, label="PRDP", color=PRDP_COLOR)
axs[1].fill_between(np.arange(n_outer_for_k_epsilon), prdp_n_inner[:n_outer_for_k_epsilon], converged_n_inner[:n_outer_for_k_epsilon], color=PRDP_COLOR, alpha=0.2, label="PR savings")
axs[1].set_xlabel("Outer iteration")
axs[1].set_ylabel("# physics solver iterations")
axs[1].grid()
axs[1].set_xlim(0, None)
axs[1].set_xticks([0, n_outer_for_k_epsilon, n_outer_for_prdp])
axs[0].set_xticks([0, n_outer_for_k_epsilon, n_outer_for_prdp])

# Add colorbar and legend
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig_pr_savings.colorbar(
    sm,
    ax=axs,
    orientation="horizontal",
    location='top',
    shrink=0.4,
    anchor=(0, 2)
)
cbar.set_label("# physics solver iterations")
cbar.set_ticks(ticks=[0, 300, 600])
fig_pr_savings.legend(loc='lower left', bbox_to_anchor=(0.65, 0.73), frameon=False)

fig_pr_savings.set_layout_engine("compressed")
# fig_pr_savings.savefig(f"result_plots/poisson_1_param_{SOLVER}_{DIFF_METHOD}.pdf", bbox_inches='tight', pad_inches=0.01)