In [1]:
import numpy as np
import sys
import matplotlib.pyplot as plt
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

## Synthetic Data (from Wu et al. (2007))

In [2]:
# Set random seed for reproducibility
np.random.seed(42)

# Number of samples per region
n = 100

# Region 1: x1 in [0.1, 0.4], x2 in [0.1, 0.4], x3 in [0.6, 0.9]
x1_r1 = np.random.uniform(0.1, 0.4, n)
x2_r1 = np.random.uniform(0.1, 0.4, n)
x3_r1 = np.random.uniform(0.6, 0.9, n)

# Region 2: x1 in [0.1, 0.4], x2 in [0.6, 0.9], x3 in [0.6, 0.9]
x1_r2 = np.random.uniform(0.1, 0.4, n)
x2_r2 = np.random.uniform(0.6, 0.9, n)
x3_r2 = np.random.uniform(0.6, 0.9, n)

# Region 3: x1 in [0.6, 0.9], x2 in [0.1, 0.9], x3 in [0.1, 0.4]
x1_r3 = np.random.uniform(0.6, 0.9, n)
x2_r3 = np.random.uniform(0.1, 0.9, n)
x3_r3 = np.random.uniform(0.1, 0.4, n)

# Concatenate all
x1 = np.concatenate([x1_r1, x1_r2, x1_r3])
x2 = np.concatenate([x2_r1, x2_r2, x2_r3])
x3 = np.concatenate([x3_r1, x3_r2, x3_r3])

X = np.stack([x1, x2, x3], axis=1)

# Define output y based on the decision rules
def generate_y(x):
    x1 = x[0]
    x2 = x[1]
    if x1 <= 0.5:
        if x2 <= 0.5:
            return 1 + np.random.normal(0, np.sqrt(0.25))
        else:
            return 3 + np.random.normal(0, np.sqrt(0.25))
    else:
        return 5 + np.random.normal(0, np.sqrt(0.25))

y = np.array([generate_y(row) for row in X])

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

## 4 Special Moves

In [4]:
proposal_probs = {"grow" : 0.5, "prune" : 0.5}
#special_probs = {"birth": 0.5, "death": 0.5}
special_probs = {"birth": 0.25, "death": 0.25, "break": 0.25, "combine": 0.25}

In [5]:
all_results_test = []
all_results_train = []

for seed in range(5):
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)

    n_trees_ini = 50
    nskip = 10000
    ndpost = 1000
    theta_0_nskip_prop=0.5

    bart = ChangeNumTreeBART(
        ndpost=ndpost, nskip=nskip, n_trees=n_trees_ini, 
        proposal_probs=proposal_probs, special_probs=special_probs, 
        theta_0_ini=n_trees_ini, theta_0_min=10, theta_0_nskip_prop=theta_0_nskip_prop, theta_df=100, 
        temperature=1.0, tree_num_prior_type="poisson", special_move_interval=5
    )
    bart.fit(X_train, y_train)

    ntree = bart.trace[-1].n_trees
    rf = RandomForestRegressor(n_estimators=ntree)
    lr = LinearRegression()
    rf.fit(X_train, y_train)
    lr.fit(X_train, y_train)

    btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=ntree, ndpost=ndpost, nskip=2000, printevery=1000)
    btpred_all_test = btz.predict(np.transpose(X_test))
    btpred_test = np.mean(np.array(btpred_all_test), axis=0)
    btpred_all_train = btz.predict(np.transpose(X_train))
    btpred_train = np.mean(np.array(btpred_all_train), axis=0)

    bart_default = DefaultBART(ndpost=ndpost, nskip=2000, n_trees=ntree, proposal_probs=proposal_probs)
    bart_default.fit(X_train, y_train)

    models = {
        "bart": bart,
        "rf": rf,
        "lr": lr,
        "btz": btz,
        "bart_default": bart_default
    }

    # Test set
    results_test = {}
    for model_name, model in models.items():
        if model_name == "btz":
            results_test[model_name] = mean_squared_error(y_test, btpred_test)
        else:
            results_test[model_name] = mean_squared_error(y_test, model.predict(X_test))
    all_results_test.append(results_test)

    # Train set
    results_train = {}
    for model_name, model in models.items():
        if model_name == "btz":
            results_train[model_name] = mean_squared_error(y_train, btpred_train)
        else:
            results_train[model_name] = mean_squared_error(y_train, model.predict(X_train))
    all_results_train.append(results_train)

print("Test MSE for 10 seeds:")
for i, res in enumerate(all_results_test):
    print(f"Seed {i}: {res}")

print("\nTrain MSE for 10 seeds:")
for i, res in enumerate(all_results_train):
    print(f"Seed {i}: {res}")

Iterations: 100%|██████████| 11000/11000 [01:19<00:00, 139.23it/s]
INFO:2025-05-22 02:19:43,380:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-05-22 02:19:43,383:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


Iteration 1000/3000 P_grow=0.42 P_prune=0.58 A_grow=0.40 A_prune=0.29 (burnin)
Iteration 2000/3000 P_grow=0.67 P_prune=0.33 A_grow=0.12 A_prune=0.00 (burnin)
Iteration 3000/3000 P_grow=0.33 P_prune=0.67 A_grow=0.00 A_prune=0.38


Iterations: 100%|██████████| 3000/3000 [00:07<00:00, 376.57it/s]
Iterations: 100%|██████████| 11000/11000 [01:29<00:00, 123.25it/s]


Iteration 1000/3000 P_grow=0.60 P_prune=0.40 A_grow=0.33 A_prune=0.17 (burnin)
Iteration 2000/3000 P_grow=0.60 P_prune=0.40 A_grow=0.22 A_prune=0.00 (burnin)
Iteration 3000/3000 P_grow=0.33 P_prune=0.67 A_grow=0.20 A_prune=0.30


Iterations: 100%|██████████| 3000/3000 [00:12<00:00, 248.41it/s]
Iterations: 100%|██████████| 11000/11000 [01:21<00:00, 135.72it/s]


Iteration 1000/3000 P_grow=0.54 P_prune=0.46 A_grow=0.14 A_prune=0.00 (burnin)
Iteration 2000/3000 P_grow=0.54 P_prune=0.46 A_grow=0.29 A_prune=0.33 (burnin)
Iteration 3000/3000 P_grow=0.62 P_prune=0.38 A_grow=0.00 A_prune=0.60


Iterations: 100%|██████████| 3000/3000 [00:09<00:00, 327.69it/s]
Iterations: 100%|██████████| 11000/11000 [01:14<00:00, 146.70it/s]


Iteration 1000/3000 P_grow=0.45 P_prune=0.55 A_grow=0.00 A_prune=0.17 (burnin)
Iteration 2000/3000 P_grow=0.45 P_prune=0.55 A_grow=0.40 A_prune=0.33 (burnin)
Iteration 3000/3000 P_grow=0.45 P_prune=0.55 A_grow=0.20 A_prune=0.50


Iterations: 100%|██████████| 3000/3000 [00:08<00:00, 372.19it/s]
Iterations: 100%|██████████| 11000/11000 [01:10<00:00, 154.94it/s]


Iteration 1000/3000 P_grow=0.62 P_prune=0.38 A_grow=0.50 A_prune=0.33 (burnin)
Iteration 2000/3000 P_grow=0.50 P_prune=0.50 A_grow=0.12 A_prune=0.50 (burnin)
Iteration 3000/3000 P_grow=0.81 P_prune=0.19 A_grow=0.15 A_prune=0.67


Iterations: 100%|██████████| 3000/3000 [00:10<00:00, 277.00it/s]


Test MSE for 10 seeds:
Seed 0: {'bart': 0.38985630356073925, 'rf': 0.4052676755213467, 'lr': 0.7653693077399851, 'btz': 0.30705953946465425, 'bart_default': 0.3346751008392659}
Seed 1: {'bart': 0.2591057754461334, 'rf': 0.31865064720389175, 'lr': 0.6055878466068164, 'btz': 0.2777292494154197, 'bart_default': 0.2695558808005837}
Seed 2: {'bart': 0.30967085985221077, 'rf': 0.3446938795774862, 'lr': 0.6075788572945195, 'btz': 0.31552239761185535, 'bart_default': 0.3113181711748485}
Seed 3: {'bart': 0.24868843458976522, 'rf': 0.3108460682730246, 'lr': 0.557049629471702, 'btz': 0.24132353001669746, 'bart_default': 0.2508741656134269}
Seed 4: {'bart': 0.329682766860967, 'rf': 0.2819357385275089, 'lr': 0.7035209920712743, 'btz': 0.24874052833993887, 'bart_default': 0.2807518652596897}

Train MSE for 10 seeds:
Seed 0: {'bart': 0.17420845156102874, 'rf': 0.046037082458847446, 'lr': 0.6035991145784113, 'btz': 0.1857674821288368, 'bart_default': 0.18544787420294293}
Seed 1: {'bart': 0.19509541970