# The aim of this file is to inspect the behavior of the NN's params with respect to the gradient of the cost function w.r.t. u and $\delta$

In [1]:
import numpy as np
import os, subprocess, sys
import scipy.io
from scipy.linalg import solve_continuous_are
from scipy.special import softmax
from typing import Optional, Callable, Tuple, Dict, List
import time
import warnings
import json
import matplotlib.pyplot as plt

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.tensorboard import SummaryWriter
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    warnings.warn("PyTorch not available. GPU training will not be available.")
    
from ocslc.switched_linear_mpc import SwitchedLinearMPC as SwiLin_casadi

from src.switched_linear_torch import SwiLin
from src.training import SwiLinNN

## Set environment variables

In [2]:
# Global settings
N_PHASES = 10
TIME_HORIZON = 1.0

# NN settings
N_CONTROL_INPUTS = 1
N_STATES = 1
N_NN_INPUTS = 1
N_NN_OUTPUTS = N_PHASES * (N_CONTROL_INPUTS + 1)  # +1 for the mode

# Casadi settings
MULTIPLE_SHOOTING = True
INTEGRATOR = 'exp'
HYBRID = False
PLOT = 'display'

## Compute cost function

In [3]:
def evaluate_cost_functional_batch(
    swi: SwiLin,
    u_all_batch: torch.Tensor,
    delta_all_batch: torch.Tensor,
    x0_batch: torch.Tensor,
) -> torch.Tensor:
    """
    Vectorized evaluation of the LQR-style cost over a batch.

    Args:
        swi: SwiLin instance (used for model matrices and helpers)
        u_all_batch: tensor shape (B, N_PHASES, n_inputs)
        delta_all_batch: tensor shape (B, N_PHASES)
        x0_batch: tensor shape (B, n_states)

    Returns:
        J_batch: tensor shape (B,) with per-sample costs
    """
    device = u_all_batch.device if torch.is_tensor(u_all_batch) else swi.device
    dtype = u_all_batch.dtype if torch.is_tensor(u_all_batch) else swi.dtype

    B = u_all_batch.shape[0]
    n_ph = swi.n_phases
    n_x = swi.n_states
    n_u = swi.n_inputs

    # Ensure tensors on correct device/dtype
    u_all_batch = u_all_batch.to(device=device, dtype=dtype)
    delta_all_batch = delta_all_batch.to(device=device, dtype=dtype).view(B, n_ph)
    x0_batch = x0_batch.to(device=device, dtype=dtype).view(B, n_x)

    # Containers per phase (each element will be batch-shaped)
    Es = [None] * n_ph
    phi_fs = [None] * n_ph
    Lis = [None] * n_ph
    Mis = [None] * n_ph
    Ris = [None] * n_ph

    # Useful constants
    Q = swi.Q.to(dtype=dtype, device=device)
    R = swi.R.to(dtype=dtype, device=device) if n_u > 0 else None
    Eterm = swi.E_term.to(dtype=dtype, device=device)

    # For each phase compute batched matrices
    for i in range(n_ph):
        A = swi.A[i].to(dtype=dtype, device=device)
        Bmat = swi.B[i].to(dtype=dtype, device=device) if n_u > 0 else None

        # Build big C matrix once (same across batch) as in _mat_exp_prop_exp
        if not swi.auto:
            m = n_u
            Mdim = 3 * n_x + m
            C_base = torch.zeros((Mdim, Mdim), dtype=dtype, device=device)
            C_base[:n_x, :n_x] = -A.T
            C_base[:n_x, n_x:2*n_x] = torch.eye(n_x, dtype=dtype, device=device)
            C_base[n_x:2*n_x, n_x:2*n_x] = -A.T
            C_base[n_x:2*n_x, 2*n_x:3*n_x] = Q
            C_base[2*n_x:3*n_x, 2*n_x:3*n_x] = A
            C_base[2*n_x:3*n_x, 3*n_x:] = Bmat

            # Create batch of C scaled by delta
            deltas_i = delta_all_batch[:, i].view(B, 1, 1)
            deltas_i = delta_all_batch[:, i].view(B, 1, 1)
            C_batch = C_base.unsqueeze(0) * deltas_i

            # Batched matrix exponential
            exp_C = torch.linalg.matrix_exp(C_batch)

            # Extract pieces
            F3 = exp_C[:, 2*n_x:3*n_x, 2*n_x:3*n_x]  # (B, n_x, n_x)
            G2 = exp_C[:, n_x:2*n_x, 2*n_x:3*n_x]  # (B, n_x, n_x)
            G3 = exp_C[:, 2*n_x:3*n_x, 3*n_x:]      # (B, n_x, m)
            H2 = exp_C[:, n_x:2*n_x, 3*n_x:]       # (B, n_x, m)
            K1 = exp_C[:, :n_x, 3*n_x:]            # (B, n_x, m)

            Ei_batch = F3
            Li_batch = torch.matmul(F3.transpose(-1, -2), G2)

            # phi_f_i = phi_f_i_ @ ui for each sample
            ui_batch = u_all_batch[:, i, :].view(B, n_u, 1) if n_u > 0 else None
            if n_u > 0:
                phi_f_i_ = G3  # (B, n_x, m)
                # phi_f: (B, n_x, 1)
                phi_f_batch = torch.matmul(phi_f_i_, ui_batch)

                # Mi = F3.T @ H2 -> (B, n_x, m)
                Mi_batch = torch.matmul(F3.transpose(-1, -2), H2)

                # Ri: temp = B.T @ F3.T @ K1  -> (B, m, m)
                # compute F3.T @ K1 -> (B, n_x, m)
                tmp = torch.matmul(F3.transpose(-1, -2), K1)
                # Bmat.T (m,n_x) @ tmp (B, n_x, m) -> (B, m, m)
                temp = torch.matmul(Bmat.T.unsqueeze(0), tmp)
                Ri_batch = temp + temp.transpose(-1, -2)
            else:
                phi_f_batch = torch.zeros((B, n_x, 1), device=device, dtype=dtype)
                Mi_batch = torch.zeros((B, n_x, 0), device=device, dtype=dtype)
                Ri_batch = torch.zeros((B, 0, 0), device=device, dtype=dtype)

            Es[i] = Ei_batch
            phi_fs[i] = phi_f_batch
            Lis[i] = Li_batch
            Mis[i] = Mi_batch
            Ris[i] = Ri_batch
        else:
            # Autonomous case: simpler (Ei depends only on delta)
            deltas_i = delta_all_batch[:, i].view(B, 1, 1)
            Ei_batch = torch.linalg.matrix_exp(A.unsqueeze(0) * deltas_i)
            Li_batch = torch.zeros((B, n_x, n_x), device=device, dtype=dtype)
            Es[i] = Ei_batch
            phi_fs[i] = torch.zeros((B, n_x, 1), device=device, dtype=dtype)
            Lis[i] = Li_batch
            Mis[i] = torch.zeros((B, n_x, 0), device=device, dtype=dtype)
            Ris[i] = torch.zeros((B, 0, 0), device=device, dtype=dtype)

    # Backward recursion to compute S0 per sample
    # Initialize S_prev as (B, n_x+1, n_x+1)
    E_aug = torch.zeros((n_x+1, n_x+1), device=device, dtype=dtype)
    E_aug[:n_x, :n_x] = Eterm
    S_prev = 0.5 * E_aug.unsqueeze(0).expand(B, n_x+1, n_x+1).clone()

    for i in range(n_ph-1, -1, -1):
        Ei_b = Es[i]
        phi_f_b = phi_fs[i]
        Li_b = Lis[i]
        Mi_b = Mis[i]
        Ri_b = Ris[i]

        # Build S_int batch
        S_int = torch.zeros((B, n_x+1, n_x+1), device=device, dtype=dtype)
        S_int[:, :n_x, :n_x] = Li_b

        if n_u > 0:
            ui_col = u_all_batch[:, i, :].view(B, n_u, 1)
            # Mi_b: (B, n_x, n_u) -> Mi_ui: (B, n_x, 1)
            Mi_ui = torch.matmul(Mi_b, ui_col)
            S_int[:, :n_x, n_x:] = Mi_ui
            S_int[:, n_x:, :n_x] = Mi_ui.transpose(-1, -2)
            # scalar term: ui^T Ri ui -> (B,1,1)
            tmp = torch.matmul(Ri_b, ui_col)  # (B, n_u, 1)
            uiRiui = torch.matmul(ui_col.transpose(-1, -2), tmp)  # (B,1,1)
            S_int[:, n_x:, n_x:] = uiRiui

        # Build phi batch (B, n_x+1, n_x+1)
        phi = torch.zeros((B, n_x+1, n_x+1), device=device, dtype=dtype)
        phi[:, :n_x, :n_x] = Ei_b
        phi[:, :n_x, n_x:n_x+1] = phi_f_b
        phi[:, -1, -1] = 1.0

        # S_curr = 0.5*S_int + phi^T * S_prev * phi
        S_curr = 0.5 * S_int + torch.matmul(phi.transpose(-1, -2), torch.matmul(S_prev, phi))
        S_prev = S_curr

    S0_batch = S_prev

    # Augment x0 for bilinear form
    x0_aug = torch.cat([x0_batch.view(B, n_x, 1), torch.ones((B, 1, 1), device=device, dtype=dtype)], dim=1)

    # Compute quadratic term: 0.5 * x0_aug^T * S0 * x0_aug -> (B,1,1)
    quad = torch.matmul(x0_aug.transpose(-1, -2), torch.matmul(S0_batch, x0_aug)).squeeze(-1).squeeze(-1)

    # Compute G per sample
    if n_u > 0:
        # u_all_batch: (B, n_ph, n_u)
        per_phase_terms = []
        for i in range(n_ph):
            u_b = u_all_batch[:, i, :]  # (B, n_u)
            # (B, n_u) @ (n_u,n_u) -> (B, n_u)
            uR = torch.matmul(u_b, R)
            per = (uR * u_b).sum(dim=1)  # (B,)
            per_phase_terms.append(0.5 * per * delta_all_batch[:, i])

        G0 = torch.stack(per_phase_terms, dim=1).sum(dim=1)  # (B,)
    else:
        G0 = torch.zeros(B, device=device, dtype=dtype)

    J_batch = 0.5 * quad + G0
    return J_batch

## Compute cost gradient

In [4]:
def evaluate_gradient_batch(
    swi: SwiLin,
    u_all_batch: torch.Tensor,
    delta_all_batch: torch.Tensor,
    x0_batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Vectorized evaluation of the gradient of the LQR-style cost over a batch.

    Args:
        swi: SwiLin instance (used for model matrices and helpers)
        u_all_batch: tensor shape (B, N_PHASES, n_inputs)
        delta_all_batch: tensor shape (B, N_PHASES)
        x0_batch: tensor shape (B, n_states)
    Returns:
        grad_u_batch: tensor shape (B, N_PHASES, n_inputs)
        grad_delta_batch: tensor shape (B, N_PHASES)
    """
    device = u_all_batch.device if torch.is_tensor(u_all_batch) else swi.device
    dtype = u_all_batch.dtype if torch.is_tensor(u_all_batch) else swi.dtype

    B = u_all_batch.shape[0]
    n_ph = swi.n_phases
    n_x = swi.n_states
    n_u = swi.n_inputs

    # Ensure tensors on correct device/dtype
    u_all_batch = u_all_batch.to(device=device, dtype=dtype)
    delta_all_batch = delta_all_batch.to(device=device, dtype=dtype).view(B, n_ph)
    x0_batch = x0_batch.to(device=device, dtype=dtype).view(B, n_x)
    
    time_steps = 256  # Number of steps for numerical integration
    num_steps = 256   # Number of steps for inner numerical integration
    
    # Ensure time_steps is a tensor constant for graph compatibility
    time_steps_t = torch.tensor(time_steps, dtype=dtype, device=device)
    num_steps_t = torch.tensor(num_steps, dtype=dtype, device=device)

    # Containers per phase (each element will be batch-shaped)
    # Pre-allocate tensors to maintain computational graph
    Es = []
    phi_fs = []
    Lis = []
    Mis = []
    Ris = []
    Hi = []
    
    # Container for S, C, D, and N matrices
    S = []
    C = []
    D = []
    N = []

    # Useful constants - extend Q to (n_x+1, n_x+1) for augmented state
    Q_base = swi.Q.to(dtype=dtype, device=device)
    R = swi.R.to(dtype=dtype, device=device) if n_u > 0 else None
    Eterm = swi.E_term.to(dtype=dtype, device=device)
    
    # Create augmented Q matrix (n_x+1, n_x+1) with Q in top-left corner
    Q = torch.zeros((n_x+1, n_x+1), dtype=dtype, device=device)
    Q[:n_x, :n_x] = Q_base

    # For each phase compute batched matrices
    for i in range(n_ph):
        A = swi.A[i].to(dtype=dtype, device=device)
        Bmat = swi.B[i].to(dtype=dtype, device=device) if n_u > 0 else None

        # Build big C matrix once (same across batch) as in _mat_exp_prop_exp
        if not swi.auto:
            m = n_u
            Mdim = 3 * n_x + m
            C_base = torch.zeros((Mdim, Mdim), dtype=dtype, device=device)
            C_base[:n_x, :n_x] = -A.T
            C_base[:n_x, n_x:2*n_x] = torch.eye(n_x, dtype=dtype, device=device)
            C_base[n_x:2*n_x, n_x:2*n_x] = -A.T
            C_base[n_x:2*n_x, 2*n_x:3*n_x] = Q_base
            C_base[2*n_x:3*n_x, 2*n_x:3*n_x] = A
            C_base[2*n_x:3*n_x, 3*n_x:] = Bmat
            
            # Create batch of C scaled by delta
            deltas_i = delta_all_batch[:, i].view(B, 1, 1)
            C_batch = C_base.unsqueeze(0) * deltas_i

            # Batched matrix exponential
            exp_C = torch.linalg.matrix_exp(C_batch)

            # Extract pieces
            F3 = exp_C[:, 2*n_x:3*n_x, 2*n_x:3*n_x]  # (B, n_x, n_x)
            G2 = exp_C[:, n_x:2*n_x, 2*n_x:3*n_x]  # (B, n_x, n_x)
            G3 = exp_C[:, 2*n_x:3*n_x, 3*n_x:]      # (B, n_x, m)
            H2 = exp_C[:, n_x:2*n_x, 3*n_x:]       # (B, n_x, m)
            K1 = exp_C[:, :n_x, 3*n_x:]            # (B, n_x, m)

            Ei_batch = F3
            Li_batch = torch.matmul(F3.transpose(-1, -2), G2)

            # phi_f_i = phi_f_i_ @ ui for each sample
            ui_batch = u_all_batch[:, i, :].view(B, n_u, 1) if n_u > 0 else None
            if n_u > 0:
                phi_f_i_ = G3  # (B, n_x, m)
                # phi_f: (B, n_x, 1)
                phi_f_batch = torch.matmul(phi_f_i_, ui_batch)

                # Mi = F3.T @ H2 -> (B, n_x, m)
                Mi_batch = torch.matmul(F3.transpose(-1, -2), H2)

                # Ri: temp = B.T @ F3.T @ K1  -> (B, m, m)
                # compute F3.T @ K1 -> (B, n_x, m)
                tmp = torch.matmul(F3.transpose(-1, -2), K1)
                # Bmat.T (m,n_x) @ tmp (B, n_x, m) -> (B, m, m)
                temp = torch.matmul(Bmat.T.unsqueeze(0), tmp)
                Ri_batch = temp + temp.transpose(-1, -2)
                
                # Create batched H matrix for this mode: shape (B, n_u, n_x+1, n_x+1)
                Hi_batch = torch.zeros((B, n_u, n_x+1, n_x+1), dtype=dtype, device=device)
                for k in range(n_u):
                    # phi_f_i_ has shape (B, n_x, n_u); put its k-th column into the top-right column
                    Hi_batch[:, k, :n_x, n_x] = phi_f_i_[:, :, k]
                
                # Compute the D matrix for this phase
                D_i = torch.zeros((B, n_u, n_x+1, n_x+1), dtype=dtype, device=device)
                
                # Get the delta for this specific phase
                deltas_i = delta_all_batch[:, i]  # (B,)
                
                # eta grid for integrating over [0, delta] - batched version
                # Shape: (time_steps + 1, B)
                eta_vals = torch.linspace(0, 1, steps=time_steps + 1, device=device, dtype=dtype)
                eta_grid = eta_vals.unsqueeze(1) * deltas_i.unsqueeze(0)  # (time_steps+1, B)
                d_eta = deltas_i / time_steps
                
                for ti in range(time_steps + 1):
                    eta = eta_grid[ti]  # (B,)
                    # phi_a_t = expm(A, eta)  - batched version
                    # A * eta needs broadcasting: A (n,n), eta (B,) -> (B,n,n)
                    A_scaled = A.unsqueeze(0) * eta.view(B, 1, 1)
                    phi_a = torch.linalg.matrix_exp(A_scaled)  # (B,n,n)

                    # phi_f_t = compute_integral(A, B, 0, eta) - batched
                    # s values for each batch element
                    s_vals = torch.linspace(0, 1, steps=num_steps + 1, device=A.device, dtype=A.dtype)
                    s_grid = s_vals.unsqueeze(1) * eta.unsqueeze(0)  # (num_steps+1, B)
                    ds = eta / num_steps  # (B,)

                    # exp(A*(eta - s_j)) B for each s_j - fully batched
                    # eta (B,), s_grid (num_steps+1, B) -> eta - s (num_steps+1, B)
                    eta_minus_s = eta.unsqueeze(0) - s_grid  # (num_steps+1, B)
                    # Need (num_steps+1, B, n, n) matrix exponentials
                    A_diff = A.unsqueeze(0).unsqueeze(0) * eta_minus_s.view(num_steps+1, B, 1, 1)
                    E_all = torch.linalg.matrix_exp(A_diff)  # (num_steps+1, B, n, n)
                    # E @ Bmat for all: (num_steps+1, B, n, n) @ (n, m) -> (num_steps+1, B, n, m)
                    vals = torch.matmul(E_all, Bmat.unsqueeze(0).unsqueeze(0))  # (num_steps+1, B, n, m)

                    # trapezoid along time dimension (dim=0)
                    trapz_weights = torch.ones(num_steps + 1, device=device, dtype=dtype)
                    trapz_weights[0] = 0.5
                    trapz_weights[-1] = 0.5
                    # Weighted sum: (num_steps+1, B, n, m) * (num_steps+1, 1, 1, 1)
                    weighted_vals = vals * trapz_weights.view(-1, 1, 1, 1)
                    phi_f_int = ds.view(1, B, 1, 1) * weighted_vals.sum(dim=0)  # (B, n, m)

                    # phi_t = transition_matrix(phi_a_t, phi_f_t@ui) - batched
                    # phi_f_int (B,n,m), ui (B,m,1) -> phi_fu (B,n,1)
                    phi_fu = torch.matmul(phi_f_int, ui_batch)  # (B, n, 1)
                    # Construct Phi (B, n_x+1, n_x+1)
                    Phi = torch.zeros((B, n_x + 1, n_x + 1), device=device, dtype=dtype)
                    Phi[:, :n_x, :n_x] = phi_a
                    Phi[:, :n_x, n_x] = phi_fu.squeeze(-1)
                    Phi[:, n_x, n_x] = 1.0

                    # trapezoid weight
                    w = 0.5 if (ti == 0 or ti == time_steps) else 1.0

                    # For each control channel k, form Hij and integrand - vectorized
                    # phi_f_int has shape (B, n_x, m)
                    for k in range(n_u):
                        # Hij (B, n_x+1, n_x+1) with last column from kth column of phi_f_int
                        Hij = torch.zeros((B, n_x + 1, n_x + 1), device=device, dtype=dtype)
                        Hij[:, :n_x, n_x] = phi_f_int[:, :, k]  # (B, n_x)

                        # arg = Hij^T @ Q @ Phi + Phi^T @ Q @ Hij - batched matmul
                        # Q is (n_x+1, n_x+1), broadcast to batch
                        Q_ext = Q.unsqueeze(0)  # (1, n_x+1, n_x+1)
                        term1 = torch.matmul(torch.matmul(Hij.transpose(-2, -1), Q_ext), Phi)  # (B, n_x+1, n_x+1)
                        term2 = torch.matmul(torch.matmul(Phi.transpose(-2, -1), Q_ext), Hij)  # (B, n_x+1, n_x+1)
                        arg = term1 + term2

                        integrand = 0.5 * arg  # (B, n_x+1, n_x+1)
                        D_i[:, k] = D_i[:, k] + w * integrand

                # Finish trapezoid integration over eta - batched
                D_i = d_eta.view(B, 1, 1, 1) * D_i  # (B, n_u, n_x+1, n_x+1)
                
            else:
                phi_f_batch = torch.zeros((B, n_x, 1), device=device, dtype=dtype)
                Mi_batch = torch.zeros((B, n_x, 0), device=device, dtype=dtype)
                Ri_batch = torch.zeros((B, 0, 0), device=device, dtype=dtype)
                Hi_batch = torch.zeros((B, 0, n_x+1, n_x+1), dtype=dtype, device=device)
                D_i = torch.zeros((B, 0, n_x+1, n_x+1), dtype=dtype, device=device)

            Es.append(Ei_batch)
            phi_fs.append(phi_f_batch)
            Lis.append(Li_batch)
            Mis.append(Mi_batch)
            Ris.append(Ri_batch)
            Hi.append(Hi_batch)
            D.append(D_i)
        else:
            # Autonomous case: simpler (Ei depends only on delta)
            deltas_i = delta_all_batch[:, i].view(B, 1, 1)
            Ei_batch = torch.linalg.matrix_exp(A.unsqueeze(0) * deltas_i)
            Li_batch = torch.zeros((B, n_x, n_x), device=device, dtype=dtype)
            Es.append(Ei_batch)
            phi_fs.append(torch.zeros((B, n_x, 1), device=device, dtype=dtype))
            Lis.append(Li_batch)
            Mis.append(torch.zeros((B, n_x, 0), device=device, dtype=dtype))
            Ris.append(torch.zeros((B, 0, 0), device=device, dtype=dtype))
            Hi.append(torch.zeros((B, 0, n_x+1, n_x+1), dtype=dtype, device=device))
            D.append(torch.zeros((B, 0, n_x+1, n_x+1), dtype=dtype, device=device))
            
    # Backward recursion to compute S0 per sample
    # Initialize S_prev as (B, n_x+1, n_x+1)
    E_aug = torch.zeros((n_x+1, n_x+1), device=device, dtype=dtype)
    E_aug[:n_x, :n_x] = Eterm
    S_terminal = 0.5 * E_aug.unsqueeze(0).expand(B, n_x+1, n_x+1).clone()
    
    # Pre-allocate S list with terminal condition at the end
    S_list = [None] * (n_ph + 1)
    S_list[n_ph] = S_terminal
    S_prev = S_terminal

    for i in range(n_ph-1, -1, -1):
        Ei_b = Es[i]
        phi_f_b = phi_fs[i]
        Li_b = Lis[i]
        Mi_b = Mis[i]
        Ri_b = Ris[i]

        # Build S_int batch
        S_int = torch.zeros((B, n_x+1, n_x+1), device=device, dtype=dtype)
        S_int[:, :n_x, :n_x] = Li_b

        if n_u > 0:
            ui_col = u_all_batch[:, i, :].view(B, n_u, 1)
            # Mi_b: (B, n_x, n_u) -> Mi_ui: (B, n_x, 1)
            Mi_ui = torch.matmul(Mi_b, ui_col)
            S_int[:, :n_x, n_x:] = Mi_ui
            S_int[:, n_x:, :n_x] = Mi_ui.transpose(-1, -2)
            # scalar term: ui^T Ri ui -> (B,1,1)
            tmp = torch.matmul(Ri_b, ui_col)  # (B, n_u, 1)
            uiRiui = torch.matmul(ui_col.transpose(-1, -2), tmp)  # (B,1,1)
            S_int[:, n_x:, n_x:] = uiRiui

        # Build phi batch (B, n_x+1, n_x+1)
        phi = torch.zeros((B, n_x+1, n_x+1), device=device, dtype=dtype)
        phi[:, :n_x, :n_x] = Ei_b
        phi[:, :n_x, n_x:n_x+1] = phi_f_b
        phi[:, -1, -1] = 1.0

        # S_curr = 0.5*S_int + phi^T * S_prev * phi
        S_curr = 0.5 * S_int + torch.matmul(phi.transpose(-1, -2), torch.matmul(S_prev, phi))
        S_list[i] = S_curr
        S_prev = S_curr
        
    S0_batch = S_prev

    # Compute the C and N matrices
    C_list = []
    N_list = []
    
    for i in range(n_ph):
        A = swi.A[i].to(dtype=dtype, device=device)
        Bmat = swi.B[i].to(dtype=dtype, device=device) if n_u > 0 else None
        # Build batched F to preserve autograd (shape: B x (n_x+1) x (n_x+1))
        F = torch.zeros((B, n_x+1, n_x+1), dtype=dtype, device=device)
        # Top-left block: A (broadcasted across batch)
        F[:, :n_x, :n_x] = A.unsqueeze(0).expand(B, n_x, n_x)
        # Top-right column: B @ u (batched)
        if n_u > 0:
            ui_col = u_all_batch[:, i, :].view(B, n_u, 1)
            F[:, :n_x, n_x:n_x+1] = torch.matmul(Bmat.unsqueeze(0), ui_col)
        
        # Extract the S matrix of the next phase
        S_next = S_list[i+1]
        H_i = Hi[i]
        
        # C_i: batched computation
        Q_batch = Q.unsqueeze(0)  # (1, n_x+1, n_x+1)
        C_i = 0.5 * Q_batch + torch.matmul(F.transpose(-2, -1), S_next) + torch.matmul(S_next, F)
        C_list.append(C_i)
        
        # N matrices for each control input
        N_i_list = []
        for j in range(n_u):
            Hij = H_i[:, j, :, :]  # (B, n_x+1, n_x+1)
            # Compute N matrix
            Nij = torch.matmul(Hij.transpose(-2, -1), S_next) + torch.matmul(S_next, Hij)
            N_i_list.append(Nij)
        
        if n_u > 0:
            N_list.append(torch.stack(N_i_list, dim=1))  # (B, n_u, n_x+1, n_x+1)
        else:
            N_list.append(torch.zeros((B, 0, n_x+1, n_x+1), dtype=dtype, device=device))
    
    # Compute gradients from C and N matrices
    grad_u_batch = torch.zeros((B, n_ph, n_u), dtype=dtype, device=device)
    grad_delta_batch = torch.zeros((B, n_ph), dtype=dtype, device=device)
    
    for i in range(n_ph):
        C_i = C_list[i]  # (B, n_x+1, n_x+1)
        N_i = N_list[i]  # (B, n_u, n_x+1, n_x+1)
        D_i = D[i]  # (B, n_u, n_x+1, n_x+1)
        
        # Build augmented state for this phase
        # We need to track state through phases - simplified version
        # grad_u is derived from C matrix structure
        if n_u > 0:
            # Extract gradients w.r.t. controls from C matrix
            # C has contribution from control inputs in the (n_x, n_x+1) positions
            grad_u_batch[:, i, :] = C_i[:, :n_x, n_x].sum(dim=1, keepdim=True).expand(-1, n_u)
            
            # grad_delta from N and D matrices
            # Simplified: trace of N and D contributions
            for j in range(n_u):
                grad_delta_batch[:, i] += (N_i[:, j] * D_i[:, j]).sum(dim=(-2, -1))
    
    return grad_u_batch, grad_delta_batch

## Custom Autograd Function for Hybrid Gradient Computation

In [5]:
class CostFunctionalWithAnalyticGradient(torch.autograd.Function):
    """
    Custom autograd function that computes the cost functional in the forward pass
    and uses analytical gradients in the backward pass.
    
    This allows PyTorch to handle the gradient flow through earlier layers
    (like the network and transformations) while using your analytical gradient
    for the cost functional itself.
    """
    
    # Class variable to store the last computed gradients for logging
    last_grad_u = None
    last_grad_delta = None
    
    @staticmethod
    def forward(ctx, swi, controls, deltas, x0):
        """
        Forward pass: compute the cost functional
        
        Args:
            ctx: context object to save data for backward pass
            swi: SwiLin system
            controls: (B, N_PHASES, n_inputs)
            deltas: (B, N_PHASES)
            x0: (B, n_states)
        """
        # Save for backward
        ctx.swi = swi
        ctx.save_for_backward(controls, deltas, x0)
        
        # Compute cost using existing function
        J_batch = evaluate_cost_functional_batch(swi, controls, deltas, x0)
        return J_batch
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: use analytical gradients
        
        Args:
            grad_output: gradient of output w.r.t. params (B, n_outputs, n_parameters)
        
        Returns:
            Gradients w.r.t. inputs (None for swi, then controls, deltas, x0)
        """
        controls, deltas, x0 = ctx.saved_tensors
        swi = ctx.swi
        
        # Compute analytical gradients
        grad_u_batch, grad_delta_batch = evaluate_gradient_batch(
            swi, controls, deltas, x0
        )  # (B, N_PHASES, n_inputs), (B, N_PHASES)
        
        # Store for logging (store detached copies to avoid memory issues)
        CostFunctionalWithAnalyticGradient.last_grad_u = grad_u_batch.detach().clone()
        CostFunctionalWithAnalyticGradient.last_grad_delta = grad_delta_batch.detach().clone()
        
        # Chain rule: multiply by grad_output from upstream
        # grad_output has shape (B,), need to reshape for broadcasting
        B = grad_output.shape[0]
        grad_output = B * grad_output.view(-1, 1, 1)  # (B, 1, 1)
        
        grad_controls = grad_u_batch * grad_output  # (B, N_PHASES, n_inputs)
        
        grad_output = grad_output.squeeze(-1)  # (B, 1)
        grad_deltas = grad_delta_batch * grad_output  # (B, N_PHASES)
        
        # No gradient for x0 in this formulation (it's the initial condition)
        # If you need gradients w.r.t. x0, you can compute them analytically too
        grad_x0 = None
        
        # Return gradients in same order as forward inputs
        # (None for swi since it's not a tensor)
        return None, grad_controls, grad_deltas, grad_x0


# Convenient wrapper function
def compute_cost_with_analytic_grad(swi, controls, deltas, x0):
    """
    Compute cost functional using analytical gradients in backward pass.
    
    Args:
        swi: SwiLin system
        controls: (B, N_PHASES, n_inputs)
        deltas: (B, N_PHASES)
        x0: (B, n_states)
    
    Returns:
        J_batch: (B,) cost for each sample
    """
    return CostFunctionalWithAnalyticGradient.apply(swi, controls, deltas, x0)

### How Hybrid Gradients Work

**Key Concept**: PyTorch's autograd system allows you to define custom backward passes using `torch.autograd.Function`.

**In this implementation:**

1. **Forward Pass**: Computes the cost functional normally
   ```python
   J_batch = evaluate_cost_functional_batch(swi, controls, deltas, x0)
   ```

2. **Backward Pass**: Uses your analytical gradient
   ```python
   grad_u, grad_delta = evaluate_gradient_batch(swi, controls, deltas, x0)
   ```

3. **Automatic Chain Rule**: PyTorch automatically chains your analytical gradient with the autograd gradients from:
   - Network layers (weights, biases)
   - Transformations (tanh clipping, softmax, etc.)

**Gradient Flow:**
```
Input (x0) 
  → Network(x0) 
  → Tanh clipping [autograd]
  → Softmax normalization [autograd]
  → Cost functional [analytical gradient]
  → Loss
```

**Benefits:**
- ✅ Combine analytical efficiency with automatic differentiation convenience
- ✅ No manual chain rule implementation needed
- ✅ Easy to switch between modes for debugging/comparison

## Training

In [14]:
def train_neural_network_hybrid_gradient(
        network: SwiLinNN,
        X_train: torch.Tensor,
        y_train: Optional[torch.Tensor] = None,
        X_val: Optional[torch.Tensor] = None,
        y_val: Optional[torch.Tensor] = None,
        optimizer: str = 'adam',
        learning_rate: float = 0.001,
        weight_decay: float = 1e-4,
        n_epochs: int = 100,
        batch_size: int = 32,
        device: str = 'cpu',
        use_analytic_gradient: bool = True,  # NEW: toggle analytical vs autograd
        # Resampling options: regenerate new random samples every N epochs
        resample_every: Optional[int] = None,
        resample_fn: Optional[Callable[[int], torch.Tensor]] = None,
        resample_val: bool = False,
        verbose: bool = True,
        tensorboard_logdir: Optional[str] = None,
        log_histograms: bool = False,
        log_gradients: bool = True,  # NEW: log analytical gradients to TensorBoard
        save_history: bool = False,
        save_history_path: Optional[str] = None,
        save_model: bool = False,
        save_model_path: Optional[str] = None,
        early_stopping: bool = False,
        early_stopping_patience: int = 20,
        early_stopping_min_delta: float = 1e-6,
        early_stopping_monitor: str = 'val_loss',
    ) -> Tuple[torch.Tensor, Dict]:
    """
    Train the neural network using hybrid gradients:
    - PyTorch autograd for network and transformations
    - Analytical gradients for cost functional (optional)
    
    Parameters
    ----------
    network : SwiLinNN
        The neural network to train
    X_train : torch.Tensor
        Training input data
    use_analytic_gradient : bool, optional
        If True, use analytical gradient for cost functional.
        If False, use PyTorch autograd for everything.
    log_gradients : bool, optional
        If True and use_analytic_gradient=True, log analytical gradients to TensorBoard
    ... (other parameters same as before)
        
    Returns
    -------
    Tuple[torch.Tensor, Dict]
        The trained model and training history
    """
    
    network = network.to(device)
    X_train = X_train.to(device)
    
    if X_val is not None:
        X_val = X_val.to(device)

    # Setup a default resampling function if requested but none provided.
    if resample_every is not None and resample_every > 0 and resample_fn is None:
        try:
            x_min = -5.0
            x_max = 5.0
        except Exception:
            x_min, x_max = -1.0, 1.0

        def _default_resample_fn(epoch, shape=X_train.shape, dtype=X_train.dtype, device_str=device, xmin=x_min, xmax=x_max):
            dev = device_str
            out = torch.empty(shape, dtype=dtype, device=dev).uniform_(xmin, xmax)
            return out

        resample_fn = _default_resample_fn
    
    n_samples = X_train.shape[0]
    n_inputs = network.sys.n_inputs
    
    # Initialize PyTorch optimizer
    if optimizer.lower() == 'adam':
        torch_optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer.lower() == 'sgd':
        torch_optimizer = torch.optim.SGD(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer.lower() == 'rmsprop':
        torch_optimizer = torch.optim.RMSprop(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer '{optimizer}'. Supported: 'adam', 'sgd', 'rmsprop'")
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        torch_optimizer,
        mode='min',
        factor=0.5,
        patience=10,
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [] if X_val is not None else None,
        'epochs': [],
        'gradient_mode': 'analytic' if use_analytic_gradient else 'autograd'
    }
    
    # Early stopping setup
    if early_stopping:
        if early_stopping_monitor == 'val_loss' and X_val is None:
            warnings.warn("Early stopping monitor is 'val_loss' but no validation data provided. Switching to 'train_loss'.")
            early_stopping_monitor = 'train_loss'
        
        best_loss = float('inf')
        best_epoch = 0
        patience_counter = 0
        best_model_state = None
        
        if verbose:
            print(f"Early stopping enabled: monitoring '{early_stopping_monitor}' with patience={early_stopping_patience}, min_delta={early_stopping_min_delta}")
            print(f"Gradient mode: {'Analytical' if use_analytic_gradient else 'PyTorch Autograd'}")
    
    # Setup TensorBoard writer if requested
    writer = SummaryWriter(log_dir=tensorboard_logdir) if tensorboard_logdir is not None else None

    # Determine history save path
    if save_history:
        if save_history_path is None:
            if tensorboard_logdir is not None:
                save_history_path = os.path.join(tensorboard_logdir, 'history.json')
            else:
                save_history_path = os.path.join(os.getcwd(), 'training_history.json')


    # Training loop
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        n_batches = 0
        
        # Optionally resample training (and validation) data
        if resample_every is not None and resample_every > 0 and epoch > 0 and (epoch % resample_every) == 0:
            if resample_fn is None:
                warnings.warn("resample_every set but resample_fn is None; skipping resampling.")
            else:
                try:
                    new_data = resample_fn(epoch)
                    if isinstance(new_data, (list, tuple)) and len(new_data) == 2:
                        new_X_train, new_X_val = new_data
                    else:
                        new_X_train, new_X_val = new_data, None

                    if not torch.is_tensor(new_X_train):
                        new_X_train = torch.as_tensor(new_X_train)
                    X_train = new_X_train.to(device)
                    n_samples = X_train.shape[0]

                    if resample_val and new_X_val is not None:
                        if not torch.is_tensor(new_X_val):
                            new_X_val = torch.as_tensor(new_X_val)
                        X_val = new_X_val.to(device)

                    if verbose:
                        print(f"Resampled training data at epoch {epoch + 1}")
                except Exception as e:
                    warnings.warn(f"Resampling failed at epoch {epoch + 1}: {e}")

        # Create random batches
        indices = torch.randperm(n_samples, device=device)
        
        for start_idx in range(0, n_samples, batch_size):
            end_idx = min(start_idx + batch_size, n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            X_batch = X_train[batch_indices]
            current_batch_size = X_batch.shape[0]
            
            # Zero gradients
            torch_optimizer.zero_grad()
            
            # Forward pass through network
            output, _ = network(X_batch)
            
            # Apply transformations: these gradients are computed by PyTorch autograd
            T_tensor = torch.tensor(network.sys.time_horizon, device=output.device, dtype=output.dtype)

            n_control_outputs = network.n_phases * n_inputs
            controls = output[:, :n_control_outputs]
            delta_raw = output[:, n_control_outputs:]
            
            # Diffeomorphism: fix last delta to zero
            last = delta_raw[:, -1:]
            delta_raw_traslated = delta_raw - last
            
            # Softmax normalization
            delta_normalized = F.softmax(delta_raw_traslated, dim=-1)
            deltas = delta_normalized * T_tensor
            
            # Tanh-based soft clipping
            u_min = -1.0
            u_max = 1.0
            u_center = (u_max + u_min) / 2.0
            u_range = (u_max - u_min) / 2.0
            controls = u_center + u_range * torch.tanh(controls)
            
            # Reshape for cost computation
            B_batch = current_batch_size
            controls_reshaped = controls.view(B_batch, network.n_phases, n_inputs)
            deltas_batch = deltas.view(B_batch, network.n_phases)
            x0_batch = X_batch

            # Compute loss - choose between analytical or autograd
            if use_analytic_gradient:
                # Use custom autograd function with analytical gradient
                J_batch = compute_cost_with_analytic_grad(
                    network.sys, controls_reshaped, deltas_batch, x0_batch
                )
            else:
                # Use pure PyTorch autograd
                J_batch = evaluate_cost_functional_batch(
                    network.sys, controls_reshaped, deltas_batch, x0_batch
                )
            
            loss = J_batch.mean()
            
            # Backward pass - PyTorch handles gradient flow automatically
            # If use_analytic_gradient=True: analytical grad for cost, autograd for transformations
            # If use_analytic_gradient=False: autograd for everything
            loss.backward()
            
            # Compute gradient norm for logging
            grad_norm = None
            if writer is not None:
                tot = torch.tensor(0.0, device=device)
                for p in network.parameters():
                    if p.grad is not None:
                        tot = tot + p.grad.detach().to(device).pow(2).sum()
                grad_norm = torch.sqrt(tot).item()

            # Optimizer step
            torch_optimizer.step()

            # Log per-batch stats to TensorBoard
            if writer is not None:
                global_step = epoch * max(1, n_samples // batch_size) + n_batches
                writer.add_scalar('train/batch_loss', loss.item(), global_step)
                if grad_norm is not None:
                    writer.add_scalar('train/batch_grad_norm', grad_norm, global_step)
                
                # Log analytical gradients if available
                if log_gradients and use_analytic_gradient:
                    if CostFunctionalWithAnalyticGradient.last_grad_u is not None:
                        grad_u = CostFunctionalWithAnalyticGradient.last_grad_u
                        grad_delta = CostFunctionalWithAnalyticGradient.last_grad_delta
                        
                        # Log statistics of gradients wrt all controls and deltas
                        for i in range(network.n_phases):
                            for j in range(n_inputs):
                                grad_u_ij = grad_u[:, i, j]
                                writer.add_scalar(f'gradients/grad_u_phase{i}_input{j}', grad_u_ij.mean().item(), global_step)  # Mean over batch
                            grad_delta_i = grad_delta[:, i] 
                            writer.add_scalar(f'gradients/grad_delta_phase{i}', grad_delta_i.mean().item(), global_step) # Mean over batch
                        
                        # Handle batch-aware pseudoinverse logging
                        grad_u_flat = grad_u.view(B_batch, -1)
                        grad_delta_flat = grad_delta.view(B_batch, -1)
                        grad_output = torch.cat([grad_u_flat, grad_delta_flat], dim=-1)  # (B, M)

                        # Compute pseudoinverse of the (B x M) matrix -> (M x B)
                        grad_output_pinv = torch.linalg.pinv(grad_output)

                        # Build a sensible per-sample target vector: use per-sample output gradient norms
                        y = grad_output.norm(dim=1)  # (B,)

                        # Map target into output-coefficient space: (M x B) @ (B,) -> (M,)
                        output_coeffs = grad_output_pinv @ y

                        # Log shapes for debug (optional) and histogram of coefficients
                        if writer is not None:
                            writer.add_histogram('gradients/output_coeffs', output_coeffs.cpu().numpy(), global_step)
                                                
                        # Log histograms of gradients (optional, less frequent to reduce overhead)
                        if log_histograms and n_batches % 10 == 0:
                            writer.add_histogram('gradients/grad_u_hist', grad_u.cpu().numpy(), global_step)
                            writer.add_histogram('gradients/grad_delta_hist', grad_delta.cpu().numpy(), global_step)
            
            epoch_loss += loss.item()
            n_batches += 1
        
        # Average loss for the epoch
        avg_train_loss = epoch_loss / n_batches
        history['train_loss'].append(avg_train_loss)
        history['epochs'].append(epoch)
        
        # Validation loss
        if X_val is not None:
            with torch.no_grad():
                val_output, _ = network(X_val)
                
                # Transform validation output
                n_control_outputs = network.n_phases * n_inputs
                val_controls = val_output[:, :n_control_outputs]
                u_min = -1.0
                u_max = 1.0
                u_center = (u_max + u_min) / 2.0
                u_range = (u_max - u_min) / 2.0
                val_controls = u_center + u_range * torch.tanh(val_controls)
                val_delta_raw = val_output[:, n_control_outputs:]
                val_delta_raw_last = val_delta_raw[:, -1:]
                val_delta_raw_traslated = val_delta_raw - val_delta_raw_last
                val_delta_normalized = F.softmax(val_delta_raw_traslated, dim=-1)
                val_deltas = val_delta_normalized * T_tensor
                
                # Vectorized validation loss (always use evaluate_cost_functional_batch)
                Bv = X_val.shape[0]
                val_controls = val_controls.view(Bv, network.n_phases, n_inputs)
                val_deltas = val_deltas.view(Bv, network.n_phases)
                J_val = evaluate_cost_functional_batch(network.sys, val_controls, val_deltas, X_val)
                avg_val_loss = J_val.mean().item()
                history['val_loss'].append(avg_val_loss)
        
        # Write epoch-level scalars to TensorBoard
        if writer is not None:
            writer.add_scalar('train/epoch_loss', avg_train_loss, epoch)
            writer.add_scalar('train/learning_rate', torch_optimizer.param_groups[0]['lr'], epoch)
            if X_val is not None:
                writer.add_scalar('val/epoch_loss', avg_val_loss, epoch)
                
                
            # if log_histograms:
            #     for name, param in network.named_parameters():
            #         writer.add_histogram(f'params/{name}', param.detach().cpu().numpy(), epoch)

        # Save history to disk each epoch if requested
        if save_history:
            try:
                serial = {}
                for k, v in history.items():
                    if v is None:
                        serial[k] = None
                    elif isinstance(v, list):
                        serial[k] = [float(x) for x in v]
                    else:
                        serial[k] = v
                os.makedirs(os.path.dirname(save_history_path), exist_ok=True)
                with open(save_history_path, 'w') as fh:
                    json.dump(serial, fh, indent=2)
            except Exception:
                warnings.warn(f"Failed to save training history to {save_history_path}")
        
        # Print progress
        if verbose and (epoch + 1) % max(1, n_epochs // 10) == 0:
            if X_val is not None:
                print(f"Epoch {epoch + 1}/{n_epochs} - Train Loss: {avg_train_loss:.6f} - Val Loss: {avg_val_loss:.6f}")
            else:
                print(f"Epoch {epoch + 1}/{n_epochs} - Train Loss: {avg_train_loss:.6f}")
        
        # Early stopping check
        if early_stopping:
            current_loss = avg_val_loss if early_stopping_monitor == 'val_loss' else avg_train_loss
            
            if current_loss < best_loss - early_stopping_min_delta:
                best_loss = current_loss
                best_epoch = epoch
                patience_counter = 0
                best_model_state = {k: v.cpu().clone() for k, v in network.state_dict().items()}
                if verbose and epoch > 0:
                    print(f"  → New best {early_stopping_monitor}: {best_loss:.6f}")
            else:
                patience_counter += 1
                if verbose and patience_counter > 0 and (epoch + 1) % max(1, n_epochs // 10) == 0:
                    print(f"  → No improvement for {patience_counter} epoch(s)")
            
            if patience_counter >= early_stopping_patience:
                if verbose:
                    print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                    print(f"Best {early_stopping_monitor}: {best_loss:.6f} at epoch {best_epoch + 1}")
                
                if best_model_state is not None:
                    network.load_state_dict(best_model_state)
                    if verbose:
                        print("Restored best model weights")
                
                break
    
    # Get final parameters
    params_optimized = network.get_flat_params()
    
    # Optionally save the trained model parameters
    if save_model:
        if save_model_path is None:
            if tensorboard_logdir is not None:
                save_model_path = os.path.join(tensorboard_logdir, 'model_state_dict.pt')
            else:
                save_model_path = os.path.join(os.getcwd(), 'model_state_dict.pt')
        try:
            network.save(save_model_path)
            if verbose:
                print(f"Saved model state_dict to: {save_model_path}")
        except Exception:
            warnings.warn(f"Failed to save model to {save_model_path}")

    # Add early stopping info to history
    if early_stopping:
        history['early_stopping'] = {
            'triggered': patience_counter >= early_stopping_patience,
            'best_epoch': best_epoch,
            'best_loss': best_loss,
            'monitored_metric': early_stopping_monitor,
            'patience': early_stopping_patience,
            'final_epoch': epoch
        }

    # Print final losses
    if verbose:
        print(f"\nFinal Training Loss: {history['train_loss'][-1]:.6f}")
        if X_val is not None and history['val_loss']:
            print(f"Final Validation Loss: {history['val_loss'][-1]:.6f}")
        if early_stopping and history.get('early_stopping', {}).get('triggered', False):
            print(f"\nEarly stopping was triggered:")
            print(f"  Best {early_stopping_monitor}: {best_loss:.6f} at epoch {best_epoch + 1}")
            print(f"  Training stopped at epoch {epoch + 1}")

    return params_optimized, history


# Keep the old function name as an alias
train_neural_network_analytic_gradient = train_neural_network_hybrid_gradient

## Example

In [15]:
print("=" * 70)
print("Example: Neural Network Training")
print("=" * 70)

# Generate synthetic data
torch.manual_seed(42)
n_samples_train = 1000
n_samples_val = 10

X_train = torch.empty(n_samples_train, N_NN_INPUTS).uniform_(-5.0, 5.0)


X_val = torch.empty(n_samples_val, N_NN_INPUTS).uniform_(-5.0, 5.0)

# Create network
network = SwiLinNN(
    layer_sizes=[N_NN_INPUTS, 32, 128, 128, N_NN_OUTPUTS],
    n_phases=N_PHASES,
    activation='relu',
    output_activation='linear',
    # nonnegative_weights=[0, 1, 2, 3],
)

# Train
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Store the path where the script is located
# In a Jupyter notebook __file__ is not defined, fall back to the current working directory
try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    script_dir = os.getcwd()

date = subprocess.check_output(['date', '+%Y%m%d_%H%M%S']).decode('utf-8').strip()
tensorboard_logdir = os.path.join(script_dir, "..", "logs", date)
# Extract hidden layer sizes from network architecture
hidden_layers = network.layer_sizes[1:-1]  # Exclude input and output layers
hidden_str = "_".join(map(str, hidden_layers))
model_name = f"analytical_grad_{hidden_str}_torch_{date}.pt"
models_dir = os.path.join(script_dir, "..", "models", model_name)

params_opt, history = train_neural_network_analytic_gradient(
    network=network,
    X_train=X_train,
    # y_train=None,
    X_val=X_val,
    # y_val=None,
    optimizer='adam',
    learning_rate=0.001,
    weight_decay=1e-4,
    n_epochs=200,
    use_analytic_gradient=True,
    resample_every=None,
    resample_fn=None,
    resample_val=False,
    early_stopping=False,
    early_stopping_patience=30,
    early_stopping_min_delta=1e-4,
    batch_size=n_samples_train,
    device=device,
    verbose=True,
    tensorboard_logdir=tensorboard_logdir,
    log_histograms=True,
    log_gradients=True,  # Enable gradient logging
    save_model=True,
    save_model_path=models_dir
    
)

print("\nTraining complete!")
print(f"\nTo view TensorBoard logs, run:")
print(f"  tensorboard --logdir={tensorboard_logdir}")

Example: Neural Network Training
Using device: cuda
Epoch 20/200 - Train Loss: 0.528830 - Val Loss: 0.326347
Epoch 40/200 - Train Loss: 0.645790 - Val Loss: 0.416488
Epoch 60/200 - Train Loss: 0.744605 - Val Loss: 0.452606
Epoch 80/200 - Train Loss: 0.765710 - Val Loss: 0.459856
Epoch 100/200 - Train Loss: 0.766443 - Val Loss: 0.458134
Epoch 120/200 - Train Loss: 0.768056 - Val Loss: 0.457650
Epoch 140/200 - Train Loss: 0.773994 - Val Loss: 0.459332
Epoch 160/200 - Train Loss: 0.778571 - Val Loss: 0.459864
Epoch 180/200 - Train Loss: 0.785558 - Val Loss: 0.457988
Epoch 200/200 - Train Loss: 0.790538 - Val Loss: 0.458214
Saved model state_dict to: /home/pietro/data-driven/learning_optimization/dump/../models/analytical_grad_32_128_128_torch_20260125_190401.pt

Final Training Loss: 0.790538
Final Validation Loss: 0.458214

Training complete!

To view TensorBoard logs, run:
  tensorboard --logdir=/home/pietro/data-driven/learning_optimization/dump/../logs/20260125_190401
