In [None]:
import time

# Device configuration and core PyTorch setup
import torch
import torch.nn as nn
import torch.nn.functional as F
DTYPE = torch.float32
device = torch.device('cuda')
if device.type == "cuda":
    torch.cuda.init()
    torch.rand(1, device=device)
print(f"Using device: {device}")

from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from SA_PINN_ACTO import PINN_BDNK_1D
from IC_1D import IC_BDNK

# Hyperparameters for network architecture and training schedule
Nl, Nn = 10, 70
t_end = 10.0
L = 50.0
adam_epochs = 35_000
lr_net = 5e-3
lr_mask = 4e-2 

# BDNK simulation configuration and background field setup
from BDNK_Functions import *
BDNK_simulation = 2
setup_external_Tv(BDNK_simulation, L)

# Sampling parameters and domain sampling
N_colloc = 50000
# In the ACTO case, the N_ic below are not collocation points where an IC residual will be computed,
# but points at which the exact initial condition gets computed
N_ic = 1000

def lhs_box(n, low, high, rng=np.random):
    low, high = np.asarray(low, float), np.asarray(high, float)
    D = low.size
    H = np.empty((n, D), float)
    for j in range(D):
        P = (rng.permutation(n) + rng.rand(n)) / n
        H[:, j] = low[j] + P * (high[j] - low[j])
    return H
    
X_colloc_np = lhs_box(N_colloc, low=np.array([0.0, -L]), high=np.array([t_end, L])).astype(np.float32)

# Construction of initial condition sampling grid
x_edges = np.linspace(-L, L, N_ic+1)
x_ic = (0.5 * (x_edges[:-1] + x_edges[1:])).reshape(-1, 1)
t_ic = np.zeros_like(x_ic)
X_ic = np.hstack((t_ic, x_ic))

X_ic_t = torch.tensor(X_ic, dtype=DTYPE, device=device)
J0_ic_t, alpha_ic_t, _ = IC_BDNK(X_ic_t, L)

# Scaling factors for numerical stability of IC enforcement
with torch.no_grad():
    sJ0 = J0_ic_t.abs().max().clamp_min(1e-12).item()
    sA  = alpha_ic_t.abs().max().clamp_min(1e-12).item()
print(f"[scales] sJ0={sJ0:.3e}, sA={sA:.3e}")

# Sorting IC data to enable fast 1D interpolation
x_ic_torch = X_ic_t[:, 1:2].contiguous().view(-1)
x_sorted, idx_sort = torch.sort(x_ic_torch)
J0_sorted    = J0_ic_t.view(-1)[idx_sort]
alpha_sorted = alpha_ic_t.view(-1)[idx_sort]

@torch.no_grad()
def _torch_lin_interp_1d(xq: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    xq_flat = xq.view(-1)
    xq_clamped = xq_flat.clamp(min=x[0], max=x[-1])
    idx_hi = torch.searchsorted(x, xq_clamped, right=True)
    idx_hi = idx_hi.clamp(min=1, max=x.numel() - 1)
    idx_lo = idx_hi - 1
    x0 = x[idx_lo]; x1 = x[idx_hi]
    y0 = y[idx_lo]; y1 = y[idx_hi]
    denom = (x1 - x0)
    denom = torch.where(denom.abs() > 0, denom, torch.ones_like(denom))
    w = (xq_clamped - x0) / denom
    yq = y0 + w * (y1 - y0)
    return yq.view_as(xq)

# Initial condition functions passed to the neural network (physical scale)
def J0_ic_func(x_phys: torch.Tensor) -> torch.Tensor:
    if x_sorted.device != x_phys.device or x_sorted.dtype != x_phys.dtype:
        xk = x_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
        yk = J0_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
    else:
        xk, yk = x_sorted, J0_sorted
    yq = _torch_lin_interp_1d(x_phys.view(-1), xk, yk)
    return yq.view(-1, 1)

def alpha_ic_func(x_phys: torch.Tensor) -> torch.Tensor:
    if x_sorted.device != x_phys.device or x_sorted.dtype != x_phys.dtype:
        xk = x_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
        yk = alpha_sorted.to(device=x_phys.device, dtype=x_phys.dtype)
    else:
        xk, yk = x_sorted, alpha_sorted
    yq = _torch_lin_interp_1d(x_phys.view(-1), xk, yk)
    return yq.view(-1, 1)

# Scaled initial condition functions (used internally by the model)
def J0_ic_func_scaled(x_phys: torch.Tensor) -> torch.Tensor:
    return J0_ic_func(x_phys) / sJ0

def alpha_ic_func_scaled(x_phys: torch.Tensor) -> torch.Tensor:
    return alpha_ic_func(x_phys) / sA
    
X_colloc = torch.tensor(X_colloc_np, dtype=DTYPE, device=device)
x0_line = torch.linspace(-L, L, 500, dtype=DTYPE, device=device).unsqueeze(1)
X0 = torch.cat([torch.zeros_like(x0_line), x0_line], dim=1)
X_colloc = torch.cat([X_colloc, X0], dim=0)

# Self-adaptive collocation mask (learned weighting of PDE residual)
pde_logits = torch.nn.Parameter(torch.zeros((X_colloc.shape[0], 1), dtype=DTYPE, device=device))
def current_masks(detach: bool = False):
    pde = F.softplus(pde_logits)
    return pde.detach() if detach else pde

# Model instantiation and domain normalization
lb = torch.tensor([0.0, -L], dtype=DTYPE, device=device)
ub = torch.tensor([t_end,  L], dtype=DTYPE, device=device)
model = PINN_BDNK_1D(Nl=Nl, Nn=Nn, lb=lb, ub=ub).to(device).to(DTYPE)
model.J0_ic_func    = J0_ic_func_scaled
model.alpha_ic_func = alpha_ic_func_scaled
model.sJ0.copy_(torch.tensor(sJ0, dtype=DTYPE, device=device))
model.sA.copy_(torch.tensor(sA,  dtype=DTYPE, device=device))

# Weight initialization
def glorot_normal_all_linear(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight, gain=1.0)
        if m.bias is not None: nn.init.zeros_(m.bias)
model.apply(glorot_normal_all_linear)

# Optimizers and learning-rate scheduler setup
optimizer_theta = torch.optim.Adam(model.parameters(), lr=lr_net, betas=(0.9, 0.95))
scheduler = ReduceLROnPlateau(optimizer_theta, mode='min', factor=0.4, patience=700, threshold=1e-4, min_lr=lr_net/100)
optimizer_mask  = torch.optim.Adam([pde_logits], lr=lr_mask, betas=(0.7, 0.85), maximize=True)

# Setting up and executing Adam pre-training
def train_adam(model, optimizer_theta, optimizer_mask, epochs, print_every):
    print("Starting Adam pre-training (SA-PINN with hard IC)...")
    best_loss, best_state = float('inf'), None
    loss_history = []

    for epoch in range(1, epochs + 1):
        optimizer_theta.zero_grad()
        optimizer_mask.zero_grad()

        R = model.pde_residual(X_colloc)
        R1, R2 = R[:, 0:1], R[:, 1:2]
        
        Rnorm = torch.sqrt(R1**2 + R2**2)
        pde_mask = current_masks(detach=False)
        L_pde = (pde_mask * Rnorm).pow(2).mean()

        loss = L_pde
        if not torch.isfinite(loss): raise RuntimeError("Non-finite loss detected.")
        
        L_pde_phys   = (R1**2 + R2**2).mean()
        L_total_phys = L_pde_phys

        loss.backward()
        optimizer_theta.step()
        scheduler.step(L_total_phys.item())
        optimizer_mask.step()

        ltp = L_total_phys.detach().item()
        loss_history.append(ltp)
        if ltp < best_loss:
            best_loss = ltp
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        if epoch % print_every == 0 or epoch == epochs:
            with torch.no_grad():
                m_pde = current_masks(detach=True)
            print(f"Adam Epoch {epoch}/{epochs} | "
                  f"Total={loss:.3e}, PDE={L_pde.item():.3e}, <pde_mask>={m_pde.mean().item():.2f} | "
                  f"Unmasked: Total={L_total_phys:.3e}, PDE={L_pde_phys.item():.3e} | "
                  f"lr_net={optimizer_theta.param_groups[0]['lr']:.3e}")

    if best_state is not None:
        model.load_state_dict(best_state)
        with torch.enable_grad():
            R = model.pde_residual(X_colloc)
            L_pde_phys = ((R[:,0:1]**2 + R[:,1:2]**2)).mean()
        print(f"\nAdam finished. Best loss = {best_loss:.3e} | PDE={L_pde_phys.item():.4e}")
    
    return best_loss, loss_history, best_state
    
start_training = time.time()
adam_loss, adam_loss_history, best_state = train_adam(model, optimizer_theta, optimizer_mask, adam_epochs, print_every=1000)
adam_training_finished = time.time()

model.load_state_dict(best_state)

# Setting up executing L-BFGS fine-tuning
with torch.enable_grad():
    X_colloc.requires_grad_(True)
    R0 = model.pde_residual(X_colloc)
    init_lbfgs_loss = ((R0[:,0:1]**2 + R0[:,1:2]**2)).mean().detach().item()
print(f"LBFGS init unmasked PDE loss (from best Adam): {init_lbfgs_loss:.3e}")

loss_scale = 1.0 / max(init_lbfgs_loss, 1e-30)

optimizer_lbfgs = torch.optim.LBFGS(
    model.parameters(),
    lr=1.0,
    max_iter=1000,
    max_eval=1000,
    history_size=100,
    line_search_fn=None,
    tolerance_grad=1e-10,
    tolerance_change=1e-12,
)

best = {"loss": float("inf"), "state": None}
inner_curve = []

def closure():
    optimizer_lbfgs.zero_grad(set_to_none=True)
    X = X_colloc.requires_grad_(True)

    R = model.pde_residual(X)
    R1, R2 = R[:, 0:1], R[:, 1:2]
    raw = (R1.pow(2) + R2.pow(2)).mean()

    if not torch.isfinite(raw):
        print(f"NaN/Inf detected at iter {len(inner_curve)}. Exiting L-BFGS.")
        raise RuntimeError("L-BFGS_NAN")

    loss = raw * loss_scale

    raw_f = float(raw)
    if raw_f < best["loss"]:
        best["loss"] = raw_f
        best["state"] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    inner_curve.append(raw_f)

    loss.backward()
    return loss

lbfgs_training_started = time.time()
try:
    final_loss = optimizer_lbfgs.step(closure)
    final_raw = float(final_loss.item()) / loss_scale
except RuntimeError as e:
    if "L-BFGS_NAN" in str(e):
        print("L-BFGS terminated early due to NaN/Inf.")
        final_raw = float("nan")
    else:
        raise
end_training = time.time()

# Restore best L-BFGS state and record loss history
if best["state"] is not None:
    model.load_state_dict(best["state"])

lbfgs_loss_history = inner_curve

print(f"LBFGS finished. final_raw={final_raw:.4e} | best_raw={best['loss']:.4e} | inner_calls={len(inner_curve)}")

print(f"\nTotal training time: {end_training - start_training:.4f} seconds."
      f"\nAdam: {adam_training_finished - start_training:.4f} seconds. "
      f"L-BFGS: {end_training - lbfgs_training_started:.4f} seconds.")

In [None]:
from Plotting import *

plot_collocation_points(X_colloc, X_ic=None, X_bc_L=None, X_bc_R=None, L=L, t_end=t_end)

J0_ic_np    = J0_ic_t.detach().cpu().numpy()
alpha_ic_np = alpha_ic_t.detach().cpu().numpy()

x_edges_eval = np.linspace(-L, L, 200+1)
x_eval = 0.5 * (x_edges_eval[:-1] + x_edges_eval[1:])
t_eval = np.linspace(0, t_end, 200)

plot_results(
    model,
    t_eval=t_eval,
    x_eval=x_eval,
    alpha_ic=alpha_ic_t.detach().cpu().numpy(),
    J0_ic=J0_ic_t.detach().cpu().numpy()
)

with torch.enable_grad():
    plot_pde_residuals(model, t_eval, x_eval)

lbfgs_history = {"all_inner_per_epoch": [lbfgs_loss_history]}
plot_combined_loss_history(adam_loss_history, lbfgs_history)