In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
dgen = DataGenerator(n_samples=100, n_features=3, noise=0.1, random_seed=42)
X, y = dgen.generate(scenario="linear")
data = Dataset(X, y, X)

rng = np.random.default_rng()

In [4]:
rng = np.random.default_rng(316)
rng.choice(np.arange(0, 201),size=2)

array([  1, 200])

In [5]:
processor = DefaultPreprocessor(max_bins=100)

In [6]:
prior = NTreePrior(n_trees=100,generator=rng)


In [7]:
sampler = NtreeSampler(prior,proposal_probs=proposal_probs,generator=rng)
#data = processor.fit_transform(X_train,y_train)
sampler.add_data(data)
sampler.run(100)

Running iteration 0
<bart_playground.moves.Combine object at 0x00000181087CA920>
<bart_playground.moves.Break object at 0x0000018105F06CE0>
<bart_playground.moves.Combine object at 0x0000018105F98790>
<bart_playground.moves.Break object at 0x0000018105F06CE0>
<bart_playground.moves.Combine object at 0x000001810A5DAD10>
<bart_playground.moves.Combine object at 0x0000018105F06CE0>
<bart_playground.moves.Break object at 0x000001810A5DAD10>
<bart_playground.moves.Break object at 0x000001810A5DAD10>
<bart_playground.moves.Combine object at 0x000001810A648760>
<bart_playground.moves.Break object at 0x000001810A6B7280>
Running iteration 10
<bart_playground.moves.Combine object at 0x000001810A6B7490>
<bart_playground.moves.Combine object at 0x000001810A6B7D60>
<bart_playground.moves.Break object at 0x000001810A6B7E80>
<bart_playground.moves.Break object at 0x000001810B1C1A50>
<bart_playground.moves.Combine object at 0x000001810B1C2DD0>
<bart_playground.moves.Break object at 0x000001810B1C2DD0>

[<bart_playground.params.Parameters at 0x18105f98b20>,
 <bart_playground.params.Parameters at 0x1810a648a90>,
 <bart_playground.params.Parameters at 0x1810a648ca0>,
 <bart_playground.params.Parameters at 0x1810a64a6b0>,
 <bart_playground.params.Parameters at 0x1810a64b6a0>,
 <bart_playground.params.Parameters at 0x1810a6b46d0>,
 <bart_playground.params.Parameters at 0x1810a6b5c30>,
 <bart_playground.params.Parameters at 0x1810a6b7130>,
 <bart_playground.params.Parameters at 0x1810b1c02e0>,
 <bart_playground.params.Parameters at 0x1810b1c15d0>,
 <bart_playground.params.Parameters at 0x1810b1c2890>,
 <bart_playground.params.Parameters at 0x1810b1c3a00>,
 <bart_playground.params.Parameters at 0x1810b218970>,
 <bart_playground.params.Parameters at 0x1810b21a050>,
 <bart_playground.params.Parameters at 0x1810b21b6d0>,
 <bart_playground.params.Parameters at 0x1810b26cd90>,
 <bart_playground.params.Parameters at 0x1810b26dba0>,
 <bart_playground.params.Parameters at 0x1810b26f460>,
 <bart_pla

In [8]:
bart = ChangeNumTreeBART(ndpost=300, nskip=100, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)

Running iteration 0
<bart_playground.moves.Break object at 0x000001810DA37E20>
<bart_playground.moves.Combine object at 0x000001810DA37EE0>
<bart_playground.moves.Combine object at 0x000001810DA37C70>
<bart_playground.moves.Combine object at 0x000001810DA37E20>
<bart_playground.moves.Combine object at 0x000001810DA37E20>
<bart_playground.moves.Break object at 0x000001810DA8E650>
<bart_playground.moves.Break object at 0x000001810DA37E20>
<bart_playground.moves.Break object at 0x000001810DA8EFE0>
<bart_playground.moves.Break object at 0x000001810DAE47F0>
<bart_playground.moves.Break object at 0x000001810DAE73D0>
Running iteration 10
<bart_playground.moves.Break object at 0x000001810DAE73D0>
<bart_playground.moves.Combine object at 0x000001810DB33940>
<bart_playground.moves.Break object at 0x000001810DB33580>
<bart_playground.moves.Combine object at 0x000001810DB333A0>
<bart_playground.moves.Combine object at 0x000001810DB33850>
<bart_playground.moves.Combine object at 0x000001810DFCD720>

In [9]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=300, nskip=100, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Running iteration 0
<bart_playground.moves.Break object at 0x0000018105AFB070>
<bart_playground.moves.Combine object at 0x0000018105F070A0>
<bart_playground.moves.Combine object at 0x0000018105F070A0>
<bart_playground.moves.Combine object at 0x000001810DA8DCF0>
<bart_playground.moves.Combine object at 0x000001810A5D9720>
<bart_playground.moves.Break object at 0x000001810DA8DCF0>
<bart_playground.moves.Break object at 0x000001810A5D9720>
<bart_playground.moves.Break object at 0x0000018118F82530>
<bart_playground.moves.Break object at 0x000001810DA8DCF0>
<bart_playground.moves.Break object at 0x000001810DA8DCF0>
Running iteration 10
<bart_playground.moves.Break object at 0x0000018118FF7730>
<bart_playground.moves.Combine object at 0x000001810DA8DCF0>
<bart_playground.moves.Break object at 0x000001810DA8DCF0>
<bart_playground.moves.Combine object at 0x00000181199C3A30>
<bart_playground.moves.Combine object at 0x00000181199C2DD0>
<bart_playground.moves.Combine object at 0x00000181199C35B0>

In [10]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [11]:
bart.predict(X_test)

TypeError: list indices must be integers or slices, not list

In [None]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

TypeError: list indices must be integers or slices, not list