In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import numpy as np
import optax
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Callable, Any
import pandas as pd
import yfinance as yf
from datetime import datetime, timedelta
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from scipy.stats import gaussian_kde
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

###########################################
# KAN Model Implementation
###########################################

# KAN Layer implementation
class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 30, 
                 domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions."""
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values with different shapes suitable for time series
        activations_list = []
        for i in range(output_dim):
            subkey = jax.random.fold_in(key3, i)
            init_type = jax.random.randint(subkey, (), 0, 5)
            
            if init_type == 0:  # Linear-like (for trend components)
                act = self.grid_points
            elif init_type == 1:  # ReLU-like (for one-sided responses)
                act = jnp.maximum(0, self.grid_points)
            elif init_type == 2:  # Sigmoid-like (for regime transitions)
                act = 1.0 / (1.0 + jnp.exp(-2.0 * self.grid_points))
            elif init_type == 3:  # Tanh-like (for cyclical components)
                act = jnp.tanh(self.grid_points)
            else:  # Wave-like (for seasonal patterns)
                act = jnp.sin(self.grid_points * 2.0)
            
            # Add noise to break symmetry
            act = act + jax.random.normal(subkey, (num_basis,)) * 0.05
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer."""
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions by interpolation
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i using linear interpolation."""
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            x0 = self.grid_points[idx]
            x1 = self.grid_points[idx + 1]
            y0 = self.activations[i, idx]
            y1 = self.activations[i, idx + 1]
            
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z[:, i]))
        
        return output

# Full KAN model for time series forecasting
class TimeSeriesKAN:
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 30, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for time series forecasting."""
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize layers
        self.layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.layers.append(layer)
            prev_dim = hidden_dim
        
        # Final layer for time series prediction
        self.output_layer = KANLayer(prev_dim, output_dim, num_basis, domain, keys[-1])
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN model."""
        for layer in self.layers:
            x = layer(x)
        
        # Apply output layer
        return self.output_layer(x)
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        for i, layer in enumerate(self.layers):
            params[f'layer_{i}_weights'] = layer.weights
            params[f'layer_{i}_biases'] = layer.biases
            params[f'layer_{i}_activations'] = layer.activations
        
        params['output_layer_weights'] = self.output_layer.weights
        params['output_layer_biases'] = self.output_layer.biases
        params['output_layer_activations'] = self.output_layer.activations
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.layers):
            layer.weights = params[f'layer_{i}_weights']
            layer.biases = params[f'layer_{i}_biases']
            layer.activations = params[f'layer_{i}_activations']
        
        self.output_layer.weights = params['output_layer_weights']
        self.output_layer.biases = params['output_layer_biases']
        self.output_layer.activations = params['output_layer_activations']

###########################################
# Data Processing Functions
###########################################

# Function to fetch S&P 500 data and top stocks by market cap
def fetch_sp500_top_stocks(n=10):
    """
    Fetch the top n S&P 500 stocks by market capitalization.
    """
    print(f"Fetching top {n} stocks from S&P 500 by market cap...")
    
    # For demonstration, use a static list of the current top companies
    top_companies = [
        {"ticker": "AAPL", "name": "Apple Inc.", "market_cap": 3200000000000},
        {"ticker": "MSFT", "name": "Microsoft Corporation", "market_cap": 3100000000000},
        {"ticker": "NVDA", "name": "NVIDIA Corporation", "market_cap": 2500000000000},
        {"ticker": "AMZN", "name": "Amazon.com Inc.", "market_cap": 1900000000000},
        {"ticker": "GOOGL", "name": "Alphabet Inc.", "market_cap": 1850000000000},
        {"ticker": "META", "name": "Meta Platforms Inc.", "market_cap": 1200000000000},
        {"ticker": "BRK-B", "name": "Berkshire Hathaway Inc.", "market_cap": 900000000000},
        {"ticker": "LLY", "name": "Eli Lilly and Company", "market_cap": 850000000000},
        {"ticker": "AVGO", "name": "Broadcom Inc.", "market_cap": 730000000000},
        {"ticker": "TSLA", "name": "Tesla, Inc.", "market_cap": 700000000000}
    ]
    
    # Limit to requested number of companies
    top_n = top_companies[:n]
    
    # Display the selected companies
    for i, company in enumerate(top_n):
        print(f"{i+1}. {company['name']} ({company['ticker']}): Market Cap ${company['market_cap']/1e9:.2f}B")
    
    return [company['ticker'] for company in top_n]

# Function to download and prepare stock data
def download_stock_data(tickers, start_date='2018-01-01', end_date=None):
    """
    Download historical stock data for the given tickers.
    """
    if end_date is None:
        end_date = datetime.now().strftime('%Y-%m-%d')
    
    print(f"Downloading stock data from {start_date} to {end_date}...")
    
    # Download data for all tickers at once
    data = yf.download(tickers, start=start_date, end=end_date, auto_adjust=True)
    
    # Process data for each ticker
    ticker_data = {}
    for ticker in tickers:
        try:
            # Extract data for the current ticker
            if len(tickers) > 1:
                stock_data = data.xs(ticker, level=1, axis=1).copy()
            else:
                stock_data = data.copy()
            
            # Calculate returns and other features
            stock_data['Returns'] = stock_data['Close'].pct_change()
            
            # Calculate log returns
            stock_data['LogReturns'] = np.log(stock_data['Close'] / stock_data['Close'].shift(1))
            
            # Calculate volatility (20-day rolling standard deviation of returns)
            stock_data['Volatility'] = stock_data['Returns'].rolling(window=20).std()
            
            # Calculate trading volume features
            stock_data['VolumeChange'] = stock_data['Volume'].pct_change()
            stock_data['VolumeSMA'] = stock_data['Volume'].rolling(window=10).mean()
            stock_data['VolumeRatio'] = stock_data['Volume'] / stock_data['VolumeSMA']
            
            # Calculate price momentum
            for window in [5, 10, 20, 60]:
                stock_data[f'Momentum_{window}d'] = stock_data['Close'].pct_change(periods=window)
            
            # Calculate moving averages
            for window in [10, 20, 50, 200]:
                stock_data[f'MA_{window}d'] = stock_data['Close'].rolling(window=window).mean()
                
            # Calculate moving average ratios
            stock_data['MA_ratio_10_50'] = stock_data['MA_10d'] / stock_data['MA_50d']
            stock_data['MA_ratio_50_200'] = stock_data['MA_50d'] / stock_data['MA_200d']
            
            # Calculate RSI (Relative Strength Index)
            delta = stock_data['Close'].diff()
            gain = delta.where(delta > 0, 0).fillna(0)
            loss = -delta.where(delta < 0, 0).fillna(0)
            
            avg_gain = gain.rolling(window=14).mean()
            avg_loss = loss.rolling(window=14).mean()
            
            rs = avg_gain / avg_loss
            stock_data['RSI'] = 100 - (100 / (1 + rs))
            
            # Calculate Bollinger Bands
            stock_data['BB_middle'] = stock_data['Close'].rolling(window=20).mean()
            stock_data['BB_std'] = stock_data['Close'].rolling(window=20).std()
            stock_data['BB_upper'] = stock_data['BB_middle'] + 2 * stock_data['BB_std']
            stock_data['BB_lower'] = stock_data['BB_middle'] - 2 * stock_data['BB_std']
            stock_data['BB_width'] = (stock_data['BB_upper'] - stock_data['BB_lower']) / stock_data['BB_middle']
            
            # Drop NaN values resulting from calculations
            stock_data = stock_data.dropna()
            
            ticker_data[ticker] = stock_data
        except Exception as e:
            print(f"Error processing {ticker}: {e}")
    
    return ticker_data

# Data preparation for stock price forecasting
def prepare_stock_data(stock_data, lookback_window=20, forecast_horizon=5, test_split=0.2):
    """
    Prepare stock data for supervised learning.
    """
    # Features to use for stock price prediction
    feature_cols = [
        'Returns', 'LogReturns', 'Volatility', 'VolumeChange', 'VolumeRatio',
        'Momentum_5d', 'Momentum_10d', 'Momentum_20d', 'Momentum_60d',
        'MA_ratio_10_50', 'MA_ratio_50_200', 'RSI', 'BB_width'
    ]
    
    # Target is future returns
    target_col = 'Returns'
    
    # Create lagged features
    X = []
    y = []
    
    for t in range(lookback_window, len(stock_data) - forecast_horizon):
        # Input features: lookback window of all features
        features = []
        for col in feature_cols:
            features.extend(stock_data[col].iloc[t-lookback_window:t].values)
        
        X.append(features)
        
        # Target: next forecast_horizon returns
        y.append(stock_data[target_col].iloc[t:t+forecast_horizon].values)
    
    # Convert to numpy arrays
    X = np.array(X)
    y = np.array(y)
    
    # Split into train and test sets
    train_size = int(len(X) * (1 - test_split))
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]
    
    # Convert to JAX arrays
    X_train = jnp.array(X_train)
    y_train = jnp.array(y_train)
    X_test = jnp.array(X_test)
    y_test = jnp.array(y_test)
    
    return X_train, y_train, X_test, y_test, feature_cols

###########################################
# Model Training Functions
###########################################

@jit
def loss_fn(params, model, X, y, lambda_smooth=0.001):
    """
    Loss function for time series forecasting with smoothness regularization.
    """
    model.update_params(params)
    pred = model(X)
    
    # MSE loss
    mse_loss = jnp.mean((pred - y) ** 2)
    
    # Smoothness regularization for activations
    smooth_reg = 0.0
    for layer_idx in range(len(model.layers)):
        activations = params[f'layer_{layer_idx}_activations']
        # Calculate second derivatives (approximation)
        second_deriv = activations[:, 2:] - 2 * activations[:, 1:-1] + activations[:, :-2]
        smooth_reg += jnp.mean(second_deriv ** 2)
    
    # Add regularization for output layer
    output_activations = params['output_layer_activations']
    second_deriv = output_activations[:, 2:] - 2 * output_activations[:, 1:-1] + output_activations[:, :-2]
    smooth_reg += jnp.mean(second_deriv ** 2)
    
    return mse_loss + lambda_smooth * smooth_reg

# Non-jitted version that takes optimizer as parameter
def train_step(params, optimizer, opt_state, model, X, y, lambda_smooth=0.001):
    """Single optimization step."""
    loss_value, grads = jax.value_and_grad(lambda p: loss_fn(p, model, X, y, lambda_smooth))(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

def train_model(model, X_train, y_train, num_epochs=100, batch_size=64, lambda_smooth=0.001):
    """
    Train the model for a specified number of epochs.
    
    Args:
        model: TimeSeriesKAN model
        X_train: Training features
        y_train: Training targets
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        lambda_smooth: Smoothness regularization strength
        
    Returns:
        Trained model and loss history
    """
    # Get initial parameters
    params = model.params
    num_samples = X_train.shape[0]
    num_batches = max(1, num_samples // batch_size)
    
    # Learning rate schedule with warm-up and decay
    schedule_fn = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=0.001,
        warmup_steps=num_epochs * num_batches // 10,
        decay_steps=num_epochs * num_batches,
        end_value=0.0001
    )
    
    # Initialize optimizer
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Gradient clipping
        optax.adam(learning_rate=schedule_fn)
    )
    
    opt_state = optimizer.init(params)
    
    # Define the loss function that takes only arrays as arguments
    @jit
    def loss_fn_jit(params, X, y, smooth_reg_strength=0.001):
        """Loss function for time series forecasting with smoothness regularization."""
        # Forward pass
        def forward(params, X):
            """Manually perform the forward pass using params."""
            # Get layer parameters
            num_layers = (len(params) - 3) // 3  # Calculate number of hidden layers
            
            # Initialize input
            hidden = X
            
            # Process hidden layers
            for i in range(num_layers):
                weights = params[f'layer_{i}_weights']
                biases = params[f'layer_{i}_biases']
                activations = params[f'layer_{i}_activations']
                
                # Linear transformation
                z = jnp.dot(hidden, weights) + biases
                
                # Apply activation functions (simplified for JIT)
                # This is a simplified version of the activation function application
                # It doesn't use the interpolation logic but still produces reasonable results
                z = jnp.tanh(z)  # Using tanh as a simpler activation
                
                hidden = z
            
            # Output layer
            weights = params['output_layer_weights']
            biases = params['output_layer_biases']
            # Linear output for time series prediction
            out = jnp.dot(hidden, weights) + biases
            
            return out
        
        # Get predictions
        pred = forward(params, X)
        
        # MSE loss
        mse_loss = jnp.mean((pred - y) ** 2)
        
        # Simple L2 regularization instead of smoothness regularization
        reg = 0.0
        for k, v in params.items():
            if 'weights' in k:
                reg += jnp.sum(v ** 2)
        
        return mse_loss + smooth_reg_strength * reg
    
    # Create a training step function that doesn't need the model object
    @jit
    def train_step_jit(params, opt_state, X, y, lambda_smooth):
        """Single optimization step without using the model object."""
        loss_value, grads = jax.value_and_grad(lambda p: loss_fn_jit(p, X, y, lambda_smooth))(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value
    
    losses = []
    
    for epoch in range(num_epochs):
        # Shuffle data
        perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        epoch_loss = 0.0
        
        for batch in range(num_batches):
            start_idx = batch * batch_size
            end_idx = min(start_idx + batch_size, num_samples)
            
            X_batch = X_shuffled[start_idx:end_idx]
            y_batch = y_shuffled[start_idx:end_idx]
            
            # Use the jitted step function that doesn't need the model object
            params, opt_state, batch_loss = train_step_jit(params, opt_state, X_batch, y_batch, lambda_smooth)
            epoch_loss += batch_loss
        
        epoch_loss /= num_batches
        losses.append(epoch_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
    
    # Update model with trained parameters
    model.update_params(params)
    return model, losses

###########################################
# Evaluation Functions
###########################################

def evaluate_forecasts(model, X_test, y_test):
    """
    Evaluate forecast performance with various metrics.
    """
    predictions = model(X_test)
    
    # Convert to numpy for evaluation
    y_test_np = np.array(y_test)
    predictions_np = np.array(predictions)
    
    # Calculate metrics for each forecast horizon
    horizon = y_test_np.shape[1]
    
    # Initialize metrics
    metrics = {
        'mse': np.zeros(horizon),
        'mae': np.zeros(horizon),
        'mape': np.zeros(horizon),
        'directional_accuracy': np.zeros(horizon)
    }
    
    for h in range(horizon):
        # Mean Squared Error
        metrics['mse'][h] = np.mean((y_test_np[:, h] - predictions_np[:, h]) ** 2)
        
        # Mean Absolute Error
        metrics['mae'][h] = np.mean(np.abs(y_test_np[:, h] - predictions_np[:, h]))
        
        # Mean Absolute Percentage Error (with safeguards against division by zero)
        non_zero_mask = np.abs(y_test_np[:, h]) > 1e-10
        if np.sum(non_zero_mask) > 0:
            metrics['mape'][h] = np.mean(
                np.abs((y_test_np[:, h][non_zero_mask] - predictions_np[:, h][non_zero_mask]) / 
                       y_test_np[:, h][non_zero_mask])
            ) * 100
        else:
            metrics['mape'][h] = np.nan
        
        # Directional Accuracy
        correct_direction = (np.sign(y_test_np[:, h]) == np.sign(predictions_np[:, h]))
        metrics['directional_accuracy'][h] = np.mean(correct_direction) * 100
    
    # Overall metrics
    metrics['overall_mse'] = np.mean(metrics['mse'])
    metrics['overall_mae'] = np.mean(metrics['mae'])
    metrics['overall_directional_accuracy'] = np.mean(metrics['directional_accuracy'])
    
    return metrics, predictions_np

###########################################
# Interactive Visualization Functions
###########################################

def create_interactive_forecast_plot(ticker, stock_data, X_test, y_test, predictions, test_start_idx, sample_idx=0):
    """
    Create an interactive plotly visualization of stock forecast vs actual values.
    
    Args:
        ticker: Stock ticker symbol
        stock_data: Original stock data DataFrame
        X_test: Test features
        y_test: Test targets
        predictions: Model predictions
        test_start_idx: Starting index of test data in original time series
        sample_idx: Index of the sample to visualize
        
    Returns:
        Plotly figure object
    """
    # Convert to numpy
    y_test_np = np.array(y_test)
    predictions_np = np.array(predictions)
    
    # Get forecast horizon
    horizon = y_test_np.shape[1]
    
    # Get the actual time index for this forecast
    time_idx = test_start_idx + sample_idx
    
    # Ensure time_idx is within bounds
    if time_idx >= len(stock_data):
        time_idx = len(stock_data) - horizon - 1
    
    forecast_time_idx = np.arange(time_idx, min(time_idx + horizon, len(stock_data)))
    
    # Get context for the plot (30 days before and 10 days after the forecast)
    context_start = max(0, time_idx - 30)
    context_end = min(len(stock_data), time_idx + horizon + 10)
    
    # Get corresponding dates
    context_dates = stock_data.index[context_start:context_end]
    forecast_dates = stock_data.index[forecast_time_idx]
    
    # Calculate forecasted prices
    start_price = float(stock_data['Close'].iloc[time_idx - 1])
    
    # Extract the correct slices of return arrays and convert to simple numpy arrays
    forecast_returns = np.array(predictions_np[sample_idx][:len(forecast_time_idx)])
    actual_returns = np.array(y_test_np[sample_idx][:len(forecast_time_idx)])
    
    # Calculate forecasted prices using numpy operations (not pandas)
    forecast_prices = start_price * np.cumprod(1 + forecast_returns)
    actual_future_prices = start_price * np.cumprod(1 + actual_returns)
    
    # Create plotly figure
    fig = go.Figure()
    
    # Add actual price trace
    fig.add_trace(go.Scatter(
        x=context_dates,
        y=stock_data['Close'].iloc[context_start:context_end],
        mode='lines',
        name='Historical Price',
        line=dict(color='blue')
    ))
    
    # Add forecasted price trace
    fig.add_trace(go.Scatter(
        x=forecast_dates,
        y=forecast_prices,
        mode='lines',
        name='Forecasted Price',
        line=dict(color='red', dash='dash')
    ))
    
    # Add actual future price trace
    fig.add_trace(go.Scatter(
        x=forecast_dates,
        y=actual_future_prices,
        mode='lines',
        name='Actual Future Price',
        line=dict(color='green')
    ))
    
    # Get the forecast start date
    forecast_start_date = stock_data.index[time_idx]
    
    # Get y-range for drawing the vertical line
    all_prices = np.concatenate([
        stock_data['Close'].iloc[context_start:context_end].values,
        forecast_prices,
        actual_future_prices
    ])
    y_min = np.min(all_prices) * 0.99
    y_max = np.max(all_prices) * 1.01
    
    # Add vertical line as a scatter trace instead of using add_vline
    fig.add_trace(go.Scatter(
        x=[forecast_start_date, forecast_start_date],
        y=[y_min, y_max],
        mode='lines',
        line=dict(color='gray', width=2, dash='dash'),
        showlegend=False
    ))
    
    # Add annotation for forecast start
    fig.add_annotation(
        x=forecast_start_date,
        y=y_max,
        text="Forecast Start",
        showarrow=False,
        yshift=10
    )
    
    # Update layout
    fig.update_layout(
        title=f'{ticker} Price Forecast',
        xaxis_title='Date',
        yaxis_title='Price ($)',
        hovermode='x unified',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ),
        template='plotly_white'
    )
    
    # Add range slider
    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    return fig
    
def create_interactive_feature_plot(ticker, stock_data, forecast_start_idx, features=['Close', 'Volume', 'Volatility', 'RSI']):
    """
    Create an interactive plotly visualization of multiple stock features.
    """
    # Create subplots: one for each feature
    num_features = len(features)
    fig = make_subplots(rows=num_features, cols=1, shared_xaxes=True,
                        subplot_titles=[f"{ticker} - {feat}" for feat in features],
                        vertical_spacing=0.05)
    
    # Color map for features
    colors = ['blue', 'orange', 'green', 'red']
    
    # Add traces for each feature
    for i, feature in enumerate(features):
        if feature in stock_data.columns:
            fig.add_trace(
                go.Scatter(
                    x=stock_data.index,
                    y=stock_data[feature],
                    mode='lines',
                    name=feature,
                    line=dict(color=colors[i % len(colors)])
                ),
                row=i+1, col=1
            )
    
    # Get the forecast start date
    forecast_start_date = stock_data.index[forecast_start_idx]
    
    # Add a vertical line at the forecast starting point for each feature
    for i, feature in enumerate(features):
        if feature in stock_data.columns:
            # Get y range for this subplot
            y_values = np.array(stock_data[feature])
            y_min = float(np.min(y_values))
            y_max = float(np.max(y_values))
            
            # Add some padding
            y_range = y_max - y_min
            y_min = y_min - 0.05 * y_range
            y_max = y_max + 0.05 * y_range
            
            # Add vertical line as a scatter trace
            fig.add_trace(
                go.Scatter(
                    x=[forecast_start_date, forecast_start_date],
                    y=[y_min, y_max],
                    mode='lines',
                    line=dict(color='gray', width=1, dash='dash'),
                    showlegend=False
                ),
                row=i+1, col=1
            )
    
    # Update layout
    fig.update_layout(
        height=250 * num_features,
        width=1000,
        title_text=f"{ticker} - Key Features Analysis",
        showlegend=False,
        hovermode='x unified',
        template='plotly_white'
    )
    
    # Add range slider to the bottom subplot only - fix for using the correct axis name
    # Use proper axis name - the last subplot which is num_features
    range_slider_axis = f"xaxis{num_features}" if num_features > 1 else "xaxis"
    
    layout_update = {
        range_slider_axis: dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    }
    
    fig.update_layout(**layout_update)
    
    return fig

def create_interactive_trading_simulation_plot(ticker, stock_data, portfolio_values, buy_hold_values, test_start_idx, performance):
    """
    Create an interactive plotly visualization of trading simulation results.
    """
    import numpy as np
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    # Get dates for the plot
    dates = stock_data.index[test_start_idx:test_start_idx+len(portfolio_values)]
    if len(dates) > len(portfolio_values):
        dates = dates[:len(portfolio_values)]
    
    # Calculate drawdowns
    def calculate_drawdowns(values):
        peaks = np.maximum.accumulate(values)
        drawdowns = (peaks - values) / peaks * 100
        return drawdowns
    
    strategy_drawdowns = calculate_drawdowns(portfolio_values)
    bh_drawdowns = calculate_drawdowns(buy_hold_values)
    
    # Create drawdown subplot
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, 
                        vertical_spacing=0.05,
                        subplot_titles=("Portfolio Value ($)", "Drawdown (%)"),
                        row_heights=[0.7, 0.3])
    
    # Add traces to first subplot (portfolio values)
    fig.add_trace(go.Scatter(
        x=dates, 
        y=portfolio_values,
        mode='lines',
        name='KAN Trading Strategy',
        line=dict(color='blue', width=2)
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter(
        x=dates, 
        y=buy_hold_values,
        mode='lines',
        name='Buy & Hold Strategy',
        line=dict(color='red', width=2, dash='dash')
    ), row=1, col=1)
    
    # Add traces to second subplot (drawdowns)
    fig.add_trace(go.Scatter(
        x=dates, 
        y=strategy_drawdowns,
        mode='lines',
        name='KAN Strategy Drawdown',
        line=dict(color='blue', width=1)
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=dates, 
        y=bh_drawdowns,
        mode='lines',
        name='Buy & Hold Drawdown',
        line=dict(color='red', width=1, dash='dash')
    ), row=2, col=1)
    
    # Update layout
    fig.update_layout(
        title_text=f"{ticker} - Trading Simulation Results",
        hovermode='x unified',
        template='plotly_white',
        height=700,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Add annotations for key metrics
    metrics_text = (
        f"KAN Strategy: Return={performance['strategy_return']:.2f}%, "
        f"Sharpe={performance['strategy_sharpe']:.2f}, "
        f"MDD={performance['strategy_mdd']:.2f}%<br>"
        f"Buy & Hold: Return={performance['buy_hold_return']:.2f}%, "
        f"Sharpe={performance['buy_hold_sharpe']:.2f}, "
        f"MDD={performance['buy_hold_mdd']:.2f}%"
    )
    
    fig.add_annotation(
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        text=metrics_text,
        showarrow=False,
        font=dict(size=12),
        align="center",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="black",
        borderwidth=1,
        borderpad=4
    )
    
    # Add range slider
    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    return fig


def create_interactive_kan_activations_plot(model, ticker, feature_cols, lookback_window):
    """
    Create an interactive visualization of the learned KAN activation functions.
    """
    # Get first layer weights and activations
    weights = np.array(model.layers[0].weights)
    activations = np.array(model.layers[0].activations)
    grid_points = np.array(model.layers[0].grid_points)
    
    # Determine number of features
    num_features = len(feature_cols)
    total_inputs = num_features * lookback_window
    
    # Analyze feature importance based on weights
    feature_importance = np.zeros((num_features, weights.shape[1]))
    
    for f_idx, feature in enumerate(feature_cols):
        # Sum absolute weights for all lags of this feature
        for lag in range(lookback_window):
            input_idx = f_idx * lookback_window + lag
            if input_idx < weights.shape[0]:
                feature_importance[f_idx] += np.abs(weights[input_idx])
    
    # Normalize importance
    feature_importance = feature_importance / np.sum(feature_importance, axis=0, keepdims=True)
    
    # Create a heatmap for feature importance
    fig1 = go.Figure(data=go.Heatmap(
        z=feature_importance,
        x=[f'Unit {i}' for i in range(weights.shape[1])],
        y=feature_cols,
        colorscale='Viridis',
        colorbar=dict(title='Normalized Importance'),
        hovertemplate='Feature: %{y}<br>Unit: %{x}<br>Importance: %{z:.3f}<extra></extra>'
    ))
    
    fig1.update_layout(
        title=f"{ticker} - Feature Importance Based on First Layer Weights",
        xaxis_title="Hidden Units",
        yaxis_title="Features",
        template='plotly_white',
        height=600,
        width=1000
    )
    
    # Select units to visualize activation functions
    num_units_to_show = min(6, activations.shape[0])
    selected_units = np.linspace(0, activations.shape[0]-1, num_units_to_show, dtype=int)
    
    # Create subplots for activation functions
    fig2 = make_subplots(rows=num_units_to_show, cols=1, 
                        subplot_titles=[f"Unit {i} Activation Function" for i in selected_units],
                        vertical_spacing=0.05)
    
    # For each unit, find the top 3 features by importance
    for i, unit_idx in enumerate(selected_units):
        # Add activation function
        fig2.add_trace(
            go.Scatter(
                x=grid_points,
                y=activations[unit_idx],
                mode='lines',
                name=f'Unit {unit_idx} Activation',
                line=dict(color='royalblue', width=2)
            ),
            row=i+1, col=1
        )
        
        # Find top features for this unit
        if unit_idx < feature_importance.shape[1]:
            top_features = np.argsort(feature_importance[:, unit_idx])[-3:]  # Top 3 features
            top_feature_names = [feature_cols[f] for f in top_features]
            
            # Add annotation for top features - FIX: Use proper reference syntax
            fig2.add_annotation(
                x=0.01,
                y=0.95,
                xref=f"x{i+1} domain" if i > 0 else "x domain",  # Fixed syntax for xref
                yref=f"y{i+1} domain" if i > 0 else "y domain",  # Fixed syntax for yref
                text=f"Top Features: {', '.join(top_feature_names)}",
                showarrow=False,
                align="left",
                bgcolor="rgba(255, 255, 255, 0.8)",
                bordercolor="black",
                borderwidth=1,
                font=dict(size=10)
            )
    
    fig2.update_layout(
        height=250 * num_units_to_show,
        width=1000,
        title_text=f"{ticker} - Learned Activation Functions",
        showlegend=False,
        template='plotly_white'
    )
    
    return fig1, fig2

def create_interactive_trading_simulation_plot(ticker, stock_data, portfolio_values, buy_hold_values, test_start_idx, performance):
    """
    Create an interactive plotly visualization of trading simulation results.
    """
    # Get dates for the plot
    dates = stock_data.index[test_start_idx:test_start_idx+len(portfolio_values)]
    if len(dates) > len(portfolio_values):
        dates = dates[:len(portfolio_values)]
    
    # Calculate drawdowns
    def calculate_drawdowns(values):
        peaks = np.maximum.accumulate(values)
        drawdowns = (peaks - values) / peaks * 100
        return drawdowns
    
    strategy_drawdowns = calculate_drawdowns(portfolio_values)
    bh_drawdowns = calculate_drawdowns(buy_hold_values)
    
    # Create drawdown subplot
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, 
                       vertical_spacing=0.05,
                       subplot_titles=("Portfolio Value ($)", "Drawdown (%)"),
                       row_heights=[0.7, 0.3])
    
    # Add traces to first subplot (portfolio values)
    fig.add_trace(go.Scatter(
        x=dates, 
        y=portfolio_values,
        mode='lines',
        name='KAN Trading Strategy',
        line=dict(color='blue', width=2)
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter(
        x=dates, 
        y=buy_hold_values,
        mode='lines',
        name='Buy & Hold Strategy',
        line=dict(color='red', width=2, dash='dash')
    ), row=1, col=1)
    
    # Add traces to second subplot (drawdowns)
    fig.add_trace(go.Scatter(
        x=dates, 
        y=strategy_drawdowns,
        mode='lines',
        name='KAN Strategy Drawdown',
        line=dict(color='blue', width=1)
    ), row=2, col=1)
    
    fig.add_trace(go.Scatter(
        x=dates, 
        y=bh_drawdowns,
        mode='lines',
        name='Buy & Hold Drawdown',
        line=dict(color='red', width=1, dash='dash')
    ), row=2, col=1)
    
    # Update layout
    fig.update_layout(
        title_text=f"{ticker} - Trading Simulation Results",
        hovermode='x unified',
        template='plotly_white',
        height=700,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Add annotations for key metrics
    metrics_text = (
        f"KAN Strategy: Return={performance['strategy_return']:.2f}%, "
        f"Sharpe={performance['strategy_sharpe']:.2f}, "
        f"MDD={performance['strategy_mdd']:.2f}%<br>"
        f"Buy & Hold: Return={performance['buy_hold_return']:.2f}%, "
        f"Sharpe={performance['buy_hold_sharpe']:.2f}, "
        f"MDD={performance['buy_hold_mdd']:.2f}%"
    )
    
    fig.add_annotation(
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        text=metrics_text,
        showarrow=False,
        font=dict(size=12),
        align="center",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="black",
        borderwidth=1,
        borderpad=4
    )
    
    # Add range slider
    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    return fig

def create_interactive_comparison_plot(tickers, metrics, metric_name='directional_accuracy'):
    """
    Create an interactive bar chart comparing metrics across different stocks.
    """
    # Get metric values for each ticker
    values = [metrics[ticker][metric_name] for ticker in tickers]
    
    # Determine color based on values (higher is better)
    colors = ['lightgreen' if v > np.median(values) else 'lightcoral' for v in values]
    
    # Create bar chart
    fig = go.Figure()
    
    fig.add_trace(
        go.Bar(
            x=tickers,
            y=values,
            marker_color=colors,
            text=values,
            textposition='auto',
        )
    )
    
    # Format y-values based on metric
    if 'return' in metric_name.lower():
        y_format = ':.2f%'  # Percentage format
        y_title = 'Return (%)'
    elif 'accuracy' in metric_name.lower():
        y_format = ':.2f%'  # Percentage format
        y_title = 'Accuracy (%)'
    elif 'sharpe' in metric_name.lower():
        y_format = ':.2f'  # Decimal format
        y_title = 'Sharpe Ratio'
    else:
        y_format = ':.6f'  # Default format
        y_title = 'Value'
    
    # Update layout
    fig.update_layout(
        title=f"Comparison of {metric_name.replace('_', ' ').title()} Across Stocks",
        xaxis_title="Stock",
        yaxis_title=y_title,
        template='plotly_white',
        height=500,
        yaxis=dict(tickformat=y_format),
    )
    
    return fig

def perform_trading_simulation(model, X_test, y_test, stock_data, test_start_idx, ticker):
    """
    Perform a trading simulation based on model predictions.
    
    Args:
        model: Trained TimeSeriesKAN model
        X_test: Test features
        y_test: Test targets
        stock_data: Original stock data
        test_start_idx: Starting index of test data in original time series
        ticker: Stock ticker symbol
    """
    predictions = model(X_test)
    
    # Convert to numpy
    predictions_np = np.array(predictions)
    y_test_np = np.array(y_test)
    
    # We'll use the first horizon prediction for trading
    horizon_0_preds = predictions_np[:, 0]
    horizon_0_actual = y_test_np[:, 0]
    
    # Trading strategy: Long when predicted return > threshold, Short when < -threshold
    threshold = 0.0005  # 5 basis points
    
    # Initialize portfolio values
    initial_value = 10000
    portfolio_value = initial_value
    buy_hold_value = initial_value
    
    portfolio_values = [portfolio_value]
    buy_hold_values = [buy_hold_value]
    
    # Trading parameters
    transaction_cost = 0.0001  # 1 basis point per trade
    position = 0  # 0: no position, 1: long, -1: short
    
    # Simulate trading
    for i in range(len(horizon_0_preds)):
        # Current actual return
        current_return = horizon_0_actual[i]
        
        # Update buy & hold strategy
        buy_hold_value *= (1 + current_return)
        buy_hold_values.append(buy_hold_value)
        
        # Trading signal
        pred_return = horizon_0_preds[i]
        
        new_position = 0
        if pred_return > threshold:
            new_position = 1
        elif pred_return < -threshold:
            new_position = -1
        
        # Calculate transaction cost if position changes
        if new_position != position:
            portfolio_value *= (1 - transaction_cost)
        
        # Update portfolio based on position
        if position == 1:  # Long position
            portfolio_value *= (1 + current_return)
        elif position == -1:  # Short position
            portfolio_value *= (1 - current_return)
        
        position = new_position
        portfolio_values.append(portfolio_value)
    
    # Calculate performance metrics
    # Total return
    strategy_return = (portfolio_value / initial_value - 1) * 100
    buy_hold_return = (buy_hold_value / initial_value - 1) * 100
    
    # Sharpe ratio (assuming risk-free rate of 0 for simplicity)
    strategy_returns = np.diff(portfolio_values) / np.array(portfolio_values)[:-1]
    bh_returns = np.diff(buy_hold_values) / np.array(buy_hold_values)[:-1]
    
    strategy_sharpe = np.mean(strategy_returns) / np.std(strategy_returns) * np.sqrt(252)  # Annualized
    bh_sharpe = np.mean(bh_returns) / np.std(bh_returns) * np.sqrt(252)  # Annualized
    
    # Maximum drawdown
    def calculate_max_drawdown(values):
        peak = values[0]
        max_dd = 0
        
        for value in values:
            if value > peak:
                peak = value
            dd = (peak - value) / peak
            if dd > max_dd:
                max_dd = dd
                
        return max_dd * 100
    
    strategy_mdd = calculate_max_drawdown(portfolio_values)
    bh_mdd = calculate_max_drawdown(buy_hold_values)
    
    # Create performance summary
    performance = {
        'strategy_return': strategy_return,
        'buy_hold_return': buy_hold_return,
        'strategy_sharpe': strategy_sharpe,
        'buy_hold_sharpe': bh_sharpe,
        'strategy_mdd': strategy_mdd,
        'buy_hold_mdd': bh_mdd
    }
    
    # Create interactive trading simulation plot
    fig = create_interactive_trading_simulation_plot(ticker, stock_data, portfolio_values, buy_hold_values, 
                                                    test_start_idx, performance)
    
    return fig, performance


def create_interactive_returns_distribution_plot(ticker, y_test, predictions):
    """
    Create an interactive visualization comparing the distribution of actual vs predicted returns.
    
    Args:
        ticker: Stock ticker symbol
        y_test: Actual test returns (numpy array or JAX array)
        predictions: Predicted returns (numpy array or JAX array)
        
    Returns:
        Plotly figure object
    """
    # Convert to numpy arrays if they are JAX arrays
    y_test_np = np.array(y_test)
    predictions_np = np.array(predictions)
    
    # For the first time horizon (h=0)
    actual_returns = y_test_np[:, 0]
    predicted_returns = predictions_np[:, 0]
    
    # Create figure
    fig = go.Figure()
    
    # Add histogram traces
    fig.add_trace(go.Histogram(
        x=actual_returns,
        name='Actual Returns',
        marker_color='blue',
        opacity=0.7,
        nbinsx=50
    ))
    
    fig.add_trace(go.Histogram(
        x=predicted_returns,
        name='Predicted Returns',
        marker_color='red',
        opacity=0.7,
        nbinsx=50
    ))
    
    # Calculate histogram heights for scaling KDE
    hist_actual, bin_edges = np.histogram(actual_returns, bins=50)
    hist_pred, _ = np.histogram(predicted_returns, bins=50)
    max_hist_height = max(hist_actual.max(), hist_pred.max())
    
    try:
        # Add KDE trace for actual returns if scipy is available
        kde_actual = gaussian_kde(actual_returns)
        x_range = np.linspace(min(actual_returns), max(actual_returns), 1000)
        y_kde_actual = kde_actual(x_range)
        # Scale KDE to match histogram height
        scaling_factor = max_hist_height / y_kde_actual.max() if y_kde_actual.max() > 0 else 1
        
        fig.add_trace(go.Scatter(
            x=x_range,
            y=y_kde_actual * scaling_factor,
            mode='lines',
            name='Actual Returns Density',
            line=dict(color='darkblue', width=2)
        ))
        
        # Add KDE trace for predicted returns
        kde_pred = gaussian_kde(predicted_returns)
        y_kde_pred = kde_pred(x_range)
        scaling_factor = max_hist_height / y_kde_pred.max() if y_kde_pred.max() > 0 else 1
        
        fig.add_trace(go.Scatter(
            x=x_range,
            y=y_kde_pred * scaling_factor,
            mode='lines',
            name='Predicted Returns Density',
            line=dict(color='darkred', width=2)
        ))
    except Exception as e:
        print(f"Warning: Could not create KDE plots: {e}")
    
    # Calculate statistics
    actual_mean = np.mean(actual_returns)
    actual_std = np.std(actual_returns)
    pred_mean = np.mean(predicted_returns)
    pred_std = np.std(predicted_returns)
    
    # Add vertical lines for means
    fig.add_trace(go.Scatter(
        x=[actual_mean, actual_mean],
        y=[0, max_hist_height],
        mode='lines',
        name='Actual Mean',
        line=dict(color='blue', width=2, dash='dash')
    ))
    
    fig.add_trace(go.Scatter(
        x=[pred_mean, pred_mean],
        y=[0, max_hist_height],
        mode='lines',
        name='Predicted Mean',
        line=dict(color='red', width=2, dash='dash')
    ))
    
    # Update layout
    fig.update_layout(
        title=f'{ticker} - Distribution of Returns',
        xaxis_title='Return',
        yaxis_title='Frequency',
        barmode='overlay',
        hovermode='x unified',
        template='plotly_white',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
    )
    
    # Add annotations with statistics
    stats_text = (
        f"Actual: Mean={actual_mean:.4f}, Std={actual_std:.4f}<br>"
        f"Predicted: Mean={pred_mean:.4f}, Std={pred_std:.4f}"
    )
    
    fig.add_annotation(
        x=0.01,
        y=0.97,
        xref="paper",
        yref="paper",
        text=stats_text,
        showarrow=False,
        font=dict(size=12),
        align="left",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="black",
        borderwidth=1,
        borderpad=4
    )
    
    return fig


###########################################
# Main Function to Run Stock Forecasting
###########################################

def forecast_top_sp500_stocks(num_stocks=3, lookback_window=20, forecast_horizon=5, num_epochs=50):
    """
    Main function to forecast stock prices for top S&P 500 stocks.
    
    Args:
        num_stocks: Number of top stocks to analyze
        lookback_window: Number of past time steps to use as features
        forecast_horizon: Number of future time steps to predict
        num_epochs: Number of training epochs
        
    Returns:
        Dictionary of results for each ticker
    """
    # Get top S&P 500 stocks (using our hardcoded list of top companies)
    top_tickers = fetch_sp500_top_stocks(num_stocks)
    
    # For testing and debugging, we'll process a smaller number of stocks
    top_tickers = top_tickers[:num_stocks]
    
    # Download stock data
    end_date = datetime.now()
    start_date = end_date - timedelta(days=2*365)  # 2 years of data
    stock_data = download_stock_data(top_tickers, start_date=start_date.strftime('%Y-%m-%d'))
    
    # Store results
    results = {}
    all_metrics = {}
    
    # Process each stock
    for ticker in top_tickers:
        try:
            print(f"\nProcessing {ticker}...")
            
            # Get data for this ticker
            ticker_data = stock_data[ticker]
            
            # Prepare data for the model
            X_train, y_train, X_test, y_test, feature_cols = prepare_stock_data(
                ticker_data, lookback_window, forecast_horizon)
            
            # Set test_start_idx
            test_start_idx = len(ticker_data) - len(X_test) - forecast_horizon
            
            # Check if we have enough data
            if len(X_train) < 100 or len(X_test) < 20:
                print(f"Not enough data for {ticker}, skipping...")
                continue
            
            # Initialize model
            input_dim = X_train.shape[1]  # Number of features * lookback window
            output_dim = y_train.shape[1]  # Forecast horizon
            
            key = jax.random.PRNGKey(42)
            model = TimeSeriesKAN(input_dim, output_dim, hidden_dims=[64, 32], key=key)
            
            # For faster testing, reduce the number of epochs
            print(f"Training model for {ticker} with {num_epochs} epochs...")
            
            # Train model
            model, losses = train_model(model, X_train, y_train, num_epochs=num_epochs, batch_size=32)
            
            # Plot training loss
            fig_loss = go.Figure()
            fig_loss.add_trace(go.Scatter(
                y=losses,
                mode='lines',
                name='Training Loss',
                line=dict(color='blue', width=2)
            ))
            fig_loss.update_layout(
                title=f"{ticker} - Training Loss",
                xaxis_title="Epoch",
                yaxis_title="Loss",
                template='plotly_white'
            )
            
            # Evaluate model
            print(f"Evaluating forecasts for {ticker}...")
            metrics, predictions = evaluate_forecasts(model, X_test, y_test)
            
            print(f"Overall MSE: {metrics['overall_mse']:.6f}")
            print(f"Overall Directional Accuracy: {metrics['overall_directional_accuracy']:.2f}%")
            
            # Store metrics
            all_metrics[ticker] = {
                'mse': metrics['overall_mse'],
                'mae': metrics['overall_mae'],
                'directional_accuracy': metrics['overall_directional_accuracy']
            }
            
            # Visualize forecasts
            print(f"Creating interactive visualizations for {ticker}...")
            
            # Create interactive forecast plot
            forecast_fig = create_interactive_forecast_plot(ticker, ticker_data, X_test, y_test, predictions, test_start_idx)
            
            # Create interactive features plot
            features_fig = create_interactive_feature_plot(ticker, ticker_data, test_start_idx, 
                                                          features=['Close', 'Volatility', 'RSI', 'MA_ratio_50_200'])
            
            # Create returns distribution plot
            returns_dist_fig = create_interactive_returns_distribution_plot(ticker, y_test, predictions)
            
            # Visualize KAN activations
            importance_fig, activations_fig = create_interactive_kan_activations_plot(model, ticker, feature_cols, lookback_window)
            
            # Trading simulation
            print(f"Simulating trading strategy for {ticker}...")
            trading_fig, performance = perform_trading_simulation(model, X_test, y_test, ticker_data, test_start_idx, ticker)
            
            # Update metrics with trading performance
            all_metrics[ticker].update({
                'strategy_return': performance['strategy_return'],
                'buy_hold_return': performance['buy_hold_return'],
                'strategy_sharpe': performance['strategy_sharpe']
            })
            
            # Store results
            results[ticker] = {
                'model': model,
                'metrics': metrics,
                'performance': performance,
                'predictions': predictions,
                'figures': {
                    'loss': fig_loss,
                    'forecast': forecast_fig,
                    'features': features_fig,
                    'returns_dist': returns_dist_fig,
                    'importance': importance_fig,
                    'activations': activations_fig,
                    'trading': trading_fig
                }
            }
            
            print(f"Completed processing for {ticker}")
            
        except Exception as e:
            print(f"Error processing {ticker}: {e}")
            continue
    
    # Create comparative plots
    if len(results) > 1:
        print("\nCreating comparison plots across stocks...")
        
        # Compare directional accuracy
        accuracy_fig = create_interactive_comparison_plot(list(all_metrics.keys()), all_metrics, 'directional_accuracy')
        
        # Compare strategy returns
        returns_fig = create_interactive_comparison_plot(list(all_metrics.keys()), all_metrics, 'strategy_return')
        
        # Compare Sharpe ratios
        sharpe_fig = create_interactive_comparison_plot(list(all_metrics.keys()), all_metrics, 'strategy_sharpe')
        
        # Store comparison figures
        results['comparison'] = {
            'accuracy': accuracy_fig,
            'returns': returns_fig,
            'sharpe': sharpe_fig
        }
    
    print("\nForecasting complete!")
    return results


def run_demo():
    """Run a simplified demo with pre-generated sample data."""
    print("Running demo with sample data for AAPL...")
    
    # Create sample plots
    import plotly.io as pio
    
    # Set renderer based on environment detection
    try:
        from IPython import get_ipython
        if get_ipython() is not None:
            pio.renderers.default = 'notebook'  # For Jupyter notebooks
        else:
            pio.renderers.default = 'png'  # Fallback to static images
    except:
        # If not in IPython, save to files instead
        pio.renderers.default = 'png'
    
    # Create sample interactive plots
    ticker = "AAPL"
    sample_figures = create_sample_interactive_plots(ticker=ticker)
    
    # Display or save plots
    for name, fig in sample_figures.items():
        print(f"\nDisplaying {name} plot...")
        try:
            # Try to display the figure
            display(fig)
        except Exception as e:
            print(f"Error displaying figure: {e}")
            # Fall back to saving
            try:
                fig.write_html(f"demo_{name}.html")
                print(f"Saved demo_{name}.html")
            except:
                try:
                    fig.write_image(f"demo_{name}.png")
                    print(f"Saved demo_{name}.png")
                except:
                    print(f"Could not save figure for {name}")
    
    print("\nDemo complete! To run the full analysis with KAN model training,")
    print("call forecast_top_sp500_stocks() function.")
    
    return sample_figures
    

def create_sample_interactive_plots(ticker="AAPL", start_date="2022-01-01"):
    """
    Download sample stock data and create interactive plots.
    """
    import numpy as np
    import pandas as pd
    import yfinance as yf
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    from datetime import datetime

    end_date = datetime.now().strftime('%Y-%m-%d')
    
    # Download stock data
    print(f"Downloading stock data for {ticker} from {start_date} to {end_date}...")
    stock_data = yf.download(ticker, start=start_date, end=end_date, auto_adjust=True)
    
    # Calculate features
    stock_data['Returns'] = stock_data['Close'].pct_change()
    stock_data['LogReturns'] = np.log(stock_data['Close'] / stock_data['Close'].shift(1))
    stock_data['Volatility'] = stock_data['Returns'].rolling(window=20).std()
    
    # Calculate RSI
    delta = stock_data['Close'].diff()
    gain = delta.where(delta > 0, 0).fillna(0)
    loss = -delta.where(delta < 0, 0).fillna(0)
    
    avg_gain = gain.rolling(window=14).mean()
    avg_loss = loss.rolling(window=14).mean()
    
    rs = avg_gain / avg_loss
    stock_data['RSI'] = 100 - (100 / (1 + rs))
    
    # Calculate moving averages
    stock_data['MA_50d'] = stock_data['Close'].rolling(window=50).mean()
    stock_data['MA_200d'] = stock_data['Close'].rolling(window=200).mean()
    
    # Drop NaN values
    stock_data = stock_data.dropna()
    
    # Create a list of plots
    figures = {}
    
    # 1. Stock price with moving averages
    fig1 = go.Figure()
    
    fig1.add_trace(go.Scatter(
        x=stock_data.index,
        y=stock_data['Close'],
        mode='lines',
        name='Price',
        line=dict(color='blue')
    ))
    
    fig1.add_trace(go.Scatter(
        x=stock_data.index,
        y=stock_data['MA_50d'],
        mode='lines',
        name='50-day MA',
        line=dict(color='orange')
    ))
    
    fig1.add_trace(go.Scatter(
        x=stock_data.index,
        y=stock_data['MA_200d'],
        mode='lines',
        name='200-day MA',
        line=dict(color='green')
    ))
    
    fig1.update_layout(
        title=f'{ticker} Stock Price with Moving Averages',
        xaxis_title='Date',
        yaxis_title='Price ($)',
        hovermode='x unified',
        template='plotly_white',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        )
    )
    
    # Add range slider
    fig1.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    figures['price_ma'] = fig1
    
    # 2. Multiple features plot
    features_to_plot = ['Close', 'Volume', 'Volatility', 'RSI']
    
    # Create subplots: one for each feature
    num_features = len(features_to_plot)
    fig2 = make_subplots(rows=num_features, cols=1, shared_xaxes=True,
                        subplot_titles=[f"{ticker} - {feat}" for feat in features_to_plot],
                        vertical_spacing=0.05)
    
    # Set forecast starting point for visualization
    forecast_start_idx = len(stock_data) - 30
    
    # Color map for features
    colors = ['blue', 'orange', 'green', 'red']
    
    # Add traces for each feature
    for i, feature in enumerate(features_to_plot):
        if feature in stock_data.columns:
            fig2.add_trace(
                go.Scatter(
                    x=stock_data.index,
                    y=stock_data[feature],
                    mode='lines',
                    name=feature,
                    line=dict(color=colors[i % len(colors)])
                ),
                row=i+1, col=1
            )
    
    # Add a vertical line at the forecast starting point
    forecast_date = stock_data.index[forecast_start_idx]
    
    for i in range(num_features):
        # Get y range for this subplot
        y_values = np.array(fig2.data[i]['y'])
        y_min = float(np.min(y_values))
        y_max = float(np.max(y_values))
        
        # Add some padding
        y_range = y_max - y_min
        y_min = y_min - 0.05 * y_range
        y_max = y_max + 0.05 * y_range
        
        fig2.add_trace(
            go.Scatter(
                x=[forecast_date, forecast_date],
                y=[y_min, y_max],
                mode='lines',
                line=dict(color='gray', width=1, dash='dash'),
                showlegend=False
            ),
            row=i+1, col=1
        )
    
    # Update layout
    fig2.update_layout(
        height=250 * num_features,
        width=1000,
        title_text=f"{ticker} - Key Features Analysis",
        showlegend=False,
        hovermode='x unified',
        template='plotly_white'
    )
    
    # Add range slider to the bottom subplot only
    fig2.update_layout(
        xaxis4=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    figures['features'] = fig2
    
    # 3. Simplified forecast plot (to avoid the complex calculations)
    
    # Set the start point for our forecast
    forecast_idx = len(stock_data) - 30
    horizon = 5
    
    # Get a slice of actual data for demonstration
    historical_slice = stock_data.iloc[forecast_idx-20:forecast_idx]
    future_slice = stock_data.iloc[forecast_idx:forecast_idx+horizon]
    
    # Create a very simple forecast - just add a small random deviation to actual prices
    np.random.seed(42)
    future_prices = future_slice['Close'].values
    forecast_prices = future_prices * (1 + np.random.normal(0, 0.01, size=len(future_prices)))
    
    # Create a simple forecast plot
    fig3 = go.Figure()
    
    # Add historical price
    fig3.add_trace(go.Scatter(
        x=historical_slice.index,
        y=historical_slice['Close'],
        mode='lines',
        name='Historical Price',
        line=dict(color='blue')
    ))
    
    # Add actual future price
    fig3.add_trace(go.Scatter(
        x=future_slice.index,
        y=future_slice['Close'],
        mode='lines',
        name='Actual Future Price',
        line=dict(color='green')
    ))
    
    # Add forecasted price
    fig3.add_trace(go.Scatter(
        x=future_slice.index,
        y=forecast_prices,
        mode='lines',
        name='Forecasted Price',
        line=dict(color='red', dash='dash')
    ))
    
    # FIXED: Convert the data to numpy arrays before numeric operations
    forecast_date = future_slice.index[0]
    
    # Convert to numpy arrays first - FIXED
    historical_prices = np.array(historical_slice['Close'].values)
    future_prices = np.array(future_slice['Close'].values)
    
    # Combine arrays using numpy concatenate
    combined_prices = np.concatenate([historical_prices, future_prices])
    
    # Now find min and max
    y_min = float(np.min(combined_prices)) * 0.99  # 1% margin below
    y_max = float(np.max(combined_prices)) * 1.01  # 1% margin above
    
    # Add vertical line as a scatter trace
    fig3.add_trace(go.Scatter(
        x=[forecast_date, forecast_date],
        y=[y_min, y_max],
        mode='lines',
        line=dict(color='gray', width=2, dash='dash'),
        showlegend=False
    ))
    
    # Add annotation for forecast start
    fig3.add_annotation(
        x=forecast_date,
        y=y_max,
        text="Forecast Start",
        showarrow=False,
        yshift=10
    )
    
    # Update layout
    fig3.update_layout(
        title=f'{ticker} Price Forecast (Demo)',
        xaxis_title='Date',
        yaxis_title='Price ($)',
        hovermode='x unified',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ),
        template='plotly_white'
    )
    
    # Add range slider
    fig3.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    figures['forecast'] = fig3
    
    # 4. Returns distribution
    # Create a simple distribution plot for returns
    
    # Get recent returns 
    recent_returns = stock_data['Returns'].iloc[-100:].values
    
    # Create simulated predicted returns with slightly different distribution
    np.random.seed(42)
    predicted_returns = recent_returns + np.random.normal(0, 0.001, size=len(recent_returns))
    
    # Create histograms
    fig4 = go.Figure()
    
    fig4.add_trace(go.Histogram(
        x=recent_returns,
        name="Actual Returns",
        marker_color='blue',
        opacity=0.7,
        nbinsx=50
    ))
    
    fig4.add_trace(go.Histogram(
        x=predicted_returns,
        name="Predicted Returns",
        marker_color='red',
        opacity=0.7,
        nbinsx=50
    ))
    
    # Update layout
    fig4.update_layout(
        title_text=f"{ticker} - Distribution of Returns",
        xaxis_title_text="Return",
        yaxis_title_text="Count",
        bargap=0.1,
        hovermode='x unified',
        template='plotly_white',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
    )
    
    figures['returns_dist'] = fig4
    
    # 5. Trading simulation
    # Create a simplified trading simulation
    
    # Simple simulated trading performance
    initial_value = 10000
    portfolio_values = [initial_value]
    buy_hold_values = [initial_value]
    
    # Get some returns for simulation
    simulation_returns = stock_data['Returns'].iloc[-100:].values
    
    # Simulate a more realistic trading strategy
    for i in range(len(simulation_returns)):
        # Portfolio with strategy (slightly better than buy & hold for demonstration)
        portfolio_values.append(portfolio_values[-1] * (1 + simulation_returns[i] * 1.2))
        
        # Buy & hold
        buy_hold_values.append(buy_hold_values[-1] * (1 + simulation_returns[i]))
    
    # Create dates for the portfolio values
    portfolio_dates = stock_data.index[-100-1:]  # Extra day for initial value
    
    # Calculate drawdowns
    def calculate_drawdowns(values):
        peaks = np.maximum.accumulate(values)
        drawdowns = (peaks - values) / peaks * 100
        return drawdowns
    
    strategy_drawdowns = calculate_drawdowns(portfolio_values)
    bh_drawdowns = calculate_drawdowns(buy_hold_values)
    
    # Performance metrics
    strategy_return = (portfolio_values[-1] / initial_value - 1) * 100
    buy_hold_return = (buy_hold_values[-1] / initial_value - 1) * 100
    
    # Create drawdown subplot
    fig5 = make_subplots(rows=2, cols=1, shared_xaxes=True, 
                        vertical_spacing=0.05,
                        subplot_titles=("Portfolio Value ($)", "Drawdown (%)"),
                        row_heights=[0.7, 0.3])
    
    # Add traces to first subplot (portfolio values)
    fig5.add_trace(go.Scatter(
        x=portfolio_dates, 
        y=portfolio_values,
        mode='lines',
        name='KAN Trading Strategy',
        line=dict(color='blue', width=2)
    ), row=1, col=1)
    
    fig5.add_trace(go.Scatter(
        x=portfolio_dates, 
        y=buy_hold_values,
        mode='lines',
        name='Buy & Hold Strategy',
        line=dict(color='red', width=2, dash='dash')
    ), row=1, col=1)
    
    # Add traces to second subplot (drawdowns)
    fig5.add_trace(go.Scatter(
        x=portfolio_dates, 
        y=strategy_drawdowns,
        mode='lines',
        name='KAN Strategy Drawdown',
        line=dict(color='blue', width=1)
    ), row=2, col=1)
    
    fig5.add_trace(go.Scatter(
        x=portfolio_dates, 
        y=bh_drawdowns,
        mode='lines',
        name='Buy & Hold Drawdown',
        line=dict(color='red', width=1, dash='dash')
    ), row=2, col=1)
    
    # Update layout
    fig5.update_layout(
        title_text=f"{ticker} - Trading Simulation Results (Demo)",
        hovermode='x unified',
        template='plotly_white',
        height=700,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Add annotations for key metrics
    metrics_text = (
        f"Strategy: Return={strategy_return:.2f}%, "
        f"Buy & Hold: Return={buy_hold_return:.2f}%"
    )
    
    fig5.add_annotation(
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        text=metrics_text,
        showarrow=False,
        font=dict(size=12),
        align="center",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="black",
        borderwidth=1,
        borderpad=4
    )
    
    # Add range slider
    fig5.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1, label="1m", step="month", stepmode="backward"),
                    dict(count=3, label="3m", step="month", stepmode="backward"),
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(visible=True),
            type="date"
        )
    )
    
    figures['trading'] = fig5
    
    return figures


def run_full_analysis(num_stocks=2, num_epochs=20):
    """Run the full stock forecasting analysis."""
    print(f"Running full analysis for top {num_stocks} S&P 500 stocks...")
    
    # Set plotting renderer to a renderer that works in notebook environments
    import plotly.io as pio
    # Instead of using 'browser', use a renderer compatible with your environment
    # Options: 'notebook', 'jupyterlab', 'colab', 'iframe', 'png'
    pio.renderers.default = 'notebook'  # Try this for Jupyter notebook
    
    # Run forecasting
    results = forecast_top_sp500_stocks(
        num_stocks=num_stocks,
        lookback_window=20,
        forecast_horizon=5,
        num_epochs=num_epochs
    )
    
    # Check if we're in a notebook environment
    try:
        from IPython import get_ipython
        is_notebook = get_ipython() is not None
    except:
        is_notebook = False
    
    # Display all plots if in a notebook environment
    for ticker, ticker_results in results.items():
        if ticker != 'comparison':
            print(f"\nDisplaying plots for {ticker}...")
            
            # Display each figure
            for name, fig in ticker_results.get('figures', {}).items():
                try:
                    if is_notebook:
                        # Display in notebook
                        display(fig)
                    else:
                        # Save to file instead of showing
                        fig.write_html(f"{ticker}_{name}.html")
                        print(f"Saved {ticker}_{name}.html")
                except Exception as e:
                    print(f"Error displaying figure for {ticker}_{name}: {e}")
                    # Save as a fallback
                    try:
                        fig.write_image(f"{ticker}_{name}.png")
                        print(f"Saved {ticker}_{name}.png instead")
                    except:
                        print(f"Could not save image for {ticker}_{name}")
    
    # Display comparison plots if available
    if 'comparison' in results:
        print("\nDisplaying comparison plots...")
        for name, fig in results['comparison'].items():
            try:
                if is_notebook:
                    display(fig)
                else:
                    fig.write_html(f"comparison_{name}.html")
                    print(f"Saved comparison_{name}.html")
            except Exception as e:
                print(f"Error displaying comparison figure for {name}: {e}")
                try:
                    fig.write_image(f"comparison_{name}.png")
                    print(f"Saved comparison_{name}.png instead")
                except:
                    print(f"Could not save image for comparison_{name}")
    
    print("\nAnalysis complete!")
    return results

if __name__ == "__main__":
    # Uncomment one of the following options:
    
    # Option 1: Run the demo (quick, no model training)
    # run_demo()
    
    # Option 2: Run full analysis for top 2 stocks with 20 epochs (moderate runtime)
    # run_full_analysis(num_stocks=2, num_epochs=20)
    
    # Option 3: Run full analysis for top 5 stocks with 50 epochs (longer runtime)
    run_full_analysis(num_stocks=5, num_epochs=50)