# build_seed_runner による FNmodel 推定デモ
FNmodel\_nh\_100.ipynb の設定をそのまま再利用し、`build_seed_runner` の挙動を確認する実験ノートブックです。以下では環境初期化・設定抽出・推定実行・結果保存の流れで進めます。

In [6]:
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import jax
import jax.numpy as jnp

# リポジトリルートを解決して sys.path に追加
REPO_ROOT = Path.cwd().resolve()
if REPO_ROOT.name == "notebooks":
    REPO_ROOT = REPO_ROOT.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# JAX 設定と乱数キー初期化
jax.config.update("jax_enable_x64", True)
base_prng_key = jax.random.PRNGKey(42)

print(f"Project root detected: {REPO_ROOT}")

Project root detected: /Users/yanoshouta/dev/degenerate-diffusion


## 1. 環境および依存関係の初期化
上のセルで Python 標準ライブラリと必要な外部ライブラリをインポートし、プロジェクトルートのパス解決と JAX の乱数キー初期化を済ませました。

In [15]:
from typing import Dict, Tuple

from sympy import Array, symbols

from degenerate_diffusion.evaluation.likelihood_evaluator import LikelihoodEvaluator
from degenerate_diffusion.processes.degenerate_diffusion_process import DegenerateDiffusionProcess

@dataclass
class NotebookSettings:
    model: Any
    evaluator: Any
    true_theta: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
    bounds_theta1: list[tuple[float, float]]
    bounds_theta2: list[tuple[float, float]]
    bounds_theta3: list[tuple[float, float]]
    t_max: float
    h: float
    burn_out: float
    dt: float
    loop_plan: Dict[int, tuple[str, str, str]]
    initial_theta_stage0: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

# --- FNmodel 定義（FNmodel_nh_100.ipynb と同一設定）
x_sym, y_sym = symbols("x, y")
theta_10, theta_20, theta_21 = symbols("theta_10 theta_20 theta_21")
theta_30, theta_31 = symbols("theta_30 theta_31")

x = Array([x_sym])
y = Array([y_sym])
theta_1 = Array([theta_10])
theta_2 = Array([theta_20, theta_21])
theta_3 = Array([theta_30, theta_31])

A = Array([theta_20 * y_sym - x_sym + theta_21])
B = Array([[theta_10]])
H = Array([(y_sym - y_sym**3 - x_sym + theta_31) / theta_30])

FNmodel = DegenerateDiffusionProcess(
    x=x,
    y=y,
    theta_1=theta_1,
    theta_2=theta_2,
    theta_3=theta_3,
    A=A,
    B=B,
    H=H,
 )
FN_likelihood = LikelihoodEvaluator(FNmodel)

true_theta1 = jnp.array([0.3])
true_theta2 = jnp.array([1.5, 0.8])
true_theta3 = jnp.array([0.1, 0.0])
true_theta = (true_theta1, true_theta2, true_theta3)

t_max = 100.0
burn_out = 50.0
h = 0.05
dt = 0.001

bounds_theta1 = [(0.1, 0.5)]
bounds_theta2 = [(0.5, 2.5), (0.5, 1.5)]
bounds_theta3 = [(0.01, 0.3), (-1.0, 1.0)]

loop_plan: Dict[int, tuple[str, str, str]] = {
    1: ("B", "B", "B"),
    2: ("M", "M", "M"),
    3: ("M", "M", "M"),
    4: ("M", "M", "M"),
 }

initial_theta_stage0 = (
    jnp.array([0.2]),
    jnp.array([0.5, 0.5]),
    jnp.array([0.2, 0.1]),
 )

settings = NotebookSettings(
    model=FNmodel,
    evaluator=FN_likelihood,
    true_theta=true_theta,
    bounds_theta1=bounds_theta1,
    bounds_theta2=bounds_theta2,
    bounds_theta3=bounds_theta3,
    t_max=t_max,
    h=h,
    burn_out=burn_out,
    dt=dt,
    loop_plan=loop_plan,
    initial_theta_stage0=initial_theta_stage0,
 )

print(settings)

NotebookSettings(model=DegenerateDiffusionProcess(x=[x], y=[y], theta_1=[theta_10], theta_2=[theta_20, theta_21], theta_3=[theta_30, theta_31], A=SymbolicArtifact(expr=[theta_20*y + theta_21 - x], func=<function _lambdifygenerated at 0x345c045e0>), B=SymbolicArtifact(expr=[[theta_10]], func=<function _lambdifygenerated at 0x345c04ea0>), H=SymbolicArtifact(expr=[(theta_31 - x - y**3 + y)/theta_30], func=<function _lambdifygenerated at 0x345c059e0>)), evaluator=<degenerate_diffusion.evaluation.likelihood_evaluator.LikelihoodEvaluator object at 0x3333b5190>, true_theta=(Array([0.3], dtype=float64), Array([1.5, 0.8], dtype=float64), Array([0.1, 0. ], dtype=float64)), bounds_theta1=[(0.1, 0.5)], bounds_theta2=[(0.5, 2.5), (0.5, 1.5)], bounds_theta3=[(0.01, 0.3), (-1.0, 1.0)], t_max=100.0, h=0.05, burn_out=50.0, dt=0.001, loop_plan={1: ('B', 'B', 'B'), 2: ('M', 'M', 'M'), 3: ('M', 'M', 'M'), 4: ('M', 'M', 'M')}, initial_theta_stage0=(Array([0.2], dtype=float64), Array([0.5, 0.5], dtype=float

## 2. FNmodel\_nh\_100 設定のロードと検証
nbformat で元ノートブックからパラメータセルを抽出し、データクラス `NotebookSettings` に整理して再利用します。

## 3. build\_seed\_runner による推定実行
抽出した設定を `SeedRunnerConfig` に流し込み、`build_seed_runner` が返す JIT 関数で複数シードを処理します。処理時間も合わせて計測します。

In [None]:
import importlib

import degenerate_diffusion.estimation.parameter_estimator as param_est_new
import degenerate_diffusion.estimation.loop_estimation_algorithm as loop_alg_new

importlib.reload(param_est_new)
loop_alg_new = importlib.reload(loop_alg_new)

SeedRunnerConfig = loop_alg_new.SeedRunnerConfig
build_seed_runner = loop_alg_new.build_seed_runner

seed_runner_config = SeedRunnerConfig(
    true_theta=settings.true_theta,
    t_max=settings.t_max,
    h=settings.h,
    burn_out=settings.burn_out,
    dt=settings.dt,
    bounds_theta1=settings.bounds_theta1,
    bounds_theta2=settings.bounds_theta2,
    bounds_theta3=settings.bounds_theta3,
    newton_kwargs={},
    nuts_kwargs={},
    one_step_kwargs={},
)

seed_runner = build_seed_runner(
    evaluator=settings.evaluator,
    model=settings.model,
    plan=settings.loop_plan,
    config=seed_runner_config,
)

# seeds をまとめて処理するために vmap でベクトル化
vectorized_seed_runner = jax.jit(jax.vmap(seed_runner, in_axes=(0, None)))

seeds = jnp.arange(10, dtype=jnp.int32)
start = time.perf_counter()
theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch = vectorized_seed_runner(
    seeds,
    settings.initial_theta_stage0,
)
elapsed = time.perf_counter() - start

(theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch) = jax.device_get(
    (theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch)
)

In [17]:

seeds = jnp.arange(100, dtype=jnp.int32)
start = time.perf_counter()
theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch = vectorized_seed_runner(
    seeds,
    settings.initial_theta_stage0,
)
elapsed = time.perf_counter() - start

(theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch) = jax.device_get(
    (theta1_stage0_batch, theta1_final_batch, theta2_stage0_batch, theta3_final_batch)
)

In [21]:
theta1_stage0_batch

array([[[0.33725348],
        [0.297337  ],
        [0.29701802]],

       [[0.33440347],
        [0.29923722],
        [0.29837553]],

       [[0.33437596],
        [0.2921036 ],
        [0.29165718]],

       [[0.34427096],
        [0.30700734],
        [0.30550945]],

       [[0.34330503],
        [0.30556112],
        [0.30463679]],

       [[0.34959392],
        [0.30521088],
        [0.30371557]],

       [[0.34246125],
        [0.30738399],
        [0.30684657]],

       [[0.33836054],
        [0.30041334],
        [0.29991873]],

       [[0.35304427],
        [0.30674733],
        [0.30506069]],

       [[0.35378073],
        [0.30870785],
        [0.3073566 ]],

       [[0.34692207],
        [0.30389093],
        [0.30312424]],

       [[0.34494084],
        [0.30156897],
        [0.30087986]],

       [[0.34500673],
        [0.30149996],
        [0.30006914]],

       [[0.33201057],
        [0.30021576],
        [0.29957932]],

       [[0.33554448],
        [0.29859215],
    

In [23]:
#theta1_stage0_batchの平均と標準偏差を計算して表示
mean_theta1_stage0 = jnp.mean(theta1_stage0_batch,axis=0)
std_theta1_stage0 = jnp.std(theta1_stage0_batch,axis=0)
print(f"theta1_stage0_batch - mean: {mean_theta1_stage0}, std: {std_theta1_stage0}")

mean_theta1 = jnp.mean(theta1_final_batch,axis=0)
std_theta1 = jnp.std(theta1_final_batch,axis=0)
print(f"theta1_final_batch - mean: {mean_theta1}, std: {std_theta1}")

mean_theta2_stage0 = jnp.mean(theta2_stage0_batch,axis=0)
std_theta2_stage0 = jnp.std(theta2_stage0_batch,axis=0)
print(f"theta2_stage0_batch - mean: {mean_theta2_stage0}, std: {std_theta2_stage0}")

mean_theta3_final = jnp.mean(theta3_final_batch,axis=0)
std_theta3_final = jnp.std(theta3_final_batch,axis=0)
print(f"theta3_final_batch - mean: {mean_theta3_final}, std: {std_theta3_final}")

theta1_stage0_batch - mean: [[0.34047469]
 [0.30139725]
 [0.3005843 ]], std: [[0.00787739]
 [0.00450327]
 [0.00443061]]
theta1_final_batch - mean: [[0.37907151]
 [0.30443826]
 [0.30926322]], std: [[0.01097523]
 [0.00380774]
 [0.00368141]]
theta2_stage0_batch - mean: [[1.46467633 0.78169733]
 [1.49715044 0.80184337]
 [1.50580466 0.80753478]], std: [[0.0553575  0.04368912]
 [0.05685349 0.04436981]
 [0.05754481 0.04483153]]
theta3_final_batch - mean: [[1.03743734e-01 7.00264980e-04]
 [1.01261343e-01 9.61729061e-05]
 [9.99124625e-02 2.30205263e-05]], std: [[0.00031848 0.00050003]
 [0.00024943 0.0004967 ]
 [0.00026521 0.00050811]]
