In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from finmc.models.heston import HestonMC
from finmc.models.localvol import LVMC
from sankey_plot import plot_sankey

# Define ticker and spot price
ticker = "SPX"
spot = 2900

# Common Monte Carlo settings
mc_settings = {"PATHS": 100_000, "TIMESTEP": 1 / 250, "SEED": 1}

# --- Heston Model Dataset ---
heston_dataset = {
    "MC": mc_settings,
    "BASE": "USD",
    "ASSETS": {
        "USD": ("ZERO_RATES", np.array([[2.0, 0.05]])),
        ticker: ("FORWARD", np.array([[0.0, spot], [1.0, spot * 1.02]])),
    },
    "HESTON": {
        "ASSET": ticker,
        "INITIAL_VAR": 0.04,
        "LONG_VAR": 0.04,
        "MEANREV": 10,
        "VOL_OF_VOL": 1,
        "CORRELATION": -1.0,
    },
}

# --- Local Volatility Model Dataset ---
lm = heston_dataset["HESTON"]["MEANREV"]
eta = heston_dataset["HESTON"]["VOL_OF_VOL"]
vl = heston_dataset["HESTON"]["LONG_VAR"]
vi = heston_dataset["HESTON"]["INITIAL_VAR"]

_lm = lm + eta / 2
_vl = vl * lm / _lm

def local_vol(points):
    """Local volatility function derived from the Heston parameters."""
    (t, x_vec) = points
    atm = (vi - _vl) * np.exp(-_lm * t) + _vl
    if t < 1e-6:
        shape = -eta  # Avoid division by zero at t=0
    else:
        shape = -eta * (1 - np.exp(-_lm * t)) / (_lm * t)

    out = x_vec * shape
    np.add(out, atm, out=out)
    np.maximum(0.001, out, out=out)  # Floor negative values
    np.sqrt(out, out=out)
    return out

# Local volatility model dataset
lv_dataset = {
    "MC": mc_settings,
    "BASE": "USD",
    "ASSETS": heston_dataset["ASSETS"],  # Use the same asset settings
    "LV": {"ASSET": ticker, "VOL": local_vol},
}

# Initialize both models
heston_model = HestonMC(heston_dataset)
lv_model = LVMC(lv_dataset)

# Define bins dynamically based on spot price
bins = [spot / 1.1, spot * 1.1]
times = [0, 0.5, 1.0]

# --- Generate and Plot Sankey Diagrams ---
print("Sankey plot for Heston Model...")
plot_sankey(heston_model, ticker, times, bins)

print("Sankey plot for Local Volatility Model...")
plot_sankey(lv_model, ticker, times, bins)


Sankey plot for Heston Model...


Sankey plot for Local Volatility Model...
