In [None]:
import math
from dataclasses import dataclass
from typing import List, Tuple, Literal, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# from caas_jupyter_tools import display_dataframe_to_user


OptionType = Literal["call", "put"]
ExerciseStyle = Literal["european", "american"]


@dataclass
class Market:
    S0: float            # spot
    r: float             # risk-free rate 
    sigma: float         # volatility (annualized)
    T: float             # maturity in years


@dataclass
class VanillaOption:
    K: float
    type: OptionType
    style: ExerciseStyle = "european"


def bs_price(mkt: Market, opt: VanillaOption) -> float:
    """Black-Scholes price with no dividend"""
    S0, K, r, sigma, T = mkt.S0, opt.K, mkt.r, mkt.sigma, mkt.T
    if T <= 0 or sigma <= 0:
        # Immediate expiry or zero vol edge-cases
        intrinsic = max(0.0, (S0 - K)) if opt.type == "call" else max(0.0, (K - S0))
        return intrinsic
    d1 = (math.log(S0 / K) + (r + 0.5 * sigma ** 2) * T) / (sigma * math.sqrt(T))
    d2 = d1 - sigma * math.sqrt(T)
    from math import erf, sqrt, exp
    def N(x):
        return 0.5 * (1.0 + erf(x / sqrt(2.0)))
    if opt.type == "call":
        return S0 * N(d1) - K * math.exp(-r * T) * N(d2)
    else:
        return K * math.exp(-r * T) * N(-d2) - S0 * N(-d1)


class TrinomialTree:
    """
    
      - middle branch to forward S*exp(r*dt)
      - node spacing ratio alpha = exp(sigma * sqrt(3*dt))
      - no dividends in this first step
      - simplified probabilities for no-div case
    """
    def __init__(self, mkt: Market, nb_steps: int):
        assert nb_steps >= 1
        self.mkt = mkt
        self.N = nb_steps
        self.dt = mkt.T / nb_steps
        self.alpha = math.exp(mkt.sigma * math.sqrt(3.0 * self.dt))

        # risk-neutral probabilities (no dividend case)
        a = self.alpha
        v = math.exp(mkt.sigma**2 * self.dt) - 1.0
        denom = (1.0 - a) * ((a**-2) - 1.0)
        self.p_down = v / denom
        self.p_up = self.p_down / a
        self.p_mid = 1.0 - self.p_up - self.p_down

        # check up to verify proba are valid ( +/- small epsilon for noise)
        for p in (self.p_up, self.p_mid, self.p_down):
            if p < -1e-12 or p > 1 + 1e-12:
                raise ValueError(f"Probability out of [0,1]: {p} (N={nb_steps})")
            
        # Clamp tiny negatives/overshoots from numerical noise
        self.p_up = max(0.0, min(1.0, self.p_up))
        self.p_mid = max(0.0, min(1.0, self.p_mid))
        self.p_down = max(0.0, min(1.0, self.p_down))

        self.df = math.exp(-mkt.r * self.dt)

    def build_price_grid(self) -> List[np.ndarray]:
        """
        Build and return the underlying price grid as a list of arrays per time step.

        """
        S_layers: List[np.ndarray] = []
        # time 0
        S_layers.append(np.array([self.mkt.S0]))

        # For each step, compute mid first (forward), then up/down by powers of alpha
        for i in range(1, self.N + 1):
            prev = S_layers[-1]
            n_nodes = 2 * i + 1
            S = np.empty(n_nodes)
            # compute mid node by taking the closest to forward from previous trunk
            S_mid_prev = prev[len(prev)//2]
            S_mid = S_mid_prev * math.exp(self.mkt.r * self.dt)
            # indices: j from -i..+i => position j+i in array
            for j in range(-i, i + 1):
                S[j + i] = S_mid * (self.alpha ** j)
            S_layers.append(S)
        return S_layers

    def price(self, opt: VanillaOption) -> float:
        S_layers = self.build_price_grid()
        N = self.N
        # terminal payoff
        ST = S_layers[-1]
        if opt.type == "call":
            V = np.maximum(ST - opt.K, 0.0)
        else:
            V = np.maximum(opt.K - ST, 0.0)

        # backward induction
        for i in range(N - 1, -1, -1):
            S_i = S_layers[i]
            V_next = V
            V = np.empty_like(S_i)
            # at layer i there are 2*i+1 nodes; next layer has 2*(i+1)+1
            # mapping: node j at i goes to (j-1, j, j+1) at i+1
            for j in range(-i, i + 1):
                up = V_next[(j + 1) + (i + 1)]
                mid = V_next[(j) + (i + 1)]
                down = V_next[(j - 1) + (i + 1)]
                hold = self.df * (self.p_up * up + self.p_mid * mid + self.p_down * down)
                if opt.style == "american":
                    intrinsic = (S_i[j + i] - opt.K) if opt.type == "call" else (opt.K - S_i[j + i])
                    intrinsic = max(intrinsic, 0.0)
                    V[j + i] = max(hold, intrinsic)
                else:
                    V[j + i] = hold
        return float(V[0])


# demo
mkt = Market(S0=100.0, r=0.02, sigma=0.25, T=1.0)
call = VanillaOption(K=100.0, type="call", style="european")
put  = VanillaOption(K=100.0, type="put",  style="european")

for N in [5, 10, 25, 50, 100, 200, 2000]:
    tree = TrinomialTree(mkt, N)
    tri_call = tree.price(call)
    tri_put = tree.price(put)
    bs_call = bs_price(mkt, call)
    bs_put = bs_price(mkt, put)
    print(f"N={N:3d}  Call: tri={tri_call:.6f}  BS={bs_call:.6f}  gap={tri_call-bs_call:+.6e} | "
          f"Put: tri={tri_put:.6f}  BS={bs_put:.6f}  gap={tri_put-bs_put:+.6e}")



N=  5  Call: tri=10.619955  BS=10.870558  gap=-2.506033e-01 | Put: tri=8.639823  BS=8.890426  gap=-2.506033e-01
N= 10  Call: tri=10.802193  BS=10.870558  gap=-6.836573e-02 | Put: tri=8.822060  BS=8.890426  gap=-6.836573e-02
N= 25  Call: tri=10.876008  BS=10.870558  gap=+5.449087e-03 | Put: tri=8.895875  BS=8.890426  gap=+5.449087e-03
N= 50  Call: tri=10.885986  BS=10.870558  gap=+1.542722e-02 | Put: tri=8.905853  BS=8.890426  gap=+1.542722e-02
N=100  Call: tri=10.882547  BS=10.870558  gap=+1.198899e-02 | Put: tri=8.902415  BS=8.890426  gap=+1.198899e-02
N=200  Call: tri=10.874959  BS=10.870558  gap=+4.400373e-03 | Put: tri=8.894826  BS=8.890426  gap=+4.400373e-03
N=2000  Call: tri=10.869786  BS=10.870558  gap=-7.723958e-04 | Put: tri=8.889653  BS=8.890426  gap=-7.723958e-04
