# Implied Volatility Inference with JAX

This notebook demonstrates how to recover ("invert") the implied volatility of a European vanilla option using:

- Black–Scholes pricing implemented in JAX
- Automatic differentiation to obtain Vega (gradient wrt volatility)
- Newton–Raphson root-finding

We will:
1. Price a synthetic option and recover its volatility.
2. Compare autodiff Vega vs. analytic Vega.
3. Visualize the loss function and Newton iterations.
4. (Optional) Pull a real option chain with `yfinance` and compute a mini volatility smile.

> Educational, work-in-progress. Not production code or trading advice.


In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

import sys
import warnings
from datetime import datetime, timezone

import numpy as np
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

sys.path.append("..")
from src.model.main import black_scholes, NR_for_sigma

Using Newton–Raphson and `jax.grad` to compute implied ("intrinsic") volatility.

Given a market option premium C_mkt we solve for σ such that:

BlackScholes(S, K, T, r, σ, q, otype) = C_mkt

Define f(σ) = model_price(σ) - C_mkt. Newton step:

σ_{n+1} = σ_n - f(σ_n) / f'(σ_n)

Here f'(σ) (Vega) is supplied by automatic differentiation instead of a manually coded analytic derivative.

We will later:
- Compare autodiff Vega vs closed-form Vega
- Plot f(σ) to visualize curvature and convergence behavior
- Extend to multiple strikes from a real option chain

In [2]:
S, K, T, r, q = 100, 110, 1, 0.05, 0.0
otype = "call"

true_sigma = 0.2
market_price = black_scholes(S, K, T, r, true_sigma, q, otype)
print(market_price)



6.040088129724232


In [3]:
inferred_sigma = NR_for_sigma(
    S,
    K,
    T,
    r,
    market_price,
    sigma_guess=0.3,
    q=q,
    otype=otype,
    tol=1e-7,
    max_iter=100,
    verbose=True,
)

print()
print(f"Inferred sigma: {inferred_sigma}")
print(f"True sigma: {true_sigma}")

Iter 0: sigma=0.3, loss=3.979989490331725
Iter 1: sigma=0.20023640478353788, loss=0.00935619522458353
Iter 2: sigma=0.20000000575577456, loss=2.2779081376711474e-07
Iter 3: sigma=0.19999999999999993, loss=7.105427357601002e-15

Inferred sigma: 0.19999999999999993
True sigma: 0.2


### Analytic Vega vs Autodiff
We'll add the closed-form Vega and compare it to the gradient produced by JAX to ensure correctness.

In [None]:
import math

# Analytic Black-Scholes Vega (undiscounted form times discount factors)
# Vega = S * exp(-q T) * phi(d1) * sqrt(T)

def analytic_vega(S, K, T, r, sigma, q=0.0):
    d1 = (jnp.log(S / K) + (r - q + 0.5 * sigma**2) * T) / (sigma * jnp.sqrt(T))
    pdf = 1.0 / jnp.sqrt(2 * jnp.pi) * jnp.exp(-0.5 * d1**2)
    return S * jnp.exp(-q * T) * pdf * jnp.sqrt(T)

# Use autodiff on model price directly for comparison
price_wrt_sigma = jax.grad(lambda sig: black_scholes(S, K, T, r, sig, q, otype))

vega_ad = price_wrt_sigma(true_sigma)
vega_analytic = analytic_vega(S, K, T, r, true_sigma, q)

print(f"Autodiff Vega:  {vega_ad}")
print(f"Analytic Vega:  {vega_analytic}")
print(f"Abs diff:       {jnp.abs(vega_ad - vega_analytic)}")
print(f"Rel diff:       {jnp.abs(vega_ad - vega_analytic) / vega_analytic}")

### Visualizing the Loss Function and Newton Steps
We'll sample the loss over a grid of sigma values and overlay the Newton iterates.

In [None]:
# Capture Newton iterations manually for plotting

sig0 = 0.30
sig_iter = []
loss_iter = []

sigma = sig0
for i in range(12):  # hard cap
    f = (black_scholes(S, K, T, r, sigma, q, otype) - market_price)
    sig_iter.append(float(sigma))
    loss_iter.append(float(f))
    if jnp.abs(f) < 1e-8:
        break
    vega = jax.grad(lambda s: black_scholes(S, K, T, r, s, q, otype))(sigma)
    sigma = sigma - f / vega

# Grid for visualization
sig_grid = jnp.linspace(0.05, 0.60, 200)
loss_grid = black_scholes(S, K, T, r, sig_grid, q, otype) - market_price

fig, ax = plt.subplots(1,2, figsize=(10,4))
ax[0].plot(sig_grid, loss_grid, label='loss(sigma)')
ax[0].axhline(0, color='k', lw=0.7)
ax[0].scatter(sig_iter, loss_iter, color='red', zorder=5, label='Newton path')
for i,(s,l) in enumerate(zip(sig_iter, loss_iter)):
    ax[0].annotate(str(i), (s,l), textcoords="offset points", xytext=(4,4), fontsize=8)
ax[0].set_xlabel('sigma')
ax[0].set_ylabel('loss')
ax[0].legend()

price_grid = black_scholes(S, K, T, r, sig_grid, q, otype)
ax[1].plot(sig_grid, price_grid, label='Price(sigma)')
ax[1].axhline(float(market_price), color='k', lw=0.7, linestyle='--', label='Market price')
ax[1].scatter([true_sigma],[float(market_price)], color='green', label='True sigma')
ax[1].scatter(sig_iter, black_scholes(S,K,T,r,jnp.array(sig_iter), q, otype), color='red', s=20)
ax[1].set_xlabel('sigma')
ax[1].set_ylabel('Option Price')
ax[1].legend()
plt.tight_layout()
plt.show()

### Mini Option Chain Example
Fetch a live chain (AAPL by default), take a small subset, and compute implied vols for mid prices (illustrative only).

In [None]:
from src.loader.main import get_option_chain

try:
    chain = get_option_chain("AAPL", clean=True)
    display(chain.head())
except Exception as e:
    print("Failed to load option chain:", e)
    chain = None

# Select a slice: near-the-money calls with first expiry
if chain is not None and not chain.empty:
    first_exp = chain['expiry'].min()
    sub = chain[(chain['expiry']==first_exp) & (chain['otype']=='call')].copy()
    sub = sub.sort_values('K').head(8)  # small sample
    results = []
    for _, row in sub.iterrows():
        S_row = float(row['spot'])
        K_row = float(row['K'])
        T_row = float(row['T'])
        mid = float(row['mid'])
        # crude guess for sigma
        guess = 0.25
        try:
            impl = NR_for_sigma(S_row, K_row, T_row, r=0.05, market_price=mid, sigma_guess=guess, q=0.0, otype='call', tol=1e-6, max_iter=100)
            results.append({"K":K_row, "T":T_row, "mid":mid, "implied_vol": float(impl)})
        except Exception as err:
            results.append({"K":K_row, "T":T_row, "mid":mid, "implied_vol": float('nan')})
    iv_df = pd.DataFrame(results)
    display(iv_df)

    # Simple smile plot
    if not iv_df.empty:
        plt.figure(figsize=(5,3))
        plt.plot(iv_df['K'], iv_df['implied_vol'], marker='o')
        plt.xlabel('Strike')
        plt.ylabel('Implied Vol')
        plt.title('Mini Call Smile (First Expiry)')
        plt.tight_layout()
        plt.show()

### Wrap-Up
We demonstrated:
- Autodiff-based Vega matches analytic Vega closely
- Newton–Raphson converges rapidly with a reasonable initial guess
- Visual diagnostics (loss curve) help understand convergence
- Extension to a small live option chain to build an embryonic smile

Next steps (future work): batch vectorized IV surface, robustness improvements (fallback solvers), and performance benchmarking of autodiff vs analytic Greeks.
