diff --git a/econml/_cate_estimator.py b/econml/_cate_estimator.py index 3c453b291..c1f9e26d8 100644 --- a/econml/_cate_estimator.py +++ b/econml/_cate_estimator.py @@ -599,7 +599,6 @@ def effect(self, X=None, *, T0, T1): # of rows of T was not taken into account if X is None: eff = np.repeat(eff, shape(T0)[0], axis=0) - m = shape(eff)[0] dT = T1 - T0 einsum_str = 'myt,mt->my' if ndim(dT) == 1: @@ -852,7 +851,7 @@ def _expand_treatments(self, X=None, *Ts, transform=True): n_rows = 1 if X is None else shape(X)[0] outTs = [] for T in Ts: - if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1: + if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1 and T != 0: warn("A scalar was specified but there are multiple treatments; " "the same value will be used for each treatment. Consider specifying" "all treatments, or using the const_marginal_effect method.")