In [None]:
import time

# Device configuration and core PyTorch setup
import torch
import torch.nn as nn
import torch.nn.functional as F
DTYPE = torch.float32
device = torch.device('cuda')
if device.type == "cuda":
    torch.cuda.init()
    torch.rand(1, device=device)
print(f"Using device: {device}")

from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from SA_PINN_ACTO import PINN_BDNK_1D
from IC_1D import IC_BDNK

# Hyperparameters for network architecture and training schedule
Nl, Nn = 10, 70
t_end = 20.0
L = 50.0
adam_epochs = 40_000
lr_net = 2e-3
lr_mask = 8e-2

# Sampling parameters and domain sampling
N_colloc = 20_000
N_ic = 1000
N_bc = 1000

def lhs_box(n, low, high, rng=np.random):
    low, high = np.asarray(low, float), np.asarray(high, float)
    D = low.size
    H = np.empty((n, D), float)
    for j in range(D):
        P = (rng.permutation(n) + rng.rand(n)) / n
        H[:, j] = low[j] + P * (high[j] - low[j])
    return H

X_colloc_np = lhs_box(N_colloc, low=np.array([0.0, -L]), high=np.array([t_end, L])).astype(np.float32)

# Construction of initial condition sampling grid
x_edges = np.linspace(-L, L, N_ic+1)
x_ic = (0.5 * (x_edges[:-1] + x_edges[1:])).reshape(-1, 1)
t_ic = np.zeros_like(x_ic)
X_ic = np.hstack((t_ic, x_ic))

X_ic_t = torch.tensor(X_ic, dtype=DTYPE, device=device)
alpha1st_ic_t, alpha2nd_ic_t = IC_BDNK(X_ic_t, L)

# Scaling factors for numerical stability of IC enforcement
with torch.no_grad():
    sA = alpha1st_ic_t.abs().max().clamp_min(1e-12).item()
print(f"[scales] sA={sA:.3e}")

# Sorting IC data to enable fast 1D interpolation
x_ic_torch = X_ic_t[:, 1:2].contiguous().view(-1)
x_sorted, idx_sort = torch.sort(x_ic_torch)
alpha1st_sorted = alpha1st_ic_t.view(-1)[idx_sort]
alpha2nd_sorted = alpha2nd_ic_t.view(-1)[idx_sort]

@torch.no_grad()
def _torch_lin_interp_1d(xq: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    xq_flat = xq.view(-1)
    xq_clamped = xq_flat.clamp(min=x[0], max=x[-1])
    idx_hi = torch.searchsorted(x, xq_clamped, right=True)
    idx_hi = idx_hi.clamp(min=1, max=x.numel() - 1)
    idx_lo = idx_hi - 1
    x0 = x[idx_lo]; x1 = x[idx_hi]
    y0 = y[idx_lo]; y1 = y[idx_hi]
    denom = (x1 - x0)
    denom = torch.where(denom.abs() > 0, denom, torch.ones_like(denom))
    w = (xq_clamped - x0) / denom
    yq = y0 + w * (y1 - y0)
    return yq.view_as(xq)

# Initial condition functions passed to the neural network (physical scale)
def alpha1st_ic_func(x_phys: torch.Tensor) -> torch.Tensor:
    if x_sorted.device != x_phys.device or x_sorted.dtype != x_phys.dtype:
        xk = x_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
        yk = alpha1st_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
    else:
        xk, yk = x_sorted, alpha1st_sorted
    yq = _torch_lin_interp_1d(x_phys.view(-1), xk, yk)
    return yq.view(-1, 1)

def alpha2nd_ic_func(x_phys: torch.Tensor) -> torch.Tensor:
    if x_sorted.device != x_phys.device or x_sorted.dtype != x_phys.dtype:
        xk = x_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
        yk = alpha2nd_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
    else:
        xk, yk = x_sorted, alpha2nd_sorted
    yq = _torch_lin_interp_1d(x_phys.view(-1), xk, yk)
    return yq.view(-1, 1)

# Scaled initial condition functions (used internally by the model)
def alpha1st_ic_func_scaled(x_phys: torch.Tensor) -> torch.Tensor:
    return alpha1st_ic_func(x_phys) / sA

def alpha2nd_ic_func_scaled(x_phys: torch.Tensor) -> torch.Tensor:
    return alpha2nd_ic_func(x_phys) / sA

xL = -L * np.ones((N_bc, 1))
xR =  L * np.ones((N_bc, 1))
t_bc = np.random.uniform(0.0, t_end, size=(N_bc, 1))
X_bc_L = np.hstack((t_bc, xL))
X_bc_R = np.hstack((t_bc, xR))

X_colloc = torch.tensor(X_colloc_np, dtype=DTYPE, device=device)
x0_line = torch.linspace(-L, L, 500, dtype=DTYPE, device=device).unsqueeze(1)
X0 = torch.cat([torch.zeros_like(x0_line), x0_line], dim=1)
X_colloc = torch.cat([X_colloc, X0], dim=0)

X_bc_L = torch.tensor(X_bc_L, dtype=DTYPE, device=device)
X_bc_R = torch.tensor(X_bc_R, dtype=DTYPE, device=device)

Nt = 100
x_mass = torch.tensor(x_ic.flatten(), dtype=DTYPE, device=device)
t_mass = torch.linspace(0, t_end, Nt, dtype=DTYPE, device=device)

# Self-adaptive collocation mask (learned weighting of PDE residual)
pde_logits = torch.nn.Parameter(torch.zeros((X_colloc.shape[0], 1), dtype=DTYPE, device=device))
def current_masks(detach: bool = False):
    pde = F.softplus(pde_logits)
    return pde.detach() if detach else pde

# Model instantiation and domain normalization
lb = torch.tensor([0.0, -L], dtype=DTYPE, device=device)
ub = torch.tensor([t_end,  L], dtype=DTYPE, device=device)
model = PINN_BDNK_1D(Nl=Nl, Nn=Nn, lb=lb, ub=ub).to(device).to(DTYPE)
model.alpha_ic_func = alpha1st_ic_func_scaled
model.sA.copy_(torch.tensor(sA, dtype=DTYPE, device=device))

# Weight initialization
def glorot_normal_all_linear(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight, gain=1.0)
        if m.bias is not None: nn.init.zeros_(m.bias)
model.apply(glorot_normal_all_linear)

# Optimizers and learning-rate scheduler setup
optimizer_theta = torch.optim.Adam(model.parameters(), lr=lr_net, betas=(0.85, 0.92))
scheduler = ReduceLROnPlateau(optimizer_theta, mode='min', factor=0.6, patience=2500, threshold=1e-4, min_lr=lr_net/100)
optimizer_mask  = torch.optim.Adam([pde_logits], lr=lr_mask, betas=(0.7, 0.85), maximize=True)

# Setting up and executing Adam pre-training
def train_adam(model, optimizer_theta, optimizer_mask, epochs, print_every):
    print("Starting Adam pre-training (SA-PINN with hard IC)...")
    best_loss, best_state = float('inf'), None
    loss_history = []

    for epoch in range(1, epochs + 1):
        optimizer_theta.zero_grad()
        optimizer_mask.zero_grad()

        R = model.pde_residual(X_colloc)
        Rnorm = R.abs()
        pde_mask = current_masks(detach=False)
        L_pde = (pde_mask * Rnorm).pow(2).mean()

        L_ic = model.loss_ic(X_ic_t, alpha2nd_ic_t)
        loss = L_pde + L_ic
        if not torch.isfinite(loss): raise RuntimeError("Non-finite loss detected.")

        L_pde_phys   = (R**2).mean()
        L_total_phys = L_pde_phys + L_ic

        loss.backward()
        optimizer_theta.step()
        scheduler.step(L_total_phys.item())
        optimizer_mask.step()

        ltp = L_total_phys.detach().item()
        loss_history.append(ltp)
        if ltp < best_loss:
            best_loss = ltp
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        if epoch % print_every == 0 or epoch == epochs:
            with torch.no_grad():
                m_pde = current_masks(detach=True)
            print(f"Adam Epoch {epoch}/{epochs} | "
                  f"Total={loss:.3e}, PDE={L_pde.item():.3e}, <pde_mask>={m_pde.mean().item():.2f} | "
                  f"Unmasked: Total={L_total_phys:.3e}, PDE={L_pde_phys.item():.3e}, IC={L_ic:.3e} | "
                  f"lr_net={optimizer_theta.param_groups[0]['lr']:.3e}")

    if best_state is not None:
        model.load_state_dict(best_state)
        with torch.enable_grad():
            R = model.pde_residual(X_colloc)
            L_pde_phys = (R**2).mean()
            L_ic = model.loss_ic(X_ic_t, alpha2nd_ic_t)
        print(f"\nAdam finished. Best loss = {best_loss:.3e} | PDE={L_pde_phys.item():.4e}, IC={L_ic.item():.4e}")

    return best_loss, loss_history, best_state

start_training = time.time()
adam_loss, adam_loss_history, best_state = train_adam(model, optimizer_theta, optimizer_mask, adam_epochs, print_every=1000)
adam_training_finished = time.time()

model.load_state_dict(best_state)

# Setting up executing L-BFGS fine-tuning
with torch.enable_grad():
    X_colloc_tmp = X_colloc.clone().detach().requires_grad_(True)
    R0   = model.pde_residual(X_colloc_tmp)
    L_IC = model.loss_ic(X_ic_t, alpha2nd_ic_t)
    init_lbfgs_loss = (R0**2).mean().detach().item() + L_IC.detach().item()

print(f"LBFGS init unmasked PDE loss (from best Adam): {init_lbfgs_loss:.3e}")

loss_scale = 1.0 / max(init_lbfgs_loss, 1e-30)

optimizer_lbfgs = torch.optim.LBFGS(
    model.parameters(),
    lr=1.0,
    max_iter=3000,
    max_eval=3000,
    history_size=3000,
    line_search_fn=None,
    tolerance_grad=1e-10,
    tolerance_change=1e-12,
)

best = {"loss": float("inf"), "state": None}
inner_curve = []

def closure():
    optimizer_lbfgs.zero_grad(set_to_none=True)
    X = X_colloc.requires_grad_(True)

    R = model.pde_residual(X)
    L_ic = model.loss_ic(X_ic_t, alpha2nd_ic_t)
    raw = (R.pow(2)).mean() + L_ic

    if not torch.isfinite(raw):
        print(f"NaN/Inf detected at iter {len(inner_curve)}. Exiting L-BFGS.")
        raise RuntimeError("L-BFGS_NAN")

    loss = raw * loss_scale

    raw_f = float(raw)
    if raw_f < best["loss"]:
        best["loss"] = raw_f
        best["state"] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    inner_curve.append(raw_f)

    loss.backward()
    return loss

lbfgs_training_started = time.time()
try:
    final_loss = optimizer_lbfgs.step(closure)
    final_raw = float(final_loss.item()) / loss_scale
except RuntimeError as e:
    if "L-BFGS_NAN" in str(e):
        print("L-BFGS terminated early due to NaN/Inf.")
        final_raw = float("nan")
    else:
        raise
end_training = time.time()

# Restore best L-BFGS state and record loss history
if best["state"] is not None:
    model.load_state_dict(best["state"])

with torch.enable_grad():
    X_eval = X_colloc.clone().detach().requires_grad_(True)
    R_eval = model.pde_residual(X_eval)
    L_pde_best = (R_eval**2).mean().detach().item()
    L_ic_best = model.loss_ic(X_ic_t, alpha2nd_ic_t).detach().item()
    total_best = L_pde_best + L_ic_best

lbfgs_loss_history = inner_curve

print(f"LBFGS finished. Final total loss={final_raw:.4e} | Best total loss={best['loss']:.4e} | "
      f"PDE={L_pde_best:.4e} | IC={L_ic_best:.4e} | Iterations={len(inner_curve)}")

print(f"\nTotal training time: {end_training - start_training:.4f} seconds."
      f"\nAdam: {adam_training_finished - start_training:.4f} seconds. "
      f"L-BFGS: {end_training - lbfgs_training_started:.4f} seconds.")

In [None]:
from Plotting import *

plot_collocation_points(X_colloc, X_ic=None, X_bc_L=None, X_bc_R=None, L=L, t_end=t_end)

alpha_ic_t = alpha1st_ic_t

x_vals_ic = X_ic_t[:, 1:2]
d, f, g = 0.05, 10.0, 1.05
J0_ic_t = (d * torch.exp(- (f * x_vals_ic / L)**2) + g)

J0_ic_np    = J0_ic_t.detach().cpu().numpy()
alpha_ic_np = alpha_ic_t.detach().cpu().numpy()

x_edges_eval = np.linspace(-L, L, 200+1)
x_eval = 0.5 * (x_edges_eval[:-1] + x_edges_eval[1:])
t_eval = np.linspace(0, t_end, 200)

plot_results(
    model,
    t_eval=t_eval,
    x_eval=x_eval,
    alpha_ic=alpha_ic_t.detach().cpu().numpy(),
    J0_ic=J0_ic_t.detach().cpu().numpy()
)

with torch.enable_grad():
    plot_pde_residuals(model, t_eval, x_eval)

lbfgs_history = {"all_inner_per_epoch": [lbfgs_loss_history]}
plot_combined_loss_history(adam_loss_history, lbfgs_history)

In [None]:
# Save all results

import os, time
from typing import Dict, Any

def _as_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

@torch.no_grad()
def _eval_on_grid(model, t_eval: np.ndarray, x_eval: np.ndarray, batch_size: int = 32768):
    """
    Evaluates:
      - PDE residual R0 = model.pde_residual(...)
      - α(t,x), J^0(t,x), n(t,x)
    over t_eval × x_eval.
    Returns a dict with all grids, to be dumped in the npz.
    """
    was_training = model.training
    model.eval()

    t_eval = np.asarray(t_eval, dtype=np.float32).ravel()
    x_eval = np.asarray(x_eval, dtype=np.float32).ravel()
    Nt, Nx = len(t_eval), len(x_eval)

    TT, XX = np.meshgrid(t_eval, x_eval, indexing='ij')
    TX = np.stack([TT.reshape(-1), XX.reshape(-1)], axis=1)
    
    TX_t = torch.tensor(TX, dtype=DTYPE, device=device, requires_grad=True)
    r_list = []
    with torch.enable_grad():
        for i in range(0, TX_t.shape[0], batch_size):
            r = model.pde_residual(TX_t[i:i+batch_size])
            r_list.append(r)
        r_all = torch.cat(r_list, dim=0)
    r_pde_grid = r_all.view(Nt, Nx, -1).detach().cpu().numpy()

    TX_t2 = torch.tensor(TX, dtype=DTYPE, device=device, requires_grad=True)
    with torch.enable_grad():
        alpha = model(TX_t2)

        grad_alpha = torch.autograd.grad(
            alpha, TX_t2,
            grad_outputs=torch.ones_like(alpha),
            create_graph=True, retain_graph=True
        )[0]
        alpha_t = grad_alpha[:, 0:1]

        t_tensor = TX_t2[:, 0:1]
        x_tensor = TX_t2[:, 1:2]
        T_tensor = T_func(t_tensor, x_tensor)
        v_tensor = v_func(t_tensor, x_tensor)

        n_tensor  = n_from_alpha_func(alpha, T_tensor)
        J0_tensor = J0_func(T_tensor, v_tensor, alpha, alpha_t, TX_t2)

    alpha_grid = alpha.view(Nt, Nx, 1).detach().cpu().numpy()
    J0_grid    = J0_tensor.view(Nt, Nx).detach().cpu().numpy().astype(np.float32)
    n_grid     = n_tensor.view(Nt, Nx).detach().cpu().numpy().astype(np.float32)

    pred_grid  = alpha_grid
    pred_names = np.array(["alpha"], dtype=object)

    # restore mode
    if was_training:
        model.train()

    return {
        "t_eval": t_eval,
        "x_eval": x_eval,
        "TX": TX.astype(np.float32),
        "r_pde_grid": r_pde_grid,
        "pred_grid": pred_grid if pred_grid is not None else None,
        "pred_names": np.array(pred_names, dtype=object) if pred_names is not None else None,
        "J0_grid": J0_grid,
        "n_grid":  n_grid,
    }

def _compute_residual_fields(model, t_eval: np.ndarray, x_eval: np.ndarray):
    model.eval()

    p = next(model.parameters())
    dev, dty = p.device, p.dtype

    t_eval = np.asarray(t_eval, dtype=np.float64).ravel()
    x_eval = np.asarray(x_eval, dtype=np.float64).ravel()
    Nt, Nx = len(t_eval), len(x_eval)

    tt, xx = np.meshgrid(t_eval, x_eval, indexing='ij')
    tx = np.column_stack([tt.ravel(), xx.ravel()])
    tx_tensor = torch.tensor(tx, dtype=dty, device=dev, requires_grad=True)

    def grad(u):
        return torch.autograd.grad(
            u, tx_tensor, grad_outputs=torch.ones_like(u),
            create_graph=True, retain_graph=True
        )[0]

    with torch.set_grad_enabled(True):
        alpha = model(tx_tensor)

        t = tx_tensor[:, 0:1]
        x = tx_tensor[:, 1:2]

        T     = T_func(t, x)
        v     = v_func(t, x)
        gamma = gamma_func(v)

        n     = n_from_alpha_func(alpha, T)
        sigma = sigma_func(alpha, T)
        lambd = lambd_func(sigma)

        a_g     = grad(alpha)
        alpha_t = a_g[:, 0:1]
        alpha_x = a_g[:, 1:2]
        N_x     = -alpha_x

        J0 = J0_func(T, v, alpha, alpha_t, tx_tensor)
        N_0 = N_0_func(lambd, sigma, T, J0, n, N_x, v)
        Jx  = Jx_func(n, sigma, lambd, T, N_x, N_0, v)

        # R1
        J0_t = grad(J0)[:, 0:1]
        Jx_x = grad(Jx)[:, 1:2]
        R1   = J0_t + Jx_x

        # R2
        R2   = alpha_t + N_0

        # R0
        helper   = alpha_t + v * alpha_x
        d_gn_dt  = grad(gamma * n)[:, 0:1]
        d_gnv_dx = grad(gamma * n * v)[:, 1:2]
        d_lt_dt  = grad((gamma**2) * lambd * T * helper)[:, 0:1]
        d_lx_dx  = grad((gamma**2) * v * lambd * T * helper)[:, 1:2]
        Wt       = -alpha_t + (gamma**2) * helper
        Wx       =  alpha_x + (gamma**2) * v * helper
        d_st_dt  = grad(sigma * T * Wt)[:, 0:1]
        d_sx_dx  = grad(sigma * T * Wx)[:, 1:2]
        R0 = d_gn_dt + d_gnv_dx + d_lt_dt + d_lx_dx - d_st_dt - d_sx_dx 

    def to_grid(Tv):
        return Tv.detach().cpu().numpy().reshape(Nt, Nx)

    return {
        "R1_grid": to_grid(R1),
        "R2_grid": to_grid(R2),
        "R0_grid": to_grid(R0),
    }

res_fields = _compute_residual_fields(model, t_eval=t_eval, x_eval=x_eval)
grid_dump = _eval_on_grid(model, t_eval=t_eval, x_eval=x_eval)
run_id = time.strftime("%Y%m%d-%H%M%S")
save_dir = os.path.abspath("./pinn_runs")
os.makedirs(save_dir, exist_ok=True)
dump_path = os.path.join(save_dir, f"PINN_run_dump_{run_id}.npz")

lbfgs_all_inner = np.array(lbfgs_history["all_inner_per_epoch"], dtype=object) if "all_inner_per_epoch" in lbfgs_history else np.array([], dtype=object)
lbfgs_inner_curve_np = np.asarray(lbfgs_loss_history, dtype=np.float64)

np.savez_compressed(
    dump_path,
    run_id=run_id, dtype=str(DTYPE), device_type=str(device.type),
    t_end=float(t_end), L=float(L),

    X_colloc=_as_numpy(X_colloc),
    X_bc_L=_as_numpy(X_bc_L) if 'X_bc_L' in globals() and X_bc_L is not None else np.array([]),
    X_bc_R=_as_numpy(X_bc_R) if 'X_bc_R' in globals() and X_bc_R is not None else np.array([]),
    x_ic=_as_numpy(x_ic), t_ic=_as_numpy(t_ic),
    alpha_ic=_as_numpy(alpha_ic_np), J0_ic=_as_numpy(J0_ic_np),

    t_eval=grid_dump["t_eval"], x_eval=grid_dump["x_eval"], TX=grid_dump["TX"],
    pred_grid=grid_dump["pred_grid"] if grid_dump["pred_grid"] is not None else np.array([]),
    pred_names=grid_dump["pred_names"] if grid_dump["pred_names"] is not None else np.array([], dtype=object),
    J0_grid=grid_dump["J0_grid"],
    n_grid=grid_dump["n_grid"],
    r_pde_grid=grid_dump["r_pde_grid"],

    R1_grid=res_fields["R1_grid"],
    R2_grid=res_fields["R2_grid"],
    R0_grid=res_fields["R0_grid"],

    adam_loss_history=np.asarray(adam_loss_history, dtype=np.float64),
    lbfgs_final_per_epoch=np.asarray(lbfgs_history.get("final_per_epoch", []), dtype=np.float64),
    lbfgs_best_inner_per_epoch=np.asarray(lbfgs_history.get("best_inner_per_epoch", []), dtype=np.float64),
    lbfgs_num_closure_calls=np.asarray(lbfgs_history.get("num_closure_calls", []), dtype=np.int32),
    lbfgs_all_inner_per_epoch=lbfgs_all_inner,

    lbfgs_inner_curve=lbfgs_inner_curve_np,
)
print(f"[save] Dumped replot data to: {dump_path}")