# Testing BCF Incremental Updating with Multiple Treatment Arms

This notebook demonstrates the incremental updating capability of BCF with multiple treatment arms.

In [1]:
import numpy as np
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF
from bart_playground import *

## Generate synthetic data with multiple treatment arms

In [2]:
# Generate data
n_samples = 500
n_features = 2
n_treat_arms = 2

# Generate covariates and base outcome
generator = DataGenerator(n_samples=n_samples, n_features=n_features, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")

# Generate treatment indicators
z_rng = np.random.default_rng(0)
z1 = z_rng.binomial(1, 0.4, n_samples).astype(bool)
z2 = ((1 - z1) * z_rng.binomial(1, 0.3, n_samples)).astype(bool)
z = np.column_stack((z1, z2))

# Generate treatment effects
y = y + z[:, 0] * 0.5 - z[:, 1] * 0.5

# Split into initial training, incremental and test sets
X_init, X_temp, y_init, y_temp, z_init, z_temp = train_test_split(X, y, z, test_size=0.96, random_state=42)
X_inc, X_test, y_inc, y_test, z_inc, z_test = train_test_split(X_temp, y_temp, z_temp, test_size=20/96, random_state=42)

print(f"Initial training set size: {X_init.shape[0]}")
print(f"Incremental data set size: {X_inc.shape[0]}")
print(f"Test set size: {X_test.shape[0]}")

# Preview the data
np.set_printoptions(suppress=True)
print("\nSample initial training data:")
print("X:", X_init[:3])
print("y:", y_init[:3])
print("z:", z_init[:3])

Initial training set size: 20
Incremental data set size: 380
Test set size: 100

Sample initial training data:
X: [[0.14617324 0.82466419]
 [0.53606804 0.51422287]
 [0.65292213 0.95534943]]
y: [0.54945567 0.53520345 0.52929338]
z: [[False False]
 [False False]
 [False False]]


## Initialize and fit BCF model with initial data

In [3]:
# Initialize BCF model
bcf = BCF(
    n_treat_arms=n_treat_arms,  # Number of treatment arms
    n_mu_trees=100,             # Number of prognostic effect trees
    n_tau_trees=[50, 50],       # Number of treatment effect trees
    ndpost=100,                 # Posterior samples
    nskip=100,                  # Burn-in iterations
    random_state=42
)

# Fit model with initial training data
bcf.fit(X_init, y_init, z_init)

Iterations: 100%|██████████| 200/200 [00:06<00:00, 28.98it/s]


## Update model with incremental data

In [4]:
# Method 1: Update with all incremental data at once
bcf_batch_update = BCF(
    n_treat_arms=n_treat_arms,
    n_mu_trees=100,
    n_tau_trees=[50, 50],
    ndpost=100,
    nskip=100,
    random_state=42
)
bcf_batch_update.fit(X_init, y_init, z_init)

# Update with all incremental data at once
bcf_batch_update.update_fit(X_inc, y_inc, z_inc, add_ndpost=100, add_nskip=100)

# Method 2: Update incrementally one sample at a time
bcf_incremental = BCF(
    n_treat_arms=n_treat_arms,
    n_mu_trees=100,
    n_tau_trees=[50, 50],
    ndpost=100,
    nskip=100,
    random_state=42
)
bcf_incremental.fit(X_init, y_init, z_init)

# Update one sample at a time
inc_n = X_inc.shape[0]
for i in tqdm(range(inc_n)):
    # Use i:(i+1) to maintain dimensions
    bcf_incremental.update_fit(
        X_inc[i:(i+1), :], 
        y_inc[i:(i+1)], 
        z_inc[i:(i+1), :], 
        add_ndpost=1, 
        add_nskip=0, 
        quietly=True
    )
    

Iterations: 100%|██████████| 200/200 [00:03<00:00, 61.55it/s]
Iterations: 100%|██████████| 200/200 [00:03<00:00, 56.88it/s]
Iterations: 100%|██████████| 200/200 [00:03<00:00, 61.47it/s]
100%|██████████| 380/380 [00:11<00:00, 31.84it/s]


## Evaluate model after incremental updates

In [5]:
before_update_preds = bcf.predict(X_test, z_test)
before_update_mse = mean_squared_error(y_test, before_update_preds)

batch_preds = bcf_batch_update.predict(X_test, z_test)
batch_mse = mean_squared_error(y_test, batch_preds)

incremental_preds = bcf_incremental.predict(X_test, z_test)
incremental_mse = mean_squared_error(y_test, incremental_preds)

print(f"MSE before any updates: {before_update_mse:.6f}")
print(f"MSE after batch update: {batch_mse:.6f}")
print(f"MSE after incremental updates: {incremental_mse:.6f}")

MSE before any updates: 0.060028
MSE after batch update: 0.025059
MSE after incremental updates: 0.025569


## Examine model internals

In [6]:
# Examine a tree from the mu ensemble
mu_tree_idx = 10  # Choose an arbitrary tree index
print("Sample mu tree before updates:")
print(bcf.trace[-1].mu_trees[mu_tree_idx])

print("\nSample mu tree after batch update:")
print(bcf_batch_update.trace[-1].mu_trees[mu_tree_idx])

print("\nSample mu tree after incremental updates:")
print(bcf_incremental.trace[-1].mu_trees[mu_tree_idx])

Sample mu tree before updates:
X_3 <= 0.150 (split, n = 20)
	Val: -0.026 (leaf, n = 11)
	Val: -0.022 (leaf, n = 9)

Sample mu tree after batch update:
X_3 <= 0.229 (split, n = 400)
	X_0 <= 0.205 (split, n = 328)
		Val: 0.026 (leaf, n = 60)
		Val: -0.009 (leaf, n = 268)
	X_2 <= 0.400 (split, n = 72)
		Val: -0.040 (leaf, n = 65)
		Val: 0.045 (leaf, n = 7)

Sample mu tree after incremental updates:
X_3 <= 0.148 (split, n = 21)
	Val: -0.032 (leaf, n = 267)
	Val: 0.014 (leaf, n = 133)


## Fit a new model with all data at once for comparison

In [7]:
# Create full dataset
X_full = np.vstack([X_init, X_inc])
y_full = np.concatenate([y_init, y_inc])
z_full = np.vstack([z_init, z_inc])

# Fit a new model from scratch with all data
bcf_full = BCF(
    n_treat_arms=n_treat_arms,
    n_mu_trees=100,
    n_tau_trees=[50, 50],
    ndpost=100,
    nskip=100,
    random_state=42
)
bcf_full.fit(X_full, y_full, z_full)

# Get predictions and MSE
full_preds = bcf_full.predict(X_test, z_test)
full_mse = mean_squared_error(y_test, full_preds)

print(f"MSE with model fit on all data at once: {full_mse:.6f}")

Iterations: 100%|██████████| 200/200 [00:03<00:00, 54.16it/s]


MSE with model fit on all data at once: 0.021381


## Compare all approaches

In [8]:
# Compare all methods
results = {
    "Initial training only": before_update_mse,
    "Batch update": batch_mse,
    "Incremental updates": incremental_mse,
    "Full data from scratch": full_mse
}

for name, mse in results.items():
    print(f"{name}: {mse:.6f}")

Initial training only: 0.060028
Batch update: 0.025059
Incremental updates: 0.025569
Full data from scratch: 0.021381
