In [1]:
import numpy as np
from bart_playground.params import Tree, Parameters
from bart_playground.priors import ComprehensivePrior
from bart_playground.moves import Move

In [2]:
class DummyMove(Move):
    def is_feasible(self): return True
    def try_propose(self, proposed, generator): return True

## Grow

In [3]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)
y = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

tree = Tree.new(dataX=X)
tree.split_leaf(node_id=0, var=0, threshold=0.5)

# Prepare Parameters and Priors
global_params = {"eps_sigma2": np.array([0.1], dtype=np.float32)}
params = Parameters([tree], global_params)
prior = ComprehensivePrior(n_trees=1)
likelihood = prior.likelihood
tree_prior = prior.tree_prior

move = DummyMove(params, trees_changed=np.array([0]))
move.tree_prior = tree_prior  # Needed for _calculate_simulated_likelihood

residuals = y - params.evaluate(all_except=move.trees_changed)

# Simulate splitting the root node
sim_leaf_ids, sim_n, sim_vars = tree.simulate_split_leaf(node_id=1, var=1, threshold=0.3)

lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])


print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)

print("Simulated split leaf_ids:", sim_leaf_ids)
print("Simulated split n:", sim_n)
print("Simulated split vars:", sim_vars)

# Actually split the root node
tree.split_leaf(node_id=1, var=1, threshold=0.3)
print("Actual split leaf_ids:", tree.leaf_ids)
print("Actual split n:", tree.n)
print("Actual split vars:", tree.vars)

lkhd_true = likelihood.trees_log_marginal_lkhd(params, y, [0])

# Check if simulation matches actual split
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))
print("Likelihood match (within 1e-8):", np.allclose(lkhd_sim, lkhd_true, atol=1e-8))

Old leaf_ids: [1 1 2 1]
Old n: [4 3 1 0 0 0 0 0]
Old vars: [ 0 -1 -1 -2 -2 -2 -2 -2]
Simulated split leaf_ids: [3 4 2 4]
Simulated split n: [4 3 1 1 2 0 0 0]
Simulated split vars: [ 0  1 -1 -1 -1 -2 -2 -2]
Actual split leaf_ids: [3 4 2 4]
Actual split n: [4 3 1 1 2 0 0 0]
Actual split vars: [ 0  1 -1 -1 -1 -2 -2 -2]
leaf_ids match: True
n match: True
vars match: True
Likelihood match (within 1e-8): True


## Prune

In [4]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)
y = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

tree = Tree.new(dataX=X)
tree.split_leaf(node_id=0, var=0, threshold=0.5)
tree.split_leaf(node_id=1, var=1, threshold=0.3)

# Prepare Parameters and Priors
global_params = {"eps_sigma2": np.array([0.1], dtype=np.float32)}
params = Parameters([tree], global_params)
prior = ComprehensivePrior(n_trees=1)
likelihood = prior.likelihood
tree_prior = prior.tree_prior

move = DummyMove(params, trees_changed=np.array([0]))
move.tree_prior = tree_prior  # Needed for _calculate_simulated_likelihood

residuals = y - params.evaluate(all_except=move.trees_changed)

# Simulate pruning the split at node 1
sim_leaf_ids, sim_n, sim_vars = tree.simulate_prune_split(node_id=1)

# Calculate simulated likelihood
lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])

print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)

print("Simulated prune leaf_ids:", sim_leaf_ids)
print("Simulated prune n:", sim_n)
print("Simulated prune vars:", sim_vars)
print("Simulated prune likelihood:", lkhd_sim)

# Actually prune the split at node 1
tree.prune_split(node_id=1)
print("Actual prune leaf_ids:", tree.leaf_ids)
print("Actual prune n:", tree.n)
print("Actual prune vars:", tree.vars)

# Calculate true likelihood after actual prune
lkhd_true = likelihood.trees_log_marginal_lkhd(params, y, [0])
print("Actual prune likelihood:", lkhd_true)

# Check if simulation matches actual prune
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))
print("Likelihood match (within 1e-8):", np.allclose(lkhd_sim, lkhd_true, atol=1e-8))

Old leaf_ids: [3 4 2 4]
Old n: [4 3 1 1 2 0 0 0]
Old vars: [ 0  1 -1 -1 -1 -2 -2 -2]
Simulated prune leaf_ids: [1 1 2 1]
Simulated prune n: [4 3 1 0 0 0 0 0]
Simulated prune vars: [ 0 -1 -1 -2 -2 -2 -2 -2]
Simulated prune likelihood: -80.20221761552283
Actual prune leaf_ids: [1 1 2 1]
Actual prune n: [4 3 1 0 0 0 0 0]
Actual prune vars: [ 0 -1 -1 -2 -2 -2 -2 -2]
Actual prune likelihood: -80.20221761552283
leaf_ids match: True
n match: True
vars match: True
Likelihood match (within 1e-8): True


## Change

In [5]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)
y = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

tree = Tree.new(dataX=X)
tree.split_leaf(node_id=0, var=0, threshold=0.5)
tree.split_leaf(node_id=1, var=1, threshold=0.3)

# Prepare Parameters and Priors
global_params = {"eps_sigma2": np.array([0.1], dtype=np.float32)}
params = Parameters([tree], global_params)
prior = ComprehensivePrior(n_trees=1)
likelihood = prior.likelihood
tree_prior = prior.tree_prior

move = DummyMove(params, trees_changed=np.array([0]))
move.tree_prior = tree_prior  # Needed for _calculate_simulated_likelihood

residuals = y - params.evaluate(all_except=move.trees_changed)

node_id = 0
sim_leaf_ids, sim_n, sim_vars = tree.simulate_change_split(node_id=node_id, var=1, threshold=0.6)
lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])

print("Old tree structure:", tree)
print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)

print("Simulated change leaf_ids:", sim_leaf_ids)
print("Simulated change n:", sim_n)
print("Simulated change vars:", sim_vars)
print("Simulated change likelihood:", lkhd_sim)

tree.change_split(node_id=node_id, var=1, threshold=0.6)
print("New tree structure:", tree)

# Calculate true likelihood after actual prune
lkhd_true = likelihood.trees_log_marginal_lkhd(params, y, [0])
print("Actual change likelihood:", lkhd_true)

# Check if simulation matches actual prune
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))
print("Likelihood match (within 1e-8):", np.allclose(lkhd_sim, lkhd_true, atol=1e-8))

valid = True
for i in range(node_id, len(sim_vars)):
    if sim_vars[i] != -2 and sim_n[i] == 0:
        valid = False
        break

print("Is the new tree structure valid?", valid)

Old tree structure: X_0 <= 0.500000000 (split, n = 4)
	X_1 <= 0.300000012 (split, n = 3)
		Val: nan (leaf, n = 1)
		Val: nan (leaf, n = 2)
	Val: nan (leaf, n = 1)
Old leaf_ids: [3 4 2 4]
Old n: [4 3 1 1 2 0 0 0]
Old vars: [ 0  1 -1 -1 -1 -2 -2 -2]
Simulated change leaf_ids: [3 4 2 2]
Simulated change n: [4 2 2 1 1 0 0 0]
Simulated change vars: [ 1  1 -1 -1 -1 -2 -2 -2]
Simulated change likelihood: -73.22003220417899
New tree structure: X_1 <= 0.600000024 (split, n = 4)
	X_1 <= 0.300000012 (split, n = 2)
		Val: nan (leaf, n = 1)
		Val: nan (leaf, n = 1)
	Val: nan (leaf, n = 2)
Actual change likelihood: -73.22003220417899
leaf_ids match: True
n match: True
vars match: True
Likelihood match (within 1e-8): True
Is the new tree structure valid? True


In [6]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)

tree = Tree.new(dataX=X)

tree.split_leaf(node_id=0, var=0, threshold=0.5)

tree.split_leaf(node_id=1, var=1, threshold=0.3)

node_id = 0
sim_leaf_ids, sim_n, sim_vars = tree.simulate_change_split(node_id=node_id, var=1, threshold=0.4)

print("Old tree structure:", tree)
print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)
print("New leaf_ids:", sim_leaf_ids)
print("New n:", sim_n)
print("New vars:", sim_vars)

tree.change_split(node_id=node_id, var=1, threshold=0.4)
print("New tree structure:", tree)
# Check if simulation matches actual prune
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))

valid = True
for i in range(node_id, len(sim_vars)):
    if sim_vars[i] != -2 and sim_n[i] == 0:
        valid = False
        break

print("Is the new tree structure valid?", valid)

Old tree structure: X_0 <= 0.500000000 (split, n = 4)
	X_1 <= 0.300000012 (split, n = 3)
		Val: nan (leaf, n = 1)
		Val: nan (leaf, n = 2)
	Val: nan (leaf, n = 1)
Old leaf_ids: [3 4 2 4]
Old n: [4 3 1 1 2 0 0 0]
Old vars: [ 0  1 -1 -1 -1 -2 -2 -2]
New leaf_ids: [3 2 2 2]
New n: [4 1 3 1 0 0 0 0]
New vars: [ 1  1 -1 -1 -1 -2 -2 -2]
New tree structure: X_1 <= 0.400000006 (split, n = 4)
	X_1 <= 0.300000012 (split, n = 1)
		Val: nan (leaf, n = 1)
		Val: nan (leaf, n = 0)
	Val: nan (leaf, n = 3)
leaf_ids match: True
n match: True
vars match: True
Is the new tree structure valid? False


## Swap

In [7]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)

tree = Tree.new(dataX=X)

# Split root node (node 0)
tree.split_leaf(node_id=0, var=0, threshold=0.35)
# Split left child (node 1)
tree.split_leaf(node_id=1, var=1, threshold=0.6)

# Prepare Parameters and Priors
global_params = {"eps_sigma2": np.array([0.1], dtype=np.float32)}
params = Parameters([tree], global_params)
prior = ComprehensivePrior(n_trees=1)
likelihood = prior.likelihood
tree_prior = prior.tree_prior

move = DummyMove(params, trees_changed=np.array([0]))
move.tree_prior = tree_prior  # Needed for _calculate_simulated_likelihood

residuals = y - params.evaluate(all_except=move.trees_changed)

# Simulate swapping the split between root (0) and left child (1)
parent_id = 0
sim_leaf_ids, sim_n, sim_vars = tree.simulate_swap_split(parent_id=parent_id, child_id=1)
lkhd_sim = likelihood.calculate_simulated_likelihood(sim_leaf_ids, sim_n, residuals, global_params["eps_sigma2"][0])

print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)

print("Simulated swap leaf_ids:", sim_leaf_ids)
print("Simulated swap n:", sim_n)
print("Simulated swap vars:", sim_vars)
print("Simulated swap likelihood:", lkhd_sim)

# Actually swap
tree.swap_split(parent_id=parent_id, child_id=1)
lkhd_true = likelihood.trees_log_marginal_lkhd(params, y, [0])
print("Actual swap leaf_ids:", tree.leaf_ids)
print("Actual swap n:", tree.n)
print("Actual swap vars:", tree.vars)
print("Actual swap likelihood:", lkhd_true)

# Check if simulation matches actual swap
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))
print("likelihood match (within 1e-8):", np.allclose(lkhd_sim, lkhd_true, atol=1e-8))

valid = True
for i in range(parent_id, len(sim_vars)):
    if sim_vars[i] != -2 and sim_n[i] == 0:
        valid = False
        break

print("Is the new tree structure valid?", valid)

Old leaf_ids: [3 2 2 4]
Old n: [4 2 2 1 1 0 0 0]
Old vars: [ 0  1 -1 -1 -1 -2 -2 -2]
Simulated swap leaf_ids: [3 4 2 2]
Simulated swap n: [4 2 2 1 1 0 0 0]
Simulated swap vars: [ 1  0 -1 -1 -1 -2 -2 -2]
Simulated swap likelihood: -73.22003220417899
Actual swap leaf_ids: [3 4 2 2]
Actual swap n: [4 2 2 1 1 0 0 0]
Actual swap vars: [ 1  0 -1 -1 -1 -2 -2 -2]
Actual swap likelihood: -73.22003220417899
leaf_ids match: True
n match: True
vars match: True
likelihood match (within 1e-8): True
Is the new tree structure valid? True


In [8]:
X = np.array([[0.1, 0.2],
              [0.4, 0.5],
              [0.8, 0.7],
              [0.3, 0.9]], dtype=np.float32)

tree = Tree.new(dataX=X)

# Split root node (node 0)
tree.split_leaf(node_id=0, var=0, threshold=0.5)
# Split left child (node 1)
tree.split_leaf(node_id=1, var=1, threshold=0.3)

# Simulate swapping the split between root (0) and left child (1)
parent_id = 0
sim_leaf_ids, sim_n, sim_vars = tree.simulate_swap_split(parent_id=parent_id, child_id=1)

print("Old leaf_ids:", tree.leaf_ids)
print("Old n:", tree.n)
print("Old vars:", tree.vars)
print("Simulated swap leaf_ids:", sim_leaf_ids)
print("Simulated swap n:", sim_n)
print("Simulated swap vars:", sim_vars)

# Actually swap
tree.swap_split(parent_id=parent_id, child_id=1)
print("Actual swap leaf_ids:", tree.leaf_ids)
print("Actual swap n:", tree.n)
print("Actual swap vars:", tree.vars)

# Check if simulation matches actual swap
print("leaf_ids match:", np.all(sim_leaf_ids == tree.leaf_ids))
print("n match:", np.all(sim_n == tree.n))
print("vars match:", np.all(sim_vars == tree.vars))

valid = True
for i in range(parent_id, len(sim_vars)):
    if sim_vars[i] != -2 and sim_n[i] == 0:
        valid = False
        break

print("Is the new tree structure valid?", valid)

Old leaf_ids: [3 4 2 4]
Old n: [4 3 1 1 2 0 0 0]
Old vars: [ 0  1 -1 -1 -1 -2 -2 -2]
Simulated swap leaf_ids: [3 2 2 2]
Simulated swap n: [4 1 3 1 0 0 0 0]
Simulated swap vars: [ 1  0 -1 -1 -1 -2 -2 -2]
Actual swap leaf_ids: [3 2 2 2]
Actual swap n: [4 1 3 1 0 0 0 0]
Actual swap vars: [ 1  0 -1 -1 -1 -2 -2 -2]
leaf_ids match: True
n match: True
vars match: True
Is the new tree structure valid? False
