In [1]:
import numpy as np
from bart_playground.params import Tree, Parameters
from bart_playground.priors import ComprehensivePrior
from bart_playground.moves import Move
from bart_playground import *

from numba import njit
import numpy as np
import math

In [2]:
class DummyMove(Move):
    def is_feasible(self): return True
    def try_propose(self, proposed, generator): return True

In [3]:
from sklearn.datasets import fetch_california_housing

data = fetch_california_housing(as_frame=True)

X = data.data
y = data.target

X = X.values.astype(float)
y = np.array(y).reshape(-1)


tree = Tree.new(dataX=X)
tree.split_leaf(node_id=0, var=0, threshold=X[:, 0].mean())

True

In [4]:
# Prepare possible thresholds for each variable using DefaultPreprocessor
from bart_playground.util import DefaultPreprocessor
from numpy.random import default_rng

preprocessor = DefaultPreprocessor(max_bins=50)
preprocessor.fit(X, y)
possible_thresholds = preprocessor.thresholds

# Gather all possible split candidates: (node_id, var, threshold)
all_candidates = [
    (node_id, var, threshold)
    for node_id in tree.leaves
    for var in range(X.shape[1])
    for threshold in possible_thresholds[var]
]

generator = default_rng(4)
generator.shuffle(all_candidates)

In [5]:
selected_candidates = []
while len(selected_candidates) < 10:
    candidate = all_candidates.pop()
    tree_copy = tree.copy()
    if tree_copy.split_leaf(*candidate):
        selected_candidates.append(candidate)

selected_candidates

[(2, 0, 7.7197),
 (1, 4, 859.0),
 (2, 6, 36.28),
 (1, 1, 11.0),
 (1, 5, 2.7576301179736293),
 (1, 4, 302.0),
 (1, 7, -117.89),
 (2, 1, 19.0),
 (1, 0, 1.7734),
 (2, 7, -122.24)]

In [6]:
# Prepare Parameters and Priors
global_params = {"eps_sigma2": np.array([0.1], dtype=np.float32)}
params = Parameters([tree], global_params)
prior = ComprehensivePrior(n_trees=1)
likelihood = prior.likelihood
tree_prior = prior.tree_prior

move = DummyMove(params, trees_changed=np.array([0]))
move.tree_prior = tree_prior  # Needed for _calculate_simulated_likelihood

residuals = y - params.evaluate(all_except=move.trees_changed)

In [None]:
@njit(cache=True)
def _leaf_log_marginal_lkhd(sample_n, resid_sum, eps_sigma2, f_sigma2):
    noise_ratio = eps_sigma2 / f_sigma2
    logdet = math.log(sample_n / noise_ratio + 1.0)
    ls_resids = - (resid_sum ** 2) / sample_n
    ridge_bias = (resid_sum ** 2) / (sample_n * (sample_n / noise_ratio + 1.0))
    return - (logdet + (ls_resids + ridge_bias) / eps_sigma2) / 2

@njit(cache=True)
def fast_grow_likelihood_delta_numba(old_leaf_idx, tree_leaf_ids, tree_n, sim_leaf_ids, sim_n, residuals, eps_sigma2, f_sigma2):
    left_child = 2 * old_leaf_idx + 1
    right_child = 2 * old_leaf_idx + 2

    # Old leaf stats
    resid_sum_old = 0.0

    # New leaves stats
    resid_sum_left = 0.0
    resid_sum_right = 0.0

    # Loop
    for i in range(len(tree_leaf_ids)):
        if tree_leaf_ids[i] == old_leaf_idx:
            resid_sum_old += residuals[i]
            if sim_leaf_ids[i] == left_child:
                resid_sum_left += residuals[i]
            elif sim_leaf_ids[i] == right_child:
                resid_sum_right += residuals[i]

    sample_n_old = tree_n[old_leaf_idx]
    L_old = _leaf_log_marginal_lkhd(sample_n_old, resid_sum_old, eps_sigma2, f_sigma2)
    sample_n_left = sim_n[left_child]
    sample_n_right = sim_n[right_child]
    L_new_left = _leaf_log_marginal_lkhd(sample_n_left, resid_sum_left, eps_sigma2, f_sigma2)
    L_new_right = _leaf_log_marginal_lkhd(sample_n_right, resid_sum_right, eps_sigma2, f_sigma2)

    return (L_new_left + L_new_right) - L_old

In [8]:
# Run numba once
sim_leaf_ids, sim_n, sim_vars = tree.simulate_split_leaf(node_id=1, var=1, threshold=X[:, 1].mean())
lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])
delta_lkhd = fast_grow_likelihood_delta_numba(
    old_leaf_idx=1,
    tree_leaf_ids=tree.leaf_ids,
    tree_n=tree.n,
    sim_leaf_ids=sim_leaf_ids,
    sim_n=sim_n,
    residuals=residuals,
    eps_sigma2=global_params["eps_sigma2"][0],
    f_sigma2=likelihood.f_sigma2
)

In [9]:
import time

# First loop: calculate_simulated_likelihood
start1 = time.time()
for _ in range(1000):  # Repeat to measure time
    lkhd_sim_list = []
    for node_id, var, threshold in selected_candidates:
        sim_leaf_ids, sim_n, sim_vars = tree.simulate_split_leaf(node_id=node_id, var=var, threshold=threshold)
        lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])
        lkhd_sim_list.append(lkhd_sim)
end1 = time.time()
print(f"Time for calculate_simulated_likelihood: {end1 - start1:.4f} seconds")

# Second loop: marginal likelihood delta
start2 = time.time()
for _ in range(1000):  # Repeat to measure time
    lkhd_new_list = []
    lkhd_orig = likelihood.calculate_simulated_likelihood(tree.leaf_ids, tree.n, residuals, global_params["eps_sigma2"][0])
    for node_id, var, threshold in selected_candidates:
        threshold = np.float32(threshold)
        sim_leaf_ids, sim_n, sim_vars = tree.simulate_split_leaf(node_id=node_id, var=var, threshold=threshold)
        delta_lkhd = fast_grow_likelihood_delta_numba(
            tree_leaf_ids=tree.leaf_ids,
            tree_n=tree.n,
            sim_leaf_ids=sim_leaf_ids,
            sim_n=sim_n,
            residuals=residuals,
            eps_sigma2=global_params["eps_sigma2"][0],
            f_sigma2=likelihood.f_sigma2,
            old_leaf_idx=node_id
        )
        lkhd_new = lkhd_orig + delta_lkhd
        lkhd_new_list.append(lkhd_new)
end2 = time.time()
print(f"Time for marginal likelihood delta: {end2 - start2:.4f} seconds")

# Optionally, print results for comparison
for i, (lkhd_sim, lkhd_new) in enumerate(zip(lkhd_sim_list, lkhd_new_list)):
    print(f"Candidate {i+1}:")
    print(f"  calculate_simulated_likelihood: {lkhd_sim}")
    print(f"  marginal_lkhd (original+delta): {lkhd_new}")
    print(f"  Difference: {np.abs(lkhd_sim - lkhd_new)}\n")

Time for calculate_simulated_likelihood: 1.2007 seconds
Time for marginal likelihood delta: 1.0876 seconds
Candidate 1:
  calculate_simulated_likelihood: -85091.8702230717
  marginal_lkhd (original+delta): -85091.47150680376
  Difference: 0.3987162679404719

Candidate 2:
  calculate_simulated_likelihood: -98638.25670788652
  marginal_lkhd (original+delta): -98637.57046998572
  Difference: 0.6862379008089192

Candidate 3:
  calculate_simulated_likelihood: -98695.9849604097
  marginal_lkhd (original+delta): -98695.43013661774
  Difference: 0.5548237919574603

Candidate 4:
  calculate_simulated_likelihood: -98584.83316169672
  marginal_lkhd (original+delta): -98584.7302220859
  Difference: 0.10293961082061287

Candidate 5:
  calculate_simulated_likelihood: -95939.59621630819
  marginal_lkhd (original+delta): -95938.85592546873
  Difference: 0.7402908394578844

Candidate 6:
  calculate_simulated_likelihood: -98654.73172021899
  marginal_lkhd (original+delta): -98654.61535942854
  Differenc