Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class PGBART(ArrayStepShared):
default_blocked = False
generates_stats = True
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
"variable_inclusion": (int, []),
"variable_inclusion": (object, []),
"tune": (bool, []),
}

Expand Down
68 changes: 34 additions & 34 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=too-many-branches
"""Utility function for variable selection and bart interpretability."""

import base64
import warnings
from collections.abc import Callable
from typing import Any, TypeVar
Expand Down Expand Up @@ -708,7 +709,7 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
"""
n_vars = X.shape[1]
vi_xarray = idata["sample_stats"]["variable_inclusion"]
if "variable_inclusion_dim_0" in vi_xarray.coords:
if vi_xarray.variable_inclusion_dim_0.size > 1:
if model is None or bart_var_name is None:
raise ValueError(
"The InfereceData was generated from a model with multiple BART variables, \n"
Expand All @@ -727,13 +728,13 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = list(X.columns)
labels = list(X.columns[indices])

if labels is None:
labels = [str(i) for i in range(n_vars)]
labels = [str(i) for i in indices]

if to_kulprit:
return [labels[:idx] for idx in range(n_vars)]
return [labels[:idx] for idx in range(n_vars + 1)]
else:
return VI_norm[indices], labels

Expand Down Expand Up @@ -884,7 +885,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

if method in ["VI", "backward_VI"]:
vi_xarray = idata["sample_stats"]["variable_inclusion"]
if "variable_inclusion_dim_0" in vi_xarray.coords:
if vi_xarray.variable_inclusion_dim_0.size > 1:
if model is None:
raise ValueError(
"The InfereceData was generated from a model with multiple BART variables, \n"
Expand Down Expand Up @@ -968,7 +969,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

# Save values for plotting later
r2_mean[i_var - init] = max_r_2
r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars)
r2_hdi[i_var - init] = array_stats.hdi(
r_2_without_least_important_vars, prob=rcParams["stats.ci_prob"]
)
preds[i_var - init] = least_important_samples.squeeze()

# extend current list of least important variable
Expand Down Expand Up @@ -1282,37 +1285,34 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
return ax


def _decode_vi(n: int, length: int) -> list[int]:
"""
Decode the variable inclusion from the BART model.
"""
bits = bin(n)[2:]
vi_list: list[int] = []
def _decode_vi(s: str, length: int) -> list[int]:
"""Decode base64 string back to vector."""
data = base64.b64decode(s)
result: list[int] = []
i = 0
while len(vi_list) < length:
# Count prefix ones
prefix_len = 0
while bits[i] == "1":
prefix_len += 1
while len(result) < length and i < len(data):
num = 0
shift = 0
while i < len(data):
byte = data[i]
i += 1
i += 1 # skip the '0'
b = bits[i : i + prefix_len]
vi_list.append(int(b, 2))
i += prefix_len
return vi_list
num |= (byte & 0x7F) << shift
if not (byte & 0x80):
break
shift += 7
result.append(num)
return result


def _encode_vi(vec: npt.NDArray) -> int:
def _encode_vi(vec: list[int]) -> str:
"""
Encode variable inclusion vector into a single integer.

The encoding is done by converting each element of the vector into a binary string,
where each element contributes a prefix of '1's followed by a '0' and its binary representation.
The final result is the integer representation of the concatenated binary string.
Encode vector to base64 string.
"""
bits = ""
for x in vec:
b = bin(x)[2:]
prefix = "1" * len(b) + "0"
bits += prefix + b
return int(bits, 2)
result = bytearray()
for num in vec:
n = num
while n > 127:
result.append((n & 0x7F) | 0x80)
n >>= 7
result.append(n & 0x7F)
return base64.b64encode(bytes(result)).decode("ascii")
Loading