Concept for Dusty's identification of steady state:

1. Run with some basal slip, ensure layers are basically in steady state
2. Leave layers (at t0) the same, but either move basal slip to start earlier or increase magnitude (or both)

Little pieces:
- [ ] Plot mismatch between flow direction and local layer slope
- [ ] Create way to save layers
- [X] Re-factor lamdify_and_vectorize_if_needed out and use it here too
- [X] Clean path to turning off noise

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.constants
import scipy.interpolate
import sympy
from sympy import *
import pickle
import datetime
from tqdm import tqdm
import time

from flowline_ode.plots_setup import *
from flowline_ode.sia_model import *
from flowline_ode.noise import *
from flowline_ode.finite_differences import *

In [None]:
t_start_notebook = time.time()

### Problem setup

In [None]:
n = 2 # Flow exponent for the actual model

output_results_base = f"outputs/{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_n{n:.1f}"
print(f"{output_results_base}")

In [None]:
# Domain size
domain_x = 100000 # meters
domain_z = 3000 # meters

# Grids for when discretization is needed
dx = 100
dz = 25
xs = np.arange(0, domain_x, dx)
zs = np.arange(0, domain_z, dz)

# Sympy symbolic variables
x = sympy.symbols('x', real=True, positive=True)
z = sympy.symbols('z', real=True, positive=True)

# Define surface geometry
#surface_sym = domain_z - ((x / 18000.0)**3.0)
#surface_sym = domain_z - (((x+10e3) / 18000.0)**3.0)
surface_sym = domain_z - (((x+5e3) / 22000.0)**3.0)
surface = lambdify_and_vectorize_if_needed(x, surface_sym)

# Use sympy to build a function for the derivative of the surface
ds_dx_sym = sympy.diff(surface_sym, x)

ds_dx_lambdify = lambdify_and_vectorize_if_needed(x, ds_dx_sym)
def ds_dx(x):
    tmp = ds_dx_lambdify(x)
    if np.isscalar(tmp):
        return tmp + (0*x)
    else:
        return tmp

# Plot the surface and its derivative
fig, _ = plot_surface(xs, surface, ds_dx)
fig.savefig(f"{output_results_base}_surface_and_slope.png")
plt.show()

### Generate SIA-based velocity field

This is a basic SIA velocity field, heavily based on course notes here: https://ocw.hokudai.ac.jp/wp-content/uploads/2016/02/DynamicsOfIce-2005-Note-all.pdf

The input functions (surface(x) and ds/dx(x)) are passed in as symbolic expressions and the outputs are also returned as symbolic expressions.


In [None]:
#basal_velocity = ((20/scipy.constants.year) / (1 + sympy.exp(-0.0005*(x-75000)))) # Layer dip example
basal_velocity = ((20/scipy.constants.year) / (1 + sympy.exp(-0.0006*(x-80000)))) # Layer dip example -- modified
#basal_velocity = ((20/scipy.constants.year) / (1 + sympy.exp(-0.0001*(x-75000)))) # Gradual transition to basal sliding
#basal_velocity = 0

# A zero basal velocity version (sometimes used for layers to simulate prior steady state)
u_zerobasal, w_zerobasal, du_dx_zerobasal = sia_model(x, z, surface_sym, ds_dx_sym, n=n)
# The actual model we'll use going forward
u, w, du_dx = sia_model(x, z, surface_sym, ds_dx_sym, n=n, basal_velocity_sym=basal_velocity)

# Plot the resulting horizontal and vertical velocity fields
fig, _ = plot_velocity(xs, zs, x, z, u, w, surface=surface, domain_z=domain_z)
fig.savefig(f"{output_results_base}_velocity_fields.png")
plt.show()

In [None]:
# For verification purposes, also plot the surface velocity

fig, _ = plot_surface_bed_velocity(xs, x, z, u, surface_sym)
fig.savefig(f"{output_results_base}_surface_bed_velocity.png")
plt.show()


### Generate synthetic layers

In [None]:
xs_layers = np.linspace(0, domain_x, 200) # don't need tight grid spacing for smooth layer
xs_layers_initial = xs_layers

# OLD WAY:
# Generate layers with offsets from the surface contour
# layer_ages = np.logspace(0, 4, 16)
# layers_t0 = []
# for idx, age in enumerate(layer_ages):
#     start_offset = 100 + (idx * 150)
#     layer_start_fn = lambda x: surface(x) - start_offset
#     layer = advect_layer(lambdify_and_vectorize_if_needed((x, z), u),
#                          lambdify_and_vectorize_if_needed((x, z), w),
#                          xs_layers, layer_start_fn, layer_ages[:idx+1]*scipy.constants.year,
#                          max_age_timestep=500*scipy.constants.year)
#     layers_t0.append(layer[-1])

# ALTERNATIVE: Advect from surface until they hit desired depths

depth_target_x = 70e3 # meters
depth_targets = np.linspace(surface(depth_target_x)-100, 500, 16)
print(f"depth_targets = {depth_targets}")
first_layer_fn = lambda x: surface(x) - 10

u_zerobasal_fn = lambdify_and_vectorize_if_needed((x, z), u_zerobasal)
w_zerobasal_fn = lambdify_and_vectorize_if_needed((x, z), w_zerobasal)

def simulate_layers_with_target_spacing(u_fn, w_fn, depth_targets, depth_target_x,
                                        xs_layers, first_layer_fn,
                                        initial_time_step = 1*scipy.constants.year,
                                        max_iters_per_layer = 50):
    time_step = initial_time_step
    layer_ages_years = np.nan * np.ones_like(depth_targets)
    age_years = 0
    layers_t0 = []
    last_layer_fn = first_layer_fn
    for idx, depth in enumerate(depth_targets):
        layer_z_at_start = domain_z
        layer = last_layer_fn
        layer_iters = 0
        while layer_z_at_start > depth_targets[idx]:
            l_res = advect_layer(u_fn, w_fn, xs_layers, layer, [time_step])
            age_years += time_step/scipy.constants.year
            layer = l_res[-1]
            layer_z_at_start = layer(depth_target_x)
            layer_iters += 1
            if layer_iters < 3 or layer_iters % 20 == 0:
                print(f"[Layer {idx}] layer_z_at_start = {layer_z_at_start}, time_step = {time_step/scipy.constants.year}, layer_iters = {layer_iters}")

            if layer_iters > max_iters_per_layer:
                print(f"[Layer {idx}] Max iters reached, breaking")
                break

            time_step *= 1.2
        time_step /= 3
        last_layer_fn = layer
        layer_ages_years[idx] = age_years

        layers_t0.append(layer)

    return layers_t0, layer_ages_years

# Layers simulated without the basal slip
layers_t0_noslip, layer_ages_years = simulate_layers_with_target_spacing(u_zerobasal_fn, w_zerobasal_fn, depth_targets, depth_target_x,
                                                                    xs_layers, first_layer_fn)
# Then advect the layers with the actual model for a fixed number of years
xs_layers_t0 = np.arange(0, domain_x, 10)
layers_t0 = []
for layer in layers_t0_noslip:
    layers_t0.append(advect_layer(lambdify_and_vectorize_if_needed((x, z), u),
                                  lambdify_and_vectorize_if_needed((x, z), w),
                                  xs_layers_t0, layer, [scipy.constants.year * 100])[-1])


# Simulate layers if current velocity field was steady state
u_fn = lambdify_and_vectorize_if_needed((x, z), u)
w_fn = lambdify_and_vectorize_if_needed((x, z), w)

layers_current_velocity_steady_state, _ = simulate_layers_with_target_spacing(u_fn, w_fn, depth_targets, depth_target_x,
                                                                    xs_layers, first_layer_fn)

#layers_t0 = layers_current_velocity_steady_state

In [None]:
fig, ax = plot_velocity_magnitude(xs, zs, x, z, u, w, surface, domain_x, layers_t0, xs_layers)
fig.savefig(f"{output_results_base}_layers.png")
plt.show()

In [None]:
fig, ax = plot_velocity_magnitude(xs, zs, x, z, u, w, surface, domain_x, layers_current_velocity_steady_state, xs_layers)


### Advect layers by 1 year to simulate layer motion

In [None]:
xs_layers_t1 = np.arange(0, domain_x, 10)
layers_t1 = []
for layer in layers_t0:
    layers_t1.append(advect_layer(lambdify_and_vectorize_if_needed((x, z), u),
                                  lambdify_and_vectorize_if_needed((x, z), w),
                                  xs_layers_t1, layer, [scipy.constants.year * 1])[-1])

In [None]:
layer_idx = 10
xs_layers_t1 = np.arange(0, domain_x, 10)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(xs_layers_t1/1e3, layers_t1[layer_idx](xs_layers_t1) - layers_t0[layer_idx](xs_layers_t1), 'r-', label='t1 - t0')
ax.set_title('Layer advection')
ax.legend()
ax.set_xlabel('x [km]')
ax.set_ylabel('z [m]')
ax.grid(True)

### Set up finite difference approximations of layer deformation-related partial derivatives

In [None]:
add_simulated_noise = False
apply_stacking = False
apply_gp_smoothing = False

snr_db = 10
velocity = 50 # m/s
prf = 10 # Hz
center_frequency = 60e6 # Hz
pulse_spacing = velocity / prf # m
pulses_x = np.arange(0, domain_x, pulse_spacing)

rng = np.random.default_rng(275209073368752189122200994498511502265)

if add_simulated_noise:
    layers_t0_measured = add_noise_to_layers(layers_t0, snr_db, domain_x, velocity=velocity, prf=prf, center_frequency=center_frequency, rng=rng)
    layers_t1_measured = add_noise_to_layers(layers_t1, snr_db, domain_x, velocity=velocity, prf=prf, center_frequency=center_frequency, rng=rng)

else:
    print("WARNING: Noise disabled in this simulation")
    layers_t0_measured = layers_t0
    layers_t1_measured = layers_t1

# Noise filtering

from sklearn.gaussian_process.kernels import WhiteKernel, Matern, RBF
kernel = RBF(length_scale=1e3/domain_x, length_scale_bounds=(1e3, 50e3)) + WhiteKernel(1e-5, noise_level_bounds=(1e-10, 1))

if apply_stacking:
    aperture_length = 200

    layers_t0_smoothed = simulate_stacking(layers_t0_measured, domain_x, kernel_length_m=aperture_length, velocity=velocity, prf=prf)
    layers_t1_smoothed = simulate_stacking(layers_t1_measured, domain_x, kernel_length_m=aperture_length, velocity=velocity, prf=prf)

    if apply_gp_smoothing:
        layers_t0_smoothed = gp_smoothing(layers_t0_smoothed, domain_x, x_spacing=aperture_length, initialize_with_prior_kernel=False,
                                        kernel=kernel)
        layers_t1_smoothed = gp_smoothing(layers_t1_smoothed, domain_x, x_spacing=aperture_length, initialize_with_prior_kernel=False,
                                        kernel=kernel)

else:
    layers_t0_smoothed = layers_t0_measured
    layers_t1_smoothed = layers_t1_measured

In [None]:
layer_idx = 10

xs_tmp = xs[:-100]

layer_t0_diff_meas = layers_t0_measured[layer_idx](xs_tmp) - layers_t0[layer_idx](xs_tmp)
layer_t1_diff_meas = layers_t1_measured[layer_idx](xs_tmp) - layers_t1[layer_idx](xs_tmp)

layer_t0_diff = layers_t0_smoothed[layer_idx](xs_tmp) - layers_t0[layer_idx](xs_tmp)
layer_t1_diff = layers_t1_smoothed[layer_idx](xs_tmp) - layers_t1[layer_idx](xs_tmp)

layer_t0_deriv_diff = np.gradient(layers_t0_smoothed[layer_idx](xs_tmp), xs_tmp) - np.gradient(layers_t0[layer_idx](xs_tmp), xs_tmp)
layer_t1_deriv_diff = np.gradient(layers_t1_smoothed[layer_idx](xs_tmp), xs_tmp) - np.gradient(layers_t1[layer_idx](xs_tmp), xs_tmp)

fig, axs_tmp = plt.subplots(3, 1, figsize=(12, 8))

axs_tmp[0].plot(xs_tmp/1e3, layer_t0_diff_meas, 'r-', label='t0')
axs_tmp[0].plot(xs_tmp/1e3, layer_t1_diff_meas, 'b-', label='t1')
axs_tmp[0].set_title(f'[Layer {layer_idx}] Difference in layer position (measured - true)\nRMS Error: t0 {np.sqrt(np.mean(layer_t0_diff_meas**2)):.2e} m, t1 {np.sqrt(np.mean(layer_t1_diff_meas**2)):.2e} m')

axs_tmp[1].plot(xs_tmp/1e3, layer_t0_diff, 'r-', label='t0')
axs_tmp[1].plot(xs_tmp/1e3, layer_t1_diff, 'b-', label='t1')
axs_tmp[1].set_title(f'[Layer {layer_idx}] Difference in layer position (smoothed - true)\nRMS Error: t0 {np.sqrt(np.mean(layer_t0_diff**2)):.2e} m, t1 {np.sqrt(np.mean(layer_t1_diff**2)):.2e} m')

axs_tmp[2].plot(xs_tmp/1e3, layer_t0_deriv_diff, 'r-', label='t0')
axs_tmp[2].plot(xs_tmp/1e3, layer_t1_deriv_diff, 'b-', label='t1')
axs_tmp[2].set_title(f'[Layer {layer_idx}] Difference in layer derivative (smoothed - true)\nRMS Error: t0 {np.sqrt(np.mean(layer_t0_deriv_diff**2)):.2e} m, t1 {np.sqrt(np.mean(layer_t1_deriv_diff**2)):.2e} m')

for ax in axs_tmp:
    ax.grid()

fig.tight_layout()


In [None]:
layer_idx = 10

fig, ax = plt.subplots(figsize=(20, 4))

xs_tmp = np.arange(70e3, 80e3, 10)

#test_layer_gp_output = gaussian_process.predict(xs_tmp.reshape(-1, 1))

ax.plot(xs_tmp/1e3, layers_t0[layer_idx](xs_tmp), label='t0')
ax.plot(xs_tmp/1e3, layers_t0_smoothed[layer_idx](xs_tmp), label='t0 smoothed')
#ax.plot(xs_tmp/1e3, test_layer_gp_output, label='GP', linestyle=':')
y_bot, y_top = ax.get_ylim()
ax.set_ylim(y_bot, y_top)

ax.plot(xs_tmp/1e3, layers_t0_measured[layer_idx](xs_tmp), label='t0 measured', alpha=0.2)
ax.set_xlabel('x [km]')
ax.set_ylabel('z [m]')

ax.legend()

In [None]:
layer_dl_dx, layer_dl_dt, layer_d2l_dxdz, layer_d2l_dtdz = create_layer_finite_difference_fns(layers_t0_smoothed, layers_t1_smoothed)

# For ground truth only

# def layer_dl_dx_true(x, layer_idx, dx=1):
#     """
#     Numerical central difference approximation of layer slope
#     Output units are m/m
#     """
#     return (layers_t0[layer_idx](x+(dx/2)) - layers_t0[layer_idx](x-(dx/2))) / dx

# def layer_dl_dt_true(x, layer_idx):
#     """
#     Numerical forward difference approximation of layer vertical deformation
#     Output units are m/year
#     """
#     return (layers_t1[layer_idx](x) - layers_t0[layer_idx](x))


In [None]:
max_layer_slope = 0
min_layer_slope = 0

fig, ax = plt.subplots(figsize=(8, 4))
vmin, vmax = -10, 10
sc = None
for idx, layer in enumerate(layers_t0):
    layer_slopes = np.arctan(layer_dl_dx(xs, idx))*(180/np.pi)
    sc = ax.scatter(xs/1e3, layer(xs), c=layer_slopes, s=2, vmin=vmin, vmax=vmax, cmap='coolwarm')

    max_layer_slope = max(max_layer_slope, np.max((layer_slopes)))
    min_layer_slope = min(min_layer_slope, np.min((layer_slopes)))
fig.colorbar(sc, ax=ax, label='Layer slope [deg]')
ax.set_title('Layer slope')
ax.grid(True)
ax.set_xlim(0, domain_x/1e3)
ax.set_ylim(0, surface(0))
ax.set_xlabel('x [km]')
ax.set_ylabel('z [m]')
fig.savefig(f"{output_results_base}_layer_slope.png")
plt.show()

print(f"Max layer slope: {max_layer_slope}")
print(f"Min layer slope: {min_layer_slope}")

### Method of Characteristics Solution

In [None]:
def du_dtau(tau, u, layer_idx):
    res = -1 * layer_d2l_dxdz(tau, layer_idx)*u - layer_d2l_dtdz(tau, layer_idx)
    #res[np.where(u < 0)] = np.max(res, 0)
    return res

start_pos_x = 100

layer_solutions = {}

for idx in tqdm(np.arange(1, len(layers_t0)-1)):
    layer = layers_t0[idx]
    
    u0 = 0 # More realistic (and effectively equivalent for this example) to assume no specific knowledge of horizontal velocity at depth
    #u0 = u.subs([(x, start_pos_x), (z, layer(start_pos_x))]).evalf() * scipy.constants.year
    
    # If you're extracting points from the solution (for rheology estimates, for example), set max_step to no more than the spacing at which you'll extract points
    # (but if you're just plotting velocity, this will run a lot faster if you let the solver pick a max step size)
    in_lower_guardrails = xs_layers[layer(xs_layers) < 200]
    if len(in_lower_guardrails) > 0:
        layer_end = np.min(in_lower_guardrails)
    else:
        layer_end = domain_x
    layer_solutions[idx] = scipy.integrate.solve_ivp(du_dtau, [start_pos_x, layer_end], np.array([u0]), args=(idx,), dense_output=True) # , max_step=100
    if layer_solutions[idx].status != 0:
        print(f"Layer {idx} failed to solve")
        print(layer_solutions[idx].message)

In [None]:
def layer_solution_velocity(layer_idx, x):
    res = layer_solutions[layer_idx].sol(x)[0]
    if np.isscalar(res):
        if x > layer_solutions[layer_idx].sol.t_max:
            return np.nan
        else:
            return res
    res[x > layer_solutions[layer_idx].sol.t_max] = np.nan
    return res

In [None]:
fig, (ax, ax_err) = plt.subplots(2, 1, figsize=(8, 8), sharex=True, sharey=True)
sc = None

# Solution plot
for layer_idx in layer_solutions.keys():
    xs_tmp = xs
    sc = ax.scatter(xs_tmp/1e3, layers_t0[layer_idx](xs_tmp), c=layer_solution_velocity(layer_idx, xs_tmp), vmin=0, vmax=40, s=2, cmap='viridis')
fig.colorbar(sc, ax=ax, label='Horizontal Velocity [m/yr]')

ax.grid(True)
#ax.set_xlim(0, domain_x/1e3)
#ax.set_ylim(0, surface(0))
#ax.set_xlabel('x [km]')
ax.set_ylabel('z [m]')

# TMP plot
fig_tmp, ax_tmp = plt.subplots(figsize=(8, 4))

# Error plot

mse_accumulator = 0
mse_sample_count = 0

for layer_idx in layer_solutions.keys():
    xs_tmp = xs
    err = layer_solution_velocity(layer_idx, xs_tmp) - (lambdify_and_vectorize_if_needed((x,z), u)(xs_tmp, layers_t0[layer_idx](xs_tmp)) * scipy.constants.year)
    sc = ax_err.scatter(xs_tmp/1e3, layers_t0[layer_idx](xs_tmp), c=err, vmin=-2, vmax=2, s=2, cmap='coolwarm_r')

    # Track mean squared error
    err_valid = err
    err_valid[:10] = np.nan
    ax_tmp.plot(xs_tmp/1e3, err_valid, label=f"Layer {layer_idx}")

    mse_accumulator += np.sum((err_valid[~np.isnan(err_valid)])**2)
    mse_sample_count += len(err_valid[~np.isnan(err_valid)])

fig.colorbar(sc, ax=ax_err, label='Horizontal Velocity Error [m/yr]')

ax_err.grid(True)
ax_err.set_xlabel('x [km]')
ax_err.set_ylabel('z [m]')

ax.set_title('Horizontal velocity solution')
ax_err.set_title('Error in horizontal velocity estimate')
fig.savefig(f"{output_results_base}_layer_solutions.png")
plt.show()

print(f"Mean squared error: {mse_accumulator / mse_sample_count}")

In [None]:
len(layers_t0)

In [None]:
# True surface and bed velocities
fig, ax = plot_surface_bed_velocity(xs, x, z, u, surface_sym)
# ODE solution along lowest layer
sc = ax.plot(xs/1e3, layer_solution_velocity(len(layers_t0)-2, xs), label='Bottom layer velocity [m/yr]')
ax.legend()
fig.savefig(f"{output_results_base}_surface_bed_botlayer_velocity.png", dpi=500)
plt.show()

In [None]:
# Plot each layer solution at x=100e3 as a function of depth
fig, (ax_u, ax_dwdz) = plt.subplots(1,2, figsize=(8, 8), sharey=True)

plot_pos_x = 85e3

u_at_plot_pos = lambdify_and_vectorize_if_needed(z, u.subs(x, plot_pos_x))(zs)

for layer_idx in layer_solutions.keys():
    if layer_idx == 1:
        lbl = 'ODE Solutions'
    else:
        lbl = None
    ax_u.scatter([layer_solution_velocity(layer_idx, plot_pos_x)], [layers_t0[layer_idx](plot_pos_x)], label=lbl, c='r')

ax_u.plot(u_at_plot_pos*scipy.constants.year, zs, 'k--', label='True')
ax_u.set_title(f'Horizontal Velocity\nn = {n}')
ax_u.set_xlabel('Horizontal velocity [m/yr]')
ax_u.set_ylabel('z [m]')
ax_u.grid()
ax_u.legend()
ax_u.set_xlim(0, 1.2 * np.nanmax(u_at_plot_pos*scipy.constants.year))

# Vertical strain rate
dwdz_at_plot_pos = lambdify_and_vectorize_if_needed(z, sympy.diff(w, z).subs(x, plot_pos_x))(zs)

for layer_idx in layer_solutions.keys():
    if layer_idx == 1:
        lbl_ode = 'ODE Solutions'
        lbl_zeroapprox = 'Zero Slope Approximation'
    else:
        lbl_ode = None
        lbl_zeroapprox = None
    
    # Recovered from layer ODE solutions
    du_dx_ode_at_plot_pos = layer_solution_velocity(layer_idx, plot_pos_x) - layer_solution_velocity(layer_idx, plot_pos_x-1)
    ax_dwdz.scatter([-1*du_dx_ode_at_plot_pos], [layers_t0[layer_idx](plot_pos_x)], label=lbl_ode, c='r')

    # Estimated from zero slope approximation
    dwdz_zs_approx = (layer_dl_dt(plot_pos_x, layer_idx+1) - layer_dl_dt(plot_pos_x, layer_idx-1)) / (layers_t0[layer_idx+1](plot_pos_x) - layers_t0[layer_idx-1](plot_pos_x))
    ax_dwdz.scatter([dwdz_zs_approx], [layers_t0[layer_idx](plot_pos_x)], label=lbl_zeroapprox, c='b', marker='x')

ax_dwdz.plot(dwdz_at_plot_pos*scipy.constants.year, zs, 'k--', label='True')
ax_dwdz.set_title(f'Vertical Strain Rate\nn = {n}')
ax_dwdz.set_xlabel('Vertical strain rate [m/(yr*m)]')
ax_dwdz.grid()
ax_dwdz.legend()

fig.savefig(f"{output_results_base}_layer_solutions_at_xpos.png")
plt.show()

### Estimate effective stress

In [None]:
# Estimate du_dz
# dudz_central_diff = (layer_sols[layer_idx-1](x_pos) - layer_sols[layer_idx+1](x_pos)) / ( seconds_per_year * (layers_t0[layer_idx-1](x_pos)[1] - layers_t0[layer_idx+1](x_pos)[1]))

# # Estimate effective stress under SIA
# dist_to_surf = surface(x_pos) - z_pos
# rheology_eff_stress = ρ * g * dist_to_surf * -dsdx(x_pos) # (3.88)

In [None]:
rho = 918
g = 9.8

visc_start_x = 5e3
visc_end_x = 95e3

xs_visc = np.linspace(visc_start_x, visc_end_x, 100)

# Estimate du_dz and effective viscosity for the entire domain
du_dz_central_diff = np.zeros((len(xs_visc), len(layer_solutions)-1))
eff_stress = np.zeros_like(du_dz_central_diff)

for idx, x_pos in enumerate(xs_visc):
    for layer_idx in np.arange(2, len(layers_t0)-2):
        du_dz_central_diff[idx, layer_idx-1] = (layer_solution_velocity(layer_idx-1, x_pos) - layer_solution_velocity(layer_idx+1, x_pos)) / (layers_t0[layer_idx-1](x_pos) - layers_t0[layer_idx+1](x_pos))
        dist_to_surf = surface(x_pos) - layers_t0[layer_idx](x_pos)
        eff_stress[idx, layer_idx-1] = rho * g * dist_to_surf * -1 * ds_dx(x_pos)

# Plot a scatter plot of log10(du_dz_central_diff) vs log10(eff_stress)
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(np.log10(eff_stress), np.log10(du_dz_central_diff), s=2, label=f'n = {n}')
ax.set_aspect('equal')
ax.set_xlabel('log(effective stress)')
ax.set_ylabel('log(strain rate)')
ax.legend()
ax.grid()
fig.savefig(f"{output_results_base}_stress_strain_rate.png")
plt.show()

### Save results to a pickle

In [None]:
# Create filename containing current timestamp and n value
filename = output_results_base + ".pickle"
with open(filename, 'wb') as f:
    pickle.dump({
        'n': n,
        'xs': xs,
        'zs': zs,
        'domain_x': domain_x,
        'domain_z': domain_z,
        'surface': surface_sym,
        'x': x,
        'z': z,
        'u': u,
        'w': w,
        'ds_dx': ds_dx_sym,
        'layers_t0': layers_t0,
        'layers_t1': layers_t1,
        'layer_solutions': layer_solutions,
        'eff_stress': eff_stress,
        'du_dz_central_diff': du_dz_central_diff
    }, f)

print(f"Saved to {filename}")

### Find vertical velocity from horizontal velocity

In [None]:
# Zero slope approximation vertical strain rate

xs_tmp = xs[xs > start_pos_x]

interp_x = np.zeros((len(layer_solutions), len(xs_tmp)))
interp_z = np.zeros_like(interp_x)
interp_dl_dt = np.zeros_like(interp_x)
interp_dl_dx = np.zeros_like(interp_x)
for idx, layer_idx in enumerate(layer_solutions.keys()):
    interp_x[idx, :] = xs_tmp
    interp_z[idx, :] = layers_t0[layer_idx](xs_tmp)
    interp_dl_dt[idx, :] = layer_dl_dt(xs_tmp, layer_idx)
    interp_dl_dx[idx, :] = layer_dl_dx(xs_tmp, layer_idx)

X, Z = np.meshgrid(xs_tmp, zs)
dl_dt_grid = scipy.interpolate.griddata((interp_x.flatten(), interp_z.flatten()), interp_dl_dt.flatten(), (X, Z), method='linear')
dl_dx_grid = scipy.interpolate.griddata((interp_x.flatten(), interp_z.flatten()), interp_dl_dx.flatten(), (X, Z), method='linear')
dw_dz_zeroslope_grid = np.gradient(dl_dt_grid, zs, axis=0)


In [None]:
# Interpolate layer solutions to a grid defined by xs and zs
xs_tmp = xs[xs > start_pos_x]
interp_x = np.zeros((len(layer_solutions), len(xs_tmp)))
interp_z = np.zeros_like(interp_x)
interp_u = np.zeros_like(interp_x)
interp_w = np.zeros_like(interp_x)
for idx, layer_idx in enumerate(layer_solutions.keys()):
    interp_x[idx, :] = xs_tmp
    interp_z[idx, :] = layers_t0[layer_idx](xs_tmp)
    interp_u[idx, :] = layer_solution_velocity(layer_idx, xs_tmp)

    interp_w[idx, :] = layer_dl_dt(xs_tmp, layer_idx) + interp_u[idx,:]*layer_dl_dx(xs_tmp, layer_idx)

# Gridded horizontal velocity from layer line solutions
X, Z = np.meshgrid(xs_tmp, zs)
u_mol_grid = scipy.interpolate.griddata((interp_x.flatten(), interp_z.flatten()),
                                        interp_u.flatten(), (X, Z), method='linear')
w_mol_grid = scipy.interpolate.griddata((interp_x.flatten(), interp_z.flatten()),
                                        interp_w.flatten(), (X, Z), method='linear')

# Mask inside non-outer layers
mask = np.nan * np.zeros_like(X)
l1 = layers_t0[1](xs_tmp)
l2 = layers_t0[-2](xs_tmp)
mask[Z < np.maximum(l1, l2)] = 1
mask[Z <= np.minimum(l1, l2)] = np.nan

# Vertical strain rate
dw_dz_mol_grid = -1 * np.gradient(u_mol_grid, xs_tmp, axis=1)

fig, axs = plt.subplots(5,2, figsize=(16, 12), sharex=True)
((ax_U, ax_U_err), (ax_W, ax_W_err), (ax_W_zeroslope, ax_W_zeroslope_err), (ax_dWdz, ax_dWdz_err), (ax_dWdz_zeroslope, ax_dWdz_zeroslope_err)) = axs

err_clb_pct_of_max = 0.25

# Horizontal velocity

pcm_U = ax_U.pcolormesh(X/1e3, Z, mask * u_mol_grid, cmap='viridis')
fig.colorbar(pcm_U, ax=ax_U, label='Horizontal velocity [m/yr]')
ax_U.set_title('Horizontal velocity\n(interpolated from layer ODEs)')
ax_U.set_ylabel('z [m]')

U_tmp = lambdify_and_vectorize_if_needed((x, z), u)(X, Z)
pcm_U_err = ax_U_err.pcolormesh(X/1e3, Z, mask * (u_mol_grid - U_tmp*scipy.constants.year), cmap='coolwarm',
                                vmin=-1*err_clb_pct_of_max*np.nanmax(U_tmp*scipy.constants.year),
                                vmax=err_clb_pct_of_max*np.nanmax(U_tmp*scipy.constants.year))
fig.colorbar(pcm_U_err, ax=ax_U_err, label='Error in horizontal velocity [m/yr]')
ax_U_err.set_title('Error in horizontal velocity\n(ODE interpolation - true)')

rms_err_horizontal = np.sqrt(np.nanmean((mask * (u_mol_grid - U_tmp*scipy.constants.year))**2))
print(f"RMS error in horizontal velocity: {rms_err_horizontal:.2e} m/yr")

# Vertical velocity

pcm_W = ax_W.pcolormesh(X/1e3, Z, mask * w_mol_grid, cmap='viridis')
fig.colorbar(pcm_W, ax=ax_W, label='Vertical velocity [m/yr]')
ax_W.set_title('Vertical velocity\n(interpolated from layer ODEs)')
ax_W.set_ylabel('z [m]')

W_tmp = lambdify_and_vectorize_if_needed((x, z), w)(X, Z)
pcm_W_err = ax_W_err.pcolormesh(X/1e3, Z, mask * (w_mol_grid - W_tmp*scipy.constants.year), cmap='coolwarm',
                                vmin=-1*err_clb_pct_of_max*np.nanmax(np.abs(W_tmp*scipy.constants.year)),
                                vmax=err_clb_pct_of_max*np.nanmax(np.abs(W_tmp*scipy.constants.year)))
fig.colorbar(pcm_W_err, ax=ax_W_err, label='Error in vertical velocity [m/yr]')
ax_W_err.set_title('Error in vertical velocity\n(ODE interpolation - true)')

rms_err_vertical = np.sqrt(np.nanmean((mask * (w_mol_grid - W_tmp*scipy.constants.year))**2))
print(f"RMS error in vertical velocity: {rms_err_vertical:.2e} m/yr")

# Vertical velocity (zero layer slope approximation)

vmin_ref, vmax_ref = pcm_W.get_clim()
pcm_W_zeroslope = ax_W_zeroslope.pcolormesh(X/1e3, Z, mask * dl_dt_grid, cmap='viridis', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_W_zeroslope, ax=ax_W_zeroslope, label='Vertical velocity [m/yr]')
ax_W_zeroslope.set_title('Vertical velocity\n(zero slope approximation)')
ax_W_zeroslope.set_ylabel('z [m]')

vmin_ref, vmax_ref = pcm_W_err.get_clim()
pcm_W_zeroslope_err = ax_W_zeroslope_err.pcolormesh(X/1e3, Z, mask * (dl_dt_grid - W_tmp*scipy.constants.year), cmap='coolwarm', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_W_zeroslope_err, ax=ax_W_zeroslope_err, label='Error in vertical velocity [m/yr]')
ax_W_zeroslope_err.set_title('Error in vertical velocity\n(zero slope approximation - true)')

# Vertical strain rate

dwdz_abs_max = np.nanmax(np.abs(mask * dw_dz_mol_grid))
pcm_dWdz = ax_dWdz.pcolormesh(X/1e3, Z, mask * dw_dz_mol_grid, cmap='coolwarm', vmin=-1*dwdz_abs_max, vmax=dwdz_abs_max)
fig.colorbar(pcm_dWdz, ax=ax_dWdz, label='Vertical strain rate [m/(yr*m)]')
ax_dWdz.set_title('Vertical strain rate\n(interpolated from layer ODEs)')
ax_dWdz.set_xlabel('x [km]')
ax_dWdz.set_ylabel('z [m]')

dw_dz_sym = sympy.diff(scipy.constants.year * w, z)
dw_dz_lambdify = lambdify_and_vectorize_if_needed((x, z), dw_dz_sym)
dw_dz_true = dw_dz_lambdify(X, Z)

vmin_ref, vmax_ref = pcm_dWdz.get_clim()
dwdz_abs_max = np.maximum(np.abs(vmin_ref), np.abs(vmax_ref))
pcm_dWdz_err = ax_dWdz_err.pcolormesh(X/1e3, Z, mask * (dw_dz_mol_grid - dw_dz_true), cmap='coolwarm', vmin=-1*dwdz_abs_max, vmax=1*dwdz_abs_max)
fig.colorbar(pcm_dWdz_err, ax=ax_dWdz_err, label='Error in vertical strain rate\n[m/(yr*m)]')
ax_dWdz_err.set_title('Error in vertical strain rate\n(ODE interpolation - true)')
ax_dWdz_err.set_xlabel('x [km]')
ax_dWdz_err.set_ylabel('z [m]')

# Vertical strain rate (zero slope approximation)

vmin_ref, vmax_ref = pcm_dWdz.get_clim()
pcm_dWdz_zeroslope = ax_dWdz_zeroslope.pcolormesh(X/1e3, Z, mask * dw_dz_zeroslope_grid, cmap='coolwarm', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_dWdz_zeroslope, ax=ax_dWdz_zeroslope, label='Vertical strain rate [m/(yr*m)]')
ax_dWdz_zeroslope.set_title('Vertical strain rate\n(zero slope approximation)')
ax_dWdz_zeroslope.set_xlabel('x [km]')
ax_dWdz_zeroslope.set_ylabel('z [m]')

vmin_ref, vmax_ref = pcm_dWdz_err.get_clim()
pcm_dWdz_zeroslope_err = ax_dWdz_zeroslope_err.pcolormesh(X/1e3, Z, mask * (dw_dz_zeroslope_grid - dw_dz_true), cmap='coolwarm', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_dWdz_zeroslope_err, ax=ax_dWdz_zeroslope_err, label='Error in vertical strain rate\n[m/(yr*m)]')
ax_dWdz_zeroslope_err.set_title('Error in vertical strain rate\n(zero slope approximation - true)')
ax_dWdz_zeroslope_err.set_xlabel('x [km]')
ax_dWdz_zeroslope_err.set_ylabel('z [m]')

# Draw layer lines on all plots for reference
for layer in layers_t0:
    for ax in axs.flatten():
        ax.plot(xs_layers/1e3, layer(xs_layers), 'k--', linewidth=0.5)

fig.tight_layout()
fig.savefig(f"{output_results_base}_results_overview.png")
plt.show()

In [None]:
#interp_u[idx, :] = layer_solution_velocity(layer_idx, xs_tmp)
#interp_w[idx, :] = layer_dl_dt(xs_tmp, layer_idx) + interp_u[idx,:]*layer_dl_dx(xs_tmp, layer_idx)

def plot_layer_flow_misalignment(layer_solution_velocity, layer_dl_dt, layer_dl_dx, layers, xs):
    angles_flow = np.nan * np.zeros((len(layers), len(xs)))
    angles_layers = np.nan * np.zeros((len(layers), len(xs)))

    for idx, layer in enumerate(layers):
        if idx == 0 or idx == len(layers)-1:
            continue # Skip the top and bottom layers

        # Layer angles
        angles_layers[idx, :] = np.arctan2(layer_dl_dx(xs, idx), 1)

        # Flow angles
        u_layer = layer_solution_velocity(idx, xs)
        w_layer = layer_dl_dt(xs, idx) + u_layer*layer_dl_dx(xs, idx)
        angles_flow[idx, :] = np.arctan2(w_layer, u_layer)
    
    flow_misalignment = angles_flow - angles_layers
    # vmin = np.rad2deg(np.nanmin(flow_misalignment))
    # vmax = np.rad2deg(np.nanmax(flow_misalignment))

    vmin, vmax = -15, 15

    fig, ax = plt.subplots(figsize=(8, 4))
    sc = None
    for idx in range(len(layers)):
        sc = ax.scatter(xs/1e3, layers[idx](xs), c=np.rad2deg(flow_misalignment[idx, :]), vmin=vmin, vmax=vmax, s=2, cmap='coolwarm')
        sc.cmap.set_under('black')
        sc.cmap.set_over('orange')
    
    fig.colorbar(sc, ax=ax, label='Flow angle - Layer slope [deg]', extend='both')

    ax.set_title('Flow misalignment')
    ax.set_xlabel('x [km]')
    ax.set_ylabel('z [m]')
    ax.grid(True)

    return fig, ax

plot_layer_flow_misalignment(layer_solution_velocity, layer_dl_dt, layer_dl_dx, layers_t0, xs)

In [None]:
def plot_layer_vs_layer_misalignment(layers_reference, layers_compare, xs):
    compare_x = np.zeros((len(layers_compare), len(xs)))
    compare_z = np.zeros_like(compare_x)
    compare_dl_dt = np.zeros_like(compare_x)

    for idx, layer in enumerate(layers_compare):
        compare_x[idx, :] = xs
        compare_z[idx, :] = layer(xs)
        compare_dl_dt[idx, :] = np.gradient(compare_z[idx, :], xs)
    
    compare_dl_dt_interp = scipy.interpolate.LinearNDInterpolator((compare_x.flatten(), compare_z.flatten()), compare_dl_dt.flatten())

    angles_reference = np.nan * np.zeros((len(layers_reference), len(xs)))
    angles_compare = np.nan * np.zeros((len(layers_reference), len(xs)))

    for idx, layer in enumerate(layers_reference):
        # Layer angles - reference
        ref_zs = layer(xs)
        angles_reference[idx, :] = np.arctan2(np.gradient(ref_zs, xs), 1)

        angles_compare[idx, :] = np.arctan2(compare_dl_dt_interp(xs, layer(xs)), 1)
    
    misalignment = angles_compare - angles_reference
    
    vmin, vmax = -1*np.nanmax(np.abs(np.rad2deg(misalignment))), np.nanmax(np.abs(np.rad2deg(misalignment)))

    fig, ax = plt.subplots(figsize=(8, 4))
    sc = None
    for idx in range(len(layers_reference)):
        sc = ax.scatter(xs/1e3, layers_reference[idx](xs), c=np.rad2deg(misalignment[idx, :]), vmin=vmin, vmax=vmax, s=2, cmap='coolwarm')
        sc.cmap.set_under('black')
        sc.cmap.set_over('orange')
    
    fig.colorbar(sc, ax=ax, label='Simulated steady state - observed layer angle [deg]')

    ax.set_title('Layer misalignment from steady state')
    ax.set_xlabel('x [km]')
    ax.set_ylabel('z [m]')
    ax.grid(True)

    return fig, ax
    
fig, ax = plot_layer_vs_layer_misalignment(layers_t0, layers_current_velocity_steady_state, xs)
ax.plot(xs_layers/1e3, surface(xs_layers), 'k')
ax.set_xlim(50,100)

fig.savefig(f"{output_results_base}_layer_misalignment.png", dpi=500)

In [None]:
fig, axs = plt.subplots(1,2, figsize=(16, 4), sharex=True)
ax_U, ax_U_pct_surf = axs

# Horizontal velocity

pcm_U = ax_U.pcolormesh(X/1e3, Z, mask * u_mol_grid, cmap='viridis')
fig.colorbar(pcm_U, ax=ax_U, label='Horizontal velocity [m/yr]')
ax_U.set_title('Horizontal velocity\n(interpolated from layer ODEs)')
ax_U.set_ylabel('z [m]')

surface_u = lambdify_and_vectorize_if_needed(x, u.subs(z, surface_sym-0.1))(xs_tmp) * scipy.constants.year

pcm_U_pct_surf = ax_U_pct_surf.pcolormesh(X/1e3, Z, mask * u_mol_grid / surface_u, cmap='viridis', vmax=1, vmin=0.6)
fig.colorbar(pcm_U_pct_surf, ax=ax_U_pct_surf, label='Horizontal velocity / Surface velocity [diml]')


# Draw layer lines on all plots for reference
for layer in layers_t0:
    for ax in axs.flatten():
        ax.plot(xs_layers/1e3, layer(xs_layers), 'k--', linewidth=0.5)

fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(2,2, figsize=(12, 6), sharex=True)
((ax_W, ax_W_err), (ax_W_zeroslope, ax_W_zeroslope_err)) = axs

zero_slope_error_W = mask * (dl_dt_grid - W_tmp*scipy.constants.year)

# Vertical velocity

pcm_W = ax_W.pcolormesh(X/1e3, Z, mask * w_mol_grid, cmap='viridis', vmax=0)
fig.colorbar(pcm_W, ax=ax_W, label='Vertical velocity [m/yr]')
ax_W.set_title('Vertical velocity\n(interpolated from ODE solutions)')
ax_W.set_ylabel('z [m]')

W_tmp = lambdify_and_vectorize_if_needed((x, z), w)(X, Z)
pcm_W_err = ax_W_err.pcolormesh(X/1e3, Z, mask * (w_mol_grid - W_tmp*scipy.constants.year), cmap='coolwarm',
                                vmin=-1*np.nanmax(np.abs(zero_slope_error_W)),
                                vmax=np.nanmax(np.abs(zero_slope_error_W)))
fig.colorbar(pcm_W_err, ax=ax_W_err, label='Error in vertical velocity [m/yr]')
ax_W_err.set_title('Error in vertical velocity\n(ODE interpolation - true)')

rms_err_vertical = np.sqrt(np.nanmean((mask * (w_mol_grid - W_tmp*scipy.constants.year))**2))
print(f"RMS error in vertical velocity: {rms_err_vertical:.2e} m/yr")

# Vertical velocity (zero layer slope approximation)

vmin_ref, vmax_ref = pcm_W.get_clim()
pcm_W_zeroslope = ax_W_zeroslope.pcolormesh(X/1e3, Z, mask * dl_dt_grid, cmap='viridis', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_W_zeroslope, ax=ax_W_zeroslope, label='Vertical velocity [m/yr]')
ax_W_zeroslope.set_title('Vertical velocity\n(zero slope approximation)')
ax_W_zeroslope.set_ylabel('z [m]')
ax_W_zeroslope.set_xlabel('x [km]')

vmin_ref, vmax_ref = pcm_W_err.get_clim()
pcm_W_zeroslope_err = ax_W_zeroslope_err.pcolormesh(X/1e3, Z, zero_slope_error_W, cmap='coolwarm', vmin=vmin_ref, vmax=vmax_ref)
fig.colorbar(pcm_W_zeroslope_err, ax=ax_W_zeroslope_err, label='Error in vertical velocity [m/yr]')
ax_W_zeroslope_err.set_title('Error in vertical velocity\n(zero slope approximation - true)')
ax_W_zeroslope_err.set_xlabel('x [km]')

ax_W_zeroslope_err.set_xlim(50,100)

# Draw layer lines on all plots for reference
for ax in axs.flatten():
    for layer in layers_t0:
        ax.plot(xs_layers/1e3, layer(xs_layers), 'k--', linewidth=0.5)

    ax.plot(xs_layers/1e3, surface(xs_layers), 'k')

# Subfigure labels
subfigure_labels = ['(a)', '(b)', '(c)', '(d)']
for idx, ax in enumerate(axs.flatten()):
    ax.set_title(subfigure_labels[idx], loc='left', fontweight='bold')

# Add layer line legend
import matplotlib.lines as mlines
layer_line = mlines.Line2D([], [], linewidth=0.5, linestyle='--', color='k', label='Layers')
ax_W_zeroslope.legend(handles=[layer_line], bbox_to_anchor=(0.25, -0.08))

fig.tight_layout()
fig.savefig(f"{output_results_base}_vertical_velocity_error.png", dpi=500)
plt.show()

In [None]:
fig_hv_err, axs_hv_err = plt.subplots(2,2, figsize=(12, 5), sharex=True)
((ax_U, ax_U_err), (ax0, ax1)) = axs_hv_err

err_clb_pct_of_max = 0.05

# Horizontal velocity

pcm_U = ax_U.pcolormesh(X/1e3, Z, mask * u_mol_grid, cmap='viridis')
fig_hv_err.colorbar(pcm_U, ax=ax_U, label='Horizontal velocity [m/yr]')
ax_U.set_title('Horizontal velocity\n(interpolated from ODE solutions, 25 dB SNR)')
ax_U.set_ylabel('z [m]')

U_tmp = lambdify_and_vectorize_if_needed((x, z), u)(X, Z)
pcm_U_err = ax_U_err.pcolormesh(X/1e3, Z, mask * (u_mol_grid - U_tmp*scipy.constants.year), cmap='coolwarm',
                                vmin=-1*err_clb_pct_of_max*np.nanmax(U_tmp*scipy.constants.year),
                                vmax=err_clb_pct_of_max*np.nanmax(U_tmp*scipy.constants.year))
fig_hv_err.colorbar(pcm_U_err, ax=ax_U_err, label='Error in horizontal velocity [m/yr]')

rms_err_horizontal = np.sqrt(np.nanmean((mask * (u_mol_grid - U_tmp*scipy.constants.year))**2))
print(f"RMS error in horizontal velocity: {rms_err_horizontal:.2e} m/yr")

ax_U_err.set_title(f'Error in horizontal velocity\n({rms_err_horizontal:.2e} m/yr RMS)')



# Draw layer lines on all plots for reference
for layer in layers_t0:
    for ax in axs_hv_err.flatten():
        ax.plot(xs_layers/1e3, layer(xs_layers), 'k--', linewidth=0.5)

fig_hv_err.tight_layout()
fig_hv_err.savefig(f"{output_results_base}_horizontal_velocity_error.png")
plt.show()

### Explore a single layer solution

In [None]:
layer_idx = 13

fig, axs = plt.subplots(4, 1, figsize=(10, 12), sharex=True)
ax_U, ax_W, ax_deriv, ax_stab = axs

sol = layer_solutions[layer_idx]

# Horizontal Velocity
ax_U.set_title(f'Horizontal Velocity at Layer {layer_idx}')
ax_U.scatter(sol.t/1e3, sol.y, s=2, label=f'ODE solution for layer {layer_idx}', c='red')
# truth
xs_tmp = sol.t
ax_U.plot(xs_tmp/1e3, lambdify_and_vectorize_if_needed((x,z), u)(xs_tmp, layers_t0[layer_idx](xs_tmp)) * scipy.constants.year, 'k--', alpha=0.5, label='True')

# Vertical Velocity
ax_W.set_title(f'Vertical Velocity at Layer {layer_idx}')
ax_W.scatter(sol.t/1e3, layer_dl_dt(sol.t, layer_idx) + sol.y[0,:]*layer_dl_dx(sol.t, layer_idx), s=2, label='w(x)')
ax_W.plot(xs_tmp/1e3, lambdify_and_vectorize_if_needed((x,z), w)(xs_tmp, layers_t0[layer_idx](xs_tmp)) * scipy.constants.year, 'k--', alpha=0.5, label='True')

ax_W.scatter(sol.t/1e3, layer_dl_dt(sol.t, layer_idx), s=2, label='dl/dt')
ax_W.scatter(sol.t/1e3, sol.y[0,:]*layer_dl_dx(sol.t, layer_idx), s=2, label='u * dl/dx')

ax_W.set_ylim(-4,4)

# Plot the derivative
ax_deriv.set_title(f'du/dtau at Layer {layer_idx}')
xs_tmp = sol.t
ax_deriv.scatter(xs_tmp/1e3, du_dtau(xs_tmp, sol.y[0,:], layer_idx), s=2, label="du_dtau(x) (ODE Function)")

# Stability criterion
ax_stab.set_title(f'd2l_dxdz at Layer {layer_idx} (stability)')
stab_crit = layer_d2l_dxdz(sol.t, layer_idx)
ax_stab.scatter(sol.t/1e3, stab_crit, s=2)

# Properties for all plots
for ax in axs:
    ax.grid(True)
    ax.legend()
    ax.set_xlim(0, domain_x/1e3)

plt.show()

In [None]:
print(f"Total elapsed time: {(time.time() - t_start_notebook)/60:.2f} minutes")