In [1]:
import pandas as pd

X_train = pd.read_pickle('wine_X_train.pkl')
X_test = pd.read_pickle('wine_X_test.pkl')
y_train = pd.read_pickle('wine_y_train.pkl')
y_test = pd.read_pickle('wine_y_test.pkl')

### Searched model + VotingClassifier

In [2]:
import numpy as np
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from xgboost import XGBClassifier


def random_search(cv=StratifiedKFold(2), n_iter=1, scoring=None, refit_metric=True):
    if scoring is None:
        scoring = "accuracy"
    search = RandomizedSearchCV(
        estimator=VotingClassifier(
            estimators=[
                ("xgboost", XGBClassifier()),
                ("rfc", RandomForestClassifier()),
            ],
            voting="soft",
        ),
        param_distributions={
            "xgboost__n_estimators": np.arange(100, 1000, 100),
            "xgboost__learning_rate": np.arange(0.01, 0.1, 0.01),
            "xgboost__max_depth": np.arange(3, 10, 1),
            "rfc__n_estimators": np.arange(100, 1000, 100),
            "rfc__max_depth": np.arange(3, 10, 1),
        },
        cv=cv,
        n_iter=n_iter,
        verbose=0,
        scoring=scoring,
        refit=refit_metric,
    )
    return search


random_search = random_search()

In [3]:
random_search

In [4]:
random_search.fit(X_train, y_train)
# Get the best estimator
best_estimator = random_search.best_estimator_

### Performance report and analysis

In [5]:
from sklearn.metrics import classification_report
import joblib

# Save the model
# joblib.dump(best_estimator, "best_model.joblib")
# Load the model
loaded_model = joblib.load("best_model.joblib")

# Use the loaded model for predictions
y_pred = pd.Series(loaded_model.predict(X_test), index=y_test.index)

# Generate confusion matrix
report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           3       0.00      0.00      0.00         3
           4       0.00      0.00      0.00         9
           5       0.69      0.72      0.71       138
           6       0.56      0.68      0.62       119
           7       0.79      0.49      0.61        47
           8       0.00      0.00      0.00         4

    accuracy                           0.64       320
   macro avg       0.34      0.32      0.32       320
weighted avg       0.63      0.64      0.62       320



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
