https://towardsdatascience.com/a-primer-to-bayesian-additive-regression-tree-with-r-b9d0dbf704d

In [14]:
from pathlib import Path

import os
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb

from sklearn.model_selection import train_test_split

%config InlineBackend.figure_format = "retina"

print(f"Running on PyMC v{pm.__version__}")

Running on PyMC v5.3.1


In [15]:
# Get root directory
root_dir = os.path.dirname(os.path.abspath(''))
# read data
df = pd.read_csv(
    os.path.join(root_dir, 'data', 'filtered_data.csv')
    )

In [9]:
RANDOM_SEED = 5781
np.random.seed(RANDOM_SEED)
az.style.use("arviz-darkgrid")

In [10]:
try:
    coal = np.loadtxt(Path("..", "data", "coal.csv"))
except FileNotFoundError:
    coal = np.loadtxt(pm.get_data("coal.csv"))

In [11]:
# discretize data
years = int(coal.max() - coal.min())
bins = years // 4
hist, x_edges = np.histogram(coal, bins=bins)
# compute the location of the centers of the discretized data
x_centers = x_edges[:-1] + (x_edges[1] - x_edges[0]) / 2
# xdata needs to be 2D for BART
x_data = x_centers[:, None]
# express data as the rate number of disaster per year
y_data = hist

In [20]:
x_data = df[[
    'hombre',
    'minoria_raza',
    'minoria_lgbt',
    'discapacidad',
    'enfermedad',
    'contacto_familia',
    'recibe_ayuda',
    'anios_educacion',
    'consume_drogas',
    'edad_promedio_inicio_consumo',
    'edad'
]]
y_data = df['anios_en_calle']

In [21]:
with pm.Model() as model_coal:
    # mu_ = pmb.BART("μ_", X=x_data, Y=np.log(y_data), m=20)
    mu_ = pmb.BART("μ_", X=x_data, Y=y_data, m=20)
    mu = pm.Deterministic("μ", pm.math.exp(mu_))
    y_pred = pm.Poisson("y_pred", mu=mu, observed=y_data)
    idata_coal = pm.sample(random_seed=RANDOM_SEED)

Multiprocess sampling (4 chains in 4 jobs)
PGBART: [μ_]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 60 seconds.


ContextualVersionConflict: (arviz 0.11.2 (/opt/miniconda3/envs/stats_env/lib/python3.10/site-packages), Requirement.parse('arviz>=0.13.0'), {'pymc'})

In [None]:
_, ax = plt.subplots(figsize=(10, 6))

rates = idata_coal.posterior["μ"] / 4
rate_mean = rates.mean(dim=["draw", "chain"])
ax.plot(x_centers, rate_mean, "w", lw=3)
ax.plot(x_centers, y_data / 4, "k.")
az.plot_hdi(x_centers, rates, smooth=False)
az.plot_hdi(x_centers, rates, hdi_prob=0.5, smooth=False, plot_kwargs={"alpha": 0})
ax.plot(coal, np.zeros_like(coal) - 0.5, "k|")
ax.set_xlabel("years")
ax.set_ylabel("rate");

In [None]:
plt.step(x_data, rates.sel(chain=0, draw=[3, 10]).T);

In [None]:
bart_trees = mu_.owner.op.all_trees
for i in [0, 1, 2]:
    plt.step(x_data[:, 0], [bart_trees[0][i].predict(x) for x in x_data])