# Linearmodel Seed Runner Experiment
新しい result_processing ヘルパーを用いて、linearmodel のベイズ/M 推定結果を CSV と要約統計に変換します。

In [1]:
import sys
import time
from pathlib import Path

import jax
import jax.numpy as jnp

REPO_ROOT = Path.cwd().resolve()
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

jax.config.update("jax_enable_x64", True)
print(f"Project root: {REPO_ROOT}")

Project root: /Users/yanoshouta/dev/degenerate-diffusion/experiment


In [2]:
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from sympy import Array, symbols
from degenerate_diffusion.evaluation.likelihood_evaluator import LikelihoodEvaluator
from degenerate_diffusion.processes.degenerate_diffusion_process import DegenerateDiffusionProcess


@dataclass
class LinearmodelSettings:
    model: Any
    evaluator: Any
    true_theta: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
    bounds_theta1: list[tuple[float, float]]
    bounds_theta2: list[tuple[float, float]]
    bounds_theta3: list[tuple[float, float]]
    t_max: float
    h: float
    burn_out: float
    dt: float
    loop_plan: Dict[int, tuple[str, str, str]]
    initial_theta_stage0: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

In [None]:
x_sym, y_sym = symbols("x, y")
theta_10, theta_20, theta_21 = symbols("theta_10 theta_20 theta_21")
theta_30 = symbols("theta_30")
x = Array([x_sym])
y = Array([y_sym])
theta_1 = Array([theta_10])
theta_2 = Array([theta_20, theta_21])
theta_3 = Array([theta_30])
A = Array([-theta_20 * x_sym - theta_21 * y_sym])
B = Array([[theta_10]])
H = Array([theta_30 * x_sym])


linearmodel = DegenerateDiffusionProcess(
    x=x,
    y=y,
    theta_1=theta_1,
    theta_2=theta_2,
    theta_3=theta_3,
    A=A,
    B=B,
    H=H,
)
linearmodel_likelihood = LikelihoodEvaluator(linearmodel)


true_theta = (
    jnp.array([1.0]),
    jnp.array([1.0, 1.0]),
    jnp.array([1.0]),
)
bounds_theta1 = [(0.1, 2.0)]
bounds_theta2 = [(0.1, 2.0), (0.1, 2.0)]
bounds_theta3 = [(0.1, 2.0)]
loop_plan: Dict[int, tuple[str, str, str]] = {
    1: ("M", "M", "M"),
    2: ("M", "M", "M"),
    # 3: ("M", "M", "M"),
    # 4: ("M", "M", "M"),
    # 5: ("M", "M", "M"),
    # 6: ("M", "M", "M"),
}
initial_theta_stage0 = (
    jnp.array([0.4]),
    jnp.array([0.1, 0.2]),
    jnp.array([0.4]),
)
settings = LinearmodelSettings(
    model=linearmodel,
    evaluator=linearmodel_likelihood,
    true_theta=true_theta,
    bounds_theta1=bounds_theta1,
    bounds_theta2=bounds_theta2,
    bounds_theta3=bounds_theta3,
    t_max=100.0,
    h=0.5,
    burn_out=50.0,
    dt=0.001,
    loop_plan=loop_plan,
    initial_theta_stage0=initial_theta_stage0,
)
settings

LinearmodelSettings(model=DegenerateDiffusionProcess(x=[x], y=[y], theta_1=[theta_10], theta_2=[theta_20, theta_21], theta_3=[theta_30], A=SymbolicArtifact(expr=[-theta_20*x - theta_21*y], func=<function _lambdifygenerated at 0x15bddb060>), B=SymbolicArtifact(expr=[[theta_10]], func=<function _lambdifygenerated at 0x15bdda480>), H=SymbolicArtifact(expr=[theta_30*x], func=<function _lambdifygenerated at 0x15bddbc40>)), evaluator=<degenerate_diffusion.evaluation.likelihood_evaluator.LikelihoodEvaluator object at 0x13d3b8790>, true_theta=(Array([1.], dtype=float64), Array([1., 1.], dtype=float64), Array([1.], dtype=float64)), bounds_theta1=[(0.1, 2.0)], bounds_theta2=[(0.1, 2.0), (0.1, 2.0)], bounds_theta3=[(0.1, 2.0)], t_max=100.0, h=0.5, burn_out=50.0, dt=0.01, loop_plan={1: ('M', 'M', 'M'), 2: ('M', 'M', 'M')}, initial_theta_stage0=(Array([0.4], dtype=float64), Array([0.1, 0.2], dtype=float64), Array([0.4], dtype=float64)))

In [10]:
import importlib
import degenerate_diffusion.estimation.parameter_estimator as parameter_estimator_mod
import degenerate_diffusion.estimation.loop_estimation_algorithm as loop_algo_mod
from result_processing import process_and_save

parameter_estimator_mod = importlib.reload(parameter_estimator_mod)
loop_algo_mod = importlib.reload(loop_algo_mod)
SeedRunnerConfig = loop_algo_mod.SeedRunnerConfig
build_seed_runner = loop_algo_mod.build_seed_runner

seed_runner_config = SeedRunnerConfig(
    true_theta=settings.true_theta,
    t_max=settings.t_max,
    h=settings.h,
    burn_out=settings.burn_out,
    dt=settings.dt,
    bounds_theta1=settings.bounds_theta1,
    bounds_theta2=settings.bounds_theta2,
    bounds_theta3=settings.bounds_theta3,
    newton_kwargs={},
    nuts_kwargs={},
    one_step_kwargs={},
)


seed_runner = build_seed_runner(
    evaluator=settings.evaluator,
    model=settings.model,
    plan=settings.loop_plan,
    config=seed_runner_config,
)


vectorized_seed_runner = jax.jit(jax.vmap(seed_runner, in_axes=(0, None)))

vectorized_seed_runner

<PjitFunction of <function build_seed_runner.<locals>.runner at 0x15bdd9ee0>>

In [11]:
seeds = jnp.arange(400, dtype=jnp.int32)


start = time.perf_counter()


theta10_batch, theta1_batch, theta2_batch, theta3_batch = vectorized_seed_runner(
    seeds,
    settings.initial_theta_stage0,
)


elapsed = time.perf_counter() - start


theta10_batch, theta1_batch, theta2_batch, theta3_batch = jax.device_get(
    (theta10_batch, theta1_batch, theta2_batch, theta3_batch),
)


print(f"Processed {len(seeds)} seeds in {elapsed:.3f} seconds.")

Processed 400 seeds in 9.949 seconds.


In [12]:
batches = {
    "theta10": theta10_batch,
    "theta1": theta1_batch,
    "theta2": theta2_batch,
    "theta3": theta3_batch,
}


df_results, summary_table, output_path = process_and_save(
    seeds,
    batches,
    output_path=Path("experiment/linearmodel_seed_runner_results_test.csv"),
)


print(f"Saved full results to {output_path.resolve()}")


summary_table

Saved full results to /Users/yanoshouta/dev/degenerate-diffusion/experiment/experiment/linearmodel_seed_runner_results_test.csv


Unnamed: 0,parameter,mean,std
0,theta10_0_1,0.988459,0.053027
1,theta1_0_1,1.02269,0.040233
2,theta2_0_1,0.982008,0.10508
3,theta2_1_1,0.775977,0.111386
4,theta3_0_1,1.005427,0.026509


In [13]:
summary_table.to_csv(Path("experiment/linearmodel_summary_table.csv"), index=False)

In [14]:
from result_processing import build_theta_latex_table

latex = build_theta_latex_table(
    "experiment/linearmodel_summary_table.csv",
    {
        "theta10": [1.0],
        "theta1": [1.0],
        "theta2": [1.0, 1.0],
        "theta3": [1.0],
    },
)
print(latex)

\begin{tabular}{cccccc}
\hline
$k$ & $\theta^{{k,0}}_{1,0}$ & $\theta^{{k}}_{1,0}$ & $\theta^{{k}}_{2,0}$ & $\theta^{{k}}_{2,1}$ & $\theta^{{k}}_{3,0}$ \
\hline
true & 1.000 & 1.000 & 1.000 & 1.000 & 1.000 \\
1 & 0.988 (0.053) & 1.023 (0.040) & 0.982 (0.105) & 0.776 (0.111) & 1.005 (0.027) \\
\hline
\end{tabular}
