In [None]:
import sys
sys.path.insert(0, "../../")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mixture.model import MixtureLognormalModel

data = pd.read_csv("../../data/sample_option_data.csv")
term = pd.read_csv("../../data/sample_term_structure.csv")

Ks = np.unique(data["Strike"].values)
Ts = np.sort(data["Maturity"].unique())
S0 = 100.0

market_prices = np.empty((len(Ts), len(Ks)))
for i, T in enumerate(Ts):
    for j, K in enumerate(Ks):
        row = data[(data["Maturity"] == T) & (data["Strike"] == K)]
        market_prices[i, j] = 0.5 * (row["Bid"].values[0] + row["Ask"].values[0])

r = np.interp(Ts, term["Time"], term["r"])
q = np.interp(Ts, term["Time"], term["q"])
b = np.interp(Ts, term["Time"], term["b"])

cash_divs = term[["div_time", "div_amount"]].dropna().values

model = MixtureLognormalModel(n_components=2)
model.fit_to_prices(market_prices, Ks, Ts, S0, r, q, b, is_call=True, cash_divs=cash_divs)

fitted_prices, _ = model.price_and_greeks(S0, Ks, Ts, r, q, b, is_call=True, cash_divs=cash_divs)

plt.figure(figsize=(10, 5))
for i, T in enumerate(Ts):
    plt.plot(Ks, market_prices[i], "o", label=f"Market T={T}")
    plt.plot(Ks, fitted_prices[i], "--", label=f"Fitted T={T}")
plt.title("Mixture Lognormal Fit (Market Data)")
plt.xlabel("Strike")
plt.ylabel("Price")
plt.legend()
plt.grid(True)
plt.show()
