From 83851c94212a9d3c6cc39f99e9a30bb25d4846d1 Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Wed, 26 Apr 2023 12:38:18 -0400 Subject: [PATCH] Fix #760 Signed-off-by: Keith Battocchi Signed-off-by: kgao --- econml/_ortho_learner.py | 2 +- econml/tests/test_dml.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index fdc9e7693..39a300ade 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -899,7 +899,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): nuisances = [np.zeros((n_iters * n_splits,) + nuis.shape) for nuis in nuisance_temp] for it, nuis in enumerate(nuisance_temp): - nuisances[it][i * n_iters + j] = nuis + nuisances[it][j * n_iters + i] = nuis for it in range(len(nuisances)): nuisances[it] = np.mean(nuisances[it], axis=0) diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 8105f7ec7..57b5c3ec4 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -1095,6 +1095,7 @@ def test_nuisance_scores(self): est.fit(y, T, X=X, W=W) assert len(est.nuisance_scores_t) == len(est.nuisance_scores_y) == mc_iters assert len(est.nuisance_scores_t[0]) == len(est.nuisance_scores_y[0]) == cv + est.score(y, T, X=X, W=W) def test_categories(self): dmls = [LinearDML, SparseLinearDML]