# build_seed_runner による Linearmodel 推定デモ
linearmodel.ipynb の設定をそのまま利用し、`build_seed_runner` を用いた推定フローを再現します。

## 1. 環境初期化
sys.path の整備と JAX の構成、乱数キーの初期化を行います。

In [1]:
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

# プロジェクトルートを検出して sys.path に追加
REPO_ROOT = Path.cwd().resolve()
if REPO_ROOT.name == "notebooks":
    REPO_ROOT = REPO_ROOT.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# JAX 設定と乱数キー初期化
enable_x64 = True
jax.config.update("jax_enable_x64", enable_x64)
base_prng_key = jax.random.PRNGKey(42)

print(f"Project root detected: {REPO_ROOT}")

Project root detected: /Users/yanoshouta/dev/degenerate-diffusion


## 2. linearmodel 設定読込とモデル生成
linearmodel.ipynb と同一のモデル定義・真値パラメータ・ループ設定をデータクラスで扱いやすく整形します。

In [None]:
from typing import Dict, Tuple

from sympy import Array, symbols

from degenerate_diffusion.evaluation.likelihood_evaluator import LikelihoodEvaluator
from degenerate_diffusion.processes.degenerate_diffusion_process import DegenerateDiffusionProcess


@dataclass
class NotebookSettings:
    """Linearmodel 用の推定設定を保持するデータクラス。"""

    model: Any
    evaluator: Any
    true_theta: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
    bounds_theta1: list[tuple[float, float]]
    bounds_theta2: list[tuple[float, float]]
    bounds_theta3: list[tuple[float, float]]
    t_max: float
    h: float
    burn_out: float
    dt: float
    loop_plan: Dict[int, tuple[str, str, str]]
    initial_theta_stage0: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]


# --- Linearmodel 定義 (linearmodel.ipynb と同一設定)
x_sym, y_sym = symbols("x, y")
theta_10, theta_20, theta_21 = symbols("theta_10 theta_20 theta_21")
theta_30 = symbols("theta_30")

x = Array([x_sym])
y = Array([y_sym])
theta_1 = Array([theta_10])
theta_2 = Array([theta_20, theta_21])
theta_3 = Array([theta_30])

A = Array([-theta_20 * x_sym - theta_21 * y_sym])
B = Array([[theta_10]])
H = Array([theta_30 * x_sym])

linearmodel = DegenerateDiffusionProcess(
    x=x,
    y=y,
    theta_1=theta_1,
    theta_2=theta_2,
    theta_3=theta_3,
    A=A,
    B=B,
    H=H,
)
linearmodel_likelihood = LikelihoodEvaluator(linearmodel)

true_theta1 = jnp.array([1.0])
true_theta2 = jnp.array([1.0, 1.0])
true_theta3 = jnp.array([1.0])
true_theta = (true_theta1, true_theta2, true_theta3)

bounds_theta1 = [(0.1, 2.0)]
bounds_theta2 = [(0.1, 2.0), (0.1, 2.0)]
bounds_theta3 = [(0.1, 2.0)]

loop_plan: Dict[int, tuple[str, str, str]] = {
    1: ("B", "B", "B"),
    2: ("M", "M", "M"),
    # 3: ("M", "M", "M"),
    # 4: ("M", "M", "M"),
    # 5: ("M", "M", "M"),
}

initial_theta_stage0 = (
    jnp.array([0.4]),
    jnp.array([0.4, 0.4]),
    jnp.array([0.4]),
)

settings = NotebookSettings(
    model=linearmodel,
    evaluator=linearmodel_likelihood,
    true_theta=true_theta,
    bounds_theta1=bounds_theta1,
    bounds_theta2=bounds_theta2,
    bounds_theta3=bounds_theta3,
    t_max=100.0,
    h=0.5,
    burn_out=50.0,
    dt=0.001,
    loop_plan=loop_plan,
    initial_theta_stage0=initial_theta_stage0,
)

settings

NotebookSettings(model=DegenerateDiffusionProcess(x=[x], y=[y], theta_1=[theta_10], theta_2=[theta_20, theta_21], theta_3=[theta_30], A=SymbolicArtifact(expr=[-theta_20*x - theta_21*y], func=<function _lambdifygenerated at 0x15f0f0180>), B=SymbolicArtifact(expr=[[theta_10]], func=<function _lambdifygenerated at 0x31e5e8f40>), H=SymbolicArtifact(expr=[theta_30*x], func=<function _lambdifygenerated at 0x31cbc5080>)), evaluator=<degenerate_diffusion.evaluation.likelihood_evaluator.LikelihoodEvaluator object at 0x16101bcd0>, true_theta=(Array([1.], dtype=float64), Array([1., 1.], dtype=float64), Array([1.], dtype=float64)), bounds_theta1=[(0.1, 2.0)], bounds_theta2=[(0.1, 2.0), (0.1, 2.0)], bounds_theta3=[(0.1, 2.0)], t_max=100.0, h=0.5, burn_out=50.0, dt=0.001, loop_plan={1: ('M', 'M', 'M'), 2: ('M', 'M', 'M')}, initial_theta_stage0=(Array([1.], dtype=float64), Array([1., 1.], dtype=float64), Array([1.], dtype=float64)))

## 3. SeedRunnerConfig と推定関数組立
推定器モジュールを再読み込みし、`SeedRunnerConfig` を linearmodel の設定で初期化して seed runner を構築します。

In [23]:
import importlib

import degenerate_diffusion.estimation.parameter_estimator as parameter_estimator_mod
import degenerate_diffusion.estimation.loop_estimation_algorithm as loop_algo_mod

parameter_estimator_mod = importlib.reload(parameter_estimator_mod)
loop_algo_mod = importlib.reload(loop_algo_mod)

SeedRunnerConfig = loop_algo_mod.SeedRunnerConfig
build_seed_runner = loop_algo_mod.build_seed_runner

seed_runner_config = SeedRunnerConfig(
    true_theta=settings.true_theta,
    t_max=settings.t_max,
    h=settings.h,
    burn_out=settings.burn_out,
    dt=settings.dt,
    bounds_theta1=settings.bounds_theta1,
    bounds_theta2=settings.bounds_theta2,
    bounds_theta3=settings.bounds_theta3,
    newton_kwargs={},
    nuts_kwargs={},
    one_step_kwargs={},
)

seed_runner = build_seed_runner(
    evaluator=settings.evaluator,
    model=settings.model,
    plan=settings.loop_plan,
    config=seed_runner_config,
)

vectorized_seed_runner = jax.jit(jax.vmap(seed_runner, in_axes=(0, None)))

seed_runner

<PjitFunction of <function build_seed_runner.<locals>.runner at 0x32050e0c0>>

## 4. シード一括実行と結果取得
ベクトル化済みの seed runner を用いて複数シードを一括評価し、JAX デバイスからホストへ値を取り出します。

In [24]:
seeds = jnp.arange(400, dtype=jnp.int32)


start = time.perf_counter()


(


    theta10_batch,

    theta1_batch,

    theta2_batch,

    theta3_batch,

) = vectorized_seed_runner(seeds, settings.initial_theta_stage0)


elapsed = time.perf_counter() - start




(


    theta10_batch,

    theta1_batch,

    theta2_batch,

    theta3_batch,

) = jax.device_get(


    (

        theta10_batch,

        theta1_batch,

        theta2_batch,

        theta3_batch,

    )
)
print(f"Processed {len(seeds)} seeds in {elapsed:.3f} seconds.")


Processed 400 seeds in 8.812 seconds.


## 5. 結果保存と要約統計
推定結果を CSV に書き出し、各パラメータの平均値と標準偏差を算出します。

In [25]:
def _batch_to_columns(batch: jnp.ndarray, base_name: str) -> dict[str, jnp.ndarray]:


    """Convert (seed, k, i) arrays into flat columns keyed by base_name."""


    if batch.ndim == 2:


        seeds_count, k_dim = batch.shape


        columns: dict[str, jnp.ndarray] = {}


        for k_idx in range(k_dim):


            columns[f"{base_name}_0_{k_idx + 1}"] = batch[:, k_idx]


        return columns


    if batch.ndim != 3:


        raise ValueError(f"Expected (seed, k, i) array, got shape {batch.shape}")


    _, k_dim, i_dim = batch.shape


    columns = {}


    for i_idx in range(i_dim):


        for k_idx in range(k_dim):


            columns[f"{base_name}_{i_idx}_{k_idx + 1}"] = batch[:, k_idx, i_idx]


    return columns




column_data: dict[str, jnp.ndarray] = {"seed": seeds}



column_data.update(_batch_to_columns(theta10_batch, "theta10"))

column_data.update(_batch_to_columns(theta1_batch, "theta1"))

column_data.update(_batch_to_columns(theta2_batch, "theta2"))

column_data.update(_batch_to_columns(theta3_batch, "theta3"))



df_results = pd.DataFrame(column_data).set_index("seed")



output_path = Path("linearmodel_seed_runner_results.csv")

df_results.to_csv(output_path)



summary_table = (

    df_results.agg(["mean", "std"]).T.rename_axis("parameter").reset_index()

)



print(f"Saved aggregated results to {output_path.resolve()}")

print(f"DataFrame shape: {df_results.shape}")

summary_table

Saved aggregated results to /Users/yanoshouta/dev/degenerate-diffusion/notebooks/linearmodel_seed_runner_results.csv
DataFrame shape: (400, 5)


Unnamed: 0,parameter,mean,std
0,theta10_0_1,0.979596,0.050456
1,theta1_0_1,1.014658,0.042576
2,theta2_0_1,0.974489,0.113216
3,theta2_1_1,0.764403,0.106031
4,theta3_0_1,1.002482,0.026395


## 6. 推定推移の可視化
サンプル分布と反復平均の推移を可視化して推定挙動を確認します。

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0, 0].hist(theta1_stage0_batch.flatten(), bins=40, alpha=0.7, color="tab:blue")
axes[0, 0].set_title("theta1_stage0 distribution")
axes[0, 0].set_xlabel("value")
axes[0, 0].set_ylabel("frequency")

axes[0, 1].hist(theta1_final_batch.flatten(), bins=40, alpha=0.7, color="tab:orange")
axes[0, 1].set_title("theta1_final distribution")
axes[0, 1].set_xlabel("value")
axes[0, 1].set_ylabel("frequency")

mean_theta1_stage0_traj = jnp.mean(theta1_stage0_batch, axis=0)
mean_theta1_final_traj = jnp.mean(theta1_final_batch, axis=0)
mean_theta2_stage0_traj = jnp.mean(theta2_stage0_batch, axis=0)
mean_theta3_final_traj = jnp.mean(theta3_final_batch, axis=0)

iterations = range(1, mean_theta1_stage0_traj.shape[0] + 1)
axes[1, 0].plot(iterations, mean_theta1_stage0_traj, marker="o", label="theta1_stage0")
axes[1, 0].plot(iterations, mean_theta1_final_traj, marker="s", label="theta1_final")
axes[1, 0].set_title("theta1 mean trajectory")
axes[1, 0].set_xlabel("iteration k")
axes[1, 0].set_ylabel("mean value")
axes[1, 0].legend()

axes[1, 1].plot(iterations, mean_theta2_stage0_traj[:, 0], marker="o", label="theta2_stage0[0]")
axes[1, 1].plot(iterations, mean_theta2_stage0_traj[:, 1], marker="s", label="theta2_stage0[1]")
axes[1, 1].plot(iterations, mean_theta3_final_traj, marker="^", label="theta3_final")
axes[1, 1].set_title("theta2/theta3 mean trajectory")
axes[1, 1].set_xlabel("iteration k")
axes[1, 1].set_ylabel("mean value")
axes[1, 1].legend()

plt.tight_layout()
plt.show()