In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit, value_and_grad
from tqdm import trange, tqdm 
import numpy as np
import matplotlib.pyplot as plt
import optax
import time
import os
import pickle
from functools import partial
from typing import List, Dict, Tuple, Callable, Any

# Check for GPU availability and configure JAX accordingly
def check_compute_device():
    """Check if JAX is using GPU or CPU."""
    try:
        # Try to get a GPU device
        jax.devices('gpu')
        print("GPU detected. Using GPU for computation.")
        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'  # Avoid pre-allocating all GPU memory
        return "gpu"
    except RuntimeError:
        # No GPU available
        print("No GPU detected. Using CPU for computation.")
        # Configure JAX to use CPU
        os.environ['JAX_PLATFORM_NAME'] = 'cpu'
        return "cpu"

# Initialize computing device
device = check_compute_device()

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)


def interpolate_activation(z, activation_curve, grid_points):
    """Interpolate activation function via linear interpolation."""
    idx = jnp.searchsorted(grid_points, z) - 1
    idx = jnp.clip(idx, 0, len(grid_points) - 2)

    x0 = grid_points[idx]
    x1 = grid_points[idx + 1]
    y0 = activation_curve[idx]
    y1 = activation_curve[idx + 1]

    t = (z - x0) / (x1 - x0)
    return y0 + t * (y1 - y0)


class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 30, 
                domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension
            num_basis: Number of basis functions for learned activation
            domain: Domain for activation functions
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values with different shapes suitable for option pricing
        activations_list = []
        for i in range(output_dim):
            subkey = jax.random.fold_in(key3, i)
            init_type = jax.random.randint(subkey, (), 0, 4)
            
            if init_type == 0:  # Linear-like
                act = self.grid_points
            elif init_type == 1:  # ReLU-like (for positive payoffs)
                act = jnp.maximum(0, self.grid_points)
            elif init_type == 2:  # Sigmoid-like (useful for capturing CDF-like components)
                act = 1.0 / (1.0 + jnp.exp(-2.0 * self.grid_points))
            else:  # Tanh-like
                act = jnp.tanh(self.grid_points)
            
            # Add noise to break symmetry
            act = act + jax.random.normal(subkey, (num_basis,)) * 0.05
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions by interpolation
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i using linear interpolation."""
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            x0 = self.grid_points[idx]
            x1 = self.grid_points[idx + 1]
            y0 = self.activations[i, idx]
            y1 = self.activations[i, idx + 1]
            
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            # output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z_clipped[:, i]))
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z_clipped[:, i]))

        return output

# Full KAN model for option pricing
class OptionPricingKAN:
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 30, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for option pricing.
        
        Args:
            input_dim: Input dimension (option parameters: S, K, T, r, sigma)
            output_dim: Output dimension (option prices: call, put)
            hidden_dims: List of hidden dimensions
            num_basis: Number of basis functions for learned activations
            domain: Domain for activation functions
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize layers
        self.layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.layers.append(layer)
            prev_dim = hidden_dim
        
        # Final layer for option prices
        self.output_layer = KANLayer(prev_dim, output_dim, num_basis, domain, keys[-1])
        
        # Store dimensions
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.num_basis = num_basis
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN model.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim) representing option prices
        """
        # Ensure x is a JAX array and has the right shape
        if isinstance(x, np.ndarray):
            x = jnp.array(x)
        
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
            
        for layer in self.layers:
            x = layer(x)
        
        # Apply output layer
        prices = self.output_layer(x)
        
        # Ensure option prices are positive
        prices = jnp.maximum(prices, 0.0)
        
        return prices
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        # # Store model architecture parameters
        # params['input_dim'] = self.input_dim
        # params['output_dim'] = self.output_dim
        # params['hidden_dims'] = self.hidden_dims
        # params['num_basis'] = self.num_basis
        # params['domain'] = self.domain
        
        # Store layer parameters
        for i, layer in enumerate(self.layers):
            params[f'layer_{i}_weights'] = layer.weights
            params[f'layer_{i}_biases'] = layer.biases
            params[f'layer_{i}_activations'] = layer.activations
            params[f'layer_{i}_grid_points'] = layer.grid_points
        
        params['output_layer_weights'] = self.output_layer.weights
        params['output_layer_biases'] = self.output_layer.biases
        params['output_layer_activations'] = self.output_layer.activations
        params['output_layer_grid_points'] = self.output_layer.grid_points
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.layers):
            layer.weights = params[f'layer_{i}_weights']
            layer.biases = params[f'layer_{i}_biases']
            layer.activations = params[f'layer_{i}_activations']
            if f'layer_{i}_grid_points' in params:
                layer.grid_points = params[f'layer_{i}_grid_points']
        
        self.output_layer.weights = params['output_layer_weights']
        self.output_layer.biases = params['output_layer_biases']
        self.output_layer.activations = params['output_layer_activations']
        if 'output_layer_grid_points' in params:
            self.output_layer.grid_points = params['output_layer_grid_points']

# Black-Scholes option pricing function for generating training data and comparison
def black_scholes(S, K, T, r, sigma, option_type='call'):
    """
    Black-Scholes option pricing formula.
    
    Args:
        S: Current stock price
        K: Strike price
        T: Time to maturity (in years)
        r: Risk-free interest rate
        sigma: Volatility
        option_type: 'call' or 'put'
    
    Returns:
        Option price
    """
    # Ensure positive time to maturity to avoid NaN in log
    T = jnp.maximum(T, 1e-10)
    
    d1 = (jnp.log(S / K) + (r + 0.5 * sigma**2) * T) / (sigma * jnp.sqrt(T))
    d2 = d1 - sigma * jnp.sqrt(T)
    
    if option_type == 'call':
        price = S * jax.scipy.stats.norm.cdf(d1) - K * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(d2)
    else:  # put option
        price = K * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(-d2) - S * jax.scipy.stats.norm.cdf(-d1)
    
    return price

# Calculate Greeks analytically for comparison
def calculate_greeks(S, K, T, r, sigma, option_type='call'):
    """Calculate option Greeks analytically."""
    # Ensure positive time to maturity
    T = jnp.maximum(T, 1e-10)
    
    d1 = (jnp.log(S / K) + (r + 0.5 * sigma**2) * T) / (sigma * jnp.sqrt(T))
    d2 = d1 - sigma * jnp.sqrt(T)
    
    # Common calculations
    pdf_d1 = jnp.exp(-0.5 * d1**2) / jnp.sqrt(2 * jnp.pi)
    
    # Delta
    if option_type == 'call':
        delta = jax.scipy.stats.norm.cdf(d1)
    else:  # put
        delta = jax.scipy.stats.norm.cdf(d1) - 1
    
    # Gamma (same for call and put)
    gamma = pdf_d1 / (S * sigma * jnp.sqrt(T))
    
    # Theta
    if option_type == 'call':
        theta = -S * pdf_d1 * sigma / (2 * jnp.sqrt(T)) - r * K * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(d2)
    else:  # put
        theta = -S * pdf_d1 * sigma / (2 * jnp.sqrt(T)) + r * K * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(-d2)
    
    # Vega (same for call and put)
    vega = S * jnp.sqrt(T) * pdf_d1
    
    # Rho
    if option_type == 'call':
        rho = K * T * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(d2)
    else:  # put
        rho = -K * T * jnp.exp(-r * T) * jax.scipy.stats.norm.cdf(-d2)
    
    return {
        'delta': delta,
        'gamma': gamma,
        'theta': theta / 365.0,  # Convert to daily theta
        'vega': vega / 100.0,    # Convert to 1% change
        'rho': rho / 100.0       # Convert to 1% change
    }

# Generate training data
def generate_training_data(num_samples=100000, random_key=None):
    """
    Generate synthetic training data for option pricing.
    
    Args:
        num_samples: Number of samples to generate
        random_key: JAX random key
    
    Returns:
        X: Input features (S, K, T, r, sigma)
        Y: Output targets (call_price, put_price)
    """
    if random_key is None:
        random_key = jax.random.PRNGKey(123)
    
    keys = jax.random.split(random_key, 5)
    
    # Generate random parameters within reasonable bounds
    S = jax.random.uniform(keys[0], (num_samples,), minval=50.0, maxval=200.0)  # Stock price
    K = jax.random.uniform(keys[1], (num_samples,), minval=50.0, maxval=200.0)  # Strike price
    T = jax.random.uniform(keys[2], (num_samples,), minval=0.1, maxval=2.0)     # Time to maturity (years)
    r = jax.random.uniform(keys[3], (num_samples,), minval=0.01, maxval=0.08)   # Risk-free rate
    sigma = jax.random.uniform(keys[4], (num_samples,), minval=0.1, maxval=0.5) # Volatility
    
    # Stack inputs
    X = jnp.column_stack([S, K, T, r, sigma])
    
    # Calculate option prices using Black-Scholes
    call_prices = vmap(lambda s, k, t, r, sig: black_scholes(s, k, t, r, sig, 'call'))(S, K, T, r, sigma)
    put_prices = vmap(lambda s, k, t, r, sig: black_scholes(s, k, t, r, sig, 'put'))(S, K, T, r, sigma)
    
    # Stack outputs
    Y = jnp.column_stack([call_prices, put_prices])
    
    return X, Y


KAN_DOMAIN = (-3.0, 3.0)
KAN_NUM_BASIS = 30
KAN_HIDDEN_DIMS = [64, 32]  # Match this to your model initialization


def forward_pass(params, X):
    """Forward pass using params only."""
    domain = KAN_DOMAIN
    num_basis = KAN_NUM_BASIS
    hidden_dims = KAN_HIDDEN_DIMS

    if isinstance(X, np.ndarray):
        X = jnp.array(X)
    if len(X.shape) == 1:
        X = X.reshape(1, -1)

    x = X
    for i in range(len(hidden_dims)):
        weights = params[f'layer_{i}_weights']
        biases = params[f'layer_{i}_biases']
        activations = params[f'layer_{i}_activations']
        grid_points = params.get(f'layer_{i}_grid_points', 
                                 jnp.linspace(domain[0], domain[1], num_basis))

        z = jnp.dot(x, weights) + biases
        z_clipped = jnp.clip(z, domain[0], domain[1])
        output = jnp.zeros_like(z)

        for j in range(z.shape[1]):
            z_j = z_clipped[:, j]
            idx = jnp.searchsorted(grid_points, z_j) - 1
            idx = jnp.clip(idx, 0, len(grid_points) - 2)

            x0 = grid_points[idx]
            x1 = grid_points[idx + 1]
            y0 = activations[j, idx]
            y1 = activations[j, idx + 1]

            t = (z_j - x0) / (x1 - x0)
            output = output.at[:, j].set(y0 + t * (y1 - y0))
        x = output

    # Output layer
    weights = params['output_layer_weights']
    biases = params['output_layer_biases']
    activations = params['output_layer_activations']
    grid_points = params.get('output_layer_grid_points', 
                             jnp.linspace(domain[0], domain[1], num_basis))

    z = jnp.dot(x, weights) + biases
    z_clipped = jnp.clip(z, domain[0], domain[1])
    output = jnp.zeros_like(z)

    for j in range(z.shape[1]):
        z_j = z_clipped[:, j]
        idx = jnp.searchsorted(grid_points, z_j) - 1
        idx = jnp.clip(idx, 0, len(grid_points) - 2)

        x0 = grid_points[idx]
        x1 = grid_points[idx + 1]
        y0 = activations[j, idx]
        y1 = activations[j, idx + 1]

        t = (z_j - x0) / (x1 - x0)
        output = output.at[:, j].set(y0 + t * (y1 - y0))

    return jnp.maximum(output, 0.0)

@jit
def loss_fn(params, X, Y, lambda_smooth=0.001):
    """Loss function: MSE + smoothness penalty on activation second derivatives."""
    pred = forward_pass(params, X)
    mse_loss = jnp.mean((pred - Y) ** 2)

    smooth_reg = 0.0
    # Grid spacing (assumes uniform spacing and same across all layers)
    h = params['output_layer_grid_points'][1] - params['output_layer_grid_points'][0]

    # Hidden layer smoothness
    layer_keys = [k for k in params.keys() if k.endswith('_activations') and not k.startswith('output')]
    for k in layer_keys:
        act = params[k]
        second_deriv = (act[:, 2:] - 2 * act[:, 1:-1] + act[:, :-2]) / (h ** 2)
        smooth_reg += jnp.mean(second_deriv ** 2)

    # Output layer smoothness
    out_act = params['output_layer_activations']
    second_deriv = (out_act[:, 2:] - 2 * out_act[:, 1:-1] + out_act[:, :-2]) / (h ** 2)
    smooth_reg += jnp.mean(second_deriv ** 2)

    return mse_loss + lambda_smooth * smooth_reg


@jit
def train_step(params, X, Y, opt_state, lambda_smooth=0.001):
    """Single training step using JAX JIT."""
    loss_value, grads = value_and_grad(lambda p: loss_fn(p, X, Y, lambda_smooth))(params)
    updates, new_opt_state = update_fn(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss_value


def train_model(model, X_train=None, Y_train=None, num_epochs=100, batch_size=64, lambda_smooth=0.001, num_samples=10000):
    global update_fn  # So train_step can access it

    if X_train is None or Y_train is None:
        print("Generating training data...")
        X_train, Y_train = generate_training_data(num_samples)

    params = model.params
    num_samples = X_train.shape[0]
    num_batches = num_samples // batch_size

    schedule_fn = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=5e-5,
        warmup_steps=num_epochs * num_batches // 10,
        decay_steps=num_epochs * num_batches,
        end_value=0.0001
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate=schedule_fn)
    )
    opt_state = optimizer.init(params)
    update_fn = optimizer.update

    losses = []

    print("Training model...")
    for epoch in trange(num_epochs, desc="Epochs"):
        perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        X_shuffled = X_train[perm]
        Y_shuffled = Y_train[perm]

        epoch_loss = 0.0

        for batch in range(num_batches):
            start_idx = batch * batch_size
            end_idx = start_idx + batch_size
            X_batch = X_shuffled[start_idx:end_idx]
            Y_batch = Y_shuffled[start_idx:end_idx]

            params, opt_state, batch_loss = train_step(params, X_batch, Y_batch, opt_state, lambda_smooth)
            epoch_loss += batch_loss

        epoch_loss /= num_batches
        losses.append(epoch_loss)

        tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.6f}")

    model.update_params(params)
    return model, params, losses


# Calculate Greeks using automatic differentiation with the KAN model
def calculate_greeks_from_kan(model, S, K, T, r, sigma, option_idx=0):
    """
    Calculate option Greeks using automatic differentiation on the KAN model.
    
    Args:
        model: Trained OptionPricingKAN model
        S, K, T, r, sigma: Option parameters
        option_idx: Index of the option (0 for call, 1 for put)
    
    Returns:
        Dictionary of Greeks
    """
    # Create a function that takes S as input and returns option price
    def price_wrt_S(S):
        inputs = jnp.array([S, K, T, r, sigma])
        return model(inputs)[0, option_idx]
    
    # Delta = ∂V/∂S
    delta = jax.grad(price_wrt_S)(S)
    
    # Gamma = ∂²V/∂S²
    gamma = jax.grad(jax.grad(price_wrt_S))(S)
    
    # Create functions for other Greeks
    def price_wrt_T(T):
        inputs = jnp.array([S, K, T, r, sigma])
        return model(inputs)[0, option_idx]
    
    def price_wrt_r(r):
        inputs = jnp.array([S, K, T, r, sigma])
        return model(inputs)[0, option_idx]
    
    def price_wrt_sigma(sigma):
        inputs = jnp.array([S, K, T, r, sigma])
        return model(inputs)[0, option_idx]
    
    # Theta = -∂V/∂T (daily)
    theta = -jax.grad(price_wrt_T)(T) / 365.0
    
    # Vega = ∂V/∂σ (for 1% change)
    vega = jax.grad(price_wrt_sigma)(sigma) / 100.0
    
    # Rho = ∂V/∂r (for 1% change)
    rho = jax.grad(price_wrt_r)(r) / 100.0
    
    return {
        'delta': delta,
        'gamma': gamma,
        'theta': theta,
        'vega': vega,
        'rho': rho
    }




def batched_greeks(model, X_batch, option_idx=0):
    """
    Compute Greeks in batch mode using autodiff and vmap.
    
    Args:
        model: Trained KAN model
        X_batch: Array of shape (batch_size, 5) with columns [S, K, T, r, sigma]
        option_idx: 0 for call, 1 for put

    Returns:
        Dictionary of batched Greeks
    """
    def price_fn(x): return model(x.reshape(1, -1))[0, option_idx]

    delta_fn = lambda x: jax.grad(price_fn)(x)[0]      # ∂V/∂S
    gamma_fn = lambda x: jax.grad(lambda x_: jax.grad(price_fn)(x_)[0])(x)  # ∂²V/∂S²
    theta_fn = lambda x: -jax.grad(price_fn)(x)[2] / 365.0                  # -∂V/∂T
    vega_fn  = lambda x: jax.grad(price_fn)(x)[4] / 100.0                   # ∂V/∂σ
    rho_fn   = lambda x: jax.grad(price_fn)(x)[3] / 100.0                   # ∂V/∂r

    return {
        'delta': vmap(delta_fn)(X_batch),
        'gamma': vmap(gamma_fn)(X_batch),
        'theta': vmap(theta_fn)(X_batch),
        'vega': vmap(vega_fn)(X_batch),
        'rho': vmap(rho_fn)(X_batch),
    }


# Evaluate model performance
def evaluate_model(model, num_samples=1000):
    """
    Evaluate model performance on test data.
    
    Args:
        model: Trained OptionPricingKAN model
        num_samples: Number of test samples
    
    Returns:
        metrics: Evaluation metrics
    """
    # Generate test data
    X_test, Y_test = generate_training_data(num_samples, jax.random.PRNGKey(999))
    
    # Predict option prices
    predictions = model(X_test)
    
    # Calculate errors
    abs_error = jnp.abs(predictions - Y_test)
    mean_abs_error = jnp.mean(abs_error)
    max_abs_error = jnp.max(abs_error)
    
    rel_error = jnp.abs(predictions - Y_test) / jnp.maximum(Y_test, 1e-10)
    mean_rel_error = jnp.mean(rel_error)
    
    # Calculate R-squared
    y_mean = jnp.mean(Y_test, axis=0)
    ss_total = jnp.sum((Y_test - y_mean) ** 2, axis=0)
    ss_residual = jnp.sum((Y_test - predictions) ** 2, axis=0)
    r_squared = 1 - ss_residual / ss_total
    
    # Return metrics
    metrics = {
        'mean_abs_error': mean_abs_error,
        'max_abs_error': max_abs_error,
        'mean_rel_error': mean_rel_error,
        'r_squared_call': r_squared[0],
        'r_squared_put': r_squared[1]
    }
    
    print("\nModel Evaluation:")
    print(f"Mean Absolute Error: {mean_abs_error:.6f}")
    print(f"Max Absolute Error: {max_abs_error:.6f}")
    print(f"Mean Relative Error: {mean_rel_error:.2%}")
    print(f"R-squared (Call): {r_squared[0]:.6f}")
    print(f"R-squared (Put): {r_squared[1]:.6f}")
    
    return metrics, X_test, Y_test, predictions

# For specific option pricing
def price_specific_option(model, S, K, T, r, sigma):
    """
    Price a specific option using the trained model.
    
    Args:
        model: Trained OptionPricingKAN model
        S, K, T, r, sigma: Option parameters
    
    Returns:
        Dictionary with option prices and Greeks
    """
    # Create input tensor
    inputs = jnp.array([[S, K, T, r, sigma]])
    
    # Get option prices
    prices = model(inputs)
    call_price = prices[0, 0].item()
    put_price = prices[0, 1].item()
    
    # Calculate Greeks for call and put options
    call_greeks = calculate_greeks_from_kan(model, S, K, T, r, sigma, 0)
    put_greeks = calculate_greeks_from_kan(model, S, K, T, r, sigma, 1)
    
    # Calculate analytical Black-Scholes prices and Greeks
    bs_call_price = black_scholes(S, K, T, r, sigma, 'call').item()
    bs_put_price = black_scholes(S, K, T, r, sigma, 'put').item()
    bs_call_greeks = calculate_greeks(S, K, T, r, sigma, 'call')
    bs_put_greeks = calculate_greeks(S, K, T, r, sigma, 'put')
    
    # Prepare results
    results = {
        'option_params': {
            'S': S,
            'K': K,
            'T': T,
            'r': r,
            'sigma': sigma
        },
        'call_option': {
            'price': call_price,
            'bs_price': bs_call_price,
            'price_diff': call_price - bs_call_price,
            'greeks': call_greeks,
            'bs_greeks': bs_call_greeks
        },
        'put_option': {
            'price': put_price,
            'bs_price': bs_put_price,
            'price_diff': put_price - bs_put_price,
            'greeks': put_greeks,
            'bs_greeks': bs_put_greeks
        }
    }
    
    # Print results
    print("\nOption Pricing Results:")
    print(f"Stock Price (S): {S:.2f}")
    print(f"Strike Price (K): {K:.2f}")
    print(f"Time to Maturity (T): {T:.2f} years")
    print(f"Risk-free Rate (r): {r:.2%}")
    print(f"Volatility (σ): {sigma:.2%}")
    print("\nCall Option:")
    print(f"KAN Price: {call_price:.6f}")
    print(f"Black-Scholes Price: {bs_call_price:.6f}")
    print(f"Price Difference: {call_price - bs_call_price:.6f}")
    print("\nPut Option:")
    print(f"KAN Price: {put_price:.6f}")
    print(f"Black-Scholes Price: {bs_put_price:.6f}")
    print(f"Price Difference: {put_price - bs_put_price:.6f}")
    
    return results

def run_kan_option_pricing(mode="train_and_evaluate", 
                          model_file=None,
                          num_epochs=100,
                          batch_size=64,
                          hidden_dims=[64, 32],
                          num_basis=30,
                          option_params=None):
    """
    Main function to run KAN option pricing with various modes.
    
    Args:
        mode: Operation mode 
            - "train_and_evaluate": Train a new model and evaluate it
            - "train_only": Train a new model and save it
            - "evaluate_only": Evaluate a pre-trained model
            - "price_option": Price a specific option with a pre-trained model
            - "batch_pricing": Run batch pricing benchmark
            - "visualize_greeks": Visualize option Greeks
            - "scan_volatility": Scan across volatilities
            - "extreme_scenarios": Test model on extreme scenarios
            - "full_analysis": Run full analysis pipeline
        model_file: Path to save/load model
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        hidden_dims: List of hidden layer dimensions
        num_basis: Number of basis functions
        option_params: Dictionary of option parameters for pricing
            {'S': 100, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2}
    
    Returns:
        Model and any results based on the mode
    """
    results = {}

    
    # Set default option parameters if not provided
    if option_params is None:
        option_params = {'S': 100, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2}
    
    # Check compute device
    compute_device = check_compute_device()
    
    # Initialize model for training or load pre-trained model
    if mode in ["train_and_evaluate", "train_only"]:
        print("Initializing new KAN model...")
        model = OptionPricingKAN(
            input_dim=5,  # S, K, T, r, sigma
            output_dim=2,  # call and put prices
            hidden_dims=hidden_dims,
            num_basis=num_basis
        )
        
        # Train the model
        print(f"Training model with {num_epochs} epochs and batch size {batch_size}...")
        model, params, losses = train_model(
            model=model,
            num_epochs=num_epochs,
            batch_size=batch_size,
            num_samples=10000
        )
        results['losses'] = losses
        
        # Save the model if requested
        if model_file:
            save_model(model, model_file)
    
    elif mode in ["evaluate_only", "price_option", "batch_pricing", "visualize_greeks", 
                 "scan_volatility", "extreme_scenarios", "full_analysis"]:
        if model_file:
            print(f"Loading model from {model_file}...")
            model = load_model(model_file)
            if model is None:
                print("Failed to load model. Training a default model instead...")
                model = OptionPricingKAN(
                    input_dim=5,  # S, K, T, r, sigma
                    output_dim=2,  # call and put prices
                    hidden_dims=hidden_dims,
                    num_basis=num_basis
                )
                # Fixed: Use the updated train_model function
                model, params, losses = train_model(
                    model=model,
                    num_epochs=30,  # Quick training for default model
                    batch_size=batch_size,
                    num_samples=5000
                )
                results['losses'] = losses
        else:
            print("No model file provided. Training a default model...")
            model = OptionPricingKAN(
                input_dim=5,  # S, K, T, r, sigma
                output_dim=2,  # call and put prices
                hidden_dims=hidden_dims,
                num_basis=num_basis
            )
            # Fixed: Use the updated train_model function
            model, params, losses = train_model(
                model=model,
                num_epochs=30,  # Quick training for default model
                batch_size=batch_size,
                num_samples=5000
            )
            results['losses'] = losses
    
    # Perform operations based on mode
    if mode in ["train_and_evaluate", "evaluate_only", "full_analysis"]:
        print("Evaluating model performance...")
        metrics, X_test, Y_test, predictions = evaluate_model(model)
        results['metrics'] = metrics
        results['test_data'] = (X_test, Y_test, predictions)
        
        # Visualize option pricing
        print("Visualizing option pricing results...")
        fig = visualize_option_pricing(model, X_test, Y_test, predictions)
        results['pricing_fig'] = fig
    
    if mode in ["price_option", "full_analysis"]:
        print("Pricing specific option...")
        option_results = price_specific_option(
            model, 
            S=option_params['S'], 
            K=option_params['K'], 
            T=option_params['T'], 
            r=option_params['r'], 
            sigma=option_params['sigma']
        )
        results['option_results'] = option_results
    
    if mode in ["batch_pricing", "full_analysis"]:
        print("Running batch pricing benchmark...")
        batch_results = batch_option_pricing(model, num_options=10000)
        results['batch_results'] = batch_results
    
    if mode in ["visualize_greeks", "full_analysis"]:
        print("Visualizing option Greeks...")
        greeks_fig = visualize_greeks(model)
        results['greeks_fig'] = greeks_fig
        
        # Visualize learned activation functions
        print("Visualizing learned activation functions...")
        activation_fig, output_activation_fig = visualize_activations(model)
        results['activation_figs'] = (activation_fig, output_activation_fig)
    
    if mode in ["scan_volatility", "full_analysis"]:
        print("Scanning across different volatilities...")
        vol_scan_fig = scan_implied_volatility(
            model, 
            S=option_params['S'], 
            K=option_params['K'], 
            T=option_params['T'], 
            r=option_params['r']
        )
        results['vol_scan_fig'] = vol_scan_fig
        
        # Generate pricing grid
        print("Generating pricing grid...")
        grid_fig = pricing_moneyness_maturity_grid(
            model, 
            S=option_params['S'], 
            r=option_params['r'], 
            sigma=option_params['sigma']
        )
        results['grid_fig'] = grid_fig
    
    if mode in ["extreme_scenarios", "full_analysis"]:
        print("Analyzing extreme scenarios...")
        extreme_results = analyze_extreme_scenarios(model)
        results['extreme_results'] = extreme_results

    param_range = jnp.linspace(50, 150, 100)  # Sweep from 50 to 150
    fig = plot_price_vs_param(model, 'S', param_range, option_type='call')
    plt.show()

    
    # Return the model and results
    return model, results


# Save and load model functions
def save_model(model, filename):
    """
    Save the KAN model to a file.
    
    Args:
        model: OptionPricingKAN model to save
        filename: Path to save the model
    """
    try:
        with open(filename, 'wb') as f:
            pickle.dump(model.params, f)
        print(f"Model saved successfully to {filename}")
        return True
    except Exception as e:
        print(f"Error saving model: {e}")
        return False

def load_model(filename):
    """
    Load a KAN model from a file.
    
    Args:
        filename: Path to the saved model
        
    Returns:
        Loaded OptionPricingKAN model or None if failed
    """
    try:
        if not os.path.exists(filename):
            print(f"Model file {filename} not found")
            return None
        # print("model doesnt exsist")
        with open(filename, 'rb') as f:
            params = pickle.load(f)
            
        # Need to create a new model with the right architecture
        # Extract dimensions from the parameter shapes
        output_dim = params['output_layer_weights'].shape[1]
        input_dim = params['layer_0_weights'].shape[0]
        
        # Determine hidden dimensions from the params
        hidden_dims = []
        i = 0
        while f'layer_{i}_weights' in params:
            if i > 0:  # Skip first layer since we already got input_dim
                hidden_dims.append(params[f'layer_{i-1}_weights'].shape[1])
            i += 1
            
        # Extract num_basis and domain info
        num_basis = params['layer_0_grid_points'].shape[0]
        domain = (params['layer_0_grid_points'][0].item(), 
                  params['layer_0_grid_points'][-1].item())
        
        # Create new model with the right architecture
        model = OptionPricingKAN(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_dims=hidden_dims,
            num_basis=num_basis,
            domain=domain
        )
        
        # Update the model with saved parameters
        model.update_params(params)
        
        print(f"Model loaded successfully from {filename}")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None


# Visualize learned activation functions
def visualize_activations(model):
    """
    Visualize the learned activation functions in the model.
    
    Args:
        model: Trained OptionPricingKAN model
    
    Returns:
        Figure objects for hidden and output layer activations
    """
    # Create figure for hidden layer activations
    num_hidden_layers = len(model.layers)
    fig, axs = plt.subplots(num_hidden_layers, 4, figsize=(15, 3 * num_hidden_layers))
    
    if num_hidden_layers == 1:
        axs = axs.reshape(1, -1)
    
    for layer_idx, layer in enumerate(model.layers):
        output_dim = layer.activations.shape[0]
        cols_to_show = min(output_dim, 4)
        
        for col in range(cols_to_show):
            axs[layer_idx, col].plot(layer.grid_points, layer.activations[col])
            axs[layer_idx, col].set_title(f"Layer {layer_idx+1}, Unit {col+1}")
            axs[layer_idx, col].grid(True)
            axs[layer_idx, col].axhline(y=0, color='gray', linestyle='--')
            axs[layer_idx, col].axvline(x=0, color='gray', linestyle='--')
    
    plt.tight_layout()
    
    # Create figure for output layer activations
    output_dim = model.output_layer.activations.shape[0]
    fig_output, axs_output = plt.subplots(1, output_dim, figsize=(12, 4))
    
    if output_dim == 1:
        axs_output = [axs_output]
    
    for col in range(output_dim):
        label = "Call" if col == 0 else "Put"
        axs_output[col].plot(model.output_layer.grid_points, model.output_layer.activations[col])
        axs_output[col].set_title(f"Output Layer, {label} Option")
        axs_output[col].grid(True)
        axs_output[col].axhline(y=0, color='gray', linestyle='--')
        axs_output[col].axvline(x=0, color='gray', linestyle='--')
    
    plt.tight_layout()
    return fig

# Analyze extreme scenarios
def analyze_extreme_scenarios(model):
    """
    Analyze model performance in extreme scenarios.
    
    Args:
        model: Trained OptionPricingKAN model
    
    Returns:
        Dictionary with extreme scenario results
    """
    # Define extreme scenarios
    scenarios = [
        {'name': 'Deep ITM Call', 'S': 150, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2},
        {'name': 'Deep OTM Call', 'S': 50, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2},
        {'name': 'Deep ITM Put', 'S': 50, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2},
        {'name': 'Deep OTM Put', 'S': 150, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2},
        {'name': 'Near Expiry', 'S': 100, 'K': 100, 'T': 0.01, 'r': 0.05, 'sigma': 0.2},
        {'name': 'Long Maturity', 'S': 100, 'K': 100, 'T': 5.0, 'r': 0.05, 'sigma': 0.2},
        {'name': 'High Volatility', 'S': 100, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.8},
        {'name': 'Low Volatility', 'S': 100, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.05},
        {'name': 'High Interest Rate', 'S': 100, 'K': 100, 'T': 0.5, 'r': 0.15, 'sigma': 0.2},
        {'name': 'Negative Interest Rate', 'S': 100, 'K': 100, 'T': 0.5, 'r': -0.01, 'sigma': 0.2},
    ]
    
    results = []
    
    # Analyze each scenario
    for scenario in scenarios:
        S, K, T, r, sigma = scenario['S'], scenario['K'], scenario['T'], scenario['r'], scenario['sigma']
        
        # KAN prices
        inputs = jnp.array([[S, K, T, r, sigma]])
        prices = model(inputs)
        kan_call = prices[0, 0].item()
        kan_put = prices[0, 1].item()
        
        # Black-Scholes prices
        bs_call = black_scholes(S, K, T, r, sigma, 'call').item()
        bs_put = black_scholes(S, K, T, r, sigma, 'put').item()
        
        # Calculate errors
        call_error = jnp.abs(kan_call - bs_call)
        call_rel_error = call_error / jnp.maximum(bs_call, 1e-10)
        
        put_error = jnp.abs(kan_put - bs_put)
        put_rel_error = put_error / jnp.maximum(bs_put, 1e-10)
        
        # Add to results
        scenario_result = {
            'name': scenario['name'],
            'params': {
                'S': S,
                'K': K,
                'T': T,
                'r': r,
                'sigma': sigma
            },
            'call': {
                'kan': kan_call,
                'bs': bs_call,
                'abs_error': call_error,
                'rel_error': call_rel_error
            },
            'put': {
                'kan': kan_put,
                'bs': bs_put,
                'abs_error': put_error,
                'rel_error': put_rel_error
            }
        }
        
        results.append(scenario_result)
    
    # Print results
    print("\nExtreme Scenario Analysis:")
    print(f"{'Scenario':<20} {'Call Error':<12} {'Call Rel.Err':<12} {'Put Error':<12} {'Put Rel.Err':<12}")
    print("-" * 72)
    
    for result in results:
        name = result['name']
        call_error = result['call']['abs_error']
        call_rel_error = result['call']['rel_error']
        put_error = result['put']['abs_error']
        put_rel_error = result['put']['rel_error']
        
        print(f"{name:<20} {call_error:<12.6f} {call_rel_error:<12.2%} {put_error:<12.6f} {put_rel_error:<12.2%}")
    
    return results


# Create pricing grid across moneyness and maturity
def pricing_moneyness_maturity_grid(model, S=100, r=0.05, sigma=0.2):
    """
    Create a grid visualization of option prices across moneyness and maturity.
    
    Args:
        model: Trained OptionPricingKAN model
        S: Fixed stock price
        r: Fixed risk-free rate
        sigma: Fixed volatility
    
    Returns:
        Figure object
    """
    # Define moneyness and maturity grids
    moneyness_values = jnp.linspace(0.7, 1.3, 20)  # K/S ratio
    maturity_values = jnp.linspace(0.1, 2.0, 20)   # Years
    
    # Create meshgrid
    moneyness_grid, maturity_grid = jnp.meshgrid(moneyness_values, maturity_values)
    
    # Initialize price grids
    call_prices = jnp.zeros_like(moneyness_grid)
    put_prices = jnp.zeros_like(moneyness_grid)
    bs_call_prices = jnp.zeros_like(moneyness_grid)
    bs_put_prices = jnp.zeros_like(moneyness_grid)
    
    # Calculate prices for each grid point
    for i in range(moneyness_grid.shape[0]):
        for j in range(moneyness_grid.shape[1]):
            K = S * moneyness_grid[i, j]
            T = maturity_grid[i, j]
            
            # KAN prices
            inputs = jnp.array([[S, K, T, r, sigma]])
            prices = model(inputs)
            call_prices = call_prices.at[i, j].set(prices[0, 0].item())
            put_prices = put_prices.at[i, j].set(prices[0, 1].item())
            
            # Black-Scholes prices
            bs_call_prices = bs_call_prices.at[i, j].set(black_scholes(S, K, T, r, sigma, 'call').item())
            bs_put_prices = bs_put_prices.at[i, j].set(black_scholes(S, K, T, r, sigma, 'put').item())
    
    # Calculate price differences
    call_diff = jnp.abs(call_prices - bs_call_prices)
    put_diff = jnp.abs(put_prices - bs_put_prices)
    
    # Create figure
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))
    
    # Call prices - KAN
    c1 = axs[0, 0].contourf(moneyness_grid, maturity_grid, call_prices, 20, cmap='viridis')
    axs[0, 0].set_title('Call Price (KAN)')
    axs[0, 0].set_xlabel('Moneyness (K/S)')
    axs[0, 0].set_ylabel('Maturity (years)')
    plt.colorbar(c1, ax=axs[0, 0])
    
    # Call prices - Black-Scholes
    c2 = axs[0, 1].contourf(moneyness_grid, maturity_grid, bs_call_prices, 20, cmap='viridis')
    axs[0, 1].set_title('Call Price (Black-Scholes)')
    axs[0, 1].set_xlabel('Moneyness (K/S)')
    axs[0, 1].set_ylabel('Maturity (years)')
    plt.colorbar(c2, ax=axs[0, 1])
    
    # Call price difference
    c3 = axs[0, 2].contourf(moneyness_grid, maturity_grid, call_diff, 20, cmap='hot')
    axs[0, 2].set_title('Call Price Difference')
    axs[0, 2].set_xlabel('Moneyness (K/S)')
    axs[0, 2].set_ylabel('Maturity (years)')
    plt.colorbar(c3, ax=axs[0, 2])
    
    # Put prices - KAN
    c4 = axs[1, 0].contourf(moneyness_grid, maturity_grid, put_prices, 20, cmap='viridis')
    axs[1, 0].set_title('Put Price (KAN)')
    axs[1, 0].set_xlabel('Moneyness (K/S)')
    axs[1, 0].set_ylabel('Maturity (years)')
    plt.colorbar(c4, ax=axs[1, 0])
    
    # Put prices - Black-Scholes
    c5 = axs[1, 1].contourf(moneyness_grid, maturity_grid, bs_put_prices, 20, cmap='viridis')
    axs[1, 1].set_title('Put Price (Black-Scholes)')
    axs[1, 1].set_xlabel('Moneyness (K/S)')
    axs[1, 1].set_ylabel('Maturity (years)')
    plt.colorbar(c5, ax=axs[1, 1])
    
    # Put price difference
    c6 = axs[1, 2].contourf(moneyness_grid, maturity_grid, put_diff, 20, cmap='hot')
    axs[1, 2].set_title('Put Price Difference')
    axs[1, 2].set_xlabel('Moneyness (K/S)')
    axs[1, 2].set_ylabel('Maturity (years)')
    plt.colorbar(c6, ax=axs[1, 2])
    
    plt.tight_layout()
    return fig


# Visualize option Greeks
def visualize_greeks(model, S=100, K=100, T=0.5, r=0.05, sigma=0.2):
    """
    Visualize option Greeks calculated from the KAN model compared to Black-Scholes.
    
    Args:
        model: Trained OptionPricingKAN model
        S, K, T, r, sigma: Option parameters
        
    Returns:
        Figure object
    """
    # Generate stock price range for visualization
    stock_prices = jnp.linspace(0.5 * K, 1.5 * K, 100)
    
    # Initialize arrays for Greeks
    kan_call_delta = []
    kan_call_gamma = []
    kan_call_theta = []
    kan_put_delta = []
    kan_put_gamma = []
    kan_put_theta = []
    
    bs_call_delta = []
    bs_call_gamma = []
    bs_call_theta = []
    bs_put_delta = []
    bs_put_gamma = []
    bs_put_theta = []
    
    # Calculate Greeks for each stock price
    for s in stock_prices:
        # KAN Greeks - call
        call_greeks = calculate_greeks_from_kan(model, s, K, T, r, sigma, 0)
        kan_call_delta.append(call_greeks['delta'])
        kan_call_gamma.append(call_greeks['gamma'])
        kan_call_theta.append(call_greeks['theta'])
        
        # KAN Greeks - put
        put_greeks = calculate_greeks_from_kan(model, s, K, T, r, sigma, 1)
        kan_put_delta.append(put_greeks['delta'])
        kan_put_gamma.append(put_greeks['gamma'])
        kan_put_theta.append(put_greeks['theta'])
        
        # Black-Scholes Greeks - call
        bs_call = calculate_greeks(s, K, T, r, sigma, 'call')
        bs_call_delta.append(bs_call['delta'])
        bs_call_gamma.append(bs_call['gamma'])
        bs_call_theta.append(bs_call['theta'])
        
        # Black-Scholes Greeks - put
        bs_put = calculate_greeks(s, K, T, r, sigma, 'put')
        bs_put_delta.append(bs_put['delta'])
        bs_put_gamma.append(bs_put['gamma'])
        bs_put_theta.append(bs_put['theta'])
    
    # Convert to arrays
    kan_call_delta = jnp.array(kan_call_delta)
    kan_call_gamma = jnp.array(kan_call_gamma)
    kan_call_theta = jnp.array(kan_call_theta)
    kan_put_delta = jnp.array(kan_put_delta)
    kan_put_gamma = jnp.array(kan_put_gamma)
    kan_put_theta = jnp.array(kan_put_theta)
    
    bs_call_delta = jnp.array(bs_call_delta)
    bs_call_gamma = jnp.array(bs_call_gamma)
    bs_call_theta = jnp.array(bs_call_theta)
    bs_put_delta = jnp.array(bs_put_delta)
    bs_put_gamma = jnp.array(bs_put_gamma)
    bs_put_theta = jnp.array(bs_put_theta)
    
    # Create figure
    fig, axs = plt.subplots(3, 2, figsize=(15, 12))
    
    # Delta - Call
    axs[0, 0].plot(stock_prices, kan_call_delta, 'b-', label='KAN')
    axs[0, 0].plot(stock_prices, bs_call_delta, 'r--', label='Black-Scholes')
    axs[0, 0].set_title('Call Option - Delta')
    axs[0, 0].set_xlabel('Stock Price')
    axs[0, 0].set_ylabel('Delta')
    axs[0, 0].legend()
    axs[0, 0].grid(True)
    axs[0, 0].axvline(x=K, color='gray', linestyle='--')
    
    # Delta - Put
    axs[0, 1].plot(stock_prices, kan_put_delta, 'b-', label='KAN')
    axs[0, 1].plot(stock_prices, bs_put_delta, 'r--', label='Black-Scholes')
    axs[0, 1].set_title('Put Option - Delta')
    axs[0, 1].set_xlabel('Stock Price')
    axs[0, 1].set_ylabel('Delta')
    axs[0, 1].legend()
    axs[0, 1].grid(True)
    axs[0, 1].axvline(x=K, color='gray', linestyle='--')
    
    # Gamma - Call
    axs[1, 0].plot(stock_prices, kan_call_gamma, 'b-', label='KAN')
    axs[1, 0].plot(stock_prices, bs_call_gamma, 'r--', label='Black-Scholes')
    axs[1, 0].set_title('Call Option - Gamma')
    axs[1, 0].set_xlabel('Stock Price')
    axs[1, 0].set_ylabel('Gamma')
    axs[1, 0].legend()
    axs[1, 0].grid(True)
    axs[1, 0].axvline(x=K, color='gray', linestyle='--')
    
    # Gamma - Put
    axs[1, 1].plot(stock_prices, kan_put_gamma, 'b-', label='KAN')
    axs[1, 1].plot(stock_prices, bs_put_gamma, 'r--', label='Black-Scholes')
    axs[1, 1].set_title('Put Option - Gamma')
    axs[1, 1].set_xlabel('Stock Price')
    axs[1, 1].set_ylabel('Gamma')
    axs[1, 1].legend()
    axs[1, 1].grid(True)
    axs[1, 1].axvline(x=K, color='gray', linestyle='--')
    
    # Theta - Call
    axs[2, 0].plot(stock_prices, kan_call_theta, 'b-', label='KAN')
    axs[2, 0].plot(stock_prices, bs_call_theta, 'r--', label='Black-Scholes')
    axs[2, 0].set_title('Call Option - Theta (Daily)')
    axs[2, 0].set_xlabel('Stock Price')
    axs[2, 0].set_ylabel('Theta')
    axs[2, 0].legend()
    axs[2, 0].grid(True)
    axs[2, 0].axvline(x=K, color='gray', linestyle='--')
    
    # Theta - Put
    axs[2, 1].plot(stock_prices, kan_put_theta, 'b-', label='KAN')
    axs[2, 1].plot(stock_prices, bs_put_theta, 'r--', label='Black-Scholes')
    axs[2, 1].set_title('Put Option - Theta (Daily)')
    axs[2, 1].set_xlabel('Stock Price')
    axs[2, 1].set_ylabel('Theta')
    axs[2, 1].legend()
    axs[2, 1].grid(True)
    axs[2, 1].axvline(x=K, color='gray', linestyle='--')
    
    plt.tight_layout()
    return fig


# Visualize option pricing results
def visualize_option_pricing(model, X_test, Y_test, predictions, num_samples=100):
    """
    Visualize option pricing predictions vs. true prices.
    
    Args:
        model: Trained OptionPricingKAN model
        X_test: Test inputs
        Y_test: True option prices
        predictions: Model predictions
        num_samples: Number of samples to plot
    
    Returns:
        Figure object
    """
    # Sample a subset of test data for visualization
    idx = jax.random.randint(jax.random.PRNGKey(42), (num_samples,), 0, X_test.shape[0])
    X_sample = X_test[idx]
    Y_sample = Y_test[idx]
    pred_sample = predictions[idx]
    
    # Extract option parameters and prices
    S = X_sample[:, 0]  # Stock price
    K = X_sample[:, 1]  # Strike price
    T = X_sample[:, 2]  # Time to maturity
    
    call_true = Y_sample[:, 0]
    call_pred = pred_sample[:, 0]
    put_true = Y_sample[:, 1]
    put_pred = pred_sample[:, 1]
    
    # Calculate moneyness (S/K)
    moneyness = S / K
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    # Call option: True vs. Predicted
    axs[0, 0].scatter(call_true, call_pred, alpha=0.6)
    min_val = min(call_true.min(), call_pred.min())
    max_val = max(call_true.max(), call_pred.max())
    axs[0, 0].plot([min_val, max_val], [min_val, max_val], 'r--')
    axs[0, 0].set_xlabel('True Call Price')
    axs[0, 0].set_ylabel('Predicted Call Price')
    axs[0, 0].set_title('Call Option: True vs. Predicted')
    axs[0, 0].grid(True)
    
    # Put option: True vs. Predicted
    axs[0, 1].scatter(put_true, put_pred, alpha=0.6)
    min_val = min(put_true.min(), put_pred.min())
    max_val = max(put_true.max(), put_pred.max())
    axs[0, 1].plot([min_val, max_val], [min_val, max_val], 'r--')
    axs[0, 1].set_xlabel('True Put Price')
    axs[0, 1].set_ylabel('Predicted Put Price')
    axs[0, 1].set_title('Put Option: True vs. Predicted')
    axs[0, 1].grid(True)
    
    # Call option error vs. moneyness
    call_error = jnp.abs(call_pred - call_true)
    axs[1, 0].scatter(moneyness, call_error, alpha=0.6)
    axs[1, 0].set_xlabel('Moneyness (S/K)')
    axs[1, 0].set_ylabel('Absolute Error')
    axs[1, 0].set_title('Call Option: Error vs. Moneyness')
    axs[1, 0].grid(True)
    
    # Call option error vs. maturity
    axs[1, 1].scatter(T, call_error, alpha=0.6)
    axs[1, 1].set_xlabel('Time to Maturity (years)')
    axs[1, 1].set_ylabel('Absolute Error')
    axs[1, 1].set_title('Call Option: Error vs. Maturity')
    axs[1, 1].grid(True)
    
    plt.tight_layout()
    return fig


def plot_price_vs_param(model, param_name, param_range, fixed_params=None, option_type='call'):
    """
    Plot option price vs a single parameter by varying it over a range.
    
    Args:
        model: Trained OptionPricingKAN model
        param_name: One of 'S', 'K', 'T', 'r', 'sigma'
        param_range: 1D array of values to sweep for param_name
        fixed_params: Dict with fixed values for other 4 parameters
        option_type: 'call' or 'put'
    
    Returns:
        Matplotlib figure
    """
    # Set default fixed params if not provided
    if fixed_params is None:
        fixed_params = {'S': 100, 'K': 100, 'T': 0.5, 'r': 0.05, 'sigma': 0.2}

    # Index of the option in output
    option_idx = 0 if option_type == 'call' else 1

    # Build input batch
    X = []
    for val in param_range:
        p = fixed_params.copy()
        p[param_name] = val
        X.append([p['S'], p['K'], p['T'], p['r'], p['sigma']])
    
    X = jnp.array(X)
    prices = model(X)
    prices = prices[:, option_idx]

    # Plot
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(param_range, prices, label=f'KAN {option_type} price')
    ax.set_xlabel(param_name)
    ax.set_ylabel('Option Price')
    ax.set_title(f'{option_type.capitalize()} Price vs {param_name}')
    ax.grid(True)
    ax.legend()
    plt.tight_layout()
    return fig

# Scan implied volatility
def scan_implied_volatility(model, S=100, K=100, T=0.5, r=0.05):
    """
    Scan model predictions across different volatility levels.
    
    Args:
        model: Trained OptionPricingKAN model
        S, K, T, r: Fixed option parameters
    
    Returns:
        Figure object
    """
    # Generate volatility range
    volatilities = jnp.linspace(0.05, 0.5, 50)
    
    # Initialize arrays for prices
    kan_call_prices = []
    kan_put_prices = []
    bs_call_prices = []
    bs_put_prices = []
    
    # Calculate prices for each volatility
    for sigma in volatilities:
        # KAN prices
        inputs = jnp.array([[S, K, T, r, sigma]])
        prices = model(inputs)
        kan_call_prices.append(prices[0, 0].item())
        kan_put_prices.append(prices[0, 1].item())
        
        # Black-Scholes prices
        bs_call_prices.append(black_scholes(S, K, T, r, sigma, 'call').item())
        bs_put_prices.append(black_scholes(S, K, T, r, sigma, 'put').item())
    
    # Convert to arrays
    kan_call_prices = jnp.array(kan_call_prices)
    kan_put_prices = jnp.array(kan_put_prices)
    bs_call_prices = jnp.array(bs_call_prices)
    bs_put_prices = jnp.array(bs_put_prices)
    
    # Calculate errors
    call_errors = jnp.abs(kan_call_prices - bs_call_prices)
    put_errors = jnp.abs(kan_put_prices - bs_put_prices)
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot call prices
    axs[0, 0].plot(volatilities, kan_call_prices, 'b-', label='KAN')
    axs[0, 0].plot(volatilities, bs_call_prices, 'r--', label='Black-Scholes')
    axs[0, 0].set_title('Call Price vs. Volatility')
    axs[0, 0].set_xlabel('Volatility (σ)')
    axs[0, 0].set_ylabel('Call Price')
    axs[0, 0].legend()
    axs[0, 0].grid(True)
    
    # Plot put prices
    axs[0, 1].plot(volatilities, kan_put_prices, 'b-', label='KAN')
    axs[0, 1].plot(volatilities, bs_put_prices, 'r--', label='Black-Scholes')
    axs[0, 1].set_title('Put Price vs. Volatility')
    axs[0, 1].set_xlabel('Volatility (σ)')
    axs[0, 1].set_ylabel('Put Price')
    axs[0, 1].legend()
    axs[0, 1].grid(True)
    
    # Plot call errors
    axs[1, 0].plot(volatilities, call_errors, 'g-')
    axs[1, 0].set_title('Call Price Absolute Error')
    axs[1, 0].set_xlabel('Volatility (σ)')
    axs[1, 0].set_ylabel('Absolute Error')
    axs[1, 0].grid(True)
    
    # Plot put errors
    axs[1, 1].plot(volatilities, put_errors, 'g-')
    axs[1, 1].set_title('Put Price Absolute Error')
    axs[1, 1].set_xlabel('Volatility (σ)')
    axs[1, 1].set_ylabel('Absolute Error')
    axs[1, 1].grid(True)
    
    plt.tight_layout()
    return fig
            
# Batch option pricing benchmark
def batch_option_pricing(model, num_options=10000):
    """
    Benchmark batch option pricing performance.
    
    Args:
        model: Trained OptionPricingKAN model
        num_options: Number of options for benchmarking
    
    Returns:
        Dictionary with benchmark results
    """
    # Generate random option parameters
    print(f"Benchmarking with {num_options} options...")
    X_test, Y_test = generate_training_data(num_options, jax.random.PRNGKey(888))
    
    # Benchmark KAN model
    start_time = time.time()
    kan_prices = model(X_test)
    kan_time = time.time() - start_time
    
    # Benchmark Black-Scholes
    start_time = time.time()
    bs_call_prices = vmap(lambda s, k, t, r, sig: black_scholes(s, k, t, r, sig, 'call'))(
        X_test[:, 0], X_test[:, 1], X_test[:, 2], X_test[:, 3], X_test[:, 4]
    )
    bs_put_prices = vmap(lambda s, k, t, r, sig: black_scholes(s, k, t, r, sig, 'put'))(
        X_test[:, 0], X_test[:, 1], X_test[:, 2], X_test[:, 3], X_test[:, 4]
    )
    bs_time = time.time() - start_time
    
    # Calculate performance metrics
    kan_ops = num_options / kan_time
    bs_ops = num_options / bs_time
    speed_ratio = bs_time / kan_time
    
    # Calculate error metrics
    bs_prices = jnp.column_stack([bs_call_prices, bs_put_prices])
    mae = jnp.mean(jnp.abs(kan_prices - bs_prices))
    mre = jnp.mean(jnp.abs(kan_prices - bs_prices) / jnp.maximum(bs_prices, 1e-10))
    
    # Prepare results
    results = {
        'num_options': num_options,
        'kan_time': kan_time,
        'bs_time': bs_time,
        'kan_ops': kan_ops,
        'bs_ops': bs_ops,
        'speed_ratio': speed_ratio,
        'mae': mae,
        'mre': mre
    }
    
    # Print results
    print("\nBatch Processing Benchmark:")
    print(f"Number of Options: {num_options}")
    print(f"KAN Time: {kan_time:.4f} seconds")
    print(f"Black-Scholes Time: {bs_time:.4f} seconds")
    print(f"KAN Options/Second: {kan_ops:.0f}")
    print(f"Black-Scholes Options/Second: {bs_ops:.0f}")
    print(f"Speed Ratio (BS/KAN): {speed_ratio:.2f}x")
    print(f"Mean Absolute Error: {mae:.6f}")
    print(f"Mean Relative Error: {mre:.2%}")
    
    return results


# Example usage in the main function
if __name__ == "__main__":
    # This is the main entry point when the script is run directly
    
    # For easy configuration, set your parameters here
    RUN_MODE = "train_and_evaluate"  # Choose from available modes
    MODEL_FILE = "kan_option_model.pkl"  # Set to None to skip saving/loading
    NUM_EPOCHS = 100
    BATCH_SIZE = 64
    HIDDEN_DIMS = [64, 32]
    NUM_BASIS = 30
    
    # Option parameters for pricing
    OPTION_PARAMS = {
        'S': 100,  # Stock price
        'K': 100,  # Strike price
        'T': 0.5,  # Time to maturity (years)
        'r': 0.05,  # Risk-free rate
        'sigma': 0.2  # Volatility
    }
    
    # Run the main function with selected parameters
    model, results = run_kan_option_pricing(
        mode=RUN_MODE,
        model_file=MODEL_FILE,
        num_epochs=NUM_EPOCHS,
        batch_size=BATCH_SIZE,
        hidden_dims=HIDDEN_DIMS,
        num_basis=NUM_BASIS,
        option_params=OPTION_PARAMS
    )
    
    # Display some key results if available
    if 'metrics' in results:
        print("\nPerformance Summary:")
        metrics = results['metrics']
        print(f"Mean Absolute Error: {metrics['mean_abs_error']:.6f}")
        print(f"Mean Relative Error: {metrics['mean_rel_error']:.2%}")
        print(f"R-squared (Call): {metrics['r_squared_call']:.4f}")
        print(f"R-squared (Put): {metrics['r_squared_put']:.4f}")
        
    if 'batch_results' in results:
        batch = results['batch_results']
        print("\nBatch Processing Performance:")
        print(f"Speed Ratio (BS/KAN): {batch['speed_ratio']:.2f}x")
        print(f"KAN Options/Second: {batch['kan_ops']:.0f}")
        
    # Keep the plot windows open until closed by user
    plt.show()
