In [None]:
"""
dxBTC Yield Model - Production-ready

Functions provided:
- stableswap D / swap approximations
- calculate_lp_fees_earned
- calculate_discount_recovery_apy
- calculate_gauge_base_rewards
- calculate_boost_multiplier
- calculate_yvfrBTC_yield
- calculate_dxBTC_APY
- simulate_full_system (runs one scenario)
- generate_parameter_grid (helper)
- run_sensitivity_analysis (runs many scenarios, returns pandas DataFrame)

Usage examples are in `if __name__ == '__main__'`.

Notes:
- Units: BTC (or consistent asset units) for liquidity/volume. APYs are decimals (0.05 = 5%).
- Precision: uses float64; for on-chain parity use fixed-point arithmetic.
"""

from __future__ import annotations

from itertools import product
import math
from typing import Dict, Any, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ----------------------------- Core financial helpers -----------------------------

def calculate_discount_recovery_apy(pLBTC_purchase_price: float, time_to_maturity_years: float) -> float:
    """
    Buying pLBTC at discount -> redeem at 1.0 at maturity

    APY = ((1.0 - price) / time_to_maturity_years) / price
    """
    if time_to_maturity_years <= 0:
        return 0.0
    gain_per_year = (1.0 - pLBTC_purchase_price) / time_to_maturity_years
    apy = gain_per_year / pLBTC_purchase_price
    return float(apy)


def calculate_lp_fees_earned(
    LP_owned: float,
    total_LP: float,
    daily_volume: float,
    fee_rate: float = 0.0004,
    admin_fee: float = 0.5,
) -> float:
    """
    Calculate LP fee APY for a given LP ownership amount.
    """

    daily_fees_total = daily_volume * fee_rate
    daily_fees_to_lp = daily_fees_total * (1 - admin_fee)
    lp_share = LP_owned / total_LP if total_LP > 0 else 0.0
    daily_fees_earned = daily_fees_to_lp * lp_share
    annual_fees = daily_fees_earned * 365.0
    lp_value = LP_owned if LP_owned > 0 else 1.0
    fee_apy = annual_fees / lp_value
    return float(fee_apy)


def calculate_gauge_base_rewards(
    LP_staked: float,
    total_LP_staked: float,
    reward_token_per_day: float,
    reward_token_price_btc: float,
    lp_token_price_btc: float = 1.0,
) -> float:
    """
    Base gauge reward APY (no boost)
    """
    if total_LP_staked <= 0 or lp_token_price_btc <= 0:
        return 0.0
    lp_share = LP_staked / total_LP_staked
    daily_rewards_token = reward_token_per_day * lp_share
    daily_rewards_usd = daily_rewards_token * reward_token_price_btc
    annual_rewards = daily_rewards_usd * 365.0
    lp_value = LP_staked * lp_token_price_btc
    base_apy = annual_rewards / lp_value
    return float(base_apy)


def calculate_boost_multiplier(
    user_LP_staked: float,
    total_LP_staked: float,
    user_vxFROST: float,
    total_vxFROST: float,
    min_boost: float = 1.0,
    max_boost: float = 2.5,
) -> float:
    """
    Compute boost multiplier from vxFROST holdings. Clamped between min_boost and max_boost.
    """
    if total_vxFROST <= 0:
        return float(min_boost)
    vx_share = user_vxFROST / total_vxFROST
    boost = min_boost + vx_share * (max_boost - min_boost)
    return float(min(boost, max_boost))


# ----------------------------- Stableswap approximations -----------------------------

def calculate_D_two_asset(reserves: Tuple[float, float], A: float) -> float:
    """
    Approximate D invariant for 2-asset stableswap. 
    """
    x, y = reserves
    return float(x + y)


def stableswap_get_y(x: float, A: float, D: float) -> float:
    """
    Approximate the new y reserve after x changes using a simplified formula.
    This is an approximation: a production implementation should use the exact
    iterative solver used by Curve-style stableswaps.
    """
    if A == 0:
        # constant product approximation
        if x == 0:
            return 0.0
        return (D ** 2) / (4.0 * x)
    # approximate amplified behavior
    c = (D ** 3) / (4.0 * A * x) if x > 0 else 0.0
    b = x + D / A
    # solve quadratic approx: y = (b + sqrt(b^2 + 4*c)) / 2
    disc = b * b + 4.0 * c
    y = (b + math.sqrt(disc)) / 2.0
    return float(y)


def stableswap_swap(amount_in: float, reserve_in: float, reserve_out: float, A: float, fee_rate: float = 0.0004) -> float:
    """
    Perform an approximate swap on a stableswap-style pool.
    Returns amount_out.
    """
    if reserve_in <= 0 or reserve_out <= 0:
        return 0.0
    D = calculate_D_two_asset((reserve_in, reserve_out), A)
    amount_in_with_fee = amount_in * (1.0 - fee_rate)
    new_reserve_in = reserve_in + amount_in_with_fee
    new_reserve_out = stableswap_get_y(new_reserve_in, A, D)
    amount_out = reserve_out - new_reserve_out
    return max(0.0, float(amount_out))


# ----------------------------- Higher-level strategies -----------------------------

def calculate_pLBTC_LP_strategy_apy(
    frbtc_allocated: float,
    pLBTC_frBTC_pool_tvl: float,
    daily_volume_pLBTC_pool: float,
    pLBTC_discount: float = 0.04,
    time_to_maturity_years: float = 1.0,
    fee_rate: float = 0.0004,
    admin_fee: float = 0.5,
) -> Dict[str, float]:
    """
    Returns fee_apy, discount_apy, and total_apy for strategy that LPs pLBTC/frBTC.
    """
    lp_owned = frbtc_allocated
    total_lp = pLBTC_frBTC_pool_tvl if pLBTC_frBTC_pool_tvl > 0 else 1.0
    fee_apy = calculate_lp_fees_earned(lp_owned, total_lp, daily_volume_pLBTC_pool, fee_rate, admin_fee)
    pLBTC_price = 1.0 - pLBTC_discount
    discount_apy = calculate_discount_recovery_apy(pLBTC_price, time_to_maturity_years)
    total_apy = fee_apy + discount_apy
    return {
        "fee_apy": float(fee_apy),
        "discount_apy": float(discount_apy),
        "total_apy": float(total_apy),
    }


def calculate_gauge_strategy_apy(
    frbtc_allocated: float,
    gauge_pool: Dict[str, Any],
    user_vxFROST_locked: float = 0.0,
) -> float:
    """
    Return APY from gauge staking; optionally apply boost.
    gauge_pool must contain keys: total_lp, daily_rewards, reward_price, total_vxFROST
    """
    base_apy = calculate_gauge_base_rewards(
        LP_staked=frbtc_allocated,
        total_LP_staked=gauge_pool.get("total_lp", 1.0),
        reward_token_per_day=gauge_pool.get("daily_rewards", 0.0),
        reward_token_price_btc=gauge_pool.get("reward_price", 0.0),
        lp_token_price_btc=1.0,
    )
    if user_vxFROST_locked and gauge_pool.get("total_vxFROST", 0.0) > 0.0:
        boost = calculate_boost_multiplier(
            frbtc_allocated,
            gauge_pool.get("total_lp", 1.0),
            user_vxFROST_locked,
            gauge_pool.get("total_vxFROST", 1.0),
        )
        return float(base_apy * boost)
    return float(base_apy)


def calculate_yvfrBTC_yield(total_frbtc: float, strategy_allocations: Dict[str, float], strategy_apys: Dict[str, float]) -> float:
    """
    Weighted APY produced by yvfrBTC deploying frBTC into multiple strategies.
    strategy_allocations: dict of strategy -> fraction (sum <= 1)
    strategy_apys: dict of strategy -> apy (decimal)
    """
    weighted_apy = 0.0
    for strategy, allocation in strategy_allocations.items():
        apy = float(strategy_apys.get(strategy, 0.0))
        weighted_apy += allocation * apy
    return float(weighted_apy)


def calculate_dxBTC_APY(dxBTC_tvl: float, yvfrBTC_strategy_apy: float, utilization_rate: float) -> float:
    """
    dxBTC APY = yvfrBTC APY * utilization rate
    """
    return float(yvfrBTC_strategy_apy * utilization_rate)


# ----------------------------- Simulation runner -----------------------------

def simulate_full_system(params: Dict[str, Any]) -> Dict[str, Any]:
    """
    Simulate full system for a single parameter set.

    
    Expected params keys (a sensible subset):
      dxBTC_tvl, yvfrBTC_tvl, utilization_rate,
      yvfrBTC_allocation_pLBTC, yvfrBTC_allocation_DIESEL_gauge, yvfrBTC_allocation_reserve,
      pLBTC_pool_tvl, pLBTC_daily_volume, pLBTC_discount, pLBTC_time_to_maturity,
      gauges (dict), user_vxFROST, futures_utilization_rate
    """
    # 1. pLBTC LP strategy
    frbtc_tvl = float(params.get("yvfrBTC_tvl", params.get("dxBTC_tvl", 0.0)))
    pLBTC_alloc_fraction = float(params.get("yvfrBTC_allocation_pLBTC", 0.4))
    pLBTC_alloc_amount = pLBTC_alloc_fraction * frbtc_tvl

    pLBTC_lp_apy = calculate_pLBTC_LP_strategy_apy(
        frbtc_allocated=pLBTC_alloc_amount,
        pLBTC_frBTC_pool_tvl=params.get("pLBTC_pool_tvl", 1_000.0),
        daily_volume_pLBTC_pool=params.get("pLBTC_daily_volume", 10.0),
        pLBTC_discount=params.get("pLBTC_discount", 0.04),
        time_to_maturity_years=params.get("pLBTC_time_to_maturity", 1.0),
        fee_rate=params.get("fee_rate", 0.0004),
        admin_fee=params.get("admin_fee", 0.5),
    )

    # 2. Gauges
    gauge_apys = {}
    gauges = params.get("gauges", {})
    for gauge_name, gauge_params in gauges.items():
        alloc_key = f"yvfrBTC_allocation_{gauge_name}"
        alloc_fraction = params.get(alloc_key, params.get("yvfrBTC_allocation_DIESEL_gauge", 0.3))
        frbtc_alloc = alloc_fraction * frbtc_tvl
        gauge_apys[gauge_name] = calculate_gauge_strategy_apy(frbtc_alloc, gauge_params, params.get("user_vxFROST", 0.0))

    # 3. Weighted yvfrBTC APY
    strategy_allocations = {
        "pLBTC_frBTC_LP": pLBTC_alloc_fraction,
    }
    # add gauges to allocations (assume keys match)
    for g in gauges.keys():
        strategy_allocations[g] = params.get(f"yvfrBTC_allocation_{g}", params.get("yvfrBTC_allocation_DIESEL_gauge", 0.0))
    strategy_allocations["idle_reserve"] = params.get("yvfrBTC_allocation_reserve", 0.0)

    strategy_apys = {"pLBTC_frBTC_LP": pLBTC_lp_apy["total_apy"]}
    strategy_apys.update(gauge_apys)
    strategy_apys["idle_reserve"] = 0.0

    yvfrBTC_apy = calculate_yvfrBTC_yield(frbtc_tvl, strategy_allocations, strategy_apys)

    # 4. dxBTC APY
    dxBTC_apy = calculate_dxBTC_APY(params.get("dxBTC_tvl", 0.0), yvfrBTC_apy, params.get("utilization_rate", 0.7))

    # 5. ftrBTC dynamics: deflation boost (simple approx)
    futures_util = params.get("futures_utilization_rate", 0.0)
    total_premiums_burned = futures_util * params.get("dxBTC_tvl", 0.0) * params.get("approx_prem_pct", 0.03)
    deflation_boost = (total_premiums_burned / params.get("dxBTC_tvl", 1.0)) if params.get("dxBTC_tvl", 0.0) > 0 else 0.0

    effective_dxBTC_apy = dxBTC_apy + deflation_boost

    return {
        "yvfrBTC_apy": yvfrBTC_apy,
        "dxBTC_base_apy": dxBTC_apy,
        "deflation_boost": deflation_boost,
        "effective_dxBTC_apy": effective_dxBTC_apy,
        "breakdown": {
            "pLBTC_lp": pLBTC_lp_apy,
            "gauges": gauge_apys,
            "strategy_allocations": strategy_allocations,
            "strategy_apys": strategy_apys,
        },
    }


# ----------------------------- Utilities for batch runs -----------------------------

def generate_parameter_space(param_ranges: Dict[str, list]) -> list:
    """
    Generate a list of parameter dicts from the cartesian product of ranges.
    Useful for sensitivity analysis.
    """
    keys = list(param_ranges.keys())
    values = [param_ranges[k] for k in keys]
    combos = []
    for prod in product(*values):
        combos.append({k: v for k, v in zip(keys, prod)})
    return combos


def run_sensitivity_analysis(base_params: Dict[str, Any], param_grid: Dict[str, list]) -> pd.DataFrame:
    """
    Run simulate_full_system over a grid of parameters and return a pandas DataFrame.
    """
    rows = []
    keys = list(param_grid.keys())
    values = [param_grid[k] for k in keys]
    for combo in product(*values):
        params = base_params.copy()
        for k, v in zip(keys, combo):
            params[k] = v
        res = simulate_full_system(params)
        row = {k: v for k, v in params.items()}
        row.update({
            "effective_dxBTC_apy": res["effective_dxBTC_apy"],
            "dxBTC_base_apy": res["dxBTC_base_apy"],
            "yvfrBTC_apy": res["yvfrBTC_apy"],
        })
        rows.append(row)
    df = pd.DataFrame(rows)
    return df


# ----------------------------- Simple plotting helpers -----------------------------

def plot_apy_vs_volume(df: pd.DataFrame, volume_col: str = "pLBTC_daily_volume") -> None:
    plt.figure(figsize=(8, 5))
    df_sorted = df.sort_values(volume_col)
    plt.plot(df_sorted[volume_col], df_sorted["effective_dxBTC_apy"])
    plt.xlabel("Daily Volume (BTC)")
    plt.ylabel("Effective dxBTC APY")
    plt.title("APY vs Daily pLBTC/frBTC Volume")
    plt.grid(True)
    plt.show()


# ----------------------------- Example driver -----------------------------
if __name__ == "__main__":
    # Base parameters (example default scenario)
    base_params = {
        "dxBTC_tvl": 1000.0,
        "yvfrBTC_tvl": 1000.0,
        "utilization_rate": 0.70,
        "yvfrBTC_allocation_pLBTC": 0.40,
        "yvfrBTC_allocation_DIESEL_gauge": 0.30,
        "yvfrBTC_allocation_reserve": 0.30,
        "pLBTC_pool_tvl": 5000.0,
        "pLBTC_daily_volume": 50.0,
        "pLBTC_discount": 0.04,
        "pLBTC_time_to_maturity": 1.0,
        "gauges": {
            "DIESEL_frBTC": {
                "total_lp": 1000.0,
                "daily_rewards": 500.0,
                "reward_price": 0.5,
                "total_vxFROST": 50000.0,
            }
        },
        "user_vxFROST": 0.0,
        "futures_utilization_rate": 0.10,
        "approx_prem_pct": 0.03,
    }

    result = simulate_full_system(base_params)
    print("Simulation result (base case):")
    print(result)

    # Example sensitivity run for pLBTC_daily_volume
    param_grid = {"pLBTC_daily_volume": [10.0, 50.0, 100.0, 500.0]}
    df_res = run_sensitivity_analysis(base_params, param_grid)
    plot_apy_vs_volume(df_res, "pLBTC_daily_volume")

import plotly.graph_objects as go
import numpy as np

# --- Base parameters ---
volume_range = np.linspace(1000, 100000, 50)
lp_size_default = 50000
discount_default = 0.05
maturity_default = 12
utilization_default = 0.5

# --- Pricing model (dummy for demo, plug your formula here) ---
def price_curve(volume, lp_size, discount, maturity, utilization):
    """
    Example pricing function 
    """
    base = volume / lp_size
    risk_adj = (1 - discount) * (maturity / 12) * (1 + utilization)
    return base * risk_adj

# --- Initial curve ---
initial_y = price_curve(
    volume_range,
    lp_size_default,
    discount_default,
    maturity_default,
    utilization_default,
)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=volume_range,
        y=initial_y,
        mode="lines",
        name="Price Curve"
    )
)

# --- Sliders ---
fig.update_layout(
    title="Interactive Pricing Model",
    xaxis_title="Volume",
    yaxis_title="Price",
    sliders=[
        dict(
            active=0,
            currentvalue={"prefix": "LP Size: "},
            pad={"t": 50},
            steps=[
                dict(
                    label=f"{lp}",
                    method="update",
                    args=[
                        {
                            "y": [
                                price_curve(
                                    volume_range,
                                    lp,
                                    discount_default,
                                    maturity_default,
                                    utilization_default
                                )
                            ]
                        },
                        {"title": f"LP Size = {lp}"}
                    ]
                )
                for lp in [10000, 30000, 50000, 70000, 100000]
            ]
        ),
        dict(
            active=0,
            currentvalue={"prefix": "Discount: "},
            pad={"t": 120},
            steps=[
                dict(
                    label=f"{d}",
                    method="update",
                    args=[
                        {
                            "y": [
                                price_curve(
                                    volume_range,
                                    lp_size_default,
                                    d,
                                    maturity_default,
                                    utilization_default
                                )
                            ]
                        },
                        {"title": f"Discount = {d}"}
                    ]
                )
                for d in [0.0, 0.05, 0.1, 0.15, 0.2]
            ]
        ),
        dict(
            active=0,
            currentvalue={"prefix": "Maturity (months): "},
            pad={"t": 190},
            steps=[
                dict(
                    label=f"{m}",
                    method="update",
                    args=[
                        {
                            "y": [
                                price_curve(
                                    volume_range,
                                    lp_size_default,
                                    discount_default,
                                    m,
                                    utilization_default
                                )
                            ]
                        },
                        {"title": f"Maturity = {m} months"}
                    ]
                )
                for m in [3, 6, 12, 18, 24]
            ]
        ),
        dict(
            active=0,
            currentvalue={"prefix": "Utilization: "},
            pad={"t": 260},
            steps=[
                dict(
                    label=f"{u}",
                    method="update",
                    args=[
                        {
                            "y": [
                                price_curve(
                                    volume_range,
                                    lp_size_default,
                                    discount_default,
                                    maturity_default,
                                    u
                                )
                            ]
                        },
                        {"title": f"Utilization = {u}"}
                    ]
                )
                for u in [0.1, 0.3, 0.5, 0.7, 1.0]
            ]
        ),
    ],
)

fig.show()

