From 4dfb08aa2a2b2ec77caac8c62df65f9fdf38411d Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 19 Aug 2024 06:43:38 -0300 Subject: [PATCH 1/4] Allow Y to be a tensor --- pymc_bart/pgbart.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index be4a8e8..5bcaf57 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -138,6 +138,12 @@ def __init__( # noqa: PLR0915 else: self.X = self.bart.X + if isinstance(self.bart.Y, Variable): + self.Y = self.bart.Y.eval() + else: + self.Y = self.bart.Y + + self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m self.response = self.bart.response @@ -166,7 +172,7 @@ def __init__( # noqa: PLR0915 if rule is ContinuousSplitRule: self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx])) - init_mean = self.bart.Y.mean() + init_mean = self.Y.mean() self.num_observations = self.X.shape[0] self.num_variates = self.X.shape[1] self.available_predictors = list(range(self.num_variates)) @@ -174,18 +180,18 @@ def __init__( # noqa: PLR0915 # if data is binary self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape)) - y_unique = np.unique(self.bart.Y) + y_unique = np.unique(self.Y) if y_unique.size == 2 and np.all(y_unique == [0, 1]): self.leaf_sd *= 3 / self.m**0.5 else: - self.leaf_sd *= self.bart.Y.std() / self.m**0.5 + self.leaf_sd *= self.Y.std() / self.m**0.5 self.running_sd = [ RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape) ] self.sum_trees = np.full( - (self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean + (self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean ).astype(config.floatX) self.sum_trees_noi = self.sum_trees - init_mean self.a_tree = Tree.new_tree( From cce5c57dea851ff44e39a68f59b22214a6e1fe01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 09:43:54 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 5bcaf57..7efdd48 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -142,7 +142,7 @@ def __init__( # noqa: PLR0915 self.Y = self.bart.Y.eval() else: self.Y = self.bart.Y - + self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m From 1f98c533a9964c25c0b5485461b50a33a717778c Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 19 Aug 2024 06:45:48 -0300 Subject: [PATCH 3/4] Update pgbart.py --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 7efdd48..474a5b8 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -138,7 +138,7 @@ def __init__( # noqa: PLR0915 else: self.X = self.bart.X - if isinstance(self.bart.Y, Variable): + if isinstance(self.bart.Y, Variable): self.Y = self.bart.Y.eval() else: self.Y = self.bart.Y From 09d5528688a603ac94ee0a544c5e634877ea8572 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 09:46:00 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_bart/pgbart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 474a5b8..91a9beb 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -143,7 +143,6 @@ def __init__( # noqa: PLR0915 else: self.Y = self.bart.Y - self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m self.response = self.bart.response