** Rough Draft of Neuromancer UTIntegrator and sigma points / inverse sigma points Nodes**

## Data

Refer to `simulators.py` and `datasets.py` to generate time-series datasets of moments and create sigma points dataset. These are for PyTorch, however. Trying to find a way to handle these in Neuromancer. 

To be able to showcase the SPINODE method on existing Neuromancer examples below is a code to run PSL systems, generate Monte Carlo simulations, and then add moments initial condition and rollout to the DictDataset. 

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader

# ---- Euler-Murayama step for generic ODE_Autonomous drift: x_{k+1} = x_k + f(x_k,t_k)*dt + sigma*sqrt(dt)*eta ----
def em_rollout_system(sys, T, dt, sigma=0.0, x0=None, seed=None):
    """
    sys: an instance of ODE_Autonomous (from neuromancer.psl.base)
    T:   number of integration steps
    dt:  time step
    sigma: scalar noise std (additive, isotropic); 0 => deterministic
    x0:  initial state (1D array, length nx). If None, uses sys default x0.
    returns: ndarray [T+1, nx]
    """
    if seed is not None:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random.default_rng()

    # use system's default initial condition if not provided
    if x0 is None:
        # sys.params defines variables/constants/parameters; variables['x0'] is default init
        x0 = np.array(sys.x0, dtype=float)
    else:
        x0 = np.array(x0, dtype=float)

    nx = x0.size
    X = np.zeros((T+1, nx), dtype=float)
    X[0] = x0

    # EM time loop
    t = 0.0
    for k in range(T):
        # drift f(x,t) via the system's equations
        f_list = sys.equations(t, X[k])        # returns list of derivatives
        f = np.array(f_list, dtype=float)       # [nx]
        # additive noise
        noise = sigma * np.sqrt(dt) * rng.standard_normal(size=nx)
        # EM update
        X[k+1] = X[k] + f*dt + noise
        t += dt
    return X

class DictDataset4D:
    """
    Minimal dict-style dataset compatible with your DataLoader pattern.
    Stores tensors with shape:
      X:  [B, T, nx, E]
      x0: [B, 1, nx, E]  (optional convenience key)
    """
    def __init__(self, arrays: dict, name='dataset'):
        self.arrays = {k: torch.as_tensor(v).float() for k, v in arrays.items()}
        self.name = name
        self.N = self.arrays['X'].shape[0]  # B

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.arrays.items()}

    # simple collate: stack dict entries along batch
    @staticmethod
    def collate_fn(batch_list):
        out = {}
        keys = batch_list[0].keys()
        for k in keys:
            out[k] = torch.stack([b[k] for b in batch_list], dim=0)
        return out
def get_mc_data(system_class,
                nsteps_total: int,
                nsteps_per_batch: int,
                dt: float,
                batch_size: int,
                nx: int = None,
                experiments: int = 64,
                sigma: float = 0.05,
                x0_sampler=None,
                seed: int = 0):
    """
    Build train/dev/test with a 4-D array: [B, T, nx, E]
      - system_class: one of your systems (e.g., VanDerPol)
      - nsteps_total: total rollout length to generate per split
      - nsteps_per_batch: T for each batch item (segment length)
      - dt: time step
      - batch_size: DataLoader batch size
      - experiments: E (number of MC particles per segment)
      - sigma: additive noise std for EM
      - x0_sampler: function (E)-> array [E, nx] for initial states; default = system default, jittered
      - seed: RNG seed

    Returns:
      train_loader, dev_loader, test_data (dict with X: [1, nsteps_total, nx, E], x0)
    """
    rng = np.random.default_rng(seed)

    # instantiate the system
    sys = system_class()
    if nx is None:
        nx = len(np.array(sys.x0, dtype=float))

    # default x0 sampler: jitter around system's default
    if x0_sampler is None:
        def x0_sampler(E):
            base = np.array(sys.x0, dtype=float)
            jitter = 0.05 * rng.standard_normal(size=(E, nx))
            return base[None, :] + jitter

    # --- helper: ensemble rollout to [T+1, nx, E] ---
    def ensemble_rollout(T, E):
        X_all = np.zeros((T+1, nx, E), dtype=float)
        X0s = x0_sampler(E)                              # [E, nx]
        for e in range(E):
            Xi = em_rollout_system(sys, T=T, dt=dt, sigma=sigma, x0=X0s[e], seed=rng.integers(1<<30))
            X_all[:, :, e] = Xi
        return X0s, X_all

    # --- make three splits ---
    def make_split(nsteps_total, name='split'):
        X0s, X_all = ensemble_rollout(T=nsteps_total, E=experiments)  # X_all: [T+1, nx, E]
        # chop into batches of length nsteps_per_batch
        T_full = nsteps_total
        nbatch = T_full // nsteps_per_batch
        T_used = nbatch * nsteps_per_batch

        # reshape: [B, T, nx, E]
        X_used = X_all[:T_used+1]  # keep one extra for convenience if needed
        # we’ll use non-overlapping windows [k:k+T]
        X_batches = []
        x0_batches = []
        for b in range(nbatch):
            start = b * nsteps_per_batch
            stop  = start + nsteps_per_batch
            Xb = X_used[start:stop]             # [T, nx, E]
            x0b = X_used[start:start+1]         # [1, nx, E]
            X_batches.append(np.transpose(Xb, (0,1,2)))     # keep as [T, nx, E]
            x0_batches.append(np.transpose(x0b, (0,1,2)))   # [1, nx, E]

        X_arr  = np.stack(X_batches,  axis=0)   # [B, T, nx, E]
        x0_arr = np.stack(x0_batches, axis=0)   # [B, 1, nx, E]

        data = DictDataset4D({'X': X_arr, 'x0': x0_arr}, name=name)
        loader = DataLoader(data, batch_size=batch_size, shuffle=True,
                            collate_fn=DictDataset4D.collate_fn)
        return data

    # Train / Dev / Test (independent draws by changing seed streams)
    train_data = make_split(nsteps_total, name='train')
    dev_data = make_split(nsteps_total, name='dev')
    

    return train_data, dev_data



In [None]:
# Example: build a 4-D MC dataset for VanDerPol
from types import SimpleNamespace

from neuromancer import psl
system = psl.systems['VanDerPol']

# Assume your VanDerPol class above is imported
train_data, dev_data = get_mc_data(
    system_class=system,
    nsteps_total=2000,      # total steps per split
    nsteps_per_batch=200,   # window length T
    dt=0.01,
    batch_size=16,
    experiments=64,         # E particles
    sigma=0.05,             # additive noise
    seed=42
)

# Peek at one batch
batch = next(iter(train_loader))
for k, v in batch.items():
    print(k, v.shape)       # X: [B, T, nx, E], x0: [B, 1, nx, E]


## Neuromancer Codes

Please see `create_modular_ut_system`. Data flow is 
```
mu, var, u → sigma_points, W → sigma_points_next → mu_next, var_next
```




In [None]:


import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Optional, Tuple
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass

import neuromancer as nm
from neuromancer.system import Node, System
from neuromancer.modules import blocks
from neuromancer.dataset import DictDataset
from neuromancer.constraint import variable
from neuromancer.loss import PenaltyLoss
from neuromancer.problem import Problem
from neuromancer.trainer import Trainer
from neuromancer.dynamics.integrators import DiffEqIntegrator

# ==============================================================================
# UTILITY FUNCTIONS
# ==============================================================================

def cholesky_psd(A: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    I = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    return torch.linalg.cholesky(A + eps * I)

def ut_weights_torch(n: int, alpha=1e-3, beta=2.0, kappa=0.0,
                     device="cpu", dtype=torch.float32) -> Tuple[float, torch.Tensor, torch.Tensor]:
    lam = alpha**2 * (n + kappa) - n
    Wm = torch.zeros(2*n+1, device=device, dtype=dtype)
    Wc = torch.zeros(2*n+1, device=device, dtype=dtype)
    Wm[0] = lam/(n+lam)
    Wc[0] = lam/(n+lam) + (1 - alpha**2 + beta)
    Wm[1:] = Wc[1:] = 1.0/(2*(n+lam))
    return lam, Wm, Wc

# ==============================================================================
# UTINTEGRATOR (YOUR CODE)
# ==============================================================================



class UTIntegrator(DiffEqIntegrator):
    """Extended DiffEqIntegrator for Unscented Transform (UT) integration."""
    
    def __init__(self,
                 block_rhs: nn.Module,
                 dt: float = 0.01,
                 method: str = 'dopri5',
                 rtol: float = 1e-5,
                 atol: float = 1e-7,
                 sigma_val: float = 0.1,
                 include_diffusion: bool = True,
                 nx: int = 1,
                 nu: int = 1,
                 interp_u=None):
        
 
        super().__init__(block=block_rhs, interp_u=interp_u, h=dt, method=method, rtol=rtol, atol=atol)
        
        self.sigma_val = sigma_val
        self.include_diffusion = include_diffusion
        self.nx = nx
        self.nu = nu
        
    def integrate(self, sigma_points: torch.Tensor, *args) -> torch.Tensor:
        """Integrate sigma points through dynamics."""
        B, S, D = sigma_points.shape
        device, dtype = sigma_points.device, sigma_points.dtype
        
        ptr = 0
        x = sigma_points[..., ptr:ptr + self.nx]  # [B, S, nx]
        ptr += self.nx
        u = sigma_points[..., ptr:ptr + self.nu] if self.nu > 0 else None  # [B, S, nu] or None
        ptr += self.nu
        w = sigma_points[..., ptr:ptr + self.nx] if self.include_diffusion else None  # [B, S, nx] or None
        
        # Flatten for batch integration
        x_flat = x.reshape(-1, self.nx)  # [B*S, nx]
        u_flat = u.reshape(-1, self.nu) if u is not None else None  # [B*S, nu] or None
        
        # Use base integrate method
        x_next_flat = super().integrate(x_flat, u_flat) if u_flat is not None else super().integrate(x_flat)
        x_next = x_next_flat.view(B, S, self.nx)  # [B, S, nx]
        
        if self.include_diffusion:
            sigma_val = torch.as_tensor(self.sigma_val, device=device, dtype=dtype)
            if sigma_val.ndim == 0:
                sigma_val = sigma_val.repeat(self.nx)
            g2 = 0.5 * (sigma_val ** 2)  # [nx]
            sqrt_term = torch.sqrt(2.0 * g2 * self.h)  # [nx]
            diffusion_term = sqrt_term[None, None, :] * w  # [B, S, nx]
            x_next = x_next + diffusion_term
        
        return x_next  # [B, S, nx]

# ==============================================================================
# SIGMA POINT NODES
# ==============================================================================

class SigmaPointGeneratorNode(nn.Module):
    """Generate sigma points from mean and variance."""
    
    def __init__(self, 
                 include_diffusion: bool = True,
                 alpha: float = 1e-3,
                 beta: float = 2.0, 
                 kappa: float = 0.0):
        
        super().__init__(nn.Identity(), input_keys, output_keys, name='sigma_generator')
        self.include_diffusion = include_diffusion
        self.alpha = alpha
        self.beta = beta
        self.kappa = kappa
        
    def aug_sigma_1d_xw(self, mu: torch.Tensor, var: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate augmented sigma points for 1D state + 1D noise."""
        device, dtype = mu.device, mu.dtype
        nx = nw = 1; n = nx + nw
        lam, Wm, _ = ut_weights_torch(n, self.alpha, self.beta, self.kappa, device, dtype)
        
        Saug = torch.zeros(n, n, device=device, dtype=dtype)
        Saug[0, 0] = var  # state variance
        Saug[1, 1] = 1.0  # unit noise variance
        
        L = cholesky_psd((n + lam) * Saug)
        S = torch.zeros(2*n+1, 2, device=device, dtype=dtype)
        
        S[0, 0] = mu; S[0, 1] = 0.0
        for i in range(n):
            S[1+i,    0] = mu + L[0, i]; S[1+i,    1] =  L[1, i]
            S[1+n+i,  0] = mu - L[0, i]; S[1+n+i,  1] = -L[1, i]
            
        return S, Wm
    
    def det_sigma_copy(self, mu_vec: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate deterministic sigma points."""
        nx = mu_vec.shape[-1]
        _, Wm, _ = ut_weights_torch(nx, self.alpha, self.beta, self.kappa, 
                                  mu_vec.device, mu_vec.dtype)
        S = mu_vec.unsqueeze(-2).repeat([1] * (mu_vec.dim()-1) + [2*nx+1, 1])
        return S, Wm
        
    def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        mu = data[self.input_keys[0]]   # [B, nx]
        var = data[self.input_keys[1]]  # [B, nx]
        u = data[self.input_keys[2]]    # [B, nu]
        
        batch_size = mu.shape[0]
        
        if mu.shape[-1] == 1:  # 1D case
            sigma_points_list = []
            W_list = []
            
            for b in range(batch_size):
                if self.include_diffusion:
                    S, Wm = self.aug_sigma_1d_xw(mu[b, 0], var[b, 0])  # [S, 2]
                    x_sigma = S[:, 0:1]  # [S, 1]
                    w_sigma = S[:, 1:2]  # [S, 1]
                    u_col = u[b].unsqueeze(0).repeat(S.shape[0], 1)  # [S, 1]
                    sigma_pts = torch.cat([x_sigma, u_col, w_sigma], dim=-1)  # [S, 3]
                else:
                    S, Wm = self.det_sigma_copy(mu[b:b+1])  # [1, S, 1]
                    S = S.squeeze(0)  # [S, 1]
                    u_col = u[b].unsqueeze(0).repeat(S.shape[0], 1)  # [S, 1]
                    sigma_pts = torch.cat([S, u_col], dim=-1)  # [S, 2]
                
                sigma_points_list.append(sigma_pts)
                W_list.append(Wm)
            
            sigma_points = torch.stack(sigma_points_list, dim=0)  # [B, S, D]
            W = torch.stack(W_list, dim=0)  # [B, S]
            
        else:  # Multi-dimensional deterministic
            sigma_points_list = []
            W_list = []
            
            for b in range(batch_size):
                S, Wm = self.det_sigma_copy(mu[b:b+1])  # [1, S, nx]
                S = S.squeeze(0)  # [S, nx]
                u_col = u[b].unsqueeze(0).repeat(S.shape[0], 1)  # [S, nu]
                sigma_pts = torch.cat([S, u_col], dim=-1)  # [S, nx+nu]
                
                sigma_points_list.append(sigma_pts)
                W_list.append(Wm)
            
            sigma_points = torch.stack(sigma_points_list, dim=0)  # [B, S, nx+nu]
            W = torch.stack(W_list, dim=0)  # [B, S]
        
        return sigma_points, W
        



class InverseSigmaPointNode(nn.Module):
    """Compute mean and variance from weighted sigma points."""
    
    def __init__(self,
                 compute_variance: bool = False):
        
        if compute_variance:
            output_keys = output_keys + ['var_next']
            
        super().__init__(nn.Identity(), input_keys, output_keys, name='inverse_sigma')
        self.compute_variance = compute_variance
        
    def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        sigma_points_next = data[self.input_keys[0]]  # [B, S, nx]
        W = data[self.input_keys[1]]                  # [B, S]
        
        # Compute weighted mean
        W_expanded = W.unsqueeze(-1)  # [B, S, 1]
        mu_next = (W_expanded * sigma_points_next).sum(dim=1)  # [B, nx]
        
        result = {self.output_keys[0]: mu_next}
        
        if self.compute_variance:
            # Compute weighted variance
            diff = sigma_points_next - mu_next.unsqueeze(1)  # [B, S, nx]
            var_next = (W_expanded * diff**2).sum(dim=1)     # [B, nx]
            result[self.output_keys[1]] = var_next
            
        return result

# ==============================================================================
# SYSTEM CREATION
# ==============================================================================

def create_modular_ut_system(
    dynamics_net: nn.Module,
    nx: int = 1,
    nu: int = 1,
    dt: float = 0.01,
    sigma_val: float = 0.1,
    include_diffusion: bool = True,
    compute_variance: bool = False,
    method: str = 'dopri5',
    rtol: float = 1e-5,
    atol: float = 1e-7
) -> System:
    """
    Create complete modular UT system.
    
    Args:
        dynamics_net: Neural network f(x, u) -> dx
        nx: State dimension
        nu: Control dimension  
        dt: Integration time step
        sigma_val: Diffusion strength
        include_diffusion: Whether to include stochastic diffusion
        compute_variance: Whether to compute output variance
        method: ODE integration method
        rtol: Relative tolerance
        atol: Absolute tolerance
        
    Returns:
        Complete System: mu, var, u -> mu_next, [var_next]
    """

    rhs_fx = blocks.MLP(nx+nu, nx, ...)
    
    # Node 1: Generate sigma points
    sigma_gen_func = SigmaPointGeneratorNode(
        include_diffusion=include_diffusion,
       
    )

    sigma_gen_node = Node( sigma_gen_func, 
                            input_keys=['mu', 'var', 'u'],
                            output_keys=['sigma_points', 'W']
                        )
    
    # Node 2: Apply dynamics with UTIntegrator
    ut_integrator = UTIntegrator(
        dynamics_net=rhs_fx,
        dt=dt,
        method=method,
        rtol=rtol,
        atol=atol,
        sigma_val=sigma_val,
        include_diffusion=include_diffusion,
        nx=nx,
        nu=nu
    )

    integrator_node = Node(ut_integrator, input_keys=['sigma_points', 'W], output_keys=['sigma_points_next'])
    

    
    # Node 3: Reconstruct moments
    inverse_sigma = InverseSigmaPointNode(
        compute_variance=compute_variance,
        input_keys=['sigma_points_next', 'W'],
        output_keys=['mu_next'] + (['var_next'] if compute_variance else [])
    )

    inverse_sigma_node = Node(InverseSigmaPointNode, 
                        input_keys=['sigma_points_next', 'W'],
                        output_keys=['mu', 'var'])
    
    # Assemble system
    nodes = [sigma_gen_node, integrator_node, inverse_sigma_node]
    system = System(nodes, name='modular_ut_system')
    
    return system

# ==============================================================================
# CONFIGURATION
# ==============================================================================

@dataclass
class UTSystemConfig:
    """Configuration for UT system."""
    # System parameters
    nx: int = 1              # State dimension
    nu: int = 1              # Control dimension
    dt: float = 0.01         # Time step
    sigma_val: float = 0.1   # Diffusion strength
    include_diffusion: bool = True
    compute_variance: bool = False
    
    # Integration parameters
    method: str = 'dopri5'   # ODE method
    rtol: float = 1e-5       # Relative tolerance
    atol: float = 1e-7       # Absolute tolerance
    
    # Network parameters
    hidden_dim: int = 64     # Hidden layer size
    n_layers: int = 2        # Number of hidden layers
    
    # Training parameters
    batch_size: int = 128
    epochs: int = 20
    lr: float = 1e-3
    train_frac: float = 0.7
    val_frac: float = 0.15
    device: str = "cpu"
