In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import numpy as np
import optax
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Callable, Any
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from datetime import datetime, timedelta
import yfinance as yf
from hmmlearn.hmm import GaussianHMM
from sklearn.metrics import silhouette_score
import seaborn as sns
from sklearn.decomposition import PCA

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# KAN Layer implementation for market regime detection
class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 30, 
                 domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions."""
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values with different shapes
        activations_list = []
        for i in range(output_dim):
            subkey = jax.random.fold_in(key3, i)
            init_type = jax.random.randint(subkey, (), 0, 4)
            
            if init_type == 0:  # Linear-like
                act = self.grid_points
            elif init_type == 1:  # ReLU-like
                act = jnp.maximum(0, self.grid_points)
            elif init_type == 2:  # Sigmoid-like
                act = 1.0 / (1.0 + jnp.exp(-self.grid_points))
            else:  # Tanh-like
                act = jnp.tanh(self.grid_points)
            
            # Add noise to break symmetry
            act = act + jax.random.normal(subkey, (num_basis,)) * 0.05
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer."""
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions by interpolation
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i using linear interpolation."""
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            x0 = self.grid_points[idx]
            x1 = self.grid_points[idx + 1]
            y0 = self.activations[i, idx]
            y1 = self.activations[i, idx + 1]
            
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z_clipped[:, i]))
        
        return output

# Regime Detection KAN model
class RegimeDetectionKAN:
    def __init__(self, input_dim: int, latent_dim: int = 5, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 30, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for market regime detection."""
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize encoder layers
        self.encoder_layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.encoder_layers.append(layer)
            prev_dim = hidden_dim
        
        # Latent space representation layer (embedding space for regime detection)
        self.latent_layer = KANLayer(prev_dim, latent_dim, num_basis, domain, keys[-1])
        
        # Store latent dimension for later use
        self.latent_dim = latent_dim
    
    def encode(self, x: jnp.ndarray) -> jnp.ndarray:
        """Encode inputs to latent space representation."""
        for layer in self.encoder_layers:
            x = layer(x)
        
        # Final encoding to latent space
        latent = self.latent_layer(x)
        
        return latent
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the model."""
        latent = self.encode(x)
        return latent
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        for i, layer in enumerate(self.encoder_layers):
            params[f'encoder_layer_{i}_weights'] = layer.weights
            params[f'encoder_layer_{i}_biases'] = layer.biases
            params[f'encoder_layer_{i}_activations'] = layer.activations
        
        params['latent_layer_weights'] = self.latent_layer.weights
        params['latent_layer_biases'] = self.latent_layer.biases
        params['latent_layer_activations'] = self.latent_layer.activations
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.encoder_layers):
            layer.weights = params[f'encoder_layer_{i}_weights']
            layer.biases = params[f'encoder_layer_{i}_biases']
            layer.activations = params[f'encoder_layer_{i}_activations']
        
        self.latent_layer.weights = params['latent_layer_weights']
        self.latent_layer.biases = params['latent_layer_biases']
        self.latent_layer.activations = params['latent_layer_activations']

# Contrastive loss for unsupervised regime detection
def contrastive_loss(latent_representations, temperature=0.1):
    """Contrastive loss for learning meaningful latent representations."""
    # Normalize latent representations
    latent_norm = latent_representations / jnp.sqrt(jnp.sum(latent_representations**2, axis=1, keepdims=True) + 1e-8)
    
    # Compute similarity matrix
    similarity = jnp.matmul(latent_norm, latent_norm.T) / temperature
    
    # Create labels (neighboring points are considered positive pairs)
    batch_size = latent_representations.shape[0]
    
    # Define positive pairs as the neighboring points (e.g., t and t+1)
    # Assumption: data is in temporal order
    positive_mask = jnp.zeros((batch_size, batch_size))
    
    # Diagonal elements are excluded (self-similarity)
    pos_indices = jnp.arange(batch_size - 1)
    positive_mask = positive_mask.at[pos_indices, pos_indices + 1].set(1.0)
    positive_mask = positive_mask + positive_mask.T  # Symmetrize
    
    # All other pairs are considered negative
    negative_mask = 1.0 - positive_mask - jnp.eye(batch_size)
    
    # Compute loss for positive pairs
    pos_similarity = jnp.sum(similarity * positive_mask, axis=1) / (jnp.sum(positive_mask, axis=1) + 1e-8)
    
    # Compute loss for negative pairs (using log-sum-exp trick for numerical stability)
    neg_similarity = jnp.log(jnp.sum(jnp.exp(similarity) * negative_mask, axis=1) + 1e-8)
    
    # Contrastive loss
    loss = -jnp.mean(pos_similarity - neg_similarity)
    
    return loss

# Temporal smoothness loss to encourage smooth regime transitions
def temporal_smoothness_loss(latent_representations, lambda_smooth=1.0):
    """Temporal smoothness loss to encourage smooth regime transitions."""
    # Calculate differences between consecutive latent representations
    diffs = latent_representations[1:] - latent_representations[:-1]
    
    # Squared L2 norm of differences
    squared_diffs = jnp.sum(diffs**2, axis=1)
    
    # Mean squared difference
    smoothness_loss = jnp.mean(squared_diffs)
    
    return lambda_smooth * smoothness_loss

# Helper function to extract values from a parameter dictionary
def extract_params(params, layer_idx, param_type):
    """Extract parameters with static indexing to avoid JIT string issues."""
    key = f'encoder_layer_{layer_idx}_{param_type}'
    return params[key]

# JIT-compatible model application function
def apply_model_jit(params, x, num_encoder_layers):
    """JIT-compatible version of apply_model that uses fixed layer structure."""
    grid_points = jnp.linspace(-3.0, 3.0, 30)  # Assuming fixed num_basis=30
    
    # Apply each encoder layer
    for i in range(num_encoder_layers):
        # Static parameter access patterns - this key construction happens outside JIT
        weights_key = f'encoder_layer_{i}_weights'
        biases_key = f'encoder_layer_{i}_biases' 
        activations_key = f'encoder_layer_{i}_activations'
        
        weights = params[weights_key]
        biases = params[biases_key]
        activations = params[activations_key]
        
        # Linear transformation
        x = jnp.dot(x, weights) + biases
        
        # Apply activations through vectorized interpolation
        x_clipped = jnp.clip(x, -3.0, 3.0)
        
        # Process each output dimension
        output_cols = []
        for j in range(x.shape[1]):
            # Find indices for interpolation
            idx = jnp.searchsorted(grid_points, x_clipped[:, j]) - 1
            idx = jnp.clip(idx, 0, len(grid_points) - 2)
            
            # Get grid points and activation values
            x0 = grid_points[idx]
            x1 = grid_points[idx + 1]
            y0 = activations[j, idx]
            y1 = activations[j, idx + 1]
            
            # Linear interpolation
            t = (x_clipped[:, j] - x0) / (x1 - x0)
            output_cols.append(y0 + t * (y1 - y0))
        
        x = jnp.column_stack(output_cols)
    
    # Apply latent layer
    weights = params['latent_layer_weights']
    biases = params['latent_layer_biases']
    activations = params['latent_layer_activations']
    
    # Linear transformation
    x = jnp.dot(x, weights) + biases
    
    # Apply activations
    x_clipped = jnp.clip(x, -3.0, 3.0)
    
    # Process each output dimension
    output_cols = []
    for j in range(x.shape[1]):
        # Find indices for interpolation
        idx = jnp.searchsorted(grid_points, x_clipped[:, j]) - 1
        idx = jnp.clip(idx, 0, len(grid_points) - 2)
        
        # Get grid points and activation values
        x0 = grid_points[idx]
        x1 = grid_points[idx + 1]
        y0 = activations[j, idx]
        y1 = activations[j, idx + 1]
        
        # Linear interpolation
        t = (x_clipped[:, j] - x0) / (x1 - x0)
        output_cols.append(y0 + t * (y1 - y0))
    
    return jnp.column_stack(output_cols)

# Combined loss function for JIT compatibility
@partial(jit, static_argnums=(2,))
def regime_loss_fn_jit(params, X, num_encoder_layers, lambda_smooth=1.0, lambda_reg=0.001):
    """JIT-compatible loss function with static argument for num_encoder_layers."""
    # Apply the model to get latent representations
    latent = apply_model_jit(params, X, num_encoder_layers)
    
    # Contrastive loss
    contrast_loss = contrastive_loss(latent)
    
    # Temporal smoothness loss
    smoothness_loss = temporal_smoothness_loss(latent, lambda_smooth)
    
    # Activation smoothness regularization
    activation_reg = 0.0
    for i in range(num_encoder_layers):
        activations_key = f'encoder_layer_{i}_activations'
        activations = params[activations_key]
        # Calculate second derivatives (approximation)
        second_deriv = activations[:, 2:] - 2 * activations[:, 1:-1] + activations[:, :-2]
        activation_reg += jnp.mean(second_deriv ** 2)
    
    # Add latent layer activations
    activations = params['latent_layer_activations']
    second_deriv = activations[:, 2:] - 2 * activations[:, 1:-1] + activations[:, :-2]
    activation_reg += jnp.mean(second_deriv ** 2)
    
    # Combined loss
    total_loss = contrast_loss + smoothness_loss + lambda_reg * activation_reg
    
    return total_loss

# Training step with JIT compilation
@partial(jit, static_argnums=(2,))
def train_step_jit(params, X, num_encoder_layers, opt_state, lambda_smooth=1.0, lambda_reg=0.001):
    """JIT-compatible training step with static argument for structure."""
    # Create a partial function with fixed num_encoder_layers
    loss_fn = lambda p: regime_loss_fn_jit(p, X, num_encoder_layers, lambda_smooth, lambda_reg)
    loss_value, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

# Modified training function
def train_regime_model(model, X_train, num_epochs=100, batch_size=64, lambda_smooth=1.0, lambda_reg=0.001):
    """Train the regime detection model."""
    params = model.params
    num_encoder_layers = len(model.encoder_layers)
    
    num_samples = X_train.shape[0]
    num_batches = num_samples // batch_size
    
    # Initialize optimizer
    global optimizer
    optimizer = optax.adam(learning_rate=0.001)
    
    opt_state = optimizer.init(params)
    
    losses = []
    
    for epoch in range(num_epochs):
        # For regime detection, we maintain temporal order within batches
        # But shuffle the starting points of each batch
        starts = np.random.permutation(num_batches) * batch_size
        starts = np.clip(starts, 0, num_samples - batch_size)
        
        epoch_loss = 0.0
        
        for start_idx in starts:
            end_idx = start_idx + batch_size
            X_batch = X_train[start_idx:end_idx]
            
            # Use JIT-compatible training step with num_encoder_layers as static argument
            params, opt_state, batch_loss = train_step_jit(
                params, X_batch, num_encoder_layers, opt_state, lambda_smooth, lambda_reg
            )
            
            epoch_loss += batch_loss
        
        epoch_loss /= len(starts)
        losses.append(epoch_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
    
    # Update model with trained parameters
    model.update_params(params)
    return params, losses

# Detect regimes using the trained model
def detect_regimes(model, params, X, num_regimes=4):
    """Detect market regimes using the trained model."""
    model.update_params(params)
    latent = model(X)
    
    # Convert to numpy for clustering
    latent_np = np.array(latent)
    
    # Apply K-means clustering
    kmeans = KMeans(n_clusters=num_regimes, random_state=42, n_init=10)
    detected_regimes = kmeans.fit_predict(latent_np)
    
    # Apply temporal smoothing to regime labels
    # This helps avoid rapid switching between regimes
    window_size = 5
    smoothed_regimes = np.copy(detected_regimes)
    
    for i in range(len(detected_regimes)):
        start = max(0, i - window_size // 2)
        end = min(len(detected_regimes), i + window_size // 2 + 1)
        window = detected_regimes[start:end]
        # Get most common regime in window
        smoothed_regimes[i] = np.bincount(window).argmax()
    
    return smoothed_regimes, latent_np

# Download and prepare financial data
def download_market_data(tickers, start_date, end_date, feature_window=20):
    """
    Download market data for a list of tickers and prepare features.
    
    Args:
        tickers: List of ticker symbols
        start_date: Start date for data download
        end_date: End date for data download
        feature_window: Window size for creating features
        
    Returns:
        X: Features array
        dates: Dates corresponding to features
        returns_dict: Dictionary of returns for each ticker
    """
    # Download data
    print(f"Downloading data for {len(tickers)} tickers...")
    data = {}
    returns_dict = {}
    
    for ticker in tickers:
        try:
            # Download data
            df = yf.download(ticker, start=start_date, end=end_date, progress=False)
            
            if len(df) == 0:
                print(f"No data found for {ticker}, skipping...")
                continue
                
            # Calculate returns
            df['Return'] = df['Close'].pct_change()
            df['Log_Return'] = np.log(df['Close'] / df['Close'].shift(1))
            
            # Calculate volatility
            df['Volatility'] = df['Log_Return'].rolling(window=20).std()
            
            # Create additional features
            df['RSI'] = calculate_rsi(df['Close'])
            df['MA_10'] = df['Close'].rolling(window=10).mean()
            df['MA_50'] = df['Close'].rolling(window=50).mean()
            df['MA_Ratio'] = df['MA_10'] / df['MA_50']
            
            # Store data
            data[ticker] = df
            returns_dict[ticker] = df['Return']
            
            print(f"Downloaded {len(df)} days of data for {ticker}")
            
        except Exception as e:
            print(f"Error downloading data for {ticker}: {e}")
    
    # Check if we have any data
    if not data:
        raise ValueError("No data could be downloaded for any of the provided tickers. Please check the ticker symbols or date range.")
    
    # Align dates across all tickers
    common_dates = None
    
    for ticker in data:
        if common_dates is None:
            common_dates = set(data[ticker].index)
        else:
            common_dates = common_dates.intersection(set(data[ticker].index))
    
    # Check if we have any common dates
    if not common_dates:
        raise ValueError("No common dates found across all tickers. Try using a different date range or set of tickers.")
        
    common_dates = sorted(list(common_dates))
    print(f"Found {len(common_dates)} common dates across all tickers")
    
    # Create features for each ticker
    feature_dfs = []
    
    for ticker in data:
        df = data[ticker].loc[common_dates].copy()
        # Keep only the features we want
        features = df[['Log_Return', 'Volatility', 'RSI', 'MA_Ratio']].copy()
        # Add ticker prefix to column names
        features.columns = [f"{ticker}_{col}" for col in features.columns]
        feature_dfs.append(features)
    
    # Combine features across tickers
    combined_features = pd.concat(feature_dfs, axis=1)
    combined_features = combined_features.dropna()
    
    # Check if we have enough data for the feature window
    if len(combined_features) <= feature_window:
        raise ValueError(f"Not enough data ({len(combined_features)} days) for the specified feature window ({feature_window} days). Try using a longer date range.")
    
    # Create rolling window features
    X = []
    dates = []
    
    for i in range(feature_window, len(combined_features)):
        # Use a window of market data as features
        x_t = combined_features.iloc[i-feature_window:i].values.flatten()
        X.append(x_t)
        dates.append(combined_features.index[i])
    
    # Check if we have any data points
    if not X:
        raise ValueError("No data points could be created. Check for NaN values in your data or try a different date range.")
        
    X = np.array(X)
    
    # Standardize features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, dates, returns_dict

# Calculate RSI
def calculate_rsi(prices, window=14):
    """Calculate Relative Strength Index (RSI)."""
    # Calculate price changes
    delta = prices.diff()
    
    # Create gain and loss series
    gain = delta.clip(lower=0)
    loss = -delta.clip(upper=0)
    
    # Calculate average gain and loss
    avg_gain = gain.rolling(window=window, min_periods=1).mean()
    avg_loss = loss.rolling(window=window, min_periods=1).mean()
    
    # Calculate RS
    rs = avg_gain / avg_loss
    
    # Calculate RSI
    rsi = 100 - (100 / (1 + rs))
    
    return rsi

# Visualize regimes
def visualize_regimes(latent_representations, detected_regimes, dates):
    """Visualize detected regimes in the latent space."""
    # Apply t-SNE for dimensionality reduction
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(latent_representations)-1))
    latent_2d = tsne.fit_transform(latent_representations)
    
    # Create a figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot detected regimes
    scatter = ax.scatter(latent_2d[:, 0], latent_2d[:, 1], c=detected_regimes, cmap='viridis', alpha=0.7, s=50)
    ax.set_title('Detected Market Regimes (KAN)', fontsize=15)
    ax.set_xlabel('t-SNE dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE dimension 2', fontsize=12)
    plt.colorbar(scatter, ax=ax, label='Regime')
    
    # Add a legend
    unique_regimes = np.unique(detected_regimes)
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=plt.cm.viridis(i/len(unique_regimes)), 
                                 markersize=10, label=f'Regime {i}') for i in unique_regimes]
    ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    return fig

# Plot regime transitions
def plot_regime_transitions(detected_regimes, dates, returns_dict, ticker_to_plot=None):
    """Plot regime transitions over time with market returns in the background."""
    # Ensure dates is a list
    if not isinstance(dates, list):
        dates = list(dates)
    
    # Create a pandas Series for the regimes
    regime_series = pd.Series(detected_regimes, index=dates)
    
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [1, 3]})
    
    # Plot detected regimes
    ax1.plot(regime_series.index, regime_series.values, 'k-', linewidth=2)
    ax1.set_title('Detected Market Regimes (KAN)', fontsize=15)
    ax1.set_ylabel('Regime', fontsize=12)
    ax1.grid(True, alpha=0.3)
    
    # Set y-axis to show integer values only
    ax1.set_yticks(np.unique(detected_regimes))
    
    # Plot returns for the specified ticker or the first ticker in the dict
    if ticker_to_plot is None or ticker_to_plot not in returns_dict:
        ticker_to_plot = list(returns_dict.keys())[0]
    
    # Get returns for the ticker
    returns = returns_dict[ticker_to_plot]
    
    # Calculate cumulative returns
    cum_returns = (1 + returns).cumprod()
    
    # Plot cumulative returns
    ax2.plot(cum_returns.index, cum_returns.values, 'b-', linewidth=1.5)
    ax2.set_title(f'Cumulative Returns for {ticker_to_plot}', fontsize=15)
    ax2.set_xlabel('Date', fontsize=12)
    ax2.set_ylabel('Cumulative Return', fontsize=12)
    ax2.grid(True, alpha=0.3)
    
    # Mark regime changes with vertical lines
    regime_changes = np.where(np.diff(detected_regimes) != 0)[0]
    for change_idx in regime_changes:
        if change_idx < len(dates) - 1:  # Ensure we don't go out of bounds
            change_date = dates[change_idx + 1]
            ax2.axvline(change_date, color='r', linestyle='--', alpha=0.5)
            ax1.axvline(change_date, color='r', linestyle='--', alpha=0.5)
    
    # Add colored backgrounds for different regimes
    unique_regimes = np.unique(detected_regimes)
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_regimes)))
    
    prev_date = dates[0]
    prev_regime = detected_regimes[0]
    
    for i in range(1, len(dates)):
        if detected_regimes[i] != prev_regime or i == len(dates) - 1:
            # Add colored background for this regime period
            ax2.axvspan(prev_date, dates[i-1], alpha=0.2, color=colors[prev_regime])
            # Update for next period
            prev_date = dates[i]
            prev_regime = detected_regimes[i]
    
    # Add the last period
    ax2.axvspan(prev_date, dates[-1], alpha=0.2, color=colors[prev_regime])
    
    # Add a legend for the regimes
    legend_elements = [plt.Line2D([0], [0], color=colors[i], lw=8, alpha=0.5, 
                                  label=f'Regime {i}') for i in unique_regimes]
    ax2.legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    return fig

# Analyze regime characteristics
def analyze_regime_characteristics(detected_regimes, dates, returns_dict):
    """Analyze characteristics of each detected regime."""
    # Create a pandas Series for the regimes
    regime_series = pd.Series(detected_regimes, index=dates)
    
    # Initialize a dictionary to store regime characteristics
    regime_stats = {}
    
    # Get unique regimes
    unique_regimes = np.unique(detected_regimes)
    
    # Create a DataFrame for easier analysis
    data = pd.DataFrame(index=dates)
    
    # Add ticker returns to the DataFrame
    for ticker in returns_dict:
        # Filter returns for the dates we have
        ticker_returns = returns_dict[ticker].reindex(dates)
        data[ticker] = ticker_returns
    
    # Calculate statistics for each regime
    for regime in unique_regimes:
        regime_dates = regime_series[regime_series == regime].index
        regime_data = data.loc[regime_dates]
        
        # Calculate statistics
        mean_returns = regime_data.mean()
        volatility = regime_data.std()
        sharpe = mean_returns / volatility
        correlation = regime_data.corr()
        
        # Store statistics
        regime_stats[regime] = {
            'dates': regime_dates,
            'num_periods': len(regime_dates),
            'mean_returns': mean_returns,
            'volatility': volatility,
            'sharpe': sharpe,
            'correlation': correlation,
            'start_date': regime_dates[0] if len(regime_dates) > 0 else None,
            'end_date': regime_dates[-1] if len(regime_dates) > 0 else None
        }
    
    return regime_stats

# Visualize regime characteristics
def visualize_regime_characteristics(regime_stats, tickers):
    """Visualize statistical characteristics of each regime."""
    # Number of regimes
    num_regimes = len(regime_stats)
    
    if num_regimes == 0:
        print("No regime statistics available to visualize")
        return None
    
    # Create figure with multiple subplots - one row per regime
    fig, axs = plt.subplots(num_regimes, 3, figsize=(16, 5 * num_regimes))
    
    if num_regimes == 1:
        axs = axs.reshape(1, -1)
    
    for i, regime in enumerate(sorted(regime_stats.keys())):
        stats = regime_stats[regime]
        
        # Plot mean returns
        axs[i, 0].bar(tickers, [stats['mean_returns'][ticker] for ticker in tickers])
        axs[i, 0].set_title(f'Regime {regime}: Mean Daily Returns')
        axs[i, 0].set_xlabel('Ticker')
        axs[i, 0].set_ylabel('Mean Return')
        axs[i, 0].grid(True, alpha=0.3)
        axs[i, 0].set_xticklabels(tickers, rotation=45)
        
        # Add a horizontal line at y=0
        axs[i, 0].axhline(y=0, color='r', linestyle='-', alpha=0.5)
        
        # Add period information as text
        date_info = f"Period: {stats['start_date'].strftime('%Y-%m-%d')} to {stats['end_date'].strftime('%Y-%m-%d')}"
        axs[i, 0].text(0.5, 0.98, date_info, transform=axs[i, 0].transAxes, 
                     ha='center', va='top', bbox=dict(boxstyle='round', alpha=0.1))
        
        # Plot volatility
        axs[i, 1].bar(tickers, [stats['volatility'][ticker] for ticker in tickers])
        axs[i, 1].set_title(f'Regime {regime}: Volatility')
        axs[i, 1].set_xlabel('Ticker')
        axs[i, 1].set_ylabel('Volatility')
        axs[i, 1].grid(True, alpha=0.3)
        axs[i, 1].set_xticklabels(tickers, rotation=45)
        
        # Plot correlation matrix
        corr_matrix = stats['correlation'].loc[tickers, tickers]
        im = axs[i, 2].imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
        axs[i, 2].set_title(f'Regime {regime}: Correlation Matrix')
        axs[i, 2].set_xticks(np.arange(len(tickers)))
        axs[i, 2].set_yticks(np.arange(len(tickers)))
        axs[i, 2].set_xticklabels(tickers, rotation=45)
        axs[i, 2].set_yticklabels(tickers)
        plt.colorbar(im, ax=axs[i, 2])
        
        # Annotate the correlation matrix with the values
        for ii in range(len(tickers)):
            for jj in range(len(tickers)):
                axs[i, 2].text(jj, ii, f"{corr_matrix.iloc[ii, jj]:.2f}", 
                             ha="center", va="center", color="black" if abs(corr_matrix.iloc[ii, jj]) < 0.7 else "white")
    
    plt.tight_layout()
    return fig

# Compare with HMM-based regime detection
def detect_regimes_with_hmm(X, num_regimes=4):
    """Detect market regimes using Hidden Markov Model."""
    # Initialize the HMM
    hmm = GaussianHMM(n_components=num_regimes, covariance_type="full", random_state=42, n_iter=100)
    
    # Reduce dimensionality first (HMM can struggle with high-dimensional data)
    if X.shape[1] > 20:
        pca = PCA(n_components=min(20, X.shape[1]))
        X_reduced = pca.fit_transform(X)
    else:
        X_reduced = X
        
    # Fit the HMM
    hmm.fit(X_reduced)
    
    # Predict hidden states
    hmm_regimes = hmm.predict(X_reduced)
    
    # Apply temporal smoothing to regime labels
    window_size = 5
    smoothed_regimes = np.copy(hmm_regimes)
    
    for i in range(len(hmm_regimes)):
        start = max(0, i - window_size // 2)
        end = min(len(hmm_regimes), i + window_size // 2 + 1)
        window = hmm_regimes[start:end]
        # Get most common regime in window
        smoothed_regimes[i] = np.bincount(window).argmax()
    
    return smoothed_regimes

# Compare different regime detection methods
def compare_regime_methods(X, dates, returns_dict, ticker_to_plot=None):
    """Compare KAN vs HMM regime detection methods."""
    # Detect regimes using KAN (assuming model and params are already available)
    print("Initializing the KAN model for regime detection...")
    input_dim = X.shape[1]
    latent_dim = 5
    model = RegimeDetectionKAN(input_dim, latent_dim, hidden_dims=[64, 32])
    
    # Train the model
    print("Training the KAN model...")
    trained_params, losses = train_regime_model(
        model, X, num_epochs=100, lambda_smooth=1.0, lambda_reg=0.001)
    
    # Detect regimes using KAN
    print("Detecting regimes with KAN...")
    kan_regimes, latent = detect_regimes(model, trained_params, X, num_regimes=4)
    
    # Detect regimes using HMM
    print("Detecting regimes with HMM...")
    hmm_regimes = detect_regimes_with_hmm(X, num_regimes=4)
    
    # Create pandas Series for both methods
    kan_series = pd.Series(kan_regimes, index=dates)
    hmm_series = pd.Series(hmm_regimes, index=dates)
    
    # Choose a ticker to plot
    if ticker_to_plot is None or ticker_to_plot not in returns_dict:
        ticker_to_plot = list(returns_dict.keys())[0]
    
    # Get returns for the ticker
    returns = returns_dict[ticker_to_plot]
    
    # Calculate cumulative returns
    cum_returns = (1 + returns).cumprod()
    
    # Create figure with three subplots
    fig, axs = plt.subplots(3, 1, figsize=(16, 12), sharex=True, gridspec_kw={'height_ratios': [1, 1, 3]})
    
    # Plot KAN regimes
    axs[0].plot(kan_series.index, kan_series.values, 'b-', linewidth=2)
    axs[0].set_title('KAN Detected Regimes', fontsize=15)
    axs[0].set_ylabel('Regime', fontsize=12)
    axs[0].grid(True, alpha=0.3)
    axs[0].set_yticks(np.unique(kan_regimes))
    
    # Plot HMM regimes
    axs[1].plot(hmm_series.index, hmm_series.values, 'g-', linewidth=2)
    axs[1].set_title('HMM Detected Regimes', fontsize=15)
    axs[1].set_ylabel('Regime', fontsize=12)
    axs[1].grid(True, alpha=0.3)
    axs[1].set_yticks(np.unique(hmm_regimes))
    
    # Plot cumulative returns
    axs[2].plot(cum_returns.index, cum_returns.values, 'k-', linewidth=1.5)
    axs[2].set_title(f'Cumulative Returns for {ticker_to_plot}', fontsize=15)
    axs[2].set_xlabel('Date', fontsize=12)
    axs[2].set_ylabel('Cumulative Return', fontsize=12)
    axs[2].grid(True, alpha=0.3)
    
    # Add colored backgrounds for KAN regimes
    unique_kan_regimes = np.unique(kan_regimes)
    kan_colors = plt.cm.viridis(np.linspace(0, 1, len(unique_kan_regimes)))
    
    # Create a proper date-indexed Series
    kan_filtered = kan_series.loc[kan_series.index.intersection(cum_returns.index)]
    
    # Process KAN regime periods
    regime_changes = kan_filtered.diff().abs() > 0
    change_points = kan_filtered.index[regime_changes]
    
    # Add start and end points
    all_points = [kan_filtered.index[0]] + list(change_points) + [kan_filtered.index[-1]]
    
    # Color backgrounds for each period
    for i in range(len(all_points) - 1):
        start_date = all_points[i]
        end_date = all_points[i+1]
        regime = kan_filtered.loc[start_date]
        axs[2].axvspan(start_date, end_date, alpha=0.2, color=kan_colors[regime])
    
    # Add a legend for the KAN regimes
    kan_legend = [plt.Line2D([0], [0], color=kan_colors[i], lw=8, alpha=0.5, 
                          label=f'KAN Regime {i}') for i in unique_kan_regimes]
    
    # Add vertical lines for major market events
    major_events = {
        # Add relevant dates here, for example:
        # '2020-03-23': 'COVID-19 Market Bottom',
        # '2022-01-03': '2022 Market Peak',
        # Customize based on your data period
    }
    
    for date_str, event in major_events.items():
        try:
            event_date = pd.to_datetime(date_str)
            if event_date in cum_returns.index:
                for ax in axs:
                    ax.axvline(event_date, color='r', linestyle='--', alpha=0.7)
                    ax.text(event_date, ax.get_ylim()[1]*0.95, event, rotation=90, 
                           va='top', ha='right')
        except:
            pass
    
    # Add the legend
    axs[2].legend(handles=kan_legend, loc='lower right')
    
    plt.tight_layout()
    return fig, model, trained_params, kan_regimes, hmm_regimes

# Implement a simple trading strategy based on detected regimes
def regime_based_trading_strategy(detected_regimes, dates, returns_dict):
    """Implement a simple trading strategy based on detected regimes."""
    # Create a pandas Series for the regimes
    regime_series = pd.Series(detected_regimes, index=dates)
    
    # Initialize a dictionary to store strategy results
    strategy_results = {}
    
    # For each ticker, implement a strategy
    for ticker in returns_dict:
        # Get returns for this ticker
        ticker_returns = returns_dict[ticker].reindex(regime_series.index)
        
        # Calculate strategy for each regime
        regime_stats = {}
        for regime in np.unique(detected_regimes):
            regime_returns = ticker_returns[regime_series == regime]
            regime_stats[regime] = {
                'mean_return': regime_returns.mean(),
                'volatility': regime_returns.std(),
                'sharpe': regime_returns.mean() / regime_returns.std() if regime_returns.std() > 0 else 0
            }
        
        # Determine the strategy: long in regimes with positive expected returns, short in regimes with negative returns
        positions = np.zeros(len(regime_series))
        
        for i, (date, regime) in enumerate(regime_series.items()):
            if regime_stats[regime]['mean_return'] > 0:
                positions[i] = 1  # Long
            elif regime_stats[regime]['mean_return'] < 0:
                positions[i] = -1  # Short
            # else: 0 = cash
        
        # Calculate strategy returns
        strategy_returns = positions * ticker_returns.values
        
        # Calculate cumulative returns
        cum_strategy_returns = (1 + pd.Series(strategy_returns, index=regime_series.index)).cumprod()
        cum_buy_hold_returns = (1 + ticker_returns).cumprod()
        
        # Store results
        strategy_results[ticker] = {
            'positions': positions,
            'strategy_returns': strategy_returns,
            'cum_strategy_returns': cum_strategy_returns,
            'cum_buy_hold_returns': cum_buy_hold_returns,
            'regime_stats': regime_stats
        }
    
    return strategy_results

# Visualize trading strategy results
def visualize_trading_strategy(strategy_results, ticker, detected_regimes, dates):
    """Visualize the performance of the regime-based trading strategy."""
    # Get results for the specified ticker
    results = strategy_results[ticker]
    
    # Create a pandas Series for the regimes
    regime_series = pd.Series(detected_regimes, index=dates)
    
    # Create figure with three subplots
    fig, axs = plt.subplots(3, 1, figsize=(16, 12), sharex=True, gridspec_kw={'height_ratios': [1, 1, 3]})
    
    # Plot regimes
    axs[0].plot(regime_series.index, regime_series.values, 'k-', linewidth=2)
    axs[0].set_title('Detected Market Regimes', fontsize=15)
    axs[0].set_ylabel('Regime', fontsize=12)
    axs[0].grid(True, alpha=0.3)
    axs[0].set_yticks(np.unique(detected_regimes))
    
    # Plot positions
    positions_series = pd.Series(results['positions'], index=dates)
    axs[1].plot(positions_series.index, positions_series.values, 'b-', linewidth=2)
    axs[1].set_title('Trading Positions', fontsize=15)
    axs[1].set_ylabel('Position', fontsize=12)
    axs[1].grid(True, alpha=0.3)
    axs[1].set_yticks([-1, 0, 1])
    axs[1].set_yticklabels(['Short', 'Cash', 'Long'])
    
    # Plot cumulative returns
    axs[2].plot(results['cum_strategy_returns'].index, results['cum_strategy_returns'].values, 'g-', linewidth=2, label='Regime Strategy')
    axs[2].plot(results['cum_buy_hold_returns'].index, results['cum_buy_hold_returns'].values, 'r--', linewidth=2, label='Buy & Hold')
    axs[2].set_title(f'Cumulative Returns for {ticker}', fontsize=15)
    axs[2].set_xlabel('Date', fontsize=12)
    axs[2].set_ylabel('Cumulative Return', fontsize=12)
    axs[2].grid(True, alpha=0.3)
    axs[2].legend(loc='lower right')
    
    # Add colored backgrounds for different regimes
    unique_regimes = np.unique(detected_regimes)
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_regimes)))
    
    # Find regime change points
    regime_filtered = regime_series.loc[regime_series.index.intersection(results['cum_strategy_returns'].index)]
    regime_changes = regime_filtered.diff().abs() > 0
    change_points = regime_filtered.index[regime_changes]
    
    # Add start and end points
    all_points = [regime_filtered.index[0]] + list(change_points) + [regime_filtered.index[-1]]
    
    # Color backgrounds for each period
    for i in range(len(all_points) - 1):
        start_date = all_points[i]
        end_date = all_points[i+1]
        regime = regime_filtered.loc[start_date]
        for ax in axs:
            ax.axvspan(start_date, end_date, alpha=0.2, color=colors[regime])
    
    # Add a legend for the regimes
    regime_legend = [plt.Line2D([0], [0], color=colors[i], lw=8, alpha=0.5, 
                          label=f'Regime {i}') for i in unique_regimes]
    axs[2].legend(handles=[
        plt.Line2D([0], [0], color='g', lw=2, label='Regime Strategy'),
        plt.Line2D([0], [0], color='r', linestyle='--', lw=2, label='Buy & Hold')
    ] + regime_legend, loc='lower right')
    
    # Calculate performance metrics
    strategy_total_return = results['cum_strategy_returns'].iloc[-1] - 1
    buy_hold_total_return = results['cum_buy_hold_returns'].iloc[-1] - 1
    
    strategy_returns_series = pd.Series(results['strategy_returns'], index=dates)
    buy_hold_returns_series = results['cum_buy_hold_returns'].pct_change().fillna(0)
    
    strategy_sharpe = strategy_returns_series.mean() / strategy_returns_series.std() * np.sqrt(252)
    buy_hold_sharpe = buy_hold_returns_series.mean() / buy_hold_returns_series.std() * np.sqrt(252)
    
    # Function to calculate max drawdown
    def max_drawdown(returns):
        cum_returns = (1 + returns).cumprod()
        running_max = cum_returns.cummax()
        drawdown = (cum_returns / running_max) - 1
        return drawdown.min()
    
    strategy_max_dd = max_drawdown(strategy_returns_series)
    buy_hold_max_dd = max_drawdown(buy_hold_returns_series)
    
    # Add performance metrics as text
    metrics_text = (
        f"Strategy Return: {strategy_total_return:.2%}\n"
        f"Buy & Hold Return: {buy_hold_total_return:.2%}\n"
        f"Strategy Sharpe: {strategy_sharpe:.2f}\n"
        f"Buy & Hold Sharpe: {buy_hold_sharpe:.2f}\n"
        f"Strategy Max DD: {strategy_max_dd:.2%}\n"
        f"Buy & Hold Max DD: {buy_hold_max_dd:.2%}\n"
    )
    
    # Add text box with metrics
    props = dict(boxstyle='round', facecolor='white', alpha=0.7)
    axs[2].text(0.02, 0.98, metrics_text, transform=axs[2].transAxes, fontsize=10,
               verticalalignment='top', bbox=props)
    
    plt.tight_layout()
    return fig

# Main function to run the market regime detection
def main():
    # Parameters
    tickers = ['SPY', 'QQQ', 'TLT', 'GLD', '^FTSE']  # Example tickers
    end_date = datetime.now() #- timedelta(days=1)
    start_date = end_date - timedelta(days=5*365)  # 5 years of data
    window_size = 20
    num_regimes = 4
    
    # Download and prepare data
    print(f"Downloading data for {tickers} from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}...")
    X, dates, returns_dict = download_market_data(tickers, start_date, end_date, feature_window=window_size)
    
    # Detect regimes using KAN and HMM
    print("Comparing regime detection methods...")
    comparison_fig, model, trained_params, kan_regimes, hmm_regimes = compare_regime_methods(X, dates, returns_dict, 'SPY')
    
    # Analyze regime characteristics
    print("Analyzing KAN regime characteristics...")
    regime_stats = analyze_regime_characteristics(kan_regimes, dates, returns_dict)
    
    # Visualize regime characteristics
    print("Visualizing regime characteristics...")
    characteristics_fig = visualize_regime_characteristics(regime_stats, tickers)
    
    # Implement trading strategy
    print("Implementing regime-based trading strategy...")
    strategy_results = regime_based_trading_strategy(kan_regimes, dates, returns_dict)
    
    # Visualize strategy results for SPY
    print("Visualizing trading strategy results...")
    strategy_fig = visualize_trading_strategy(strategy_results, 'SPY', kan_regimes, dates)

    strategy_fig = visualize_trading_strategy(strategy_results, 'QQQ', kan_regimes, dates)

    strategy_fig = visualize_trading_strategy(strategy_results, 'TLT', kan_regimes, dates)

    strategy_fig = visualize_trading_strategy(strategy_results, 'GLD', kan_regimes, dates)

    strategy_fig = visualize_trading_strategy(strategy_results, '^FTSE', kan_regimes, dates)
    
    # Return results
    return model, trained_params, kan_regimes, hmm_regimes, regime_stats, strategy_results, comparison_fig, characteristics_fig, strategy_fig

# Entry point for the script
if __name__ == "__main__":
    main()