# TryTC_BDMH_np

Template notes for running the continuous-time birth/death Metropolis--Hastings sampler using the current `triskel_mc` package layout.


## Model and sampler reminders
- Slot mapping packs individual $\phi_j$ vectors into the concatenated $\theta$ space.
- The default examples assume a uniform prior over $k \in \{2,3,4\}$ (outside that set returns $-\infty$).
- Birth/death rate density `qb` controls how new slots are proposed or removed.
- Product-space log likelihood uses the masked $\theta$ values; ensure active/inactive slots are handled consistently.
- Parallel tempering swaps should leave the pair of chains with swapped states and matching post-swap log probabilities.


In [None]:
import numpy as np
import jax.numpy as jnp

from triskel_mc import bd_mh_step_numpy as BD_MH
from triskel_mc.runner import run_ct_mcmc
from triskel_mc.states import TraceConfig


In [None]:
# Fill in the model-specific pieces below before running the sampler.
# The call matches the current `run_ct_mcmc` signature in `triskel_mc.runner`.
# See `triskel_mc.bd_mh_step_numpy` for helper types and kernels.

# Example placeholders (replace with real values)
seed = 0
pt_init = ...  # PTState with initial theta samples (shape: C x W x D)
ps_init = ...  # PSState with initial phi/masks/logpi (shape: C x W x Kmax x d)
T_end = ...    # total continuous-time horizon
rho_mh = ...   # MH Poisson rate
betas = ...    # array of inverse temperatures (shape: C,)
slot_slices = ...  # tuple of slices or boolean masks describing theta layout

qb_density_np = ...      # callable (phi, m, rest?) -> birth/death density per slot
qb_eval_variant = "child"  # or "parent" depending on the model
log_prior_phi_np = ...   # callable phi -> scalar
log_pseudo_phi_np = ...  # callable phi -> scalar
log_p_k_np = ...         # callable (k array) -> log prior over k values
log_lik_masked_jax = ... # callable (theta, m, rest?) -> scalar log-likelihood
sample_pseudo_phi = ...  # callable () -> phi sample for refreshing inactive slots

# Optional linear algebra resources for MH moves
Ls = None
U = None
S = None

trace_cfg = TraceConfig(do_save=True)  # set to None to skip detailed traces

pt_out, ps_out, events, trace = run_ct_mcmc(
    seed=seed,
    pt_init=pt_init,
    ps_init=ps_init,
    T_end=T_end,
    rho_mh=rho_mh,
    betas=betas,
    qb_density_np=qb_density_np,
    qb_eval_variant=qb_eval_variant,
    log_prior_phi_np=log_prior_phi_np,
    log_pseudo_phi_np=log_pseudo_phi_np,
    log_p_k_np=log_p_k_np,
    log_lik_masked_jax=log_lik_masked_jax,
    slot_slices=slot_slices,
    bd_rate_scale=1.0,
    Ls=Ls,
    U=U,
    S=S,
    do_stretch=True,
    do_rw_fullcov=True,
    do_rw_eigenline=False,
    do_rw_student_t=False,
    do_de=True,
    do_PTswap=True,
    stretch_a=1.3,
    cross_rate=0.7,
    gamma_de=2.38,
    sample_pseudo_phi=sample_pseudo_phi,
    trace_cfg=trace_cfg,
)

print("Final Î¸ shape:", pt_out.thetas.shape)
print("Stored run trace:", trace)
