Skip to content

Commit

Permalink
Merge pull request #3710 from pycaret/fix_groups
Browse files Browse the repository at this point in the history
fix groups bug, cuml version bug and missing installed libraries bug
  • Loading branch information
tvdboom committed Aug 30, 2023
2 parents dda9365 + 79cedce commit ea88b2d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
11 changes: 8 additions & 3 deletions pycaret/internal/preprocess/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,19 +559,24 @@ def get_prepare_estimator_for_categoricals_type(estimator, estimators_dict):
return estimator, fit_params
if isinstance(estimator, estimators_dict["lightgbm"].class_def):
return "fit_params_categorical_feature"
elif isinstance(estimator, estimators_dict["catboost"].class_def):
elif "catboost" in estimators_dict and isinstance(
estimator, estimators_dict["catboost"].class_def
):
return "params_cat_features"
elif "xgboost" in estimators_dict and isinstance(
estimator, estimators_dict["xgboost"].class_def
):
return "ordinal"
elif isinstance(
estimator,
(
estimators_dict["xgboost"].class_def,
estimators_dict["rf"].class_def,
estimators_dict["et"].class_def,
estimators_dict["dt"].class_def,
estimators_dict["ada"].class_def,
estimators_dict.get(
"gbr",
estimators_dict.get("gbc", estimators_dict["xgboost"]),
estimators_dict.get("gbc", estimators_dict["rf"]),
).class_def,
),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ def _create_model_with_cv(
if hasattr(cv, "n_splits"):
fold = cv.n_splits
elif hasattr(cv, "get_n_splits"):
fold = cv.get_n_splits()
fold = cv.get_n_splits(groups=groups)
else:
raise ValueError(
"The cross validation class should implement a n_splits "
Expand Down
10 changes: 5 additions & 5 deletions pycaret/internal/pycaret_experiment/tabular_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,17 @@ def _initialize_setup(
self.logger.info(f"cuml=={cuml_version}")

try:
import cuml.common.memory_utils
import cuml.internals.memory_utils

cuml.common.memory_utils.set_global_output_type("numpy")
cuml.internals.memory_utils.set_global_output_type("numpy")
except Exception:
self.logger.exception("Couldn't set cuML global output type")

if cuml_version is None or not version.parse(cuml_version) >= version.parse(
"22.10"
"23.08"
):
message = """cuML is outdated or not found. Required version is >=22.10.
Please visit https://rapids.ai/ for installation instructions."""
message = """cuML is outdated or not found. Required version is >=23.08.
Please visit https://rapids.ai/install for installation instructions."""
if use_gpu == "force":
raise ImportError(message)
else:
Expand Down

0 comments on commit ea88b2d

Please sign in to comment.