In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import os
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import numpy as np
import optax
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Callable, Any
import pandas as pd
from scipy.optimize import minimize

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# KAN Layer implementation for portfolio optimization
class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 30, 
                 domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension
            num_basis: Number of basis functions for learned activation
            domain: Domain over which the activation functions are defined
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values with different shapes
        activations_list = []
        for i in range(output_dim):
            subkey = jax.random.fold_in(key3, i)
            init_type = jax.random.randint(subkey, (), 0, 4)
            
            if init_type == 0:  # Linear-like
                act = self.grid_points
            elif init_type == 1:  # ReLU-like
                act = jnp.maximum(0, self.grid_points)
            elif init_type == 2:  # Sigmoid-like
                act = 1.0 / (1.0 + jnp.exp(-self.grid_points))
            else:  # Tanh-like
                act = jnp.tanh(self.grid_points)
            
            # Add noise to break symmetry
            act = act + jax.random.normal(subkey, (num_basis,)) * 0.05
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions by interpolation
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i using linear interpolation."""
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            x0 = self.grid_points[idx]
            x1 = self.grid_points[idx + 1]
            y0 = self.activations[i, idx]
            y1 = self.activations[i, idx + 1]
            
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z_clipped[:, i]))
        
        return output
    
    

# Portfolio Optimization KAN model
class PortfolioKAN:
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 30, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for portfolio optimization.
        
        Args:
            input_dim: Input dimension (market factors, asset features)
            output_dim: Output dimension (portfolio weights)
            hidden_dims: List of hidden dimensions
            num_basis: Number of basis functions for learned activations
            domain: Domain for activation functions
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize layers
        self.layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.layers.append(layer)
            prev_dim = hidden_dim
        
        # Final layer for portfolio weights
        self.output_layer = KANLayer(prev_dim, output_dim, num_basis, domain, keys[-1])
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN model.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim) representing portfolio weights
        """
        for layer in self.layers:
            x = layer(x)
        
        # Apply output layer
        raw_weights = self.output_layer(x)
        
        # Apply softmax to ensure weights sum to 1 (fully invested)
        portfolio_weights = jax.nn.softmax(raw_weights, axis=-1)
        
        return portfolio_weights
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        for i, layer in enumerate(self.layers):
            params[f'layer_{i}_weights'] = layer.weights
            params[f'layer_{i}_biases'] = layer.biases
            params[f'layer_{i}_activations'] = layer.activations
        
        params['output_layer_weights'] = self.output_layer.weights
        params['output_layer_biases'] = self.output_layer.biases
        params['output_layer_activations'] = self.output_layer.activations
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.layers):
            layer.weights = params[f'layer_{i}_weights']
            layer.biases = params[f'layer_{i}_biases']
            layer.activations = params[f'layer_{i}_activations']
        
        self.output_layer.weights = params['output_layer_weights']
        self.output_layer.biases = params['output_layer_biases']
        self.output_layer.activations = params['output_layer_activations']

# Data generation and utilities for portfolio optimization
def generate_market_data(num_periods=1000, num_assets=10, num_factors=3, seed=42):
    """
    Generate synthetic market data with a factor structure.
    
    Args:
        num_periods: Number of time periods
        num_assets: Number of assets
        num_factors: Number of risk factors
        seed: Random seed
        
    Returns:
        returns: Asset returns
        factors: Factor returns
        factor_exposures: Asset exposures to factors
        feature_data: Asset features
    """
    np.random.seed(seed)
    
    # Generate factor returns
    factor_volatility = np.random.uniform(0.01, 0.05, num_factors)
    factor_returns = np.random.normal(0, factor_volatility, (num_periods, num_factors))
    
    # Generate factor exposures for each asset
    factor_exposures = np.random.normal(0, 1, (num_assets, num_factors))
    
    # Idiosyncratic volatility for each asset
    idiosyncratic_vol = np.random.uniform(0.01, 0.05, num_assets)
    
    # Generate asset returns based on factor model plus idiosyncratic returns
    asset_returns = np.zeros((num_periods, num_assets))
    
    for i in range(num_periods):
        # Systematic returns from factors
        systematic_returns = np.dot(factor_returns[i], factor_exposures.T)
        
        # Idiosyncratic returns
        idiosyncratic_returns = np.random.normal(0, idiosyncratic_vol)
        
        # Total returns
        asset_returns[i] = systematic_returns + idiosyncratic_returns
    
    # Generate asset features (market cap, value, momentum, etc.)
    feature_data = np.zeros((num_periods, num_assets, 5))
    
    # Market cap (size) - relatively stable over time
    size = np.random.lognormal(10, 2, num_assets)
    for i in range(num_periods):
        # Market cap grows/shrinks slowly
        size = size * np.exp(np.random.normal(0, 0.01, num_assets))
        feature_data[i, :, 0] = size
    
    # Value ratio (e.g., book-to-market) - changes slowly
    value_ratio = np.random.normal(1, 0.5, num_assets)
    for i in range(num_periods):
        # Value ratio changes slowly
        if i > 0:
            value_ratio = 0.98 * value_ratio + 0.02 * np.random.normal(1, 0.5, num_assets)
        feature_data[i, :, 1] = value_ratio
    
    # Momentum (trailing 12-period returns)
    for i in range(num_periods):
        if i >= 12:
            momentum = np.prod(1 + asset_returns[i-12:i], axis=0) - 1
        else:
            momentum = np.zeros(num_assets)
        feature_data[i, :, 2] = momentum
    
    # Volatility (trailing 20-period volatility)
    for i in range(num_periods):
        if i >= 20:
            volatility = np.std(asset_returns[i-20:i], axis=0)
        else:
            volatility = idiosyncratic_vol
        feature_data[i, :, 3] = volatility
    
    # Quality (synthetic quality score that correlates with returns)
    quality = np.random.normal(0, 1, num_assets)
    for i in range(num_periods):
        # Quality is somewhat persistent
        if i > 0:
            quality = 0.95 * quality + 0.05 * np.random.normal(0, 1, num_assets)
        # Higher quality tends to correlate with slightly better returns
        feature_data[i, :, 4] = quality
    
    return asset_returns, factor_returns, factor_exposures, feature_data

def prepare_portfolio_data(returns, features, lookback=20, test_split=0.2):
    """
    Prepare data for portfolio optimization.
    
    Args:
        returns: Asset returns (periods x assets)
        features: Asset features (periods x assets x features)
        lookback: Number of periods to include in each input window
        test_split: Fraction of data to use for testing
        
    Returns:
        X_train, y_train, X_test, y_test: Training and testing data
    """
    num_periods, num_assets, num_features = features.shape
    
    # Create input-output pairs
    X = []
    y = []
    
    for t in range(lookback, num_periods-1):
        # Input: historical returns and features
        x_t = []
        
        # 1. Historical returns for each asset
        for a in range(num_assets):
            x_t.extend(returns[t-lookback:t, a])
        
        # 2. Current features for each asset
        for a in range(num_assets):
            x_t.extend(features[t, a])
        
        X.append(x_t)
        
        # Output: next period returns (for computing objective function)
        y.append(returns[t+1])
    
    # Convert to numpy arrays
    X = np.array(X)
    y = np.array(y)
    
    # Split into train and test sets
    train_size = int(len(X) * (1 - test_split))
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]
    
    # Convert to JAX arrays
    X_train = jnp.array(X_train)
    y_train = jnp.array(y_train)
    X_test = jnp.array(X_test)
    y_test = jnp.array(y_test)
    
    return X_train, y_train, X_test, y_test

# Loss functions for portfolio optimization
def sharpe_ratio_loss(portfolio_weights, returns, annualization_factor=252, risk_free_rate=0.0):
    """
    Negative Sharpe ratio loss function.
    
    Args:
        portfolio_weights: Portfolio weights (batch_size, num_assets)
        returns: Asset returns (batch_size, num_assets)
        annualization_factor: Factor to annualize returns
        risk_free_rate: Risk-free rate
        
    Returns:
        negative_sharpe: Negative Sharpe ratio (to minimize)
    """
    # Portfolio returns
    portfolio_returns = jnp.sum(portfolio_weights * returns, axis=1)
    
    # Mean return
    mean_return = jnp.mean(portfolio_returns)
    
    # Portfolio volatility
    portfolio_vol = jnp.std(portfolio_returns)
    
    # Annualize
    mean_return_annual = mean_return * annualization_factor
    portfolio_vol_annual = portfolio_vol * jnp.sqrt(annualization_factor)
    
    # Compute Sharpe ratio
    sharpe_ratio = (mean_return_annual - risk_free_rate) / portfolio_vol_annual
    
    # Return negative sharpe ratio (for minimization)
    return -sharpe_ratio

def mean_variance_loss(portfolio_weights, returns, risk_aversion=1.0):
    """
    Mean-variance loss function.
    
    Args:
        portfolio_weights: Portfolio weights (batch_size, num_assets)
        returns: Asset returns (batch_size, num_assets)
        risk_aversion: Risk aversion parameter
        
    Returns:
        loss: Mean-variance loss
    """
    # Portfolio returns
    portfolio_returns = jnp.sum(portfolio_weights * returns, axis=1)
    
    # Mean return
    mean_return = jnp.mean(portfolio_returns)
    
    # Portfolio variance
    portfolio_var = jnp.var(portfolio_returns)
    
    # Mean-variance objective
    loss = -mean_return + risk_aversion * portfolio_var
    
    return loss

def cvar_loss(portfolio_weights, returns, alpha=0.05):
    """
    Conditional Value at Risk (CVaR) loss function.
    
    Args:
        portfolio_weights: Portfolio weights (batch_size, num_assets)
        returns: Asset returns (batch_size, num_assets)
        alpha: Confidence level (e.g., 0.05 for 95% CVaR)
        
    Returns:
        cvar: Conditional Value at Risk
    """
    # Portfolio returns
    portfolio_returns = jnp.sum(portfolio_weights * returns, axis=1)
    
    # Sort returns (ascending)
    sorted_returns = jnp.sort(portfolio_returns)
    
    # Find cutoff index
    n = len(sorted_returns)
    cutoff_index = int(jnp.ceil(n * alpha))
    
    # Calculate CVaR
    cvar = -jnp.mean(sorted_returns[:cutoff_index])
    
    return cvar

# Combined loss with regularization
def portfolio_loss_fn(params, model_apply_fn, X, returns, grid_points, loss_type='sharpe', 
                     lambda_l2=0.001, lambda_smooth=0.001, lambda_turnover=0.001,
                     previous_weights=None):
    """
    Combined loss function for portfolio optimization.
    
    Args:
        params: Model parameters
        model_apply_fn: Function to apply the model with given parameters
        X: Input features
        returns: Asset returns
        grid_points: Grid points for activation functions
        loss_type: Type of portfolio loss function ('sharpe', 'mean_variance', 'cvar')
        lambda_l2: L2 regularization strength
        lambda_smooth: Activation smoothness regularization strength
        lambda_turnover: Portfolio turnover regularization strength
        previous_weights: Previous portfolio weights for turnover calculation
    """
    portfolio_weights = model_apply_fn(params, X, grid_points)
    
    # Calculate portfolio loss based on type
    if loss_type == 'sharpe':
        portfolio_loss = sharpe_ratio_loss(portfolio_weights, returns)
    elif loss_type == 'mean_variance':
        portfolio_loss = mean_variance_loss(portfolio_weights, returns)
    elif loss_type == 'cvar':
        portfolio_loss = cvar_loss(portfolio_weights, returns)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")
    
    # L2 regularization for weights
    l2_reg = 0.0
    for layer_idx in range(len(params) // 3 - 1):  # All but output layer
        l2_reg += jnp.sum(params[f'layer_{layer_idx}_weights'] ** 2)
    l2_reg += jnp.sum(params['output_layer_weights'] ** 2)
    
    # Smoothness regularization for activations
    smooth_reg = 0.0
    for layer_idx in range(len(params) // 3 - 1):  # All but output layer
        activations = params[f'layer_{layer_idx}_activations']
        # Calculate second derivatives (approximation)
        second_deriv = activations[:, 2:] - 2 * activations[:, 1:-1] + activations[:, :-2]
        smooth_reg += jnp.mean(second_deriv ** 2)
    
    # Add regularization for output layer
    output_activations = params['output_layer_activations']
    second_deriv = output_activations[:, 2:] - 2 * output_activations[:, 1:-1] + output_activations[:, :-2]
    smooth_reg += jnp.mean(second_deriv ** 2)
    
    # Turnover regularization (if previous weights provided)
    turnover_reg = 0.0
    if previous_weights is not None:
        # Calculate average turnover across the batch
        turnover = jnp.mean(jnp.sum(jnp.abs(portfolio_weights - previous_weights), axis=1))
        turnover_reg = turnover
    
    # Combined loss
    total_loss = portfolio_loss + lambda_l2 * l2_reg + lambda_smooth * smooth_reg + lambda_turnover * turnover_reg
    
    return total_loss

# Define model forward function separately for JIT
def model_forward(params, X, grid_points):
    """Forward pass without the model object, just using parameters and inputs."""
    num_layers = (len(params) - 3) // 3  # Subtract output layer params, divide by params per layer
    
    # Extract parameters for all hidden layers
    hidden_params = {}
    for i in range(num_layers):
        hidden_params[f'layer_{i}_weights'] = params[f'layer_{i}_weights']
        hidden_params[f'layer_{i}_biases'] = params[f'layer_{i}_biases']
        hidden_params[f'layer_{i}_activations'] = params[f'layer_{i}_activations']
    
    # Extract output layer parameters
    output_weights = params['output_layer_weights']
    output_biases = params['output_layer_biases']
    output_activations = params['output_layer_activations']
    
    # Forward pass through hidden layers
    z = X
    for i in range(num_layers):
        # Linear transformation
        z = jnp.dot(z, hidden_params[f'layer_{i}_weights']) + hidden_params[f'layer_{i}_biases']
        
        # Apply learned activation functions
        z_clipped = jnp.clip(z, grid_points[0], grid_points[-1])
        
        # Apply activation function for each element using vectorized operations
        z_new = jnp.zeros_like(z)
        for j in range(z.shape[1]):
            # Find indices for interpolation
            idx = jnp.searchsorted(grid_points, z_clipped[:, j]) - 1
            idx = jnp.clip(idx, 0, len(grid_points) - 2)
            
            # Get interpolation points
            x0 = grid_points[idx]
            x1 = grid_points[idx + 1]
            y0 = jnp.take(hidden_params[f'layer_{i}_activations'][j], idx)
            y1 = jnp.take(hidden_params[f'layer_{i}_activations'][j], idx + 1)
            
            # Linear interpolation
            t = (z_clipped[:, j] - x0) / (x1 - x0)
            z_new = z_new.at[:, j].set(y0 + t * (y1 - y0))
        
        z = z_new
    
    # Output layer
    z = jnp.dot(z, output_weights) + output_biases
    z_clipped = jnp.clip(z, grid_points[0], grid_points[-1])
    
    # Apply output activation
    z_new = jnp.zeros_like(z)
    for j in range(z.shape[1]):
        # Find indices for interpolation
        idx = jnp.searchsorted(grid_points, z_clipped[:, j]) - 1
        idx = jnp.clip(idx, 0, len(grid_points) - 2)
        
        # Get interpolation points
        x0 = grid_points[idx]
        x1 = grid_points[idx + 1]
        y0 = jnp.take(output_activations[j], idx)
        y1 = jnp.take(output_activations[j], idx + 1)
        
        # Linear interpolation
        t = (z_clipped[:, j] - x0) / (x1 - x0)
        z_new = z_new.at[:, j].set(y0 + t * (y1 - y0))
    
    # Apply softmax to ensure weights sum to 1
    portfolio_weights = jax.nn.softmax(z_new, axis=-1)
    
    return portfolio_weights

def min_var_objective(weights, cov_matrix):
    """
    Objective function for minimum variance portfolio.
    
    Args:
        weights: Portfolio weights
        cov_matrix: Covariance matrix of returns
        
    Returns:
        Portfolio variance
    """
    weights = np.array(weights)
    portfolio_var = np.dot(weights.T, np.dot(cov_matrix, weights))
    return portfolio_var

def weight_sum_constraint(weights):
    """
    Constraint function to ensure weights sum to 1.
    
    Args:
        weights: Portfolio weights
        
    Returns:
        Constraint value (0 when sum of weights equals 1)
    """
    return np.sum(weights) - 1.0

def neg_sharpe_ratio(weights, returns, cov_matrix, risk_free_rate=0.0):
    """
    Negative Sharpe ratio for optimization.
    
    Args:
        weights: Portfolio weights
        returns: Asset returns
        cov_matrix: Covariance matrix
        risk_free_rate: Risk-free rate
        
    Returns:
        Negative Sharpe ratio
    """
    weights = np.array(weights)
    portfolio_return = np.mean(returns.dot(weights))
    portfolio_vol = np.sqrt(np.dot(weights.T, np.dot(cov_matrix, weights)))
    return -(portfolio_return - risk_free_rate) / portfolio_vol

def risk_parity_objective(weights, cov_matrix):
    """
    Objective function for risk parity portfolio.
    
    Args:
        weights: Portfolio weights
        cov_matrix: Covariance matrix
        
    Returns:
        Sum of squared differences from equal risk contribution
    """
    weights = np.array(weights)
    num_assets = len(weights)
    portfolio_vol = np.sqrt(np.dot(weights.T, np.dot(cov_matrix, weights)))
    
    # Calculate risk contribution of each asset
    marginal_risk = np.dot(cov_matrix, weights)
    risk_contribution = weights * marginal_risk / portfolio_vol
    
    # Target equal risk contribution
    target_risk = portfolio_vol / num_assets
    risk_diff = risk_contribution - target_risk
    
    # Sum of squared differences
    return np.sum(risk_diff ** 2)

def calculate_portfolio_metrics(returns, risk_free_rate=0.0):
    """
    Calculate performance metrics for a portfolio.
    
    Args:
        returns: Portfolio returns
        risk_free_rate: Risk-free rate
        
    Returns:
        Dictionary of performance metrics
    """
    annualization_factor = 252  # Assuming daily returns
    
    # Mean return
    mean_return = np.mean(returns)
    annual_return = mean_return * annualization_factor
    
    # Volatility
    volatility = np.std(returns)
    annual_volatility = volatility * np.sqrt(annualization_factor)
    
    # Sharpe ratio
    sharpe_ratio = (annual_return - risk_free_rate) / annual_volatility
    
    # Maximum drawdown
    cumulative_returns = np.cumprod(1 + returns)
    peak = np.maximum.accumulate(cumulative_returns)
    drawdown = (peak - cumulative_returns) / peak
    max_drawdown = np.max(drawdown)
    
    # Value at Risk (VaR) at 95% confidence
    var_95 = -np.percentile(returns, 5)
    
    # Conditional Value at Risk (CVaR) at 95% confidence
    cvar_mask = returns <= -var_95
    if np.any(cvar_mask):
        cvar_95 = -np.mean(returns[cvar_mask])
    else:
        cvar_95 = var_95  # Fallback if no returns below VaR
    
    return {
        'annual_return': annual_return,
        'annual_volatility': annual_volatility,
        'sharpe_ratio': sharpe_ratio,
        'max_drawdown': max_drawdown,
        'var_95': var_95,
        'cvar_95': cvar_95
    }

def compare_with_traditional_methods(asset_returns, features, test_start_idx, kan_weights, kan_metrics):
    """
    Compare KAN portfolio with traditional portfolio optimization methods.
    
    Args:
        asset_returns: Asset returns
        features: Asset features
        test_start_idx: Starting index of test data
        kan_weights: KAN portfolio weights (from model)
        kan_metrics: KAN performance metrics
        
    Returns:
        Comparison figures and data
    """
    # Extract test returns - ensure it matches the length of kan_weights
    test_returns = asset_returns[test_start_idx:test_start_idx + len(kan_weights)]
    
    # Make sure kan_weights is a numpy array (not JAX array)
    kan_weights_np = np.array(kan_weights)
    
    # 1. Equal Weight
    num_assets = asset_returns.shape[1]
    equal_weight = np.ones(num_assets) / num_assets
    equal_weight_returns = np.sum(equal_weight * test_returns, axis=1)
    
    # 2. Minimum Variance Portfolio
    # Estimate covariance matrix using training data
    train_returns = asset_returns[:test_start_idx]
    cov_matrix = np.cov(train_returns.T)
    
    # Setup optimization constraints and bounds
    constraints = ({'type': 'eq', 'fun': weight_sum_constraint})
    bounds = tuple((0, 1) for _ in range(num_assets))
    
    # Initial weights (equal weight)
    initial_weights = np.ones(num_assets) / num_assets
    
    # Optimize minimum variance portfolio
    min_var_result = minimize(
        min_var_objective,
        initial_weights, 
        args=(cov_matrix,),
        method='SLSQP',
        bounds=bounds, 
        constraints=constraints
    )
    
    min_var_weights = min_var_result['x']
    min_var_returns = np.sum(min_var_weights * test_returns, axis=1)
    
    # 3. Maximum Sharpe Ratio Portfolio
    max_sharpe_result = minimize(
        neg_sharpe_ratio, 
        initial_weights, 
        args=(train_returns, cov_matrix),
        method='SLSQP', 
        bounds=bounds, 
        constraints=constraints
    )
    
    max_sharpe_weights = max_sharpe_result['x']
    max_sharpe_returns = np.sum(max_sharpe_weights * test_returns, axis=1)
    
    # 4. Risk Parity Portfolio
    risk_parity_result = minimize(
        risk_parity_objective, 
        initial_weights, 
        args=(cov_matrix,),
        method='SLSQP', 
        bounds=bounds, 
        constraints=constraints
    )
    
    risk_parity_weights = risk_parity_result['x']
    risk_parity_returns = np.sum(risk_parity_weights * test_returns, axis=1)
    
    # Calculate performance metrics for all portfolios
    equal_weight_metrics = calculate_portfolio_metrics(equal_weight_returns)
    min_var_metrics = calculate_portfolio_metrics(min_var_returns)
    max_sharpe_metrics = calculate_portfolio_metrics(max_sharpe_returns)
    risk_parity_metrics = calculate_portfolio_metrics(risk_parity_returns)
    
    # Calculate KAN portfolio returns
    kan_portfolio_returns = np.sum(kan_weights_np * test_returns, axis=1)
    
    # Create comparison table
    methods = ['Equal Weight', 'Min Variance', 'Max Sharpe', 'Risk Parity', 'KAN Portfolio']
    metrics_list = [equal_weight_metrics, min_var_metrics, max_sharpe_metrics, risk_parity_metrics, kan_metrics]
    
    # Extract metrics for comparison
    metric_names = ['annual_return', 'annual_volatility', 'sharpe_ratio', 'max_drawdown', 'var_95', 'cvar_95']
    comparison_data = {}
    
    for metric in metric_names:
        comparison_data[metric] = [m[metric] for m in metrics_list]
    
    # Create a figure for comparison
    fig, axs = plt.subplots(2, 2, figsize=(14, 10))
    
    # Format percentages
    percentage_metrics = ['annual_return', 'annual_volatility', 'max_drawdown', 'var_95', 'cvar_95']
    for metric in percentage_metrics:
        comparison_data[metric] = [m * 100 for m in comparison_data[metric]]
    
    # Plot annual returns
    axs[0, 0].bar(methods, comparison_data['annual_return'])
    axs[0, 0].set_title('Annual Return (%)')
    axs[0, 0].set_ylabel('Return (%)')
    axs[0, 0].grid(True, alpha=0.3)
    axs[0, 0].tick_params(axis='x', rotation=45)
    
    # Plot annual volatility
    axs[0, 1].bar(methods, comparison_data['annual_volatility'])
    axs[0, 1].set_title('Annual Volatility (%)')
    axs[0, 1].set_ylabel('Volatility (%)')
    axs[0, 1].grid(True, alpha=0.3)
    axs[0, 1].tick_params(axis='x', rotation=45)
    
    # Plot Sharpe ratio
    axs[1, 0].bar(methods, comparison_data['sharpe_ratio'])
    axs[1, 0].set_title('Sharpe Ratio')
    axs[1, 0].set_ylabel('Sharpe Ratio')
    axs[1, 0].grid(True, alpha=0.3)
    axs[1, 0].tick_params(axis='x', rotation=45)
    
    # Plot max drawdown
    axs[1, 1].bar(methods, comparison_data['max_drawdown'])
    axs[1, 1].set_title('Maximum Drawdown (%)')
    axs[1, 1].set_ylabel('Drawdown (%)')
    axs[1, 1].grid(True, alpha=0.3)
    axs[1, 1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    
    # Create a second figure for cumulative returns comparison
    fig2, ax2 = plt.subplots(figsize=(12, 6))
    
    # Calculate cumulative returns
    equal_weight_cum = np.cumprod(1 + equal_weight_returns) - 1
    min_var_cum = np.cumprod(1 + min_var_returns) - 1
    max_sharpe_cum = np.cumprod(1 + max_sharpe_returns) - 1
    risk_parity_cum = np.cumprod(1 + risk_parity_returns) - 1
    kan_cum = np.cumprod(1 + kan_portfolio_returns) - 1
    
    # Plot cumulative returns
    ax2.plot(equal_weight_cum, label='Equal Weight')
    ax2.plot(min_var_cum, label='Min Variance')
    ax2.plot(max_sharpe_cum, label='Max Sharpe')
    ax2.plot(risk_parity_cum, label='Risk Parity')
    ax2.plot(kan_cum, label='KAN Portfolio')
    
    ax2.set_title('Cumulative Returns Comparison')
    ax2.set_xlabel('Time')
    ax2.set_ylabel('Cumulative Return')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    return fig, fig2, comparison_data

# Training functions with fixed JIT compilation issues
def train_portfolio_model(model, X_train, y_train, num_epochs=100, batch_size=64,
                         loss_type='sharpe', lambda_l2=0.001, lambda_smooth=0.001,
                         lambda_turnover=0.001, learning_rate=0.001):
    """Train the portfolio optimization model."""
    # Get model parameters and grid points
    params = model.params
    grid_points = model.layers[0].grid_points  # Assuming all layers use the same grid points
    
    # JIT-compile the loss function and gradient computation
    loss_fn = lambda p, x, y, prev_w: portfolio_loss_fn(
        p, model_forward, x, y, grid_points, loss_type, 
        lambda_l2, lambda_smooth, lambda_turnover, prev_w
    )
    
    value_and_grad_fn = jax.value_and_grad(loss_fn)
    
    # Initialize optimizer
    optimizer = optax.adam(learning_rate=learning_rate)
    opt_state = optimizer.init(params)
    
    num_samples = X_train.shape[0]
    num_batches = num_samples // batch_size
    
    losses = []
    previous_batch_weights = None
    
    for epoch in range(num_epochs):
        # Shuffle data
        perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        epoch_loss = 0.0
        
        for batch in range(num_batches):
            start_idx = batch * batch_size
            end_idx = start_idx + batch_size
            
            X_batch = X_shuffled[start_idx:end_idx]
            y_batch = y_shuffled[start_idx:end_idx]
            
            # Compute loss and gradients
            loss_value, grads = value_and_grad_fn(params, X_batch, y_batch, previous_batch_weights)
            
            # Update parameters
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            
            # Update previous weights for next batch (if using turnover regularization)
            if lambda_turnover > 0:
                previous_batch_weights = model_forward(params, X_batch, grid_points)
            
            epoch_loss += loss_value
        
        epoch_loss /= num_batches
        losses.append(epoch_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
    
    # Update model with trained parameters
    model.update_params(params)
    
    return params, losses

# Portfolio evaluation functions
def evaluate_portfolio(model, params, X_test, y_test, risk_free_rate=0.0):
    """
    Evaluate portfolio performance on test data.
    
    Args:
        model: Trained PortfolioKAN model
        params: Model parameters
        X_test: Test features
        y_test: Test returns
        risk_free_rate: Risk-free rate
        
    Returns:
        metrics: Performance metrics
    """
    model.update_params(params)
    portfolio_weights = model(X_test)
    
    # Convert to numpy for evaluation
    weights_np = np.array(portfolio_weights)
    returns_np = np.array(y_test)
    
    # Calculate portfolio returns
    portfolio_returns = np.sum(weights_np * returns_np, axis=1)
    
    # Calculate performance metrics
    annualization_factor = 252  # Assuming daily returns
    
    # Mean return
    mean_return = np.mean(portfolio_returns)
    annual_return = mean_return * annualization_factor
    
    # Volatility
    volatility = np.std(portfolio_returns)
    annual_volatility = volatility * np.sqrt(annualization_factor)
    
    # Sharpe ratio
    sharpe_ratio = (annual_return - risk_free_rate) / annual_volatility
    
    # Maximum drawdown
    cumulative_returns = np.cumprod(1 + portfolio_returns)
    peak = np.maximum.accumulate(cumulative_returns)
    drawdown = (peak - cumulative_returns) / peak
    max_drawdown = np.max(drawdown)
    
    # Value at Risk (VaR) at 95% confidence
    var_95 = -np.percentile(portfolio_returns, 5)
    
    # Conditional Value at Risk (CVaR) at 95% confidence
    cvar_95 = -np.mean(portfolio_returns[portfolio_returns <= -var_95])
    
    # Turnover
    turnover = np.mean(np.sum(np.abs(weights_np[1:] - weights_np[:-1]), axis=1))
    
    # Information Ratio vs equal weight benchmark
    equal_weight = np.ones_like(weights_np) / weights_np.shape[1]
    benchmark_returns = np.sum(equal_weight * returns_np, axis=1)
    active_returns = portfolio_returns - benchmark_returns
    information_ratio = np.mean(active_returns) / np.std(active_returns) * np.sqrt(annualization_factor)
    
    # Create metrics dictionary
    metrics = {
        'annual_return': annual_return,
        'annual_volatility': annual_volatility,
        'sharpe_ratio': sharpe_ratio,
        'max_drawdown': max_drawdown,
        'var_95': var_95,
        'cvar_95': cvar_95,
        'turnover': turnover,
        'information_ratio': information_ratio
    }
    
    return metrics, portfolio_returns, portfolio_weights

def visualize_portfolio_performance(portfolio_returns, equal_weight_returns, metrics):
    """
    Visualize portfolio performance compared to a benchmark.
    
    Args:
        portfolio_returns: Portfolio returns
        equal_weight_returns: Equal weight benchmark returns
        metrics: Performance metrics
    """
    # Calculate cumulative returns
    portfolio_cum_returns = np.cumprod(1 + portfolio_returns) - 1
    benchmark_cum_returns = np.cumprod(1 + equal_weight_returns) - 1
    
    # Create a figure
    fig, axs = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot cumulative returns
    axs[0].plot(portfolio_cum_returns, 'b-', label='KAN Portfolio')
    axs[0].plot(benchmark_cum_returns, 'r--', label='Equal Weight')
    axs[0].set_title('Cumulative Returns')
    axs[0].set_xlabel('Time')
    axs[0].set_ylabel('Cumulative Return')
    axs[0].legend()
    axs[0].grid(True, alpha=0.3)
    
    # Add performance metrics as a table
    metrics_text = f"""
    Annual Return: {metrics['annual_return']:.2%}
    Annual Volatility: {metrics['annual_volatility']:.2%}
    Sharpe Ratio: {metrics['sharpe_ratio']:.2f}
    Maximum Drawdown: {metrics['max_drawdown']:.2%}
    95% VaR: {metrics['var_95']:.2%}
    95% CVaR: {metrics['cvar_95']:.2%}
    Turnover: {metrics['turnover']:.2f}
    Information Ratio: {metrics['information_ratio']:.2f}
    """
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    axs[0].text(0.05, 0.05, metrics_text, transform=axs[0].transAxes, fontsize=10,
               verticalalignment='bottom', bbox=props)
    
    # Plot drawdowns
    portfolio_peak = np.maximum.accumulate(np.insert(portfolio_cum_returns + 1, 0, 1))
    portfolio_drawdown = (portfolio_peak[1:] - (portfolio_cum_returns + 1)) / portfolio_peak[1:]
    
    benchmark_peak = np.maximum.accumulate(np.insert(benchmark_cum_returns + 1, 0, 1))
    benchmark_drawdown = (benchmark_peak[1:] - (benchmark_cum_returns + 1)) / benchmark_peak[1:]
    
    axs[1].plot(portfolio_drawdown, 'b-', label='KAN Portfolio')
    axs[1].plot(benchmark_drawdown, 'r--', label='Equal Weight')
    axs[1].set_title('Drawdowns')
    axs[1].set_xlabel('Time')
    axs[1].set_ylabel('Drawdown')
    axs[1].legend()
    axs[1].grid(True, alpha=0.3)
    axs[1].invert_yaxis()  # Invert y-axis to show drawdowns as negative
    
    plt.tight_layout()
    return fig
