diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 72764b2..05e8f0d 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -118,7 +118,7 @@ class PGBART(ArrayStepShared): default_blocked = False generates_stats = True stats_dtypes_shapes: dict[str, tuple[type, list]] = { - "variable_inclusion": (int, []), + "variable_inclusion": (object, []), "tune": (bool, []), } diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index dfb5eac..e7b0f6e 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,6 +1,7 @@ # pylint: disable=too-many-branches """Utility function for variable selection and bart interpretability.""" +import base64 import warnings from collections.abc import Callable from typing import Any, TypeVar @@ -708,7 +709,7 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None """ n_vars = X.shape[1] vi_xarray = idata["sample_stats"]["variable_inclusion"] - if "variable_inclusion_dim_0" in vi_xarray.coords: + if vi_xarray.variable_inclusion_dim_0.size > 1: if model is None or bart_var_name is None: raise ValueError( "The InfereceData was generated from a model with multiple BART variables, \n" @@ -727,13 +728,13 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None n_vars = len(indices) if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = list(X.columns) + labels = list(X.columns[indices]) if labels is None: - labels = [str(i) for i in range(n_vars)] + labels = [str(i) for i in indices] if to_kulprit: - return [labels[:idx] for idx in range(n_vars)] + return [labels[:idx] for idx in range(n_vars + 1)] else: return VI_norm[indices], labels @@ -884,7 +885,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 if method in ["VI", "backward_VI"]: vi_xarray = idata["sample_stats"]["variable_inclusion"] - if "variable_inclusion_dim_0" in vi_xarray.coords: + if vi_xarray.variable_inclusion_dim_0.size > 1: if model is None: raise ValueError( "The InfereceData was generated from a model with multiple BART variables, \n" @@ -968,7 +969,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 # Save values for plotting later r2_mean[i_var - init] = max_r_2 - r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars) + r2_hdi[i_var - init] = array_stats.hdi( + r_2_without_least_important_vars, prob=rcParams["stats.ci_prob"] + ) preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable @@ -1282,37 +1285,34 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): return ax -def _decode_vi(n: int, length: int) -> list[int]: - """ - Decode the variable inclusion from the BART model. - """ - bits = bin(n)[2:] - vi_list: list[int] = [] +def _decode_vi(s: str, length: int) -> list[int]: + """Decode base64 string back to vector.""" + data = base64.b64decode(s) + result: list[int] = [] i = 0 - while len(vi_list) < length: - # Count prefix ones - prefix_len = 0 - while bits[i] == "1": - prefix_len += 1 + while len(result) < length and i < len(data): + num = 0 + shift = 0 + while i < len(data): + byte = data[i] i += 1 - i += 1 # skip the '0' - b = bits[i : i + prefix_len] - vi_list.append(int(b, 2)) - i += prefix_len - return vi_list + num |= (byte & 0x7F) << shift + if not (byte & 0x80): + break + shift += 7 + result.append(num) + return result -def _encode_vi(vec: npt.NDArray) -> int: +def _encode_vi(vec: list[int]) -> str: """ - Encode variable inclusion vector into a single integer. - - The encoding is done by converting each element of the vector into a binary string, - where each element contributes a prefix of '1's followed by a '0' and its binary representation. - The final result is the integer representation of the concatenated binary string. + Encode vector to base64 string. """ - bits = "" - for x in vec: - b = bin(x)[2:] - prefix = "1" * len(b) + "0" - bits += prefix + b - return int(bits, 2) + result = bytearray() + for num in vec: + n = num + while n > 127: + result.append((n & 0x7F) | 0x80) + n >>= 7 + result.append(n & 0x7F) + return base64.b64encode(bytes(result)).decode("ascii")