# The aim of this file is to inspect the behavior of the NN's params with respect to the gradient of the cost function w.r.t. u and $\delta$

In [None]:
import numpy as np
import os, subprocess, sys
import scipy.io
from scipy.linalg import solve_continuous_are
from scipy.special import softmax
from typing import Optional, Callable, Tuple, Dict, List
import time
import warnings
import json
import matplotlib.pyplot as plt

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.tensorboard import SummaryWriter
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    warnings.warn("PyTorch not available. GPU training will not be available.")
    
from ocslc.switched_linear_mpc import SwitchedLinearMPC as SwiLin_casadi

from src.switched_linear_torch import SwiLin
from src.training import SwiLinNN

## Set environment variables

In [None]:
# Global settings
N_PHASES = 80
TIME_HORIZON = 10.0

# NN settings
N_CONTROL_INPUTS = 1
N_STATES = 3
N_NN_INPUTS = 3
N_NN_OUTPUTS = N_PHASES * (N_CONTROL_INPUTS + 1)  # +1 for the mode

# Casadi settings
MULTIPLE_SHOOTING = True
INTEGRATOR = 'exp'
HYBRID = False
PLOT = 'display'

## Compute cost gradient

In [None]:
# Gradient function
def gradient_func(params, indices=None, data=None):
    network.set_flat_params(params)
    network.zero_grad()
    
    if indices is not None:
        X_batch = X_train[indices]
        y_batch = y_train[indices]
    else:
        X_batch = X_train
        y_batch = y_train
    
    output = network(X_batch)
    
    # Compute Jacobian: derivative of each output w.r.t. parameters
    # Sum the output over the batch dimension to get a scalar loss
    jacobian = []
    for i in range(output.shape[1] if output.dim() > 1 else 1):
        network.zero_grad()
        if output.dim() > 1:
            output[:, i].sum().backward(retain_graph=True)
        else:
            output.sum().backward(retain_graph=True)
        
        grads_i = []
        for param in network.parameters():
            if param.grad is not None:
                grads_i.append(param.grad.view(-1).clone())
        jacobian.append(torch.cat(grads_i))

    jacobian = torch.stack(jacobian)  # Shape: (n_outputs, n_params)
    
    # Include the derivative of the loss w.r.t. the outputs of the NN
    
    if network.output_activation == 'softmax' and y_batch.dtype == torch.long:
        loss = criterion(output, y_batch)
    else:
        loss = criterion(output, y_batch)
    
    print(f"Loss: {loss.item()}")
    print(f"Loss shape: {loss.shape}")
    input("Press Enter to continue...")
    loss.backward()
    
    # Collect gradients
    grads = []
    for param in network.parameters():
        if param.grad is not None:
            grads.append(param.grad.view(-1))

    pippo = torch.cat(grads)
    print(f"Gradient shape: {pippo.shape}")
    input("Press Enter to continue...")
    return torch.cat(grads)

## Train Neural Network with Analytic Gradient

In [None]:
def train_neural_network_analytic_gradient(
        network: SwiLinNN,
        X_train: torch.Tensor,
        y_train: Optional[torch.Tensor] = None,
        X_val: Optional[torch.Tensor] = None,
        y_val: Optional[torch.Tensor] = None,
        optimizer: str = 'adam',
        learning_rate: float = 0.001,
        weight_decay: float = 1e-4,
        n_epochs: int = 100,
        batch_size: int = 32,
        device: str = 'cpu',
        # Resampling options: regenerate new random samples every N epochs
        resample_every: Optional[int] = None,
        resample_fn: Optional[Callable[[int], torch.Tensor]] = None,
        resample_val: bool = False,
        verbose: bool = True,
        tensorboard_logdir: Optional[str] = None,
        log_histograms: bool = False,
        save_history: bool = False,
        save_history_path: Optional[str] = None,
        save_model: bool = False,
        save_model_path: Optional[str] = None,
        early_stopping: bool = False,
        early_stopping_patience: int = 20,
        early_stopping_min_delta: float = 1e-6,
        early_stopping_monitor: str = 'val_loss',
    ) -> Tuple[torch.Tensor, Dict]:
    """
    Train the neural network using analytic gradients
    
    Parameters
    ----------
    network : SwiLinNN
        The neural network to train
    X_train : torch.Tensor
        Training input data
    y_train : Optional[torch.Tensor], optional
        Training target data, by default None
    X_val : Optional[torch.Tensor], optional
        Validation input data, by default None
    y_val : Optional[torch.Tensor], optional
        Validation target data, by default None
    optimizer : str, optional
        Optimizer to use, by default 'adam'
    learning_rate : float, optional
        Learning rate, by default 0.001
    weight_decay : float, optional
        Weight decay (L2 regularization), by default 1e-4
    n_epochs : int, optional
        Number of training epochs, by default 100
    batch_size : int, optional
        Batch size, by default 32
    device : str, optional
        Device to use ('cpu' or 'cuda'), by default 'cpu'
    # Resampling options: regenerate new random samples every N epochs
    resample_every : Optional[int], optional
        Regenerate new random samples every N epochs, by default None
    resample_fn : Optional[Callable[[int], torch.Tensor]], optional
        Function to generate new random samples, by default None
    resample_val : bool, optional
        Whether to resample validation data, by default False
    verbose : bool, optional
        Whether to print training progress, by default True
    tensorboard_logdir : Optional[str], optional
        Directory for TensorBoard logs, by default None
    log_histograms : bool, optional
        Whether to log histograms to TensorBoard, by default False
    save_history : bool, optional
        Whether to save training history, by default False
    save_history_path : Optional[str], optional
        Path to save training history, by default None
    save_model : bool, optional
        Whether to save the trained model, by default False
    save_model_path : Optional[str], optional
        Path to save the trained model, by default None
    early_stopping : bool, optional
        Whether to use early stopping, by default False
    early_stopping_patience : int, optional
        Patience for early stopping, by default 20
    early_stopping_min_delta : float, optional
        Minimum delta for early stopping, by default 1e-6
    early_stopping_monitor : str, optional
        Metric to monitor for early stopping, by default 'val_loss'
        
    Returns
    -------
    Tuple[torch.Tensor, Dict]
        The trained model and training history
    """
    
    network = network.to(device)
    X_train = X_train.to(device)
    
    if X_val is not None:
        X_val = X_val.to(device)

    # Setup a default resampling function if requested but none provided.
    # Default resampler draws uniformly between observed min/max of X_train
    if resample_every is not None and resample_every > 0 and resample_fn is None:
        try:
            # x_min = float(X_train.min().item())
            x_min = -5.0
            # x_max = float(X_train.max().item())
            x_max = 5.0
        except Exception:
            x_min, x_max = -1.0, 1.0

        def _default_resample_fn(epoch, shape=X_train.shape, dtype=X_train.dtype, device_str=device, xmin=x_min, xmax=x_max):
            # create tensor on correct device/dtype
            dev = device_str
            out = torch.empty(shape, dtype=dtype, device=dev).uniform_(xmin, xmax)
            return out

        resample_fn = _default_resample_fn
    
    n_samples = X_train.shape[0]
    n_inputs = network.sys.n_inputs
    
    # Initialize PyTorch optimizer
    if optimizer.lower() == 'adam':
        torch_optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer.lower() == 'sgd':
        torch_optimizer = torch.optim.SGD(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer.lower() == 'rmsprop':
        torch_optimizer = torch.optim.RMSprop(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer '{optimizer}'. Supported: 'adam', 'sgd', 'rmsprop'")
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        torch_optimizer,
        mode='min',
        factor=0.5,
        patience=10,
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [] if X_val is not None else None,
        'epochs': []
    }
    
    # Early stopping setup
    if early_stopping:
        if early_stopping_monitor == 'val_loss' and X_val is None:
            warnings.warn("Early stopping monitor is 'val_loss' but no validation data provided. Switching to 'train_loss'.")
            early_stopping_monitor = 'train_loss'
        
        best_loss = float('inf')
        best_epoch = 0
        patience_counter = 0
        best_model_state = None
        
        if verbose:
            print(f"Early stopping enabled: monitoring '{early_stopping_monitor}' with patience={early_stopping_patience}, min_delta={early_stopping_min_delta}")
    
    # Setup TensorBoard writer if requested
    writer = SummaryWriter(log_dir=tensorboard_logdir) if tensorboard_logdir is not None else None

    # Determine history save path
    if save_history:
        if save_history_path is None:
            if tensorboard_logdir is not None:
                save_history_path = os.path.join(tensorboard_logdir, 'history.json')
            else:
                save_history_path = os.path.join(os.getcwd(), 'training_history.json')


    # Training loop
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        n_batches = 0
        
        # Optionally resample training (and validation) data every `resample_every` epochs
        if resample_every is not None and resample_every > 0 and epoch > 0 and (epoch % resample_every) == 0:
            if resample_fn is None:
                warnings.warn("resample_every set but resample_fn is None; skipping resampling.")
            else:
                try:
                    new_data = resample_fn(epoch)
                    # support returning either X_train or (X_train, X_val)
                    if isinstance(new_data, (list, tuple)) and len(new_data) == 2:
                        new_X_train, new_X_val = new_data
                    else:
                        new_X_train, new_X_val = new_data, None

                    if not torch.is_tensor(new_X_train):
                        new_X_train = torch.as_tensor(new_X_train)
                    X_train = new_X_train.to(device)
                    n_samples = X_train.shape[0]

                    if resample_val and new_X_val is not None:
                        if not torch.is_tensor(new_X_val):
                            new_X_val = torch.as_tensor(new_X_val)
                        X_val = new_X_val.to(device)

                    if verbose:
                        print(f"Resampled training data at epoch {epoch + 1}")
                except Exception as e:
                    warnings.warn(f"Resampling failed at epoch {epoch + 1}: {e}")

        # Create random batches
        indices = torch.randperm(n_samples, device=device)
        
        for start_idx in range(0, n_samples, batch_size):
            end_idx = min(start_idx + batch_size, n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            X_batch = X_train[batch_indices]
            current_batch_size = X_batch.shape[0]
            
            # Zero gradients
            torch_optimizer.zero_grad()
            
            # Forward pass
            output = network(X_batch)
            
            # Apply transformation: T * softmax(output[-n_phases:]) for the deltas
            T_tensor = torch.tensor(network.sys.time_horizon, device=output.device, dtype=output.dtype)

            # Handle batch dimension properly
            n_control_outputs = network.n_phases * n_inputs
            controls = output[:, :n_control_outputs] # shape (batch_size, n_phases * n_inputs)
            delta_raw = output[:, n_control_outputs:]
            
            # To build a diffeomorphism, we fix the value of one delta_raw to zero
            # This keeps the positivity and sum-to-T properties validity
            # make the first delta value identically zero while preserving gradients
            last = delta_raw[:, -1:]  # shape (batch_size, 1)
            delta_raw_traslated = delta_raw - last  # subtract broadcasted last column -> last becomes 0 (differentiable)
            
            # Apply softmax and scale deltas
            delta_normalized = F.softmax(delta_raw_traslated, dim=-1)
            deltas = delta_normalized * T_tensor # shape (batch_size, n_phases)
            
            # Clip controls using tanh-based soft clipping to preserve gradients
            u_min = -1.0  # Define your lower bound
            u_max = 1.0   # Define your upper bound
            u_center = (u_max + u_min) / 2.0
            u_range = (u_max - u_min) / 2.0
            # Soft clipping: maps (-inf, inf) to (u_min, u_max) smoothly
            controls = u_center + u_range * torch.tanh(controls)
            
            transformed_output = torch.cat([controls, deltas], dim=-1) # shape (batch_size, n_phases * (n_inputs + 1))
            
            # Instead of the for loop, I have to give the full batch to the cost function

            # Vectorized batch loss computation
            # reshape controls to (B, n_phases, n_inputs)
            B_batch = current_batch_size
            controls_reshaped = controls.view(B_batch, network.n_phases, n_inputs)
            deltas_batch = deltas.view(B_batch, network.n_phases)
            x0_batch = X_batch

            J_batch = evaluate_cost_functional_batch(network.sys, controls_reshaped, deltas_batch, x0_batch)
            loss = J_batch.mean()
            
            # Backward pass
            # TODO: substitute this with the analytic gradient computation and the matrix from the NN
            loss.backward()
            # Compute gradient norm for logging
            grad_norm = None
            if writer is not None:
                tot = torch.tensor(0.0, device=device)
                for p in network.parameters():
                    if p.grad is not None:
                        tot = tot + p.grad.detach().to(device).pow(2).sum()
                grad_norm = torch.sqrt(tot).item()

            # Optimizer step
            torch_optimizer.step()

            # Log per-batch stats to TensorBoard (optional)
            if writer is not None:
                global_step = epoch * max(1, n_samples // batch_size) + n_batches
                writer.add_scalar('train/batch_loss', loss.item(), global_step)
                if grad_norm is not None:
                    writer.add_scalar('train/batch_grad_norm', grad_norm, global_step)
            
            epoch_loss += loss.item()
            n_batches += 1
        
        # Average loss for the epoch
        avg_train_loss = epoch_loss / n_batches
        history['train_loss'].append(avg_train_loss)
        history['epochs'].append(epoch)
        
        # Validation loss
        if X_val is not None:
            with torch.no_grad():
                val_output = network(X_val)
                
                # Transform validation output
                n_control_outputs = network.n_phases * n_inputs
                val_controls = val_output[:, :n_control_outputs]
                # Clip controls using tanh-based soft clipping to preserve gradients
                u_min = -1.0  # Define your lower bound
                u_max = 1.0   # Define your upper bound
                u_center = (u_max + u_min) / 2.0
                u_range = (u_max - u_min) / 2.0
                # Soft clipping: maps (-inf, inf) to (u_min, u_max) smoothly
                val_controls = u_center + u_range * torch.tanh(val_controls)
                val_delta_raw = val_output[:, n_control_outputs:]
                val_delta_raw_last = val_delta_raw[:, -1:]
                val_delta_raw_traslated = val_delta_raw - val_delta_raw_last
                val_delta_normalized = F.softmax(val_delta_raw_traslated, dim=-1)
                val_deltas = val_delta_normalized * T_tensor
                val_transformed = torch.cat([val_controls, val_deltas], dim=-1)
                
                # Vectorized validation loss
                Bv = X_val.shape[0]
                val_controls = val_controls.view(Bv, network.n_phases, n_inputs)
                val_deltas = val_deltas.view(Bv, network.n_phases)
                J_val = evaluate_cost_functional_batch(network.sys, val_controls, val_deltas, X_val)
                avg_val_loss = J_val.mean().item()
                history['val_loss'].append(avg_val_loss)
        
        # Step the learning rate scheduler
        if X_val is not None:
            scheduler.step(avg_val_loss)
        else:
            scheduler.step(avg_train_loss)

        # Write epoch-level scalars to TensorBoard
        if writer is not None:
            writer.add_scalar('train/epoch_loss', avg_train_loss, epoch)
            writer.add_scalar('train/learning_rate', torch_optimizer.param_groups[0]['lr'], epoch)
            if X_val is not None:
                writer.add_scalar('val/epoch_loss', avg_val_loss, epoch)
            # Optionally log parameter histograms once per epoch
            if log_histograms:
                for name, param in network.named_parameters():
                    writer.add_histogram(f'params/{name}', param.detach().cpu().numpy(), epoch)

        # Save history to disk each epoch if requested
        if save_history:
            try:
                serial = {}
                for k, v in history.items():
                    if v is None:
                        serial[k] = None
                    elif isinstance(v, list):
                        serial[k] = [float(x) for x in v]
                    else:
                        serial[k] = v
                # Ensure directory exists
                os.makedirs(os.path.dirname(save_history_path), exist_ok=True)
                with open(save_history_path, 'w') as fh:
                    json.dump(serial, fh, indent=2)
            except Exception:
                # Don't interrupt training on save failure; warn instead
                warnings.warn(f"Failed to save training history to {save_history_path}")
        
        # Print progress
        if verbose and (epoch + 1) % max(1, n_epochs // 10) == 0:
            if X_val is not None:
                print(f"Epoch {epoch + 1}/{n_epochs} - Train Loss: {avg_train_loss:.6f} - Val Loss: {avg_val_loss:.6f}")
            else:
                print(f"Epoch {epoch + 1}/{n_epochs} - Train Loss: {avg_train_loss:.6f}")
        
        # Early stopping check
        if early_stopping:
            # Determine which loss to monitor
            current_loss = avg_val_loss if early_stopping_monitor == 'val_loss' else avg_train_loss
            
            # Check if there's improvement
            if current_loss < best_loss - early_stopping_min_delta:
                best_loss = current_loss
                best_epoch = epoch
                patience_counter = 0
                # Save best model state
                best_model_state = {k: v.cpu().clone() for k, v in network.state_dict().items()}
                if verbose and epoch > 0:
                    print(f"  → New best {early_stopping_monitor}: {best_loss:.6f}")
            else:
                patience_counter += 1
                if verbose and patience_counter > 0 and (epoch + 1) % max(1, n_epochs // 10) == 0:
                    print(f"  → No improvement for {patience_counter} epoch(s)")
            
            # Check if we should stop
            if patience_counter >= early_stopping_patience:
                if verbose:
                    print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                    print(f"Best {early_stopping_monitor}: {best_loss:.6f} at epoch {best_epoch + 1}")
                
                # Restore best model state
                if best_model_state is not None:
                    network.load_state_dict(best_model_state)
                    if verbose:
                        print("Restored best model weights")
                
                break
    
    # Get final parameters
    params_optimized = network.get_flat_params()
    
    # Optionally save the trained model parameters
    if save_model:
        if save_model_path is None:
            if tensorboard_logdir is not None:
                save_model_path = os.path.join(tensorboard_logdir, 'model_state_dict.pt')
            else:
                save_model_path = os.path.join(os.getcwd(), 'model_state_dict.pt')
        try:
            network.save(save_model_path)
            if verbose:
                print(f"Saved model state_dict to: {save_model_path}")
        except Exception:
            warnings.warn(f"Failed to save model to {save_model_path}")

    # Add early stopping info to history
    if early_stopping:
        history['early_stopping'] = {
            'triggered': patience_counter >= early_stopping_patience,
            'best_epoch': best_epoch,
            'best_loss': best_loss,
            'monitored_metric': early_stopping_monitor,
            'patience': early_stopping_patience,
            'final_epoch': epoch
        }

    # Print final losses
    if verbose:
        print(f"\nFinal Training Loss: {history['train_loss'][-1]:.6f}")
        if X_val is not None and history['val_loss']:
            print(f"Final Validation Loss: {history['val_loss'][-1]:.6f}")
        if early_stopping and history.get('early_stopping', {}).get('triggered', False):
            print(f"\nEarly stopping was triggered:")
            print(f"  Best {early_stopping_monitor}: {best_loss:.6f} at epoch {best_epoch + 1}")
            print(f"  Training stopped at epoch {epoch + 1}")

    return params_optimized, history