In [None]:
!pip install optax yfinance lxml plotly hmmlearn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import numpy as np
import optax
from functools import partial
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Callable, Any

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# Define the KAN Layer for yield curve modeling
class KANLayer:
    def __init__(self, input_dim: int, output_dim: int, num_basis: int = 20, 
                 domain=(-3.0, 3.0), key=None):
        """Initialize a KAN layer with learnable activation functions.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension
            num_basis: Number of basis functions for learned activation
            domain: Domain over which the activation functions are defined
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        key1, key2 = jax.random.split(key)
        
        # Initialize weights for linear transformation
        self.weights = jax.random.normal(key1, (input_dim, output_dim)) * 0.1
        
        # Initialize biases
        self.biases = jax.random.normal(key2, (output_dim,)) * 0.01
        
        # Grid points for activation function representation
        self.grid_points = jnp.linspace(domain[0], domain[1], num_basis)
        
        # Initialize activation function values at grid points
        # For yield curves, we'll start with something that can model exponential decay
        # and other common yield curve shapes
        key3 = jax.random.split(key2)[0]
        
        # Initialize different activation patterns for different outputs
        activations_list = []
        for i in range(output_dim):
            if i % 3 == 0:  # Exponential-like decay (common in yield curves)
                act = jnp.exp(-0.5 * self.grid_points) * (self.grid_points > 0)
            elif i % 3 == 1:  # Hump-shaped (for modeling yield curve humps)
                act = jnp.exp(-0.5 * (self.grid_points - 1.0)**2)
            else:  # More flexible activation
                act = 0.5 * (1 + jnp.tanh(self.grid_points))
            
            # Add some random noise to break symmetry
            act = act + jax.random.normal(key3, (num_basis,)) * 0.01
            activations_list.append(act)
        
        # Stack into a matrix: (output_dim, num_basis)
        self.activations = jnp.stack(activations_list)
        
        # Store domain for clipping
        self.domain = domain
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN layer.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        # Linear transformation
        z = jnp.dot(x, self.weights) + self.biases  # Shape: (batch_size, output_dim)
        
        # Apply learned activation functions
        # First, clip inputs to domain
        z_clipped = jnp.clip(z, self.domain[0], self.domain[1])
        
        # For each output dimension, interpolate activation function
        def apply_activation(z_i, i):
            """Apply the i-th activation function to z_i."""
            # Linear interpolation between grid points
            idx = jnp.searchsorted(self.grid_points, z_i) - 1
            idx = jnp.clip(idx, 0, len(self.grid_points) - 2)
            
            # Get surrounding points
            x0 = self.grid_points[idx]
            x1 = self.grid_points[idx + 1]
            y0 = self.activations[i, idx]
            y1 = self.activations[i, idx + 1]
            
            # Linear interpolation: y = y0 + (y1 - y0) * (x - x0) / (x1 - x0)
            t = (z_i - x0) / (x1 - x0)
            return y0 + t * (y1 - y0)
        
        # Apply activation function for each element in the batch and each output dimension
        output = jnp.zeros_like(z)
        for i in range(z.shape[1]):  # For each output dimension
            output = output.at[:, i].set(vmap(lambda z_i: apply_activation(z_i, i))(z[:, i]))
        
        return output

# Full KAN model for yield curve modeling
class YieldCurveKAN:
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [64, 32], 
                 num_basis: int = 40, domain=(-3.0, 3.0), key=None):
        """Initialize a KAN model for yield curve modeling.
        
        Args:
            input_dim: Input dimension (typically factors affecting yield curve)
            output_dim: Output dimension (typically different tenors on the yield curve)
            hidden_dims: List of hidden dimensions
            num_basis: Number of basis functions for learned activations
            domain: Domain for activation functions
            key: JAX random key
        """
        if key is None:
            key = jax.random.PRNGKey(0)
        
        keys = jax.random.split(key, len(hidden_dims) + 1)
        
        # Initialize layers
        self.layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layer = KANLayer(prev_dim, hidden_dim, num_basis, domain, keys[i])
            self.layers.append(layer)
            prev_dim = hidden_dim
        
        # Final layer specifically designed for yield curve output
        self.output_layer = KANLayer(prev_dim, output_dim, num_basis, domain, keys[-1])
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the KAN model.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim) representing yield curves
        """
        for layer in self.layers:
            x = layer(x)
        
        # Apply output layer
        y = self.output_layer(x)
        
        # Ensure yield curve outputs are positive (yields are typically positive)
        # We use softplus for positivity with a small scaling factor to prevent zeros
        return jax.nn.softplus(y)
    
    @property
    def params(self):
        """Get model parameters as a flat dictionary."""
        params = {}
        for i, layer in enumerate(self.layers):
            params[f'layer_{i}_weights'] = layer.weights
            params[f'layer_{i}_biases'] = layer.biases
            params[f'layer_{i}_activations'] = layer.activations
        
        params['output_layer_weights'] = self.output_layer.weights
        params['output_layer_biases'] = self.output_layer.biases
        params['output_layer_activations'] = self.output_layer.activations
        
        return params
    
    def update_params(self, params):
        """Update model parameters from a flat dictionary."""
        for i, layer in enumerate(self.layers):
            layer.weights = params[f'layer_{i}_weights']
            layer.biases = params[f'layer_{i}_biases']
            layer.activations = params[f'layer_{i}_activations']
        
        self.output_layer.weights = params['output_layer_weights']
        self.output_layer.biases = params['output_layer_biases']
        self.output_layer.activations = params['output_layer_activations']

# Nelson-Siegel-Svensson model for yield curve generation
def nelson_siegel_svensson(t, beta0, beta1, beta2, beta3, tau1, tau2):
    """
    Nelson-Siegel-Svensson model for yield curve generation.
    
    Args:
        t: Array of tenors (time to maturity)
        beta0, beta1, beta2, beta3: Shape parameters
        tau1, tau2: Time constants
    
    Returns:
        Yield curve rates for each tenor
    """
    # Handle division by zero for small tenors
    t_safe = jnp.maximum(t, 1e-10)
    
    # Calculate terms
    term1 = 1.0
    term2 = (1.0 - jnp.exp(-t_safe / tau1)) / (t_safe / tau1)
    term3 = term2 - jnp.exp(-t_safe / tau1)
    term4 = (1.0 - jnp.exp(-t_safe / tau2)) / (t_safe / tau2) - jnp.exp(-t_safe / tau2)
    
    # Combine terms to get yield curve
    y = beta0 * term1 + beta1 * term2 + beta2 * term3 + beta3 * term4
    
    return y

# Generate synthetic yield curve data for training
def generate_yield_curve_data(num_samples=100000, num_tenors=10):
    """
    Generate synthetic yield curve data using Nelson-Siegel-Svensson model.
    
    Args:
        num_samples: Number of yield curves to generate
        num_tenors: Number of points on each yield curve
    
    Returns:
        X: Economic factors driving the yield curve
        Y: Yield curves
    """
    key = jax.random.PRNGKey(123)
    keys = jax.random.split(key, 8)
    
    # Define tenors (in years)
    tenors = jnp.array([0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0, 20.0, 30.0])[:num_tenors]
    
    # Generate random parameters for Nelson-Siegel-Svensson model
    beta0 = jax.random.uniform(keys[0], (num_samples,), minval=0.01, maxval=0.05)  # Long-term rate
    beta1 = jax.random.uniform(keys[1], (num_samples,), minval=-0.03, maxval=0.03)  # Short-term component
    beta2 = jax.random.uniform(keys[2], (num_samples,), minval=-0.03, maxval=0.03)  # Medium-term component
    beta3 = jax.random.uniform(keys[3], (num_samples,), minval=-0.03, maxval=0.03)  # Second hump component
    tau1 = jax.random.uniform(keys[4], (num_samples,), minval=0.5, maxval=3.0)  # First time constant
    tau2 = jax.random.uniform(keys[5], (num_samples,), minval=3.0, maxval=10.0)  # Second time constant
    
    # Economic factors that drive yield curves (for input to our model)
    # In practice, these would be macroeconomic variables like inflation, GDP growth, etc.
    # Here we'll generate synthetic factors that correlate with our NSS parameters
    num_factors = 5  # Number of economic factors
    
    # Create correlations between economic factors and NSS parameters
    factor_noise = jax.random.normal(keys[6], (num_samples, num_factors)) * 0.2
    
    # Factor 1: Correlates with level (beta0)
    # Factor 2: Correlates with slope (beta1)
    # Factor 3: Correlates with curvature (beta2)
    # Factors 4-5: Additional factors with some correlation to all parameters
    
    factors = jnp.zeros((num_samples, num_factors))
    factors = factors.at[:, 0].set(beta0 * 20.0 + factor_noise[:, 0])  # Level factor
    factors = factors.at[:, 1].set(beta1 * 20.0 + factor_noise[:, 1])  # Slope factor
    factors = factors.at[:, 2].set(beta2 * 20.0 + factor_noise[:, 2])  # Curvature factor
    factors = factors.at[:, 3].set(beta3 * 10.0 + beta0 * 5.0 + factor_noise[:, 3])  # Mixed factor 1
    factors = factors.at[:, 4].set(tau1 * 0.5 - tau2 * 0.2 + factor_noise[:, 4])  # Mixed factor 2
    
    # Generate yield curves using Nelson-Siegel-Svensson model
    yield_curves = jnp.zeros((num_samples, num_tenors))
    
    for i in range(num_samples):
        yield_curves = yield_curves.at[i].set(
            nelson_siegel_svensson(tenors, beta0[i], beta1[i], beta2[i], beta3[i], tau1[i], tau2[i])
        )
    
    return factors, yield_curves, tenors

# Training functions
@jit
def loss_fn(params, X, Y):
    """Mean squared error loss function for yield curve prediction."""
    model = YieldCurveKAN(X.shape[1], Y.shape[1])
    model.update_params(params)
    pred = model(X)
    return jnp.mean((pred - Y) ** 2)

@jit
def train_step(params, X, Y, opt_state):
    """Single optimization step."""
    loss_value, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

def train_model(X, Y, params, num_epochs=100, batch_size=64):
    """Train the model for a specified number of epochs."""
    num_samples = X.shape[0]
    num_batches = num_samples // batch_size
    
    losses = []
    
    # Initialize optimizer state
    opt_state = optimizer.init(params)
    
    for epoch in range(num_epochs):
        # Shuffle data
        perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        X_shuffled = X[perm]
        Y_shuffled = Y[perm]
        
        epoch_loss = 0.0
        
        for batch in range(num_batches):
            start_idx = batch * batch_size
            end_idx = start_idx + batch_size
            
            X_batch = X_shuffled[start_idx:end_idx]
            Y_batch = Y_shuffled[start_idx:end_idx]
            
            params, opt_state, batch_loss = train_step(params, X_batch, Y_batch, opt_state)
            epoch_loss += batch_loss
        
        epoch_loss /= num_batches
        losses.append(epoch_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
    
    return params, losses

# Post-training analysis and visualization
def analyze_yield_curve_model(model, params, X_test, Y_test, tenors):
    """Analyze and visualize the trained yield curve model."""
    # Update model with trained parameters
    model.update_params(params)
    
    # Make predictions
    predictions = model(X_test)
    
    # Calculate mean absolute error
    mae = jnp.mean(jnp.abs(predictions - Y_test))
    print(f"Mean Absolute Error: {mae:.6f}")
    
    # Plot actual vs. predicted yield curves for a few examples
    num_examples = min(5, X_test.shape[0])
    
    fig, axs = plt.subplots(num_examples, 1, figsize=(10, 3 * num_examples))
    if num_examples == 1:
        axs = [axs]
    
    for i in range(num_examples):
        axs[i].plot(tenors, Y_test[i], 'b-o', label='Actual')
        axs[i].plot(tenors, predictions[i], 'r--x', label='Predicted')
        axs[i].set_title(f'Yield Curve Example {i+1}')
        axs[i].set_xlabel('Tenor (years)')
        axs[i].set_ylabel('Yield (%)')
        axs[i].grid(True)
        axs[i].legend()
    
    plt.tight_layout()
    
    # Analyze the learned activation functions
    # These can provide insights into how the model is capturing yield curve shapes
    fig2, axs2 = plt.subplots(2, 2, figsize=(12, 8))
    
    # Plot a few activation functions from the output layer
    grid_points = model.output_layer.grid_points
    output_activations = model.output_layer.activations
    
    for i in range(min(4, output_activations.shape[0])):
        row, col = i // 2, i % 2
        axs2[row, col].plot(grid_points, output_activations[i])
        axs2[row, col].set_title(f'Learned Activation for Tenor {tenors[i]:.1f}y')
        axs2[row, col].set_xlabel('Input')
        axs2[row, col].set_ylabel('Activation')
        axs2[row, col].grid(True)
    
    plt.tight_layout()
    
    # Calculate and visualize principal components of yield curve dynamics
    from sklearn.decomposition import PCA
    
    # Convert to numpy for PCA
    Y_np = np.array(Y_test)
    
    # Fit PCA to the yield curves
    pca = PCA(n_components=3)
    pca.fit(Y_np)
    
    # Plot the principal components
    fig3, ax3 = plt.subplots(figsize=(10, 6))
    
    component_names = ['Level', 'Slope', 'Curvature']
    for i in range(3):
        ax3.plot(tenors, pca.components_[i], label=component_names[i])
    
    ax3.set_title('Principal Components of Yield Curve Dynamics')
    ax3.set_xlabel('Tenor (years)')
    ax3.set_ylabel('Loading')
    ax3.grid(True)
    ax3.legend()
    
    return fig, fig2, fig3, mae

# Example: Calculate yield curve derivatives for risk management
def calculate_yield_curve_derivatives(model, params, X):
    """Calculate derivatives of yield curves with respect to input factors."""
    # Update model with parameters
    model.update_params(params)
    
    # Create a function that returns the yield curve given input factors
    def yield_curve_fn(x):
        return model(x.reshape(1, -1))[0]
    
    # Calculate Jacobian (sensitivities to all factors)
    jacobian_fn = jax.jacfwd(yield_curve_fn)
    
    # Calculate sensitivities for each example
    sensitivities = jax.vmap(jacobian_fn)(X)
    
    # sensitivities shape: (num_samples, num_tenors, num_factors)
    return sensitivities

# Yield curve interpolation and extrapolation
def interpolate_yield_curve(model, params, X, target_tenors, original_tenors):
    """Interpolate yield curves to arbitrary tenors."""
    # First, get the predicted yield curves at the original tenors
    model.update_params(params)
    predicted_curves = model(X)
    
    # Helper function to interpolate a single yield curve
    def interpolate_single_curve(curve):
        # Use cubic spline interpolation
        from scipy.interpolate import CubicSpline
        cs = CubicSpline(original_tenors, curve)
        return cs(target_tenors)
    
    # Apply to each curve
    interpolated_curves = []
    for i in range(predicted_curves.shape[0]):
        curve = np.array(predicted_curves[i])
        interpolated = interpolate_single_curve(curve)
        interpolated_curves.append(interpolated)
    
    return np.array(interpolated_curves)

# Main function to run everything
def main():
    # Generate synthetic data
    print("Generating synthetic yield curve data...")
    X, Y, tenors = generate_yield_curve_data(num_samples=2000, num_tenors=10)
    
    # Split into train/test sets
    train_size = int(0.8 * X.shape[0])
    X_train, X_test = X[:train_size], X[train_size:]
    Y_train, Y_test = Y[:train_size], Y[train_size:]
    
    # Initialize model
    print("Initializing KAN model for yield curve modeling...")
    input_dim = X.shape[1]    # Number of economic factors
    output_dim = Y.shape[1]   # Number of tenors
    
    model = YieldCurveKAN(input_dim, output_dim, hidden_dims=[32, 16], num_basis=40)
    
    # Initialize optimizer
    global optimizer  # Make accessible to JIT-compiled functions
    learning_rate = 0.001
    optimizer = optax.adam(learning_rate)
    
    # Train model
    print("Training model...")
    trained_params, losses = train_model(X_train, Y_train, model.params, num_epochs=100)
    
    # Analyze results
    print("Analyzing results...")
    fig1, fig2, fig3, mae = analyze_yield_curve_model(model, trained_params, X_test, Y_test, tenors)
    
    # Calculate sensitivities
    print("Calculating yield curve sensitivities...")
    sensitivities = calculate_yield_curve_derivatives(model, trained_params, X_test[:5])
    
    # Return results
    return model, trained_params, losses, fig1, fig2, fig3, sensitivities

if __name__ == "__main__":
    main()