# degenerate_diffusion を用いた閉形式推定ノートブック
このノートでは、旧 `degenerate_diffusion_general.ipynb` の閉形式推定ロジックを、`degenerate_diffusion` パッケージに実装されたコンポーネントを利用して再構成します。

## 1. ディレクトリとテンプレートの準備
出力を整理するために `notebooks/closed_form` ディレクトリとプレースホルダーのテンプレート notebook を用意します。

In [None]:
from pathlib import Path
import nbformat

NOTEBOOK_DIR = Path("notebooks/closed_form")
NOTEBOOK_DIR.mkdir(parents=True, exist_ok=True)

TEMPLATE_PATH = NOTEBOOK_DIR / "template_placeholder.ipynb"
if not TEMPLATE_PATH.exists():
    template_nb = nbformat.v4.new_notebook()
    template_nb["cells"] = []
    TEMPLATE_PATH.write_text(nbformat.writes(template_nb), encoding="utf-8")

TEMPLATE_PATH

## 2. インポートと基本設定
数値計算や可視化、`degenerate_diffusion` の主要クラスを読み込み、JAX を倍精度モードに切り替えます。

In [None]:
import math
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import japanize_matplotlib  # noqa: F401
from tqdm.auto import tqdm

from sympy import Array, symbols

from degenerate_diffusion.processes.degenerate_diffusion_process_jax import DegenerateDiffusionProcess
from degenerate_diffusion.evaluation.likelihood_evaluator_jax import LikelihoodEvaluator
from degenerate_diffusion.estimation.parameter_estimator import newton_solve
from degenerate_diffusion.estimation.loop_estimation_algorithm import LoopEstimationAlgorithm

jax.config.update("jax_enable_x64", True)
sns.set_theme(style="whitegrid")

## 3. モデル記述とシミュレーション設定
SymPy でモデルを定義し、`DegenerateDiffusionProcess` と `LikelihoodEvaluator` を構築、さらに推定で用いる真のパラメータや初期値・制約をまとめます。

In [None]:
x_sym = Array([symbols("x")])
y_sym = Array([symbols("y")])
theta1_sym = Array([symbols("theta10")])
theta2_sym = Array(symbols("theta20 theta21"))
theta3_sym = Array(symbols("theta30 theta31"))

A_expr = Array([theta2_sym[0] * y_sym[0] + theta2_sym[1] - x_sym[0]])
B_expr = Array([[theta1_sym[0]]])
H_expr = Array([(theta3_sym[1] - x_sym[0] - y_sym[0] ** 3 + y_sym[0]) / theta3_sym[0]])

process = DegenerateDiffusionProcess(
    x=x_sym,
    y=y_sym,
    theta_1=theta1_sym,
    theta_2=theta2_sym,
    theta_3=theta3_sym,
    A=A_expr,
    B=B_expr,
    H=H_expr,
 )
likelihood = LikelihoodEvaluator(process)

sim_config = {
    "t_max": 100.0,
    "burn_in": 20.0,
    "dt": 0.001,
    "h": 0.05,
    "num_seeds": 40,
    "k0_max": 3,
}

true_theta = (
    jnp.array([0.7]),
    jnp.array([0.4, 0.6]),
    jnp.array([1.2, 0.1]),
)
initial_theta = (
    jnp.array([0.9]),
    jnp.array([0.3, 0.5]),
    jnp.array([1.0, 0.0]),
 )

parameter_bounds = {
    "theta1": (jnp.array([0.1]), jnp.array([2.0])),
    "theta2": (jnp.array([-1.5, -1.5]), jnp.array([1.5, 1.5])),
    "theta3": (jnp.array([0.2, -5.0]), jnp.array([5.0, 5.0])),
}
sim_config

## 4. `degenerate_diffusion` からのサンプラー呼び出し
`DegenerateDiffusionProcess.simulate` を使って観測系列 $(x_t, y_t)$ を生成するユーティリティを定義します。

In [None]:
def simulate_observations(process: DegenerateDiffusionProcess, theta: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], config: dict, seed: int):
    x_series, y_series = process.simulate(
        true_theta=theta,
        t_max=config["t_max"],
        h=config["h"],
        burn_out=config["burn_in"],
        dt=config["dt"],
        seed=seed,
    )
    time_grid = jnp.arange(x_series.shape[0]) * config["h"]
    # 1次元のため squeeze して扱いやすくする
    x_flat = jnp.asarray(x_series).reshape(-1)
    y_flat = jnp.asarray(y_series).reshape(-1)
    return time_grid, x_flat, y_flat


time_sample, x_sample, y_sample = simulate_observations(process, true_theta, sim_config, seed=0)
pd.DataFrame({"t": np.asarray(time_sample), "x": np.asarray(x_sample), "y": np.asarray(y_sample)}).head()

## 5. 演算子 $L_0, \tilde{L}_0, L_1, L_2, L_3$ の関数化
`LikelihoodEvaluator` が内部で構築するシンボリック演算子を、数値計算で再利用できるように関数テーブルとしてまとめます。

In [None]:
from functools import lru_cache

max_operator_order = sim_config["k0_max"] + 3

@lru_cache(maxsize=None)
def get_L0_tables(max_order: int = max_operator_order):
    """`L_0` を様々な基底に適用した関数タプルをキャッシュして返す。"""
    tables = {
        "x": tuple(likelihood.generator.L_0(process.x, m).func for m in range(max_order + 1)),
        "y": tuple(likelihood.generator.L_0(process.y, m).func for m in range(max_order + 2)),
        "H": tuple(
            likelihood.generator.L_0(process.H.expr[0], m).func for m in range(max_order + 2)
        ),
    }
    return tables

def evaluate_L0_component(table_key: str, order: int, x_val: jnp.ndarray, y_val: jnp.ndarray, theta: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]):
    tables = get_L0_tables()
    theta_1_val, theta_2_val, theta_3_val = theta
    func = tables[table_key][order]
    return func(x_val, y_val, theta_1_val, theta_2_val, theta_3_val)

# サンプルデータ点で L0[H] を確認
sample_index = 5
x_point = jnp.asarray([x_sample[sample_index]])
y_point = jnp.asarray([y_sample[sample_index]])
evaluate_L0_component("H", 0, x_point, y_point, true_theta)

## 6. 更新関数の再実装
`LikelihoodEvaluator` が提供する疑似尤度関数と `newton_solve` を組み合わせ、段階ごとの推定器更新関数 `update_1`, `update_2`, `update_3`, `update_1_prime` を実装します。

In [None]:
def _bounds_to_tuples(bounds_array: tuple[jnp.ndarray, jnp.ndarray]):
    lower, upper = bounds_array
    lower_np = np.asarray(lower, dtype=float)
    upper_np = np.asarray(upper, dtype=float)
    return [(low, high) for low, high in zip(lower_np, upper_np)]

theta1_bounds = _bounds_to_tuples(parameter_bounds["theta1"])
theta2_bounds = _bounds_to_tuples(parameter_bounds["theta2"])
theta3_bounds = _bounds_to_tuples(parameter_bounds["theta3"])

newton_options = {
    "max_iters": 2000,
    "tol": 1e-7,
    "damping": 0.3,
    "log_interval": None,
}

def update_1(theta_bar: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], x_series: jnp.ndarray, y_series: jnp.ndarray, *, h: float, k: int, initial_guess: jnp.ndarray | None = None, my_setting: bool = True) -> jnp.ndarray:
    theta_1_bar, theta_2_bar, theta_3_bar = theta_bar
    l1_eval = likelihood.make_quasi_likelihood_l1_evaluator(x_series, y_series, h, k, my_setting=my_setting)
    init = np.asarray(initial_guess if initial_guess is not None else theta_1_bar)
    objective = lambda theta_val: l1_eval(jnp.asarray(theta_val), theta_1_bar, theta_2_bar, theta_3_bar)
    solution = newton_solve(objective, theta1_bounds, init, **newton_options)
    return jnp.asarray(solution)

def update_1_prime(theta_bar: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], x_series: jnp.ndarray, y_series: jnp.ndarray, *, h: float, k: int, initial_guess: jnp.ndarray | None = None) -> jnp.ndarray:
    theta_1_bar, theta_2_bar, theta_3_bar = theta_bar
    l1p_eval = likelihood.make_quasi_likelihood_l1_prime_evaluator(x_series, y_series, h, k)
    init = np.asarray(initial_guess if initial_guess is not None else theta_1_bar)
    objective = lambda theta_val: l1p_eval(jnp.asarray(theta_val), theta_1_bar, theta_2_bar, theta_3_bar)
    solution = newton_solve(objective, theta1_bounds, init, **newton_options)
    return jnp.asarray(solution)

def update_2(theta_bar: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], x_series: jnp.ndarray, y_series: jnp.ndarray, *, h: float, k: int, initial_guess: jnp.ndarray | None = None) -> jnp.ndarray:
    theta_1_bar, theta_2_bar, theta_3_bar = theta_bar
    l2_eval = likelihood.make_quasi_likelihood_l2_evaluator(x_series, y_series, h, k)
    init = np.asarray(initial_guess if initial_guess is not None else theta_2_bar)
    objective = lambda theta_val: l2_eval(jnp.asarray(theta_val), theta_1_bar, theta_2_bar, theta_3_bar)
    solution = newton_solve(objective, theta2_bounds, init, **newton_options)
    return jnp.asarray(solution)

def update_3(theta_bar: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], x_series: jnp.ndarray, y_series: jnp.ndarray, *, h: float, k: int, initial_guess: jnp.ndarray | None = None) -> jnp.ndarray:
    theta_1_bar, theta_2_bar, theta_3_bar = theta_bar
    l3_eval = likelihood.make_quasi_likelihood_l3_evaluator(x_series, y_series, h, k)
    init = np.asarray(initial_guess if initial_guess is not None else theta_3_bar)
    objective = lambda theta_val: l3_eval(jnp.asarray(theta_val), theta_1_bar, theta_2_bar, theta_3_bar)
    solution = newton_solve(objective, theta3_bounds, init, **newton_options)
    return jnp.asarray(solution)

update_1((initial_theta[0], initial_theta[1], initial_theta[2]), x_sample, y_sample, h=sim_config["h"], k=2)

## 7. メイン推定ループと $k_0$ 探索
Monte Carlo サンプルごとに疑似尤度を段階的に最大化し、未スケール推定量と正規化済み推定量を集計します。

In [None]:
def _flatten_theta(theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]) -> dict[str, float]:
    theta_1_val, theta_2_val, theta_3_val = theta_tuple
    return {
        "theta10": float(theta_1_val[0]),
        "theta20": float(theta_2_val[0]),
        "theta21": float(theta_2_val[1]),
        "theta30": float(theta_3_val[0]),
        "theta31": float(theta_3_val[1]),
    }

def _compute_scaled(theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], true_theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], n_transitions: int, h: float, k: int) -> dict[str, float]:
    theta1_hat, theta2_hat, theta3_hat = theta_tuple
    theta1_true, theta2_true, theta3_true = true_theta_tuple
    rate1 = likelihood.a(n_transitions, h, k, 1)
    rate2 = likelihood.a(n_transitions, h, k, 2)
    rate3 = likelihood.a(n_transitions, h, k, 3)
    scaled_theta1 = (theta1_hat - theta1_true) / rate1
    scaled_theta2 = (theta2_hat - theta2_true) / rate2
    scaled_theta3 = (theta3_hat - theta3_true) / rate3
    return {
        "theta10_scaled": float(scaled_theta1[0]),
        "theta20_scaled": float(scaled_theta2[0]),
        "theta21_scaled": float(scaled_theta2[1]),
        "theta30_scaled": float(scaled_theta3[0]),
        "theta31_scaled": float(scaled_theta3[1]),
    }

def run_closed_form_estimation(config: dict, true_theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], initial_theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]) -> tuple[pd.DataFrame, pd.DataFrame]:
    records: list[dict[str, float | int | str]] = []
    scaled_records: list[dict[str, float | int | str]] = []
    seed_iter = range(config["num_seeds"])
    for seed in tqdm(seed_iter, desc="Monte Carlo seeds"):
        _, x_series, y_series = simulate_observations(process, true_theta_tuple, config, seed=seed)
        n_transitions = int(x_series.shape[0] - 1)
        theta_stage = tuple(jnp.asarray(comp) for comp in initial_theta_tuple)
        for k in range(1, config["k0_max"] + 1):
            theta1_hat = update_1(theta_stage, x_series, y_series, h=config["h"], k=k)
            theta2_hat = update_2((theta1_hat, theta_stage[1], theta_stage[2]), x_series, y_series, h=config["h"], k=k)
            theta3_hat = update_3((theta1_hat, theta2_hat, theta_stage[2]), x_series, y_series, h=config["h"], k=k)
            theta_stage = (theta1_hat, theta2_hat, theta3_hat)

            flat_vals = _flatten_theta(theta_stage)
            scaled_vals = _compute_scaled(theta_stage, true_theta_tuple, n_transitions, config["h"], k)
            records.append({
                "seed": seed,
                "stage": "stage",
                "k": k,
                **flat_vals,
            })
            scaled_records.append({
                "seed": seed,
                "stage": "stage",
                "k": k,
                **scaled_vals,
            })

        theta1_prime = update_1_prime(theta_stage, x_series, y_series, h=config["h"], k=config["k0_max"])
        theta_final = (theta1_prime, theta_stage[1], theta_stage[2])
        flat_vals_prime = _flatten_theta(theta_final)
        scaled_vals_prime = _compute_scaled(
            theta_final, true_theta_tuple, n_transitions, config["h"], config["k0_max"]
        )
        records.append({
            "seed": seed,
            "stage": "prime",
            "k": config["k0_max"],
            **flat_vals_prime,
        })
        scaled_records.append({
            "seed": seed,
            "stage": "prime",
            "k": config["k0_max"],
            **scaled_vals_prime,
        })

    records_df = pd.DataFrame.from_records(records)
    scaled_df = pd.DataFrame.from_records(scaled_records)
    return records_df, scaled_df

raw_estimates_df, scaled_estimates_df = run_closed_form_estimation(sim_config, true_theta, initial_theta)
raw_estimates_df.head()

## 8. 集計と統計指標
推定結果から平均・分散・RMSE を計算し、真値との差を定量化します。

In [None]:
def summarize_raw_estimates(raw_df: pd.DataFrame, true_theta_tuple: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]) -> pd.DataFrame:
    parameter_names = ["theta10", "theta20", "theta21", "theta30", "theta31"]
    true_values = {
        "theta10": float(true_theta_tuple[0][0]),
        "theta20": float(true_theta_tuple[1][0]),
        "theta21": float(true_theta_tuple[1][1]),
        "theta30": float(true_theta_tuple[2][0]),
        "theta31": float(true_theta_tuple[2][1]),
    }
    rows: list[dict[str, float | int | str]] = []
    for (stage, k), group in raw_df.groupby(["stage", "k"]):
        row: dict[str, float | int | str] = {"stage": stage, "k": k}
        for name in parameter_names:
            values = group[name]
            true_val = true_values[name]
            row[f"{name}_mean"] = float(values.mean())
            row[f"{name}_std"] = float(values.std(ddof=1))
            row[f"{name}_rmse"] = float(np.sqrt(((values - true_val) ** 2).mean()))
        rows.append(row)
    return pd.DataFrame(rows)

def summarize_scaled_estimates(scaled_df: pd.DataFrame) -> pd.DataFrame:
    scaled_names = ["theta10_scaled", "theta20_scaled", "theta21_scaled", "theta30_scaled", "theta31_scaled"]
    rows: list[dict[str, float | int | str]] = []
    for (stage, k), group in scaled_df.groupby(["stage", "k"]):
        row: dict[str, float | int | str] = {"stage": stage, "k": k}
        for name in scaled_names:
            row[f"{name}_mean"] = float(group[name].mean())
            row[f"{name}_std"] = float(group[name].std(ddof=1))
        rows.append(row)
    return pd.DataFrame(rows)

raw_summary_df = summarize_raw_estimates(raw_estimates_df, true_theta)
scaled_summary_df = summarize_scaled_estimates(scaled_estimates_df)
raw_summary_df.round(4)

In [None]:
scaled_summary_df.round(4)

## 9. 可視化
推定量と正規化推定量の分布を `seaborn` のボックスプロットで確認します。

In [None]:
def plot_estimation_distributions(raw_df: pd.DataFrame, scaled_df: pd.DataFrame) -> None:
    raw_long = raw_df.melt(
        id_vars=["seed", "stage", "k"],
        value_vars=["theta10", "theta20", "theta21", "theta30", "theta31"],
        var_name="parameter",
        value_name="estimate",
    )
    scaled_long = scaled_df.melt(
        id_vars=["seed", "stage", "k"],
        value_vars=["theta10_scaled", "theta20_scaled", "theta21_scaled", "theta30_scaled", "theta31_scaled"],
        var_name="parameter",
        value_name="normalized_estimate",
    )

    fig, axes = plt.subplots(1, 2, figsize=(18, 6))

    sns.boxplot(
        data=raw_long[raw_long["stage"] == "prime"],
        x="parameter",
        y="estimate",
        ax=axes[0],
        palette="viridis",
    )
    axes[0].set_title("最終 (prime) 推定量の分布")
    axes[0].tick_params(axis="x", rotation=30)

    sns.boxplot(
        data=scaled_long[scaled_long["stage"] == "stage"],
        x="parameter",
        y="normalized_estimate",
        hue="k",
        ax=axes[1],
        palette="Set2",
    )
    axes[1].axhline(0.0, linestyle="--", color="black", linewidth=1)
    axes[1].set_title("正規化推定量の分布 (各 k)")
    axes[1].tick_params(axis="x", rotation=30)
    axes[1].legend(title="k", bbox_to_anchor=(1.02, 1), loc="upper left")

    plt.tight_layout()
    plt.show()

plot_estimation_distributions(raw_estimates_df, scaled_estimates_df)

---
このノートブックでは、旧来の手組み式を `degenerate_diffusion` パッケージの疑似尤度コンポーネントで再構築し、閉形式推定の流れを段階別に整理しました。