Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refutation & Overlap Error ("data_subset_refuter", "add_unobserved_common_cause", assess_support_and_overlap_overrule) #1185

Closed
benTC74 opened this issue May 15, 2024 · 2 comments
Labels
question Further information is requested stale

Comments

@benTC74
Copy link

benTC74 commented May 15, 2024

Hi All,

I am running into the following errors when I am performing support and overlap and refutation tests for a continuous treatment variable. The other features contain continuous, binary and categorical data. Anyone who can help is super much appreciated!! Thank you!!

  1. Refute_estimate with method "data_subset_refuter", there are no errors at all when I am performing most of the other refutation tests. Error: AssertionError: Input arrays have incompatible lengths: 154 and 192.
res_subset = cf_model.refute_estimate(
    method_name="data_subset_refuter", subset_fraction=0.8, 
    num_simulations=5)
print(res_subset)

Error:
AssertionError Traceback (most recent call last)
Cell In[585], line 1
----> 1 res_subset = cf_model.refute_estimate(
2 method_name="data_subset_refuter", subset_fraction=0.8,
3 num_simulations=3)
4 print(res_subset)

File ~\AppData\Local\anaconda3\Lib\site-packages\econml\dowhy.py:221, in DoWhyWrapper.refute_estimate(self, method_name, **kwargs)
192 def refute_estimate(self, *, method_name, **kwargs):
193 """
194 Refute an estimated causal effect.
195
(...)
219 RefuteResult: an instance of the RefuteResult class
220 """
--> 221 return self.dowhy_.refute_estimate(
222 self.identified_estimand_, self.estimate_, method_name=method_name, **kwargs
223 )

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_model.py:459, in CausalModel.refute_estimate(self, estimand, estimate, method_name, show_progress_bar, **kwargs)
456 refuter_class = causal_refuters.get_class_object(method_name)
458 refuter = refuter_class(self._data, identified_estimand=estimand, estimate=estimate, **kwargs)
--> 459 res = refuter.refute_estimate(show_progress_bar)
460 return res

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\data_subset_refuter.py:48, in DataSubsetRefuter.refute_estimate(self, show_progress_bar)
47 def refute_estimate(self, show_progress_bar: bool = False):
---> 48 refute = refute_data_subset(
49 data=self._data,
50 target_estimand=self._target_estimand,
51 estimate=self._estimate,
52 subset_fraction=self._subset_fraction,
53 num_simulations=self._num_simulations,
54 random_state=self._random_state,
55 show_progress_bar=show_progress_bar,
56 n_jobs=self._n_jobs,
57 verbose=self._verbose,
58 )
60 refute.add_refuter(self)
61 return refute

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\data_subset_refuter.py:122, in refute_data_subset(data, target_estimand, estimate, subset_fraction, num_simulations, random_state, show_progress_bar, n_jobs, verbose, **_)
115 logger.info(
116 "Refutation over {} simulated datasets of size {} each".format(
117 subset_fraction, subset_fraction * len(data.index)
118 )
119 )
121 # Run refutation in parallel
--> 122 sample_estimates = Parallel(n_jobs=n_jobs, verbose=verbose)(
123 delayed(_refute_once)(data, target_estimand, estimate, subset_fraction, random_state)
124 for _ in tqdm(
125 range(num_simulations),
126 colour=CausalRefuter.PROGRESS_BAR_COLOR,
127 disable=not show_progress_bar,
128 desc="Refuting Estimates: ",
129 )
130 )
131 sample_estimates = np.array(sample_estimates)
133 refute = CausalRefutation(estimate.value, np.mean(sample_estimates), refutation_type="Refute: Use a subset of data")

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:1085, in Parallel.call(self, iterable)
1076 try:
1077 # Only set self._iterating to True if at least a batch
1078 # was dispatched. In particular this covers the edge
(...)
1082 # was very quick and its callback already dispatched all the
1083 # remaining jobs.
1084 self._iterating = False
-> 1085 if self.dispatch_one_batch(iterator):
1086 self._iterating = self._original_iterator is not None
1088 while self.dispatch_one_batch(iterator):

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:901, in Parallel.dispatch_one_batch(self, iterator)
899 return False
900 else:
--> 901 self._dispatch(tasks)
902 return True

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:819, in Parallel._dispatch(self, batch)
817 with self._lock:
818 job_idx = len(self._jobs)
--> 819 job = self._backend.apply_async(batch, callback=cb)
820 # A job can complete so quickly than its callback is
821 # called before we get here, causing self._jobs to
822 # grow. To ensure correct results ordering, .insert is
823 # used (rather than .append) in the following line
824 self._jobs.insert(job_idx, job)

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib_parallel_backends.py:208, in SequentialBackend.apply_async(self, func, callback)
206 def apply_async(self, func, callback=None):
207 """Schedule a func to be run"""
--> 208 result = ImmediateResult(func)
209 if callback:
210 callback(result)

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib_parallel_backends.py:597, in ImmediateResult.init(self, batch)
594 def init(self, batch):
595 # Don't delay the application, to avoid keeping the input
596 # arguments in memory
--> 597 self.results = batch()

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:288, in BatchedCalls.call(self)
284 def call(self):
285 # Set the default nested backend to self._backend but do not set the
286 # change the default number of processes to -1
287 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 288 return [func(*args, **kwargs)
289 for func, args, kwargs in self.items]

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:288, in (.0)
284 def call(self):
285 # Set the default nested backend to self._backend but do not set the
286 # change the default number of processes to -1
287 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 288 return [func(*args, **kwargs)
289 for func, args, kwargs in self.items]

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\data_subset_refuter.py:77, in _refute_once(data, target_estimand, estimate, subset_fraction, random_state)
74 new_data = data.sample(frac=subset_fraction, random_state=random_state)
76 new_estimator = estimate.estimator.get_new_estimator_object(target_estimand)
---> 77 new_estimator.fit(
78 new_data,
79 estimate.estimator._effect_modifier_names,
80 **new_estimator._econml_fit_params if isinstance(new_estimator, Econml) else {},
81 )
82 new_effect = new_estimator.estimate_effect(
83 new_data,
84 control_value=estimate.control_value,
85 treatment_value=estimate.treatment_value,
86 target_units=estimate.estimator._target_units,
87 )
88 return new_effect.value

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_estimators\econml.py:194, in Econml.fit(self, data, effect_modifier_names, **kwargs)
190 estimator_named_args = estimator_argspec.args + estimator_argspec.kwonlyargs
191 estimator_data_args = {
192 arg: named_data_args[arg] for arg in named_data_args.keys() if arg in estimator_named_args
193 }
--> 194 self.estimator.fit(**estimator_data_args, **kwargs)
196 return self

File ~\AppData\Local\anaconda3\Lib\site-packages\econml\dml\causal_forest.py:854, in CausalForestDML.fit(self, Y, T, X, W, sample_weight, groups, cache_values, inference)
852 if X is None:
853 raise ValueError("This estimator does not support X=None!")
--> 854 return super().fit(Y, T, X=X, W=W,
855 sample_weight=sample_weight, groups=groups,
856 cache_values=cache_values,
857 inference=inference)

File ~\AppData\Local\anaconda3\Lib\site-packages\econml\dml_rlearner.py:422, in _RLearner.fit(self, Y, T, X, W, sample_weight, freq_weight, sample_var, groups, cache_values, inference)
385 """
386 Estimate the counterfactual model from data, i.e. estimates function :math:\\theta(\\cdot).
387
(...)
419 self: _RLearner instance
420 """
421 # Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
--> 422 return super().fit(Y, T, X=X, W=W,
423 sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups,
424 cache_values=cache_values,
425 inference=inference)

File ~\AppData\Local\anaconda3\Lib\site-packages\econml_cate_estimator.py:131, in BaseCateEstimator._wrap_fit..call(self, Y, T, inference, *args, **kwargs)
129 inference.prefit(self, Y, T, *args, **kwargs)
130 # call the wrapped fit method
--> 131 m(self, Y, T, *args, **kwargs)
132 self._postfit(Y, T, *args, **kwargs)
133 if inference is not None:
134 # NOTE: we call inference fit after calling the main fit method

File ~\AppData\Local\anaconda3\Lib\site-packages\econml_ortho_learner.py:755, in _OrthoLearner.fit(self, Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups, cache_values, inference, only_final, check_input)
752 assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization "
753 "is not supported when treatment is discrete"
754 if check_input:
--> 755 Y, T, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
756 Y, T, Z, sample_weight, freq_weight, sample_var, groups)
757 X, = check_input_arrays(
758 X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
759 W, = check_input_arrays(
760 W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)

File ~\AppData\Local\anaconda3\Lib\site-packages\econml\utilities.py:599, in check_input_arrays(validate_len, force_all_finite, dtype, *args)
597 n = m
598 else:
--> 599 assert (m == n), "Input arrays have incompatible lengths: {} and {}".format(n, m)
600 args[i] = new_arg
601 return args

AssertionError: Input arrays have incompatible lengths: 154 and 192

  1. Same case but for the method "add_unobserved_common_cause" with plugin_reisz=True, there are no errors at all when I am performing other parameters of same method. Error: IndexError: index 15 is out of bounds for axis 1 with size 8
res_unobserved_nonparam_benchmark_reisz = cf_model.refute_estimate(
                                    method_name="add_unobserved_common_cause",
                                    simulation_method = 'non-parametric-partial-R2',
                                    benchmark_common_causes = ["GDP_Millions"],
                                    effect_fraction_on_treatment = 0.2,
                                    effect_fraction_on_outcome = 0.2,
                                    plugin_reisz=True
                                    )

Error:
IndexError Traceback (most recent call last)
Cell In[608], line 3
1 #Non param with benchmark and no assumptions on underlying DGP
----> 3 res_unobserved_nonparam_benchmark_reisz = cf_model.refute_estimate(
4 method_name="add_unobserved_common_cause",
5 simulation_method = 'non-parametric-partial-R2',
6 benchmark_common_causes = ["GDP_Millions"],
7 effect_fraction_on_treatment = 0.2,
8 effect_fraction_on_outcome = 0.2,
9 plugin_reisz=True
10 )

File ~\AppData\Local\anaconda3\Lib\site-packages\econml\dowhy.py:221, in DoWhyWrapper.refute_estimate(self, method_name, **kwargs)
192 def refute_estimate(self, *, method_name, **kwargs):
193 """
194 Refute an estimated causal effect.
195
(...)
219 RefuteResult: an instance of the RefuteResult class
220 """
--> 221 return self.dowhy_.refute_estimate(
222 self.identified_estimand_, self.estimate_, method_name=method_name, **kwargs
223 )

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_model.py:459, in CausalModel.refute_estimate(self, estimand, estimate, method_name, show_progress_bar, **kwargs)
456 refuter_class = causal_refuters.get_class_object(method_name)
458 refuter = refuter_class(self._data, identified_estimand=estimand, estimate=estimate, **kwargs)
--> 459 res = refuter.refute_estimate(show_progress_bar)
460 return res

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\add_unobserved_common_cause.py:147, in AddUnobservedCommonCause.refute_estimate(self, show_progress_bar)
134 return sensitivity_linear_partial_r2(
135 self._data,
136 self._estimate,
(...)
144 self.plot_estimate,
145 )
146 elif self.simulation_method == "non-parametric-partial-R2":
--> 147 return sensitivity_non_parametric_partial_r2(
148 self._estimate,
149 self.kappa_t,
150 self.kappa_y,
151 self.frac_strength_treatment,
152 self.frac_strength_outcome,
153 self.benchmark_common_causes,
154 self.plot_estimate,
155 self.alpha_s_estimator_list,
156 self.alpha_s_estimator_param_list,
157 self.g_s_estimator_list,
158 self.g_s_estimator_param_list,
159 self.plugin_reisz,
160 )
161 elif self.simulation_method == "e-value":
162 return sensitivity_e_value(
163 self._data,
164 self._target_estimand,
(...)
168 self.plot_estimate,
169 )

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\add_unobserved_common_cause.py:737, in sensitivity_non_parametric_partial_r2(estimate, kappa_t, kappa_y, frac_strength_treatment, frac_strength_outcome, benchmark_common_causes, plot_estimate, alpha_s_estimator_list, alpha_s_estimator_param_list, g_s_estimator_list, g_s_estimator_param_list, plugin_reisz)
718 return analyzer
720 analyzer = NonParametricSensitivityAnalyzer(
721 estimator=estimate.estimator,
722 observed_common_causes=estimate.estimator._observed_common_causes,
(...)
735 plugin_reisz=plugin_reisz,
736 )
--> 737 analyzer.check_sensitivity(plot=plot_estimate)
738 return analyzer

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\non_parametric_sensitivity_analyzer.py:128, in NonParametricSensitivityAnalyzer.check_sensitivity(self, plot)
126 for train, test in split_indices:
127 reisz_fn_fit = reisz_function.fit(X[train])
--> 128 self.alpha_s[test] = reisz_fn_fit.predict(X[test])
129 if self.plugin_reisz:
130 propensities[test] = reisz_fn_fit.propensity(X[test])

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\reisz.py:114, in PluginReisz.predict(self, X)
112 t = X[:, 0]
113 preds = self.propmodel.predict_proba(W)
--> 114 weights = [1 / preds[i, t[i].astype(int)] for i in range(preds.shape[0])]
115 weights = np.where(t == 0, -1, 1) * weights
116 return weights

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\reisz.py:114, in (.0)
112 t = X[:, 0]
113 preds = self.propmodel.predict_proba(W)
--> 114 weights = [1 / preds[i, t[i].astype(int)] for i in range(preds.shape[0])]
115 weights = np.where(t == 0, -1, 1) * weights
116 return weights

IndexError: index 15 is out of bounds for axis 1 with size 8

  1. For the plot and the robustness value of the sensitivity analysis, does it automatically adjust when the upper bound should be used instead of the lower bound? In the following situation, it makes more sense that an upper bound is used as the estimate isa negative value. However, when I specified the plot_type to be "upper_confidence_bound" within the refute_estimate method, the plot becomes empty except that if I implement the plotting explicitly outside of the method. This makes me wonder whether the robustness value shown in the printing of the method relates to whether it is lower bound or upper bound, because the value is 0.
res_unobserved_nonparam_benchmark = cf_model.refute_estimate(
                                    method_name="add_unobserved_common_cause",
                                    simulation_method = 'non-parametric-partial-R2',
                                    benchmark_common_causes = ["GDP_Millions"],
                                    effect_fraction_on_treatment = 0.2,
                                    effect_fraction_on_outcome = 0.2)

res_unobserved_nonparam_benchmark.plot(plot_type = "upper_confidence_bound")

image

  1. For "assess_support_and_overlap_overrule" method to evaluate overlap and support, I run into the error " ".
refute_experiment = assess_support_and_overlap_overrule(
    data=df_processed,
    backdoor_vars=df_processed.iloc[:, 4:].columns.tolist(),
    treatment_name="PharmacyDensity_Per100k",
    support_config=support_config,
    overlap_config=overlap_config,
)

print(refute_experiment)

Error:
ValueError Traceback (most recent call last)
Cell In[614], line 3
1 #As a preprocessing step instead of refutation
----> 3 refute_experiment = assess_support_and_overlap_overrule(
4 data=df_processed,
5 backdoor_vars=df_processed.iloc[:, 4:].columns.tolist(),
6 treatment_name="PharmacyDensity_Per100k",
7 support_config=support_config,
8 overlap_config=overlap_config,
9 )
11 # Observe how everyone is in the overlap set, so we do not learn any overlap rules
12 print(refute_experiment)

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\assess_overlap.py:125, in assess_support_and_overlap_overrule(data, backdoor_vars, treatment_name, cat_feats, overlap_config, support_config, overlap_eps, support_only, overlap_only, verbose)
93 """
94 Learn support and overlap rules using OverRule.
95
(...)
112 :param: verbose: bool: Enable verbose logging of optimization output, defaults to False
113 """
114 analyzer = OverruleAnalyzer(
115 backdoor_vars=backdoor_vars,
116 treatment_name=treatment_name,
(...)
123 verbose=verbose,
124 )
--> 125 analyzer.fit(data)
126 return analyzer

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\assess_overlap_overrule.py:209, in OverruleAnalyzer.fit(self, data)
206 self.overlap_indicator = supp
207 else:
208 # Assess overlap using propensity scores with cross-fitting
--> 209 self.raw_overlap_set = self._assess_propensity_overlap(X_supp, t_supp)
210 # Check if all supported units are considered to be in the overlap set
211 if np.all(self.raw_overlap_set):

File ~\AppData\Local\anaconda3\Lib\site-packages\dowhy\causal_refuters\assess_overlap_overrule.py:231, in OverruleAnalyzer._assess_propensity_overlap(self, X, t)
230 def _assess_propensity_overlap(self, X, t):
--> 231 prop_scores = cross_val_predict(self.prop_estimator, X, t.values.ravel(), method="predict_proba", cv=2)
232 prop_scores = prop_scores[:, 1] # Probability of treatment
233 overlap_set = np.logical_and(prop_scores < 1 - self.overlap_eps, prop_scores > self.overlap_eps).astype(int)

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_validation.py:1036, in cross_val_predict(estimator, X, y, groups, cv, n_jobs, verbose, fit_params, pre_dispatch, method)
1033 # We clone the estimator to make sure that all the folds are
1034 # independent, and that it is pickle-able.
1035 parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)
-> 1036 predictions = parallel(
1037 delayed(_fit_and_predict)(
1038 clone(estimator), X, y, train, test, verbose, fit_params, method
1039 )
1040 for train, test in splits
1041 )
1043 inv_test_indices = np.empty(len(test_indices), dtype=int)
1044 inv_test_indices[test_indices] = np.arange(len(test_indices))

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\utils\parallel.py:65, in Parallel.call(self, iterable)
60 config = get_config()
61 iterable_with_config = (
62 (_with_config(delayed_func, config), args, kwargs)
63 for delayed_func, args, kwargs in iterable
64 )
---> 65 return super().call(iterable_with_config)

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:1085, in Parallel.call(self, iterable)
1076 try:
1077 # Only set self._iterating to True if at least a batch
1078 # was dispatched. In particular this covers the edge
(...)
1082 # was very quick and its callback already dispatched all the
1083 # remaining jobs.
1084 self._iterating = False
-> 1085 if self.dispatch_one_batch(iterator):
1086 self._iterating = self._original_iterator is not None
1088 while self.dispatch_one_batch(iterator):

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:901, in Parallel.dispatch_one_batch(self, iterator)
899 return False
900 else:
--> 901 self._dispatch(tasks)
902 return True

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:819, in Parallel._dispatch(self, batch)
817 with self._lock:
818 job_idx = len(self._jobs)
--> 819 job = self._backend.apply_async(batch, callback=cb)
820 # A job can complete so quickly than its callback is
821 # called before we get here, causing self._jobs to
822 # grow. To ensure correct results ordering, .insert is
823 # used (rather than .append) in the following line
824 self._jobs.insert(job_idx, job)

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib_parallel_backends.py:208, in SequentialBackend.apply_async(self, func, callback)
206 def apply_async(self, func, callback=None):
207 """Schedule a func to be run"""
--> 208 result = ImmediateResult(func)
209 if callback:
210 callback(result)

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib_parallel_backends.py:597, in ImmediateResult.init(self, batch)
594 def init(self, batch):
595 # Don't delay the application, to avoid keeping the input
596 # arguments in memory
--> 597 self.results = batch()

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:288, in BatchedCalls.call(self)
284 def call(self):
285 # Set the default nested backend to self._backend but do not set the
286 # change the default number of processes to -1
287 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 288 return [func(*args, **kwargs)
289 for func, args, kwargs in self.items]

File ~\AppData\Local\anaconda3\Lib\site-packages\joblib\parallel.py:288, in (.0)
284 def call(self):
285 # Set the default nested backend to self._backend but do not set the
286 # change the default number of processes to -1
287 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 288 return [func(*args, **kwargs)
289 for func, args, kwargs in self.items]

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\utils\parallel.py:127, in _FuncWrapper.call(self, *args, **kwargs)
125 config = {}
126 with config_context(**config):
--> 127 return self.function(*args, **kwargs)

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_validation.py:1118, in _fit_and_predict(estimator, X, y, train, test, verbose, fit_params, method)
1116 estimator.fit(X_train, **fit_params)
1117 else:
-> 1118 estimator.fit(X_train, y_train, **fit_params)
1119 func = getattr(estimator, method)
1120 predictions = func(X_test)

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\base.py:1151, in _fit_context..decorator..wrapper(estimator, *args, **kwargs)
1144 estimator._validate_params()
1146 with config_context(
1147 skip_parameter_validation=(
1148 prefer_skip_nested_validation or global_skip_validation
1149 )
1150 ):
-> 1151 return fit_method(estimator, *args, **kwargs)

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_search.py:898, in BaseSearchCV.fit(self, X, y, groups, **fit_params)
892 results = self._format_results(
893 all_candidate_params, n_splits, all_out, all_more_results
894 )
896 return results
--> 898 self._run_search(evaluate_candidates)
900 # multimetric is determined here because in the case of a callable
901 # self.scoring the return type is only known after calling
902 first_test_score = all_out[0]["test_scores"]

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_search.py:1419, in GridSearchCV._run_search(self, evaluate_candidates)
1417 def _run_search(self, evaluate_candidates):
1418 """Search all candidates in param_grid"""
-> 1419 evaluate_candidates(ParameterGrid(self.param_grid))

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_search.py:857, in BaseSearchCV.fit..evaluate_candidates(candidate_params, cv, more_results)
837 if self.verbose > 0:
838 print(
839 "Fitting {0} folds for each of {1} candidates,"
840 " totalling {2} fits".format(
841 n_splits, n_candidates, n_candidates * n_splits
842 )
843 )
845 out = parallel(
846 delayed(_fit_and_score)(
847 clone(base_estimator),
848 X,
849 y,
850 train=train,
851 test=test,
852 parameters=parameters,
853 split_progress=(split_idx, n_splits),
854 candidate_progress=(cand_idx, n_candidates),
855 **fit_and_score_kwargs,
856 )
--> 857 for (cand_idx, parameters), (split_idx, (train, test)) in product(
858 enumerate(candidate_params), enumerate(cv.split(X, y, groups))
859 )
860 )
862 if len(out) < 1:
863 raise ValueError(
864 "No fits were performed. "
865 "Was the CV iterator empty? "
866 "Were there no candidates?"
867 )

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_split.py:377, in _BaseKFold.split(self, X, y, groups)
369 if self.n_splits > n_samples:
370 raise ValueError(
371 (
372 "Cannot have number of splits n_splits={0} greater"
373 " than the number of samples: n_samples={1}."
374 ).format(self.n_splits, n_samples)
375 )
--> 377 for train, test in super().split(X, y, groups):
378 yield train, test

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_split.py:108, in BaseCrossValidator.split(self, X, y, groups)
106 X, y, groups = indexable(X, y, groups)
107 indices = np.arange(_num_samples(X))
--> 108 for test_index in self._iter_test_masks(X, y, groups):
109 train_index = indices[np.logical_not(test_index)]
110 test_index = indices[test_index]

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_split.py:758, in StratifiedKFold._iter_test_masks(self, X, y, groups)
757 def _iter_test_masks(self, X, y=None, groups=None):
--> 758 test_folds = self._make_test_folds(X, y)
759 for i in range(self.n_splits):
760 yield test_folds == i

File ~\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection_split.py:720, in StratifiedKFold._make_test_folds(self, X, y)
718 min_groups = np.min(y_counts)
719 if np.all(self.n_splits > y_counts):
--> 720 raise ValueError(
721 "n_splits=%d cannot be greater than the"
722 " number of members in each class." % (self.n_splits)
723 )
724 if self.n_splits > min_groups:
725 warnings.warn(
726 "The least populated class in y has only %d"
727 " members, which is less than n_splits=%d."
728 % (min_groups, self.n_splits),
729 UserWarning,
730 )

ValueError: n_splits=5 cannot be greater than the number of members in each class.

@benTC74 benTC74 added the question Further information is requested label May 15, 2024
Copy link

This issue is stale because it has been open for 14 days with no activity.

@github-actions github-actions bot added the stale label May 30, 2024
Copy link

github-actions bot commented Jun 7, 2024

This issue was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested stale
Projects
None yet
Development

No branches or pull requests

1 participant