In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import numpy as np
import optax
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Callable, Any
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# Define the KAN Layer for risk factor modeling
# Fix for the KANLayer class - changing the conditional activation initialization
class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 30, 
                 domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension
            num_basis: Number of basis functions for learned activation
            domain: Domain over which the activation functions are defined
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values with different shapes
        # For risk factor modeling, we want a variety of shapes to capture different risk relationships
        activations_list = []
        for i in range(output_dim):
            subkey = jax.random.fold_in(key3, i)
            # Use vmap and switch instead of if-else
            init_type = jax.random.randint(subkey, (), 0, 4)
            
            # Define all possible activations
            linear_act = self.grid_points
            relu_act = jnp.maximum(0, self.grid_points)
            sigmoid_act = 1.0 / (1.0 + jnp.exp(-self.grid_points))
            tanh_act = jnp.tanh(self.grid_points)
            
            # Use select_n (JAX's functional switch) to choose activation
            act = jnp.select(
                jnp.array([init_type == 0, init_type == 1, init_type == 2, init_type == 3]),
                jnp.array([linear_act, relu_act, sigmoid_act, tanh_act]),
                self.grid_points  # default value
            )
            
            # Add noise to break symmetry (still using the same subkey)
            noise = jax.random.normal(subkey, (num_basis,)) * 0.05
            act = act + noise
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions by interpolation
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        # Fix the apply_activation function as well to avoid control flow issues
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i using linear interpolation."""
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            x0 = jnp.take(self.grid_points, idx)
            x1 = jnp.take(self.grid_points, idx + 1)
            y0 = jnp.take(self.activations[i], idx)
            y1 = jnp.take(self.activations[i], idx + 1)
            
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z[:, i]))
        
        return output

# Full KAN model for risk factor decomposition
class RiskFactorKAN:
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 30, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for risk factor decomposition.
        
        Args:
            input_dim: Input dimension (market factors)
            output_dim: Output dimension (portfolio or asset returns)
            hidden_dims: List of hidden dimensions
            num_basis: Number of basis functions for learned activations
            domain: Domain for activation functions
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize layers
        self.layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.layers.append(layer)
            prev_dim = hidden_dim
        
        # Final layer for risk factor outputs
        self.output_layer = KANLayer(prev_dim, output_dim, num_basis, domain, keys[-1])
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN model.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim) representing asset returns
        """
        for layer in self.layers:
            x = layer(x)
        
        # Apply output layer
        return self.output_layer(x)
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        for i, layer in enumerate(self.layers):
            params[f'layer_{i}_weights'] = layer.weights
            params[f'layer_{i}_biases'] = layer.biases
            params[f'layer_{i}_activations'] = layer.activations
        
        params['output_layer_weights'] = self.output_layer.weights
        params['output_layer_biases'] = self.output_layer.biases
        params['output_layer_activations'] = self.output_layer.activations
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.layers):
            layer.weights = params[f'layer_{i}_weights']
            layer.biases = params[f'layer_{i}_biases']
            layer.activations = params[f'layer_{i}_activations']
        
        self.output_layer.weights = params['output_layer_weights']
        self.output_layer.biases = params['output_layer_biases']
        self.output_layer.activations = params['output_layer_activations']

# Generate synthetic market data for training
def generate_risk_factor_data(num_samples=20000, num_risk_factors=5, num_assets=20, seed=42):
    """
    Generate synthetic market data for risk factor decomposition.
    
    Args:
        num_samples: Number of time periods (e.g., days)
        num_risk_factors: Number of underlying risk factors
        num_assets: Number of assets in the portfolio
        seed: Random seed
        
    Returns:
        factor_returns: Risk factor returns
        asset_returns: Asset returns
        factor_exposures: True factor exposures (betas)
    """
    np.random.seed(seed)
    
    # Generate risk factor returns (e.g., market, size, value, momentum, volatility)
    factor_returns = np.random.normal(0, 1, (num_samples, num_risk_factors))
    
    # Add some autocorrelation to factors (typical in financial time series)
    for i in range(1, num_samples):
        factor_returns[i] = 0.2 * factor_returns[i-1] + 0.8 * factor_returns[i]
    
    # Introduce correlations between factors
    correlation_matrix = np.eye(num_risk_factors)
    # Add some off-diagonal correlations
    for i in range(num_risk_factors-1):
        correlation_matrix[i, i+1] = 0.3
        correlation_matrix[i+1, i] = 0.3
    
    # Apply Cholesky decomposition for correlated factors
    cholesky = np.linalg.cholesky(correlation_matrix)
    factor_returns = np.dot(factor_returns, cholesky)
    
    # Generate factor exposures (betas) for each asset
    factor_exposures = np.random.normal(0, 1, (num_assets, num_risk_factors))
    
    # Make some factors more influential
    factor_exposures[:, 0] *= 1.5  # Market factor typically has larger influence
    
    # Add sector-like clustering
    num_sectors = 5
    sector_size = num_assets // num_sectors
    for s in range(num_sectors):
        start_idx = s * sector_size
        end_idx = start_idx + sector_size
        # Assets in the same sector have similar exposures to certain factors
        sector_factor = np.random.randint(1, num_risk_factors)
        sector_exposure = np.random.normal(0, 1)
        factor_exposures[start_idx:end_idx, sector_factor] = sector_exposure + np.random.normal(0, 0.3, end_idx-start_idx)
    
    # Generate asset returns based on factor model plus idiosyncratic returns
    # R_i = sum_j(beta_ij * F_j) + epsilon_i
    asset_returns = np.zeros((num_samples, num_assets))
    
    # Linear factor model component
    for i in range(num_assets):
        for j in range(num_risk_factors):
            asset_returns[:, i] += factor_exposures[i, j] * factor_returns[:, j]
    
    # Add non-linear factor effects (to make the problem more challenging and realistic)
    for i in range(num_assets):
        for j in range(num_risk_factors):
            # Add some quadratic effects
            if j % 2 == 0:  # For even-indexed factors
                asset_returns[:, i] += 0.1 * factor_exposures[i, j] * factor_returns[:, j]**2
            # Add some threshold effects
            else:  # For odd-indexed factors
                threshold_effect = 0.1 * factor_exposures[i, j] * np.maximum(factor_returns[:, j] - 1.0, 0)
                asset_returns[:, i] += threshold_effect
    
    # Add idiosyncratic (asset-specific) returns
    idiosyncratic_vol = np.random.uniform(0.5, 1.5, num_assets)
    idiosyncratic_returns = np.random.normal(0, 1, (num_samples, num_assets))
    for i in range(num_assets):
        asset_returns[:, i] += idiosyncratic_vol[i] * idiosyncratic_returns[:, i]
    
    # Normalize returns to a reasonable scale
    asset_returns = asset_returns * 0.01  # 1% daily standard deviation
    factor_returns = factor_returns * 0.01
    
    # Convert to JAX arrays
    factor_returns_jax = jnp.array(factor_returns)
    asset_returns_jax = jnp.array(asset_returns)
    factor_exposures_jax = jnp.array(factor_exposures)
    
    return factor_returns_jax, asset_returns_jax, factor_exposures_jax

# Training functions
@jit
def loss_fn(params, X, Y):
    """Mean squared error loss function for risk factor prediction."""
    model = RiskFactorKAN(X.shape[1], Y.shape[1])
    model.update_params(params)
    pred = model(X)
    return jnp.mean((pred - Y) ** 2)

@jit
def train_step(params, X, Y, opt_state):
    """Single optimization step."""
    loss_value, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

def train_model(X, Y, params, num_epochs=100, batch_size=64):
    """Train the model for a specified number of epochs."""
    num_samples = X.shape[0]
    num_batches = num_samples // batch_size
    
    losses = []
    
    # Initialize optimizer state
    opt_state = optimizer.init(params)
    
    for epoch in range(num_epochs):
        # Shuffle data
        perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        X_shuffled = X[perm]
        Y_shuffled = Y[perm]
        
        epoch_loss = 0.0
        
        for batch in range(num_batches):
            start_idx = batch * batch_size
            end_idx = start_idx + batch_size
            
            X_batch = X_shuffled[start_idx:end_idx]
            Y_batch = Y_shuffled[start_idx:end_idx]
            
            params, opt_state, batch_loss = train_step(params, X_batch, Y_batch, opt_state)
            epoch_loss += batch_loss
        
        epoch_loss /= num_batches
        losses.append(epoch_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
    
    return params, losses

# Risk attribution and decomposition functions
def compute_risk_attribution(model, params, factor_returns, asset_returns):
    """
    Compute risk attribution using the trained KAN model.
    
    Args:
        model: Trained RiskFactorKAN model
        params: Model parameters
        factor_returns: Risk factor returns
        asset_returns: Asset returns
        
    Returns:
        risk_attribution: Attribution of risk to each factor for each asset
    """
    model.update_params(params)
    
    # Number of assets and factors
    num_assets = asset_returns.shape[1]
    num_factors = factor_returns.shape[1]
    
    # Create a function that maps factors to asset returns for a single asset
    def asset_return_fn(factors, asset_idx):
        """Function that maps factor returns to a single asset's return."""
        # Expand factors to batch size 1
        factors_batch = factors.reshape(1, -1)
        # Get all asset predictions and select the asset_idx
        return model(factors_batch)[0, asset_idx]
    
    # Compute Jacobian for each asset with respect to factors
    jacobians = []
    for asset_idx in range(num_assets):
        # Create a function specific to this asset
        asset_fn = lambda factors: asset_return_fn(factors, asset_idx)
        # Compute Jacobian for all samples
        asset_jacobian = jax.vmap(jax.grad(asset_fn))(factor_returns)
        jacobians.append(asset_jacobian)
    
    # Stack jacobians: shape (num_assets, num_samples, num_factors)
    jacobians = jnp.stack(jacobians)
    
    # Compute average sensitivity over all time periods
    avg_sensitivity = jnp.mean(jacobians, axis=1)  # Shape: (num_assets, num_factors)
    
    # Compute factor contributions to asset returns
    # For each asset and time period, multiply factor returns by sensitivities
    factor_contributions = jnp.zeros((asset_returns.shape[0], num_assets, num_factors))
    
    for t in range(asset_returns.shape[0]):
        for a in range(num_assets):
            factor_contributions = factor_contributions.at[t, a].set(
                jacobians[a, t] * factor_returns[t]
            )
    
    # Compute variance of factor contributions for risk attribution
    factor_contrib_var = jnp.var(factor_contributions, axis=0)  # Shape: (num_assets, num_factors)
    
    # Compute total variance for each asset
    total_var = jnp.var(asset_returns, axis=0)  # Shape: (num_assets,)
    
    # Compute percentage risk attribution
    risk_attribution = factor_contrib_var / total_var.reshape(-1, 1)
    
    return risk_attribution, avg_sensitivity, factor_contributions

# Visualization and analysis functions
def visualize_risk_decomposition(risk_attribution, true_exposures, factor_names=None, asset_names=None):
    """
    Visualize risk decomposition and compare with true factor exposures.
    
    Args:
        risk_attribution: Computed risk attribution from KAN model
        true_exposures: True factor exposures used to generate data
        factor_names: Names of risk factors
        asset_names: Names of assets
    """
    # Convert to numpy for plotting
    risk_attribution_np = np.array(risk_attribution)
    true_exposures_np = np.array(true_exposures)
    
    # Square true exposures to compare with variance-based attribution
    true_exposures_squared = true_exposures_np**2
    # Normalize to sum to 1 for percentage comparison
    true_exposures_norm = true_exposures_squared / true_exposures_squared.sum(axis=1, keepdims=True)
    
    # Provide default names if not provided
    if factor_names is None:
        factor_names = [f"Factor {i+1}" for i in range(risk_attribution_np.shape[1])]
    
    if asset_names is None:
        asset_names = [f"Asset {i+1}" for i in range(risk_attribution_np.shape[0])]
    
    # Number of assets and factors to plot
    num_assets_to_plot = min(6, risk_attribution_np.shape[0])
    num_factors = risk_attribution_np.shape[1]
    
    # Create subplots
    fig, axs = plt.subplots(num_assets_to_plot, 2, figsize=(15, 3 * num_assets_to_plot))
    
    # Colors for factors
    colors = plt.cm.tab10(np.linspace(0, 1, num_factors))
    
    for i in range(num_assets_to_plot):
        # Plot KAN-derived risk attribution
        axs[i, 0].bar(range(num_factors), risk_attribution_np[i], color=colors)
        axs[i, 0].set_title(f"{asset_names[i]}: KAN Risk Attribution")
        axs[i, 0].set_ylabel("Risk Contribution (%)")
        axs[i, 0].set_xticks(range(num_factors))
        axs[i, 0].set_xticklabels(factor_names)
        axs[i, 0].set_ylim(0, max(1.0, np.max(risk_attribution_np[i]) * 1.1))
        
        # Plot true factor exposures (squared and normalized)
        axs[i, 1].bar(range(num_factors), true_exposures_norm[i], color=colors)
        axs[i, 1].set_title(f"{asset_names[i]}: True Risk Attribution")
        axs[i, 1].set_ylabel("Risk Contribution (%)")
        axs[i, 1].set_xticks(range(num_factors))
        axs[i, 1].set_xticklabels(factor_names)
        axs[i, 1].set_ylim(0, max(1.0, np.max(true_exposures_norm[i]) * 1.1))
    
    plt.tight_layout()
    return fig

def analyze_factor_activations(model, factor_returns):
    """
    Analyze the learned activation functions for risk factors.
    
    Args:
        model: Trained RiskFactorKAN model
        factor_returns: Risk factor returns
    """
    # Get the grid points and activations from the first layer
    grid_points = model.layers[0].grid_points
    activations = model.layers[0].activations
    
    # Number of factors and units to visualize
    num_factors = factor_returns.shape[1]
    num_units = min(4, activations.shape[0])
    
    # Create a figure
    fig, axs = plt.subplots(num_units, 1, figsize=(10, 2.5 * num_units))
    
    # Plot learned activations
    for i in range(num_units):
        axs[i].plot(grid_points, activations[i])
        axs[i].set_title(f"Learned Activation for Hidden Unit {i+1}")
        axs[i].set_xlabel("Input")
        axs[i].set_ylabel("Activation")
        axs[i].grid(True)
    
    plt.tight_layout()
    return fig

def factor_correlation_analysis(model, params, factor_returns, asset_returns):
    """
    Analyze how KAN-derived factors correlate with true factors.
    
    Args:
        model: Trained RiskFactorKAN model
        params: Model parameters
        factor_returns: True factor returns
        asset_returns: Asset returns
    """
    # Update model with parameters
    model.update_params(params)
    
    # Define a function to extract intermediate activations
    # We'll use the first hidden layer as "KAN-derived factors"
    def get_first_layer_activations(inputs):
        return model.layers[0](inputs)
    
    # Get KAN-derived factors
    kan_factors = get_first_layer_activations(factor_returns)
    
    # Convert to numpy
    kan_factors_np = np.array(kan_factors)
    factor_returns_np = np.array(factor_returns)
    
    # Calculate correlation between KAN factors and true factors
    correlations = np.zeros((kan_factors_np.shape[1], factor_returns_np.shape[1]))
    
    for i in range(kan_factors_np.shape[1]):
        for j in range(factor_returns_np.shape[1]):
            correlations[i, j] = np.corrcoef(kan_factors_np[:, i], factor_returns_np[:, j])[0, 1]
    
    # Visualize correlations
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(correlations, cmap='coolwarm', vmin=-1, vmax=1)
    
    # Add colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Correlation", rotation=-90, va="bottom")
    
    # Set labels
    ax.set_xlabel("True Factors")
    ax.set_ylabel("KAN-derived Factors")
    
    # Add tick labels
    ax.set_xticks(np.arange(factor_returns_np.shape[1]))
    ax.set_yticks(np.arange(kan_factors_np.shape[1]))
    ax.set_xticklabels([f"Factor {i+1}" for i in range(factor_returns_np.shape[1])])
    ax.set_yticklabels([f"KAN Factor {i+1}" for i in range(kan_factors_np.shape[1])])
    
    plt.title("Correlation between KAN-derived Factors and True Factors")
    plt.tight_layout()
    
    return fig, correlations

def stress_test_analysis(model, params, factor_returns, asset_returns, stress_magnitude=3.0):
    """
    Perform stress test analysis on the portfolio using the KAN model.
    
    Args:
        model: Trained RiskFactorKAN model
        params: Model parameters
        factor_returns: Risk factor returns
        asset_returns: Asset returns
        stress_magnitude: Magnitude of stress scenarios in standard deviations
    """
    # Update model with parameters
    model.update_params(params)
    
    # Number of factors
    num_factors = factor_returns.shape[1]
    
    # Calculate factor means and standard deviations
    factor_means = jnp.mean(factor_returns, axis=0)
    factor_stds = jnp.std(factor_returns, axis=0)
    
    # Create stress scenarios: one for each factor
    # Each scenario stresses one factor by stress_magnitude standard deviations
    stress_scenarios = []
    for i in range(num_factors):
        scenario = factor_means.copy()
        scenario = scenario.at[i].set(factor_means[i] + stress_magnitude * factor_stds[i])
        stress_scenarios.append(scenario)
    
    # Convert to batch for model prediction
    stress_scenarios = jnp.stack(stress_scenarios)
    
    # Predict asset returns under stress scenarios
    stress_returns = model(stress_scenarios)
    
    # Calculate portfolio returns (assuming equal weighting for simplicity)
    portfolio_weights = jnp.ones(asset_returns.shape[1]) / asset_returns.shape[1]
    portfolio_stress_returns = jnp.dot(stress_returns, portfolio_weights)
    
    # Calculate normal portfolio return
    normal_return = jnp.dot(model(factor_means.reshape(1, -1))[0], portfolio_weights)
    
    # Calculate percentage changes
    percentage_changes = (portfolio_stress_returns - normal_return) / jnp.abs(normal_return) * 100
    
    # Create a figure
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot stress test results
    bars = ax.bar(range(num_factors), percentage_changes)
    
    # Color bars based on positive/negative changes
    for i, bar in enumerate(bars):
        if percentage_changes[i] < 0:
            bar.set_color('red')
        else:
            bar.set_color('green')
    
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax.set_xlabel('Stressed Factor')
    ax.set_ylabel('Portfolio Return Change (%)')
    ax.set_title(f'Portfolio Stress Test (Factor Shocks: {stress_magnitude} std)')
    ax.set_xticks(range(num_factors))
    ax.set_xticklabels([f"Factor {i+1}" for i in range(num_factors)])
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    return fig, percentage_changes

# Main function to run everything
def main():
    # Generate synthetic data
    print("Generating synthetic risk factor data...")
    factor_returns, asset_returns, true_exposures = generate_risk_factor_data(
        num_samples=2000, num_risk_factors=5, num_assets=20)
    
    # Split into train/test sets
    train_size = int(0.8 * factor_returns.shape[0])
    factor_returns_train = factor_returns[:train_size]
    factor_returns_test = factor_returns[train_size:]
    asset_returns_train = asset_returns[:train_size]
    asset_returns_test = asset_returns[train_size:]
    
    # Initialize model
    print("Initializing KAN model for risk factor decomposition...")
    input_dim = factor_returns.shape[1]  # Number of risk factors
    output_dim = asset_returns.shape[1]  # Number of assets
    
    model = RiskFactorKAN(input_dim, output_dim, hidden_dims=[32, 16])
    
    # Initialize optimizer
    global optimizer  # Make accessible to JIT-compiled functions
    learning_rate = 0.001
    optimizer = optax.adam(learning_rate)
    
    # Train model
    print("Training model...")
    trained_params, losses = train_model(factor_returns_train, asset_returns_train, model.params, num_epochs=100)
    
    # Evaluate model on test set
    model.update_params(trained_params)
    test_predictions = model(factor_returns_test)
    test_mse = jnp.mean((test_predictions - asset_returns_test) ** 2)
    print(f"Test MSE: {test_mse:.6f}")
    
    # Compute risk attribution
    print("Computing risk attribution...")
    risk_attribution, sensitivity, factor_contributions = compute_risk_attribution(
        model, trained_params, factor_returns, asset_returns)
    
    # Define factor names for visualization
    factor_names = ["Market", "Size", "Value", "Momentum", "Volatility"][:factor_returns.shape[1]]
    
    # Visualize risk decomposition
    print("Visualizing risk decomposition...")
    fig_risk = visualize_risk_decomposition(risk_attribution, true_exposures, factor_names)
    
    # Analyze factor activations
    print("Analyzing factor activations...")
    model.update_params(trained_params)
    fig_activations = analyze_factor_activations(model, factor_returns)
    
    # Factor correlation analysis
    print("Analyzing factor correlations...")
    fig_corr, correlations = factor_correlation_analysis(model, trained_params, factor_returns, asset_returns)
    
    # Stress test analysis
    print("Performing stress test analysis...")
    fig_stress, stress_results = stress_test_analysis(model, trained_params, factor_returns, asset_returns)
    
    # Return results
    return model, trained_params, losses, risk_attribution, fig_risk, fig_activations, fig_corr, fig_stress

if __name__ == "__main__":
    main()