In [1]:
# For readability: disable warnings from libraries like matplotlib, etc.
import warnings
warnings.filterwarnings('ignore')

import os
# Make sure torch is imported somewhere above this cell:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
import time
from itertools import product, combinations
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
import scipy.sparse as sp
import scipy.sparse.linalg as la
from pyDOE import lhs
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator, FuncFormatter
from matplotlib.ticker import FormatStrFormatter
import copy
import pandas as pd

# --- Device Setup ---
print("CUDA available?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))

# Select the most performant device available (CUDA > MPS > CPU)
device = (
    torch.device('cuda') if torch.cuda.is_available()
    else torch.device('mps') if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    else torch.device('cpu')
)
print("Using device:", device)


def u_true_numpy(X, T, logK):
    """Vectorised true solution: U = exp(-(pi^2) * T) * sin(pi * X)."""
    K = 10.0**logK    # convert log10(k) → k
    return np.exp(- K * (np.pi**2) * T) * np.sin(np.pi * X)

def net_u(x, t, logk, model):
    """
    NN input is (x, t, log10(k)).
    """
    X = torch.cat([x, t, logk], dim=1)  # If x and t are each shape (N, 1), then X becomes (N, 2).
    u = model(X)
    return u

# net_f computes the PDE residual
# If f ≈ 0 at collocation points, the NN satisfies the equation there
def net_f(x, t, logk, model):
    x.requires_grad_(True)
    t.requires_grad_(True)
    # logk.requires_grad_(True)
    
    u = net_u(x, t, logk, model)
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
    u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
    
    # convert log10(k) -> k = 10^logk
    k = 10.0**logk
    
    f = u_t - k * u_xx
    return f

class XavierInit(nn.Module):
    def __init__(self, size):
        super(XavierInit, self).__init__()
        in_dim = size[0]
        out_dim = size[1]
        xavier_stddev = torch.sqrt(torch.tensor(2.0 / (in_dim + out_dim)))
        self.weight = nn.Parameter(torch.randn(in_dim, out_dim) * xavier_stddev)
        self.bias = nn.Parameter(torch.zeros(out_dim))

    def forward(self, x):
        return torch.matmul(x, self.weight) + self.bias

def initialize_NN(layers):
    weights = nn.ModuleList()
    num_layers = len(layers)
    for l in range(num_layers - 1):
        layer = XavierInit(size=[layers[l], layers[l + 1]]) # if there was no retutn, how do I get the weight and bias?
        weights.append(layer)
    return weights

class NeuralNet(nn.Module):
    def __init__(self, layers, lb, ub):
        super(NeuralNet, self).__init__()
        self.weights = initialize_NN(layers)
        # make lb/ub move with .to(device)
        self.register_buffer('lb', torch.as_tensor(lb, dtype=torch.float32))     # <<< CHANGED >>>
        self.register_buffer('ub', torch.as_tensor(ub, dtype=torch.float32))     # <<< CHANGED >>>
        # self.register_buffer('k', torch.tensor(k_init, dtype=torch.float32))     # <<< CHANGED >>>


    def forward(self, X):
        X = X.float()                                                            # <<< CHANGED >>>
        lb = self.lb.to(X.device)                                                # <<< CHANGED >>>
        ub = self.ub.to(X.device)                                                # <<< CHANGED >>>
        H = 2.0 * (X - self.lb) / (self.ub - self.lb) - 1.0
        for l in range(len(self.weights) - 1):
            H = torch.tanh(self.weights[l](H.float()))     # Is this already a calculation?
        Y = self.weights[-1](H)
        return Y

def train(nEpoch, X, u, X_f, X_val, model, learning_rate):
    criterion = nn.MSELoss()
    
    # ----- STAGE 1: start with Adam -----
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # when to switch from Adam to L-BFGS
    switch_epoch = 100000
    used_lbfgs   = False   
    lbfgs_epochs   = 0.0           # <-- how many epochs of L-BFGS you want
    lbfgs_start_ep = None          # <-- will store the epoch where we switch

    # use the model's device
    dev = next(model.parameters()).device                                        # <<< CHANGED >>>

    x    = X[:, 0:1]
    t    = X[:, 1:2]
    logk = X[:, 2:3]
    # Collocation points (f points)
    x_f    = X_f[:, 0:1]
    t_f    = X_f[:, 1:2]
    logk_f = X_f[:, 2:3]
    # Validation points
    x_v    = X_val[:, 0:1]
    t_v    = X_val[:, 1:2]
    logk_v = X_val[:, 2:3]

    # True validation solution (analytic)
    u_v_true = u_true_numpy(x_v, t_v, logk_v)   # shape (N_val,)

    # create tensors ON THE SAME DEVICE
    x_tf      = torch.tensor(x,        dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    t_tf      = torch.tensor(t,        dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    logk_tf   = torch.tensor(logk,     dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    u_tf      = torch.tensor(u,        dtype=torch.float32, device=dev)                       # <<< CHANGED >>>
    x_f_tf    = torch.tensor(x_f,      dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    t_f_tf    = torch.tensor(t_f,      dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    logk_f_tf = torch.tensor(logk_f,   dtype=torch.float32, device=dev, requires_grad=True)   # <<< CHANGED >>>
    x_v_tf    = torch.tensor(x_v,      dtype=torch.float32, device=dev)
    t_v_tf    = torch.tensor(t_v,      dtype=torch.float32, device=dev)
    logk_v_tf = torch.tensor(logk_v,   dtype=torch.float32, device=dev)
    u_true_tf = torch.tensor(u_v_true, dtype=torch.float32, device=dev).reshape(-1, 1)

    mse_v_hist  = []
    loss_values = []
    max_errors  = []   # <<< NEW: to store (epoch, max_abs_error) >>>

    patience  = 10000          # number of validations without improvement
    pat       = 0
    best_v    = float('inf')   # best validation MSE
    best_TL   = float('inf')   # best training loss corresponding to best_v
    best_max_err = float('inf')  # <<< NEW: best max |error| on val
    best_state = copy.deepcopy(model.state_dict())
    best_ep    = -1

    start_time  = time.time()
    total_start = time.time()        # total wall-clock timer

    for ep in range(nEpoch):

        # ----- Stop if we've done lbfgs_epochs of L-BFGS -----
        if used_lbfgs and lbfgs_start_ep is not None:
            if ep - lbfgs_start_ep >= lbfgs_epochs:
                print(f"Stopping after {lbfgs_epochs} LBFGS epochs at global epoch {ep}")
                break

        # ------------------------------
        # STAGE 1: Adam (ep < switch_epoch)
        # STAGE 2: L-BFGS (ep >= switch_epoch)
        # ------------------------------
        if ep < switch_epoch:
            # ----- Adam update -----
            optimizer.zero_grad()
            
            # Compute predictions for training data (u)
            u_pred = net_u(x_tf, t_tf, logk_tf, model)          # <<< CHANGED
            # Compute PDE residual at collocation points
            u_f_pred = net_f(x_f_tf, t_f_tf, logk_f_tf, model)  # <<< CHANGED
    
            loss_PDE  = criterion(u_f_pred, torch.zeros_like(u_f_pred))
            loss_data = criterion(u_tf, u_pred)
            loss = loss_PDE + 100 * loss_data
    
            loss.backward()
            optimizer.step()
            
        else:
            # ----- Switch to L-BFGS once -----
            if not used_lbfgs:
                # 1) Load the best Adam weights BEFORE creating LBFGS
                model.load_state_dict(best_state)
        
                # 2) Print which Adam state you're starting from
                print(
                    f"Switching to L-BFGS at epoch {ep} "
                    f"-> starting from Adam best at epoch {best_ep} "
                    f"(TrainLoss={best_TL:.3e}, Val MSE={best_v:.3e})"
                )
        
                # 3) Create the LBFGS optimiser on top of that state
                optimizer = torch.optim.LBFGS(
                    model.parameters(),
                    max_iter=20,          # internal LBFGS iterations per .step()
                    history_size=100,
                    line_search_fn=None
                )
                used_lbfgs = True
                lbfgs_start_ep = ep      # <-- remember when we switched
                

            # L-BFGS requires a closure that re-computes the loss
            def closure():
                optimizer.zero_grad()
                u_pred   = net_u(x_tf, t_tf, logk_tf, model)          # <<< CHANGED
                u_f_pred = net_f(x_f_tf, t_f_tf, logk_f_tf, model)    # <<< CHANGED

                loss_PDE  = criterion(u_f_pred, torch.zeros_like(u_f_pred))
                loss_data = criterion(u_tf, u_pred)
                loss      = loss_PDE + 100 * loss_data

                loss.backward()
                return loss

            loss = optimizer.step(closure)  # returns the loss from the last closure call

        
        # ----- validation -----
        model.eval()
        with torch.no_grad():
            u_v_pred = net_u(x_v_tf, t_v_tf, logk_v_tf, model)   # <<< CHANGED
            mse_v = criterion(u_v_pred, u_true_tf).item()
            mse_v_hist.append((ep, mse_v))

            # <<< NEW: max absolute error on validation set >>>
            abs_err = torch.abs(u_v_pred - u_true_tf)   # (N_val, 1)
            max_err = abs_err.max().item()              # scalar
            max_errors.append((ep, max_err))            # store (epoch, max_err)
            # -----------------------------------------------

        model.train()  # switch back


        # ----- early stopping on val -----
        # if mse_v < best_v:
        if loss.item() < best_TL:
            best_v       = mse_v
            best_TL      = loss.item()
            best_max_err = max_err          # <<< NEW: store max error at best state
            best_state   = copy.deepcopy(model.state_dict())
            best_ep      = ep
            # print(f"[Improved] Epoch {ep} | Best Val MSE: {best_v:.3e}")
            pat = 0
        else:
            pat += 1
            if pat >= patience:
                print(f"Early stopping at it={ep}, best Val MSE={best_v:.3e}")
                break
        
        # Print progress
        # - Before LBFGS: every 1000 epochs
        # - After LBFGS is enabled: every 100 epochs
        if (not used_lbfgs and ep % 5000 == 0) or (used_lbfgs and ep % 10 == 0):
            elapsed = time.time() - start_time
            print(f"Epochs: {ep:6d} | TrainLoss: {loss.item():.3e} "
                  f"| Val MSE: {mse_v:.3e} "
                  f"| Max Val |err|: {max_err:.3e} "   # <<< NEW
                  f"| Time: {elapsed:.2f}s")
            start_time = time.time()
            
        loss_values.append(loss.item())

    total_elapsed = time.time() - total_start
    print(f"Total training time: {total_elapsed:.2f} s")
    print(f"Best Val MSE: {best_v:.3e} at epoch {best_ep}")
    print(f"Best Max |err| on validation: {best_max_err:.3e}")   # <<< NEW

    model.load_state_dict(best_state)          # <- load best here

    return loss_values, mse_v_hist, max_errors, best_ep, best_TL, best_v, best_max_err




# -----------------------------------------------------------------------------
# 1) Build dataset for arbitrary (N_i, N_b, N_k, N_f)
# -----------------------------------------------------------------------------
def build_dataset(N_i, N_b, N_k, N_f, N_val,
                  x_min, x_max,
                  t_min, t_max,
                  k_min, k_max,
                  seed):
    """
    Build training and collocation sets for the parametric heat equation.

    Returns:
        X_u_train : (N_u, 3) array of data points (x, t, log10(k)) for IC + BC.
        u_train   : (N_u, 1) IC + BC at X_u_train.
        X_f_train : (N_f, 3) collocation points in (x, t, log10(k)).
        X_val     : (N_val, 3) validation points in (x, t, log10(k)).
        lb, ub    : lower/upper bounds for normalisation in the NN.
    """

    # --- k and logk ---
    # We sample k in [k_min, k_max] but represent it via log10(k).
    logk_min = np.log10(k_min)
    logk_max = np.log10(k_max)
    logk_vec = np.linspace(logk_min, logk_max, N_k)  # equally spaced in log10(k)
    k_vec    = 10.0**logk_vec                        # corresponding physical k

    # --- Initial condition: u(x,0;k) = sin(pi x) ---
    # x in [x_min, x_max], t=0, and all logk samples
    x_ic = np.linspace(x_min, x_max, N_i)
    t_ic = np.array([t_min])   # or [t_min]
    x_ic_g, t_ic_g, logk_ic_g = np.meshgrid(x_ic, t_ic, logk_vec, indexing='ij')

    x_u_ic    = x_ic_g.ravel()[:, None]
    t_u_ic    = t_ic_g.ravel()[:, None]
    logk_u_ic = logk_ic_g.ravel()[:, None]
    X_u_train_ic = np.hstack([x_u_ic, t_u_ic, logk_u_ic])

    # --- Boundary conditions: u(0,t;k)=0, u(1,t;k)=0 ---
    # We discretise t with 2*N_b points between t_min and t_max.
    t_line = np.linspace(t_min, t_max, 2 * N_b)
    x_bc_left  = np.array([x_min])
    x_bc_right = np.array([x_max])
    x_bc = np.concatenate([x_bc_left, x_bc_right], axis=0)  # [0, 1]

    x_bc_g, t_bc_g, logk_bc_g = np.meshgrid(x_bc, t_line, logk_vec, indexing='ij')
    x_u_bc    = x_bc_g.ravel()[:, None]
    t_u_bc    = t_bc_g.ravel()[:, None]
    logk_u_bc = logk_bc_g.ravel()[:, None]
    X_u_train_bc = np.hstack([x_u_bc, t_u_bc, logk_u_bc])

    # --- Combine IC and BC into one "data" set ---
    X_u_train = np.vstack([X_u_train_ic, X_u_train_bc]).astype(np.float32)

    # --- Analytic solution for those IC/BC points ---
    x_cal    = X_u_train[:, 0]
    t_cal    = X_u_train[:, 1]
    logk_cal = X_u_train[:, 2]
    k_cal    = 10.0**logk_cal

    # Closed-form solution: u(x,t;k) = exp(-k π² t) sin(π x)
    u_train = np.exp(-k_cal * (np.pi**2) * t_cal) * np.sin(np.pi * x_cal)
    u_train = u_train[:, None].astype(np.float32)

    # --- Collocation + validation via LHS in (x, t, logk) ---
    lb = np.array([x_min, t_min, logk_min], dtype=np.float32)
    ub = np.array([x_max, t_max, logk_max], dtype=np.float32)

    np.random.seed(seed)
    U_all = lhs(3, samples=N_f + N_val)   # Latin Hypercube in [0,1]^3
    X_all = lb + (ub - lb) * U_all        # map to [lb, ub] in (x, t, logk)
    X_f_train = X_all[:N_f]
    X_val     = X_all[N_f:]

    return X_u_train, u_train, X_f_train, X_val, lb, ub


# -----------------------------------------------------------------------------
# 2) Compute global relative L2 error for a trained model at a fixed k
# -----------------------------------------------------------------------------
def compute_rel_L2(model,
                   x_min, x_max,
                   t_min, t_max,
                   k_val=1.0,
                   Nx=100, Nt=100,
                   device=device):
    """
    Compute global relative L2 error of the model solution against the analytic
    solution on a regular (x,t) grid at a fixed physical k = k_val.

    rel_L2 = ||u_pred - u_true||_2 / ||u_true||_2
    """

    # Build regular grid in x, t
    x_test = np.linspace(x_min, x_max, Nx)
    t_test = np.linspace(t_min, t_max, Nt)
    logk_val = np.log10(k_val)  # convert to log10(k) for NN input

    T, X = np.meshgrid(t_test, x_test, indexing='ij')  # shape (Nt, Nx)
    LOGK = np.full_like(T, logk_val)                   # broadcast log10(k_val)

    # Flatten to (Nt*Nx, 1) column vectors
    x_flat    = X.ravel()[:, None]
    t_flat    = T.ravel()[:, None]
    logk_flat = LOGK.ravel()[:, None]

    # Stack into (Nt*Nx, 3) array and convert to torch tensor
    X_star    = np.hstack([x_flat, t_flat, logk_flat]).astype(np.float32)
    X_star_tf = torch.from_numpy(X_star).to(device)

    # NN prediction over the grid
    model.eval()
    with torch.no_grad():
        u_pred = model(X_star_tf).squeeze(1).cpu().numpy().reshape(T.shape)

    # Analytic solution on the same grid
    u_true = u_true_numpy(X, T, LOGK)

    # Global relative L2 error
    num = np.linalg.norm(u_pred - u_true)
    den = np.linalg.norm(u_true)
    rel_L2 = num / den if den > 0 else num

    return rel_L2


# -----------------------------------------------------------------------------
# 3) Plot training curves (loss, val MSE, max |error|) and save to file
# -----------------------------------------------------------------------------
def plot_training_curves(loss_values, mse_v_hist, max_errors,
                         N_i, N_b, N_k, N_f,
                         out_dir="sweep_results"):
    """
    Make the TrainLoss / ValMSE / Max|Error| vs epoch plot and save to file.

    Args:
        loss_values : list of train loss per epoch
        mse_v_hist  : list of (epoch, val_MSE)
        max_errors  : list of (epoch, max_abs_error) on validation
        N_i, N_b, N_k, N_f : configuration used (for filename)
        out_dir     : directory where the PNG is saved
    """

    os.makedirs(out_dir, exist_ok=True)

    # Training loss epochs
    ep_train = range(len(loss_values))

    # Validation MSE: unpack (epoch, mse)
    ep_val  = [int(i) for i, _ in mse_v_hist]
    mse_val = [
        (m.detach().cpu().item() if torch.is_tensor(m) else float(m))
        for _, m in mse_v_hist
    ]

    # Max absolute error: unpack (epoch, max_err)
    ep_max   = [int(i) for i, _ in max_errors]
    max_errs = [
        (m.detach().cpu().item() if torch.is_tensor(m) else float(m))
        for _, m in max_errors
    ]

    # Plot curves on log scale (since errors/loss typically span many orders)
    plt.figure(figsize=(8, 6))
    plt.plot(ep_train, loss_values, color='black', label='Train Loss')
    plt.plot(ep_val,   mse_val,     color='red',   label='Validation MSE')
    plt.plot(ep_max,   max_errs,    color='blue',  label='Max |Error| (Validation)')

    plt.xlabel('Iteration')
    plt.ylabel('Loss / Error')
    plt.yscale('log')
    plt.title('Training Loss, Validation MSE, and Max Validation Error vs Iterations')
    plt.legend()
    plt.tight_layout()

    # File name tagged with the configuration
    fname = f"train_Ni{N_i}_Nb{N_b}_Nk{N_k}_Nf{N_f}.png"
    fpath = os.path.join(out_dir, fname)
    plt.savefig(fpath, dpi=300, bbox_inches='tight')
    plt.close()

    return fpath


# -----------------------------------------------------------------------------
# 4) Run a single experiment for given (N_i, N_b, N_k, N_f)
# -----------------------------------------------------------------------------
def run_single_experiment(N_i, N_b, N_k, N_f, seed,
                          N_val=100,
                          Train_epochs=100000,
                          learning_rate=5e-4,
                          k_val_eval=1.0,
                          results_dir="sweep_results"):
    """
    Runs one full experiment:
      - builds dataset for given (N_i, N_b, N_k, N_f)
      - trains a fresh model
      - saves training curve plot
      - computes global rel L2 error at k = k_val_eval
      - returns a dict with all requested info
    """

    # 1) Build dataset (IC+BC data, collocation, validation, bounds)
    X_u_train, u_train, X_f_train, X_val, lb, ub = build_dataset(
        N_i=N_i, N_b=N_b, N_k=N_k, N_f=N_f, N_val=N_val,
        x_min=x_min, x_max=x_max, t_min=t_min, t_max=t_max,
        k_min=k_min, k_max=k_max, seed=seed
    )

    # 2) Initialise a new PINN model for this dataset
    model = NeuralNet(layers, lb, ub).to(device).float()

    # 3) Train and measure total wall-clock time for this configuration
    exp_start = time.time()
    loss_values, mse_v_hist, max_errors, best_ep, best_TL, best_v, best_max_err = train(
        Train_epochs,
        X_u_train,
        u_train,
        X_f_train,
        X_val,
        model,
        learning_rate
    )
    total_elapsed = time.time() - exp_start

    # 4) Compute global relative L2 error at a chosen k (e.g. k = 1.0)
    rel_L2 = compute_rel_L2(
        model,
        x_min=x_min, x_max=x_max,
        t_min=t_min, t_max=t_max,
        k_val=k_val_eval,
        Nx=100, Nt=100,
        device=device
    )

    # 5) Create and save the training curve plot for this run
    curve_path = plot_training_curves(
        loss_values, mse_v_hist, max_errors,
        N_i=N_i, N_b=N_b, N_k=N_k, N_f=N_f,
        out_dir=results_dir
    )

    # 6) Pack all information into a record dictionary
    record = {
        "N_i": N_i,
        "N_b": N_b,
        "N_k": N_k,
        "N_f": N_f,
        "total_elapsed": total_elapsed,
        "best_ep": best_ep,
        "best_TL": best_TL,
        "best_v": best_v,
        "best_max_err": best_max_err,
        "rel_L2": rel_L2,
        "loss_values": loss_values,
        "mse_v_hist": mse_v_hist,
        "max_errors": max_errors,
        "training_curve_path": curve_path,
    }

    return record


# -----------------------------------------------------------------------------
# 5) Sweep settings and loops for N_f, N_i, N_b, N_k
# -----------------------------------------------------------------------------
layers = [3, 50, 50, 50, 1]
x_min=0.0
x_max=1.0
t_min=0.0
t_max=0.25
k_min=0.2
k_max=2.0
seed=123
    
# Base values (same as your current defaults)
BASE_N_i = 101
BASE_N_b = 51
BASE_N_k = 51
BASE_N_f = 1000

# Training hyperparameters for all sweeps
Train_epochs = 100000
learning_rate = 0.0005
N_val = 100          # number of validation points from LHS
k_val_eval = 1.0     # k at which global rel L2 is computed

# ==========================
# Sweep 1: Collocation points N_f
# ==========================
Nf_list = [500, 1000, 2000, 4000]  # values to test for N_f

results_Nf = []

for N_f in Nf_list:
    print(f"\n=== Sweep N_f = {N_f} (N_i={BASE_N_i}, N_b={BASE_N_b}, N_k={BASE_N_k}) ===")
    rec = run_single_experiment(
        N_i=BASE_N_i,
        N_b=BASE_N_b,
        N_k=BASE_N_k,
        N_f=N_f,
        seed=123,
        N_val=N_val,
        Train_epochs=Train_epochs,
        learning_rate=learning_rate,
        k_val_eval=k_val_eval,
        results_dir="sweep_Nf",
    )
    results_Nf.append(rec)

# Convert list of records to DataFrame for easy inspection and saving
df_Nf = pd.DataFrame([
    {
        "N_i": r["N_i"], "N_b": r["N_b"], "N_k": r["N_k"], "N_f": r["N_f"],
        "total_elapsed": r["total_elapsed"],
        "best_ep": r["best_ep"],
        "best_TL": r["best_TL"],
        "best_v": r["best_v"],
        "best_max_err": r["best_max_err"],
        "rel_L2": r["rel_L2"],
        # Histories stored as objects (lists) – still useful inside Python
        "loss_values": r["loss_values"],
        "mse_v_hist": r["mse_v_hist"],
        "max_errors": r["max_errors"],
        "training_curve_path": r["training_curve_path"],
    }
    for r in results_Nf
])

print("\nN_f sweep summary:")
display(df_Nf)

# Save N_f sweep summary table as CSV
df_Nf.to_csv("sweep_Nf_summary.csv", index=False)

# ==========================
# Sweep 2: IC points N_i
# ==========================
Ni_list = [21, 51, 101, 201]  # values to test for N_i

results_Ni = []

for N_i in Ni_list:
    print(f"\n=== Sweep N_i = {N_i} (N_b={BASE_N_b}, N_k={BASE_N_k}, N_f={BASE_N_f}) ===")
    rec = run_single_experiment(
        N_i=N_i,
        N_b=BASE_N_b,
        N_k=BASE_N_k,
        N_f=BASE_N_f,
        N_val=N_val,
        Train_epochs=Train_epochs,
        learning_rate=learning_rate,
        k_val_eval=k_val_eval,
        results_dir="sweep_Ni",
        seed=123
    )
    results_Ni.append(rec)

df_Ni = pd.DataFrame([
    {
        "N_i": r["N_i"], "N_b": r["N_b"], "N_k": r["N_k"], "N_f": r["N_f"],
        "total_elapsed": r["total_elapsed"],
        "best_ep": r["best_ep"],
        "best_TL": r["best_TL"],
        "best_v": r["best_v"],
        "best_max_err": r["best_max_err"],
        "rel_L2": r["rel_L2"],
        "loss_values": r["loss_values"],
        "mse_v_hist": r["mse_v_hist"],
        "max_errors": r["max_errors"],
        "training_curve_path": r["training_curve_path"],
    }
    for r in results_Ni
])

print("\nN_i sweep summary:")
display(df_Ni)

df_Ni.to_csv("sweep_Ni_summary.csv", index=False)

# ==========================
# Sweep 3: BC points N_b
# ==========================
Nb_list = [11, 21, 51, 101]  # values to test for N_b

results_Nb = []

for N_b in Nb_list:
    print(f"\n=== Sweep N_b = {N_b} (N_i={BASE_N_i}, N_k={BASE_N_k}, N_f={BASE_N_f}) ===")
    rec = run_single_experiment(
        N_i=BASE_N_i,
        N_b=N_b,
        N_k=BASE_N_k,
        N_f=BASE_N_f,
        N_val=N_val,
        Train_epochs=Train_epochs,
        learning_rate=learning_rate,
        k_val_eval=k_val_eval,
        results_dir="sweep_Nb",
        seed=123
    )
    results_Nb.append(rec)

df_Nb = pd.DataFrame([
    {
        "N_i": r["N_i"], "N_b": r["N_b"], "N_k": r["N_k"], "N_f": r["N_f"],
        "total_elapsed": r["total_elapsed"],
        "best_ep": r["best_ep"],
        "best_TL": r["best_TL"],
        "best_v": r["best_v"],
        "best_max_err": r["best_max_err"],
        "rel_L2": r["rel_L2"],
        "loss_values": r["loss_values"],
        "mse_v_hist": r["mse_v_hist"],
        "max_errors": r["max_errors"],
        "training_curve_path": r["training_curve_path"],
    }
    for r in results_Nb
])

print("\nN_b sweep summary:")
display(df_Nb)

df_Nb.to_csv("sweep_Nb_summary.csv", index=False)

# ==========================
# Sweep 4: parameter samples N_k
# ==========================
Nk_list = [11, 21, 51, 101]  # values to test for N_k

results_Nk = []

for N_k in Nk_list:
    print(f"\n=== Sweep N_k = {N_k} (N_i={BASE_N_i}, N_b={BASE_N_b}, N_f={BASE_N_f}) ===")
    rec = run_single_experiment(
        N_i=BASE_N_i,
        N_b=BASE_N_b,
        N_k=N_k,
        N_f=BASE_N_f,
        N_val=N_val,
        Train_epochs=Train_epochs,
        learning_rate=learning_rate,
        k_val_eval=k_val_eval,
        results_dir="sweep_Nk",
        seed=123
    )
    results_Nk.append(rec)

df_Nk = pd.DataFrame([
    {
        "N_i": r["N_i"], "N_b": r["N_b"], "N_k": r["N_k"], "N_f": r["N_f"],
        "total_elapsed": r["total_elapsed"],
        "best_ep": r["best_ep"],
        "best_TL": r["best_TL"],
        "best_v": r["best_v"],
        "best_max_err": r["best_max_err"],
        "rel_L2": r["rel_L2"],
        "loss_values": r["loss_values"],
        "mse_v_hist": r["mse_v_hist"],
        "max_errors": r["max_errors"],
        "training_curve_path": r["training_curve_path"],
    }
    for r in results_Nk
])

print("\nN_k sweep summary:")
display(df_Nk)

df_Nk.to_csv("sweep_Nk_summary.csv", index=False)


CUDA available? True
Device: NVIDIA A2
Using device: cuda

=== Sweep N_f = 500 (N_i=101, N_b=51, N_k=51) ===
Epochs:      0 | TrainLoss: 1.495e+01 | Val MSE: 1.680e-01 | Max Val |err|: 9.973e-01 | Time: 4.00s
Epochs:   5000 | TrainLoss: 4.602e-03 | Val MSE: 4.847e-06 | Max Val |err|: 6.737e-03 | Time: 16.39s
Epochs:  10000 | TrainLoss: 1.468e-03 | Val MSE: 7.672e-07 | Max Val |err|: 2.644e-03 | Time: 15.66s
Epochs:  15000 | TrainLoss: 6.021e-04 | Val MSE: 4.079e-07 | Max Val |err|: 2.443e-03 | Time: 15.43s
Epochs:  20000 | TrainLoss: 7.047e-04 | Val MSE: 6.228e-07 | Max Val |err|: 2.187e-03 | Time: 15.23s
Epochs:  25000 | TrainLoss: 3.348e-04 | Val MSE: 6.046e-07 | Max Val |err|: 3.278e-03 | Time: 15.39s
Epochs:  30000 | TrainLoss: 1.367e-03 | Val MSE: 8.946e-06 | Max Val |err|: 5.359e-03 | Time: 15.22s
Epochs:  35000 | TrainLoss: 1.211e-03 | Val MSE: 3.416e-06 | Max Val |err|: 4.487e-03 | Time: 15.10s
Epochs:  40000 | TrainLoss: 1.937e-04 | Val MSE: 2.540e-07 | Max Val |err|: 2.498e-0

Unnamed: 0,N_i,N_b,N_k,N_f,total_elapsed,best_ep,best_TL,best_v,best_max_err,rel_L2,loss_values,mse_v_hist,max_errors,training_curve_path
0,101,51,51,500,325.788458,99834,7.4e-05,1.336201e-07,0.001386,0.000723,"[14.947761535644531, 14.056835174560547, 13.30...","[(0, 0.16799819469451904), (1, 0.1549804210662...","[(0, 0.9973016381263733), (1, 0.95520174503326...",sweep_Nf/train_Ni101_Nb51_Nk51_Nf500.png
1,101,51,51,1000,354.457757,99862,5e-05,2.421271e-07,0.002742,0.000775,"[22.66531753540039, 18.658302307128906, 15.583...","[(0, 0.11447107046842575), (1, 0.0983603522181...","[(0, 0.8015083074569702), (1, 0.77515494823455...",sweep_Nf/train_Ni101_Nb51_Nk51_Nf1000.png
2,101,51,51,2000,348.41604,99995,6.7e-05,5.998968e-08,0.000902,0.00061,"[18.694698333740234, 15.184746742248535, 12.93...","[(0, 0.15464738011360168), (1, 0.1395625472068...","[(0, 0.9652208089828491), (1, 0.93920427560806...",sweep_Nf/train_Ni101_Nb51_Nk51_Nf2000.png
3,101,51,51,4000,430.942955,99992,4.4e-05,4.694115e-08,0.000869,0.000558,"[19.242321014404297, 16.772865295410156, 14.82...","[(0, 0.13613589107990265), (1, 0.1285412907600...","[(0, 0.6762097477912903), (1, 0.69359844923019...",sweep_Nf/train_Ni101_Nb51_Nk51_Nf4000.png



=== Sweep N_i = 21 (N_b=51, N_k=51, N_f=1000) ===
Epochs:      0 | TrainLoss: 3.195e+01 | Val MSE: 2.088e-01 | Max Val |err|: 1.116e+00 | Time: 0.00s
Epochs:   5000 | TrainLoss: 4.133e-03 | Val MSE: 8.004e-06 | Max Val |err|: 7.970e-03 | Time: 16.02s
Epochs:  10000 | TrainLoss: 1.664e-03 | Val MSE: 4.840e-06 | Max Val |err|: 6.313e-03 | Time: 15.35s
Epochs:  15000 | TrainLoss: 6.792e-04 | Val MSE: 2.441e-06 | Max Val |err|: 4.060e-03 | Time: 14.91s
Epochs:  20000 | TrainLoss: 3.861e-04 | Val MSE: 5.122e-07 | Max Val |err|: 2.911e-03 | Time: 14.95s
Epochs:  25000 | TrainLoss: 2.886e-04 | Val MSE: 4.254e-07 | Max Val |err|: 2.095e-03 | Time: 15.10s
Epochs:  30000 | TrainLoss: 2.532e-04 | Val MSE: 5.089e-07 | Max Val |err|: 2.321e-03 | Time: 14.96s
Epochs:  35000 | TrainLoss: 3.319e-04 | Val MSE: 2.535e-06 | Max Val |err|: 3.432e-03 | Time: 16.01s
Epochs:  40000 | TrainLoss: 2.300e-04 | Val MSE: 4.570e-07 | Max Val |err|: 1.850e-03 | Time: 14.72s
Epochs:  45000 | TrainLoss: 5.845e-04 | V

Unnamed: 0,N_i,N_b,N_k,N_f,total_elapsed,best_ep,best_TL,best_v,best_max_err,rel_L2,loss_values,mse_v_hist,max_errors,training_curve_path
0,21,51,51,1000,303.029927,99931,4.7e-05,9.675382e-08,0.001159,0.00062,"[31.947650909423828, 26.775436401367188, 22.19...","[(0, 0.2087847888469696), (1, 0.18390968441963...","[(0, 1.1162387132644653), (1, 1.07166874408721...",sweep_Ni/train_Ni21_Nb51_Nk51_Nf1000.png
1,51,51,51,1000,307.747723,99643,3.4e-05,6.672719e-08,0.00131,0.00047,"[18.742252349853516, 15.543389320373535, 13.08...","[(0, 0.15273211896419525), (1, 0.1351156979799...","[(0, 0.9353998303413391), (1, 0.87457603216171...",sweep_Ni/train_Ni51_Nb51_Nk51_Nf1000.png
2,101,51,51,1000,312.963876,99872,4.4e-05,1.102093e-07,0.001748,0.000622,"[23.467952728271484, 19.485658645629883, 16.33...","[(0, 0.12841792404651642), (1, 0.1118596345186...","[(0, 0.8848451972007751), (1, 0.83105754852294...",sweep_Ni/train_Ni101_Nb51_Nk51_Nf1000.png
3,201,51,51,1000,359.869567,99885,5e-05,1.63944e-07,0.002081,0.000535,"[29.66835594177246, 24.679580688476562, 20.844...","[(0, 0.11580260843038559), (1, 0.1047088503837...","[(0, 0.9207725524902344), (1, 0.84888905286788...",sweep_Ni/train_Ni201_Nb51_Nk51_Nf1000.png



=== Sweep N_b = 11 (N_i=101, N_k=51, N_f=1000) ===
Epochs:      0 | TrainLoss: 3.549e+01 | Val MSE: 1.890e-01 | Max Val |err|: 7.980e-01 | Time: 0.01s
Epochs:   5000 | TrainLoss: 2.984e-02 | Val MSE: 4.766e-05 | Max Val |err|: 2.146e-02 | Time: 31.94s
Epochs:  10000 | TrainLoss: 1.118e-02 | Val MSE: 1.621e-05 | Max Val |err|: 1.530e-02 | Time: 31.66s
Epochs:  15000 | TrainLoss: 6.212e-03 | Val MSE: 8.647e-06 | Max Val |err|: 1.230e-02 | Time: 31.21s
Epochs:  20000 | TrainLoss: 3.680e-03 | Val MSE: 4.769e-06 | Max Val |err|: 9.313e-03 | Time: 30.47s
Epochs:  25000 | TrainLoss: 2.986e-03 | Val MSE: 1.070e-05 | Max Val |err|: 1.183e-02 | Time: 30.67s
Epochs:  30000 | TrainLoss: 1.105e-02 | Val MSE: 4.369e-06 | Max Val |err|: 9.423e-03 | Time: 30.85s
Epochs:  35000 | TrainLoss: 9.386e-04 | Val MSE: 1.416e-06 | Max Val |err|: 6.274e-03 | Time: 29.70s
Epochs:  40000 | TrainLoss: 6.587e-04 | Val MSE: 1.074e-06 | Max Val |err|: 5.478e-03 | Time: 30.36s
Epochs:  45000 | TrainLoss: 6.289e-02 | 

Unnamed: 0,N_i,N_b,N_k,N_f,total_elapsed,best_ep,best_TL,best_v,best_max_err,rel_L2,loss_values,mse_v_hist,max_errors,training_curve_path
0,101,11,51,1000,603.245504,99726,8.5e-05,2.45e-07,0.002273,0.001077,"[35.49047088623047, 31.64017677307129, 28.8268...","[(0, 0.18902820348739624), (1, 0.1723163574934...","[(0, 0.7979638576507568), (1, 0.75469136238098...",sweep_Nb/train_Ni101_Nb11_Nk51_Nf1000.png
1,101,21,51,1000,595.70262,99960,3.5e-05,8.86949e-08,0.001138,0.000763,"[44.46357345581055, 36.97004699707031, 30.7512...","[(0, 0.13326984643936157), (1, 0.1177368164062...","[(0, 1.0488758087158203), (1, 0.98042356967926...",sweep_Nb/train_Ni101_Nb21_Nk51_Nf1000.png
2,101,51,51,1000,663.738427,99922,3.6e-05,1.151673e-07,0.001656,0.000483,"[44.688880920410156, 37.15054702758789, 30.641...","[(0, 0.169343501329422), (1, 0.145464271306991...","[(0, 1.0646004676818848), (1, 1.01183891296386...",sweep_Nb/train_Ni101_Nb51_Nk51_Nf1000.png
3,101,101,51,1000,571.894666,99935,8.5e-05,2.524136e-07,0.002452,0.001005,"[11.859679222106934, 9.724340438842773, 8.4731...","[(0, 0.0977526530623436), (1, 0.08888407796621...","[(0, 0.8150274157524109), (1, 0.76064717769622...",sweep_Nb/train_Ni101_Nb101_Nk51_Nf1000.png



=== Sweep N_k = 11 (N_i=101, N_b=51, N_f=1000) ===
Epochs:      0 | TrainLoss: 3.067e+01 | Val MSE: 1.941e-01 | Max Val |err|: 6.626e-01 | Time: 0.00s
Epochs:   5000 | TrainLoss: 2.305e-03 | Val MSE: 4.358e-06 | Max Val |err|: 1.028e-02 | Time: 16.72s
Epochs:  10000 | TrainLoss: 9.259e-04 | Val MSE: 2.452e-06 | Max Val |err|: 4.786e-03 | Time: 15.49s
Epochs:  15000 | TrainLoss: 4.043e-04 | Val MSE: 6.797e-07 | Max Val |err|: 4.319e-03 | Time: 15.82s
Epochs:  20000 | TrainLoss: 3.693e-04 | Val MSE: 7.742e-07 | Max Val |err|: 3.417e-03 | Time: 15.58s
Epochs:  25000 | TrainLoss: 2.378e-04 | Val MSE: 4.292e-07 | Max Val |err|: 3.098e-03 | Time: 15.48s
Epochs:  30000 | TrainLoss: 5.748e-04 | Val MSE: 2.852e-06 | Max Val |err|: 4.676e-03 | Time: 14.73s
Epochs:  35000 | TrainLoss: 1.974e-04 | Val MSE: 3.760e-07 | Max Val |err|: 3.058e-03 | Time: 14.69s
Epochs:  40000 | TrainLoss: 1.551e-04 | Val MSE: 5.418e-07 | Max Val |err|: 2.782e-03 | Time: 14.82s
Epochs:  45000 | TrainLoss: 3.233e-03 | 

Unnamed: 0,N_i,N_b,N_k,N_f,total_elapsed,best_ep,best_TL,best_v,best_max_err,rel_L2,loss_values,mse_v_hist,max_errors,training_curve_path
0,101,51,11,1000,301.514683,99645,4.6e-05,1.557201e-07,0.002076,0.00064,"[30.67354965209961, 25.916240692138672, 21.867...","[(0, 0.1940578818321228), (1, 0.16086277365684...","[(0, 0.6626036763191223), (1, 0.60622304677963...",sweep_Nk/train_Ni101_Nb51_Nk11_Nf1000.png
1,101,51,21,1000,301.428975,99947,8.2e-05,2.134577e-07,0.002097,0.000837,"[31.9571475982666, 25.611705780029297, 20.4957...","[(0, 0.13916951417922974), (1, 0.1182712540030...","[(0, 1.0757888555526733), (1, 0.99708050489425...",sweep_Nk/train_Ni101_Nb51_Nk21_Nf1000.png
2,101,51,51,1000,314.278008,99448,5.1e-05,1.319239e-07,0.002512,0.000591,"[26.78264045715332, 22.83032989501953, 19.5598...","[(0, 0.1974918395280838), (1, 0.17030608654022...","[(0, 1.0612671375274658), (1, 0.99739074707031...",sweep_Nk/train_Ni101_Nb51_Nk51_Nf1000.png
3,101,51,101,1000,371.071773,99958,4.9e-05,7.216311e-08,0.001092,0.000653,"[20.62648582458496, 17.80280113220215, 15.5296...","[(0, 0.13663625717163086), (1, 0.1185233294963...","[(0, 0.7760787606239319), (1, 0.71771323680877...",sweep_Nk/train_Ni101_Nb51_Nk101_Nf1000.png


In [2]:
# t = data['t'].flatten()[:,None] # read in t and flatten into column vector
# x = data['x'].flatten()[:,None] # read in x and flatten into column vector
#  # Exact represents the exact solution to the problem, from the data provided
# Exact = np.real(data['usol']).T # Exact has structure of nt times nx

# print("usol shape (nt, nx) = ", Exact.shape)

# # We need to find all the x,t coordinate pairs in the domain
# X, T = np.meshgrid(x,t)

# # Flatten the coordinate grid into pairs of x,t coordinates
# X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None])) # coordinates x,t
# u_star = Exact.flatten()[:,None]   # corresponding solution value with each coordinate


# print("X has shape ", X.shape, ", X_star has shape ", X_star.shape, ", u_star has shape ", u_star.shape)

# # Domain bounds (-1,1)
# lb = X_star.min(axis=0)
# ub = X_star.max(axis=0)

# print("Lower bounds of x,t: ", lb)
# print("Upper bounds of x,t: ", ub)
# print('')
# print('The first few entries of X_star are:')
# print( X_star[0:5, :] )

# print('')
# print('The first few entries of u_star are:')
# print( u_star[0:5, :] )