# Experiment 9.2: Sine Gordon Ablations

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import sg_res as pde_res

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.wrappers import SGModel

import numpy as np
from jax import device_get

import optax
from flax import nnx

import time
import pickle

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "ablations_sg.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

RESULTS = dict()
        
seed = 42

## Data & Parameters

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(0,1), (0,1)], sample_size = 400)
ic_collocs = _get_ic_collocs(x_range = (0, 1), sample_size = 2**6)
ic_data = jnp.sin(jnp.pi*ic_collocs[:,1]).reshape(-1,1)

# Reference solution
ref = np.load('data/sg.npz')
refsol = jnp.array(ref['usol'])

N_t, N_x = ref['usol'].shape
t, x = ref['t'].flatten(), ref['x'].flatten()
T, X = jnp.meshgrid(t, x, indexing='ij')
coords = jnp.hstack([T.flatten()[:, None], X.flatten()[:, None]])

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = collocs_pool.shape[1]
n_out = 1
D = 5
period_axes = None
sine_D = 5
alpha = 0.0
beta = 0.0
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}
n_hidden = 16
num_blocks = 6

## Ablation 1: Only RBA

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    loss = jnp.mean(w_resids**2)
    
    return loss, l_E_new


# IC Loss
def ic_loss(model, l_I, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Get new RBA weights
    abs_res = jnp.abs(ic_res)
    l_I_new = (RBA_gamma*l_I) + (RBA_eta*abs_res/jnp.max(abs_res))

    # Multiply by RBA weights
    w_resids = l_I_new * ic_res

    # Loss
    loss = jnp.mean(w_resids**2)

    return loss, l_I_new


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I):

    # PDE loss
    (loss_E, l_E_new), grads_E = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # IC loss
    (loss_I, l_I_new), grads_I = nnx.value_and_grad(ic_loss, has_aux=True)(model, l_I, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I, l_E_new, l_I_new

In [None]:
RESULTS["Ablation 1"] = dict()

for idx, run in enumerate([0, 7, 42]):

    RESULTS["Ablation 1"][idx] = dict()
    
    # Initialize RBA weights - full pool
    l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
    # Also get RBAs for ICs
    l_I = jnp.ones((ic_collocs.shape[0], 1))

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    l_E = l_E_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    model = SGModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                   init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                   period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)

    if idx == 0:
        print(f"Initialized model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)
    
    tick = time.time()

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I, l_E, l_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)

    tack = time.time()
        
    l2error = model_eval(model, coords, refsol)

    print(f"\tRun = {idx}\t L^2 = {l2error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

    RESULTS["Ablation 1"][idx]['time'] = tack-tick
    RESULTS["Ablation 1"][idx]['l2'] = l2error.item()

In [None]:
# Checkpoint 1
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Ablation 2: No RBA

In [None]:
# PDE Loss
def pde_loss(model, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = residuals.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss


# IC Loss
def ic_loss(model, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Loss
    loss = jnp.mean(ic_res**2)

    return loss


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I):

    # PDE loss
    loss_E, grads_E = nnx.value_and_grad(pde_loss, has_aux=False)(model, collocs)

    # IC loss
    loss_I, grads_I = nnx.value_and_grad(ic_loss, has_aux=False)(model, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices):

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Get absolute
    wa_resids = jnp.abs(resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices

In [None]:
RESULTS["Ablation 2"] = dict()

for idx, run in enumerate([0, 7, 42]):

    RESULTS["Ablation 2"][idx] = dict()

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    model = SGModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                   init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                   period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)

    if idx == 0:
        print(f"Initialized model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)
    
    tick = time.time()

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I)
        
        # Perform grad norm
        if (epoch != 0) and (epoch % f_grad_norm == 0):
    
            λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
    
        # Perform RAD
        if (epoch != 0) and (epoch % f_resample == 0):

            # Get new indices after resampling
            sorted_indices = get_RAD_indices(model, collocs_pool, sorted_indices)
            # Set new batch of collocs
            collocs = collocs_pool[sorted_indices]

    tack = time.time()

    l2error = model_eval(model, coords, refsol)

    print(f"\tRun = {idx}\t L^2 = {l2error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

    RESULTS["Ablation 2"][idx]['time'] = tack-tick
    RESULTS["Ablation 2"][idx]['l2'] = l2error.item()

In [None]:
# Checkpoint 2
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Ablation 3: No RBA & No RAD

In [None]:
# PDE Loss
def pde_loss(model, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = residuals.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss


# IC Loss
def ic_loss(model, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Loss
    loss = jnp.mean(ic_res**2)

    return loss


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I):

    # PDE loss
    loss_E, grads_E = nnx.value_and_grad(pde_loss, has_aux=False)(model, collocs)

    # IC loss
    loss_I, grads_I = nnx.value_and_grad(ic_loss, has_aux=False)(model, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I

In [None]:
RESULTS["Ablation 3"] = dict()

for idx, run in enumerate([0, 7, 42]):

    RESULTS["Ablation 3"][idx] = dict()

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    model = SGModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                   init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                   period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)

    if idx == 0:
        print(f"Initialized model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)
    
    tick = time.time()

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I)
        
        # Perform grad norm
        if (epoch != 0) and (epoch % f_grad_norm == 0):
    
            λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)

    tack = time.time()

    l2error = model_eval(model, coords, refsol)

    print(f"\tRun = {idx}\t L^2 = {l2error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

    RESULTS["Ablation 3"][idx]['time'] = tack-tick
    RESULTS["Ablation 3"][idx]['l2'] = l2error.item()

In [None]:
# Checkpoint 3
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Ablation 4: No RBA & No Causal

In [None]:
# PDE Loss
def pde_loss(model, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # loss
    loss = jnp.mean(residuals**2)

    return loss


# IC Loss
def ic_loss(model, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Loss
    loss = jnp.mean(ic_res**2)

    return loss


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I):

    # PDE loss
    loss_E, grads_E = nnx.value_and_grad(pde_loss, has_aux=False)(model, collocs)

    # IC loss
    loss_I, grads_I = nnx.value_and_grad(ic_loss, has_aux=False)(model, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices):

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Get absolute
    wa_resids = jnp.abs(resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices

In [None]:
RESULTS["Ablation 4"] = dict()

for idx, run in enumerate([0, 7, 42]):

    RESULTS["Ablation 4"][idx] = dict()

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    model = SGModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                   init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                   period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)

    if idx == 0:
        print(f"Initialized model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)
    
    tick = time.time()

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I)
        
        # Perform grad norm
        if (epoch != 0) and (epoch % f_grad_norm == 0):
    
            λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
    
        # Perform RAD
        if (epoch != 0) and (epoch % f_resample == 0):

            # Get new indices after resampling
            sorted_indices = get_RAD_indices(model, collocs_pool, sorted_indices)
            # Set new batch of collocs
            collocs = collocs_pool[sorted_indices]

    tack = time.time()

    l2error = model_eval(model, coords, refsol)

    print(f"\tRun = {idx}\t L^2 = {l2error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

    RESULTS["Ablation 4"][idx]['time'] = tack-tick
    RESULTS["Ablation 4"][idx]['l2'] = l2error.item()

In [None]:
# Checkpoint 4
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Ablation 5: No RBA & No Grad Norm

In [None]:
# PDE Loss
def pde_loss(model, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = residuals.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss


# IC Loss
def ic_loss(model, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Loss
    loss = jnp.mean(ic_res**2)

    return loss


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I):

    # PDE loss
    loss_E, grads_E = nnx.value_and_grad(pde_loss, has_aux=False)(model, collocs)

    # IC loss
    loss_I, grads_I = nnx.value_and_grad(ic_loss, has_aux=False)(model, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices):

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Get absolute
    wa_resids = jnp.abs(resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices

In [None]:
RESULTS["Ablation 5"] = dict()

for idx, run in enumerate([0, 7, 42]):

    RESULTS["Ablation 5"][idx] = dict()

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    model = SGModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                   init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                   period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)

    if idx == 0:
        print(f"Initialized model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)
    
    tick = time.time()

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I)
    
        # Perform RAD
        if (epoch != 0) and (epoch % f_resample == 0):

            # Get new indices after resampling
            sorted_indices = get_RAD_indices(model, collocs_pool, sorted_indices)
            # Set new batch of collocs
            collocs = collocs_pool[sorted_indices]

    tack = time.time()

    l2error = model_eval(model, coords, refsol)

    print(f"\tRun = {idx}\t L^2 = {l2error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

    RESULTS["Ablation 5"][idx]['time'] = tack-tick
    RESULTS["Ablation 5"][idx]['l2'] = l2error.item()

In [None]:
# Checkpoint 5
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Analysis

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

with open(os.path.join(results_dir, "benchmarks_sg.pkl"), "rb") as f:
    rbn = pickle.load(f)

rbn = rbn[(alpha,beta)]

In [None]:
def summarize_results(results, rbn):

    l2_values = []
    for run in rbn.keys():
        l2_values.append(rbn[run]['l2'])

    mean = np.mean(l2_values)
    sem = np.std(l2_values, ddof=1) / np.sqrt(len(l2_values))

    print(f"Original:\t L^2 = {mean:.3e}  Error = {sem:.3e}")
    
    for ablation, runs in results.items():
        l2_values = [v['l2'] for v in runs.values()]
        mean = np.mean(l2_values)
        sem = np.std(l2_values, ddof=1) / np.sqrt(len(l2_values))
        print(f"{ablation}:\t L^2 = {mean:.3e}  Error = {sem:.3e}")

In [None]:
summarize_results(RESULTS, rbn)