Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using make_scorer() for a GridSearchCV scoring parameter in a clustering task #17631

Closed
imanirajian opened this issue Jun 18, 2020 · 9 comments
Closed

Comments

@imanirajian
Copy link

imanirajian commented Jun 18, 2020

* Workflow:

1- Consider make_scorer() below for a clustering metric:

from sklearn.metrics import homogeneity_score, make_scorer

def score_func(y_true, y_pred, **kwargs):
    return homogeneity_score(y_true, y_pred)
scorer = make_scorer(score_func)

2- Consider the simple method optics():

# "optics" algorithm for clustering
# ---
def optics(data, labels):
    # data: A dataframe with two columns (x, y)
    preds = None    
    base_opt = OPTICS()
    grid_search_params = {"min_samples":np.arange(10),
                          "metric":["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"],
                          "cluster_method":["xi", "dbscan"],
                          "algorithm":["auto", "ball_tree", "kd_tree", "brute"]}
    
    grid_search_cv = GridSearchCV(estimator=base_opt,
                                  param_grid=grid_search_params,
                                  scoring=scorer)
    
    grid_search_cv.fit(data)    
    opt = grid_search_cv.best_estimator_
    opt.fit(data)
    preds = opt.labels_
    
    # return clusters corresponding to (x, y) pairs according to "optics" algorithm
    return preds

Running the optics() led to this error:
TypeError: _score() missing 1 required positional argument: 'y_true'

Even by using grid_search_cv.fit(data, labels) instead of grid_search_cv.fit(data), another exception rised:
AttributeError: 'OPTICS' object has no attribute 'predict'


I thinks we cannot use make_scorer() with a GridSearchCV for a clustering task.


* Proposed solution:

The fit() method of GridSearchCV automatically handles the type of the estimator which passed to its constructor, for example, for a clustering estimator it considers labels_ instead of predict() for scoring.

@amueller
Copy link
Member

There's maybe 2 or 3 issues here, let me try and unpack:

  • You can not usually use homogeneity_score for evaluating clustering usually because it requires ground truth, which you don't usually have for clustering (this is the missing y_true issue).
  • If you actually have ground truth, current GridSearchCV doesn't really allow evaluating on the training set, as it uses cross-validation. You could probably hack the CV splitter to use the full data both as training and test set to sort-of get around this, but it's a bit ugly.
  • By default make_scorer uses predict, which OPTICS doesn't have. So indeed that could be seen as a limitation of make_scorer but it's not really the core issue. You could provide a custom callable that calls fit_predict.

(meeting now I'll update with related issues afterwards)

@imanirajian
Copy link
Author

@amueller

  • I've tried all clustering metrics from sklearn.metrics.
  • It must be worked for either case, with/without ground truth.
  • This is a limitation, and either make_scorer() or GridSearchCV.fit() can handle this scenario.

@adrinjalali
Copy link
Member

As @amueller mentioned, having the scorer call fit_predict is probably not what you want to do, since it'd be ignoring your training set. So an algorithm such as OPTICS may not be a good example for this usecase.

Consider this code:

# %%
from sklearn.metrics import homogeneity_score, make_scorer

def score_func(y_true, y_pred, **kwargs):
    return homogeneity_score(y_true, y_pred)
scorer = make_scorer(score_func)

# %%
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.cluster import OPTICS
from sklearn.datasets import make_classification

X, y = make_classification()

base_opt = OPTICS()
grid_search_params = {"min_samples":np.arange(10),
                        "metric":["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"],
                        "cluster_method":["xi", "dbscan"],
                        "algorithm":["auto", "ball_tree", "kd_tree", "brute"]}

grid_search_cv = GridSearchCV(estimator=base_opt,
                                param_grid=grid_search_params,
                                scoring=scorer)

grid_search_cv.fit(X, y)

It'll raise:

AttributeError: 'OPTICS' object has no attribute 'predict'

which is very sensible, since predict is not really defined for OPTICS. Now if you replace it with KMeans:

base_opt = KMeans()
grid_search_params = {"n_clusters":np.arange(10)}

grid_search_cv = GridSearchCV(estimator=base_opt,
                                param_grid=grid_search_params,
                                scoring=scorer)

grid_search_cv.fit(X, y)

it works fine. Since predict is well-defined for kmeans.

Now in case we don't have the labels, we could have something like:

import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.cluster import OPTICS
from sklearn.datasets import make_classification
from sklearn.metrics import silhouette_score, make_scorer
scorer = make_scorer(silhouette_score)
X, y = make_classification()

base_opt = OPTICS()
grid_search_params = {"min_samples":np.arange(10),
                        "metric":["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"],
                        "cluster_method":["xi", "dbscan"],
                        "algorithm":["auto", "ball_tree", "kd_tree", "brute"]}

grid_search_cv = GridSearchCV(estimator=base_opt,
                                param_grid=grid_search_params,
                                scoring=scorer)

grid_search_cv.fit(X)

This raises:

TypeError: _score() missing 1 required positional argument: 'y_true'

I think we should either support this case, or raise a more informative error. WDYT @amueller ?

@imanirajian
Copy link
Author

@adrinjalali @amueller
Same issue holds true for DBSCAN.
I think GridSearchCV() should support clustering estimators as well.

So far, I've written codes below:

def score_func(X, y_true, y_pred, **kwargs):
    s1 = homogeneity_score(y_true, y_pred)
    s2 = 0
    if len(set(y_pred))>1:    
        s2 = silhouette_score(X, y_pred)
    score = s1 + s2
    return score
# w.r.t. score_func(), greater is better or not
def compare_scores(s, t):
    return True if t>s else False
# w.r.t. compare_scores() what is the initialization score
def init_score():
    return -np.inf

# ----------------------------------

def custom_grid_search_cv(X, y_true, estimator, params, n_splits, shuffle, scoring):
    
    data = pd.concat([X, y_true], axis=1)
    
    best_score = init_score()
    best_params = None
    
    # Grid search
    for i, values in enumerate(itertools.product(*(params[param] for param in params))):
        try:
            p = {} # Each possible combination of parameters
            for i, param in enumerate(params):
                p[param] = values[i]
            estimator.set_params(**p)                    

            current_score = 0
            # KFold CV
            k_fold = KFold(n_splits, shuffle)
            for k, (train_indices, test_indices) in enumerate(k_fold.split(data)):
                data_train = data.iloc[train_indices]
                data_test = data.iloc[test_indices]
                data_train_X = data_train.loc[:, ["x", "y"]].values
                data_train_y_true = data_train.loc[:, ["cluster"]].values.squeeze()
                data_test_X = data_test.loc[:, ["x", "y"]].values
                data_test_y_true = data_test.loc[:, ["cluster"]].values.squeeze()
                y_pred_test = estimator.fit_predict(data_test_X)
                fold_score = scoring(data_test_X, data_test_y_true, y_pred_test)
                current_score += fold_score
            current_score /= n_splits

            if compare_scores(best_score, current_score):
                best_params = p
                best_score = current_score

        except:
            #print(p, "is not valid parameters.")
            pass
    
    estimator.set_params(**best_params)
    
    return estimator

# ----------------------------------

# "optics" algorithm for clustering
# ---
def optics(data, labels):
    # data: A dataframe with two columns (x, y)
    
    preds = None    
    base_opt = OPTICS()
    grid_search_params = {"min_samples":list(np.arange(1, 10))+[0.1, 0.2, 0.4, 0.5, 0.6, 0.8],
                          "metric":["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"],
                          "cluster_method":["xi", "dbscan"],
                          "algorithm":["auto", "ball_tree", "kd_tree", "brute"]}
    
    opt = custom_grid_search_cv(X=data,
                                y_true=labels,
                                estimator=base_opt,
                                params=grid_search_params,
                                n_splits=3,
                                shuffle=True,
                                scoring=score_func)
    #opt = base_opt # uncomment this if you don't want the grid search
    opt.fit(data)
    preds = opt.labels_
    
    # return clusters corresponding to (x, y) pairs according to "optics" algorithm
    return preds

Thanks.

@amueller
Copy link
Member

amueller commented Jul 11, 2020

I think GridSearchCV() should support clustering estimators as well.

You could do what you're doing in your code with GridSearchCV by using a custom splitter and custom scorer. But tbh I think that's a very strange thing to do. What is the motivation of using cross-validation in this setting?
And the way you define training and test score are confusing, if not wrong. There is no notion of training and test set in your code.

The main question is "What do you want to do" and I don't see an answer to that in your post. Saying "GridSearchCV should support clustering estimators as well." is not really a meaningful statement unless you say what you'd expect it to do.

@amueller
Copy link
Member

TypeError: _score() missing 1 required positional argument: 'y_true'

We can raise a better error message there.

AttributeError: 'OPTICS' object has no attribute 'predict'

I think that's an appropriate error message.

Btw, there is a lot of discussion here:
#4301

And @jnothman has thought about this pretty in-depth, I think.

@imanirajian
Copy link
Author

imanirajian commented Jul 12, 2020

@amueller

But tbh I think that's a very strange thing to do. What is the motivation of using cross-validation in this setting?

Motivation: Search in the parameter space to find the best parameters choice for optics (or dbscan) model.

...what you'd expect it to do.

What do you want to do

Goal: Finding the best parameters (w.r.t. the parameters grid grid_search_params) for a clustering estimator, with or without labels (in my case I have labels).

There is no notion of training and test set in your code

...
for k, (train_indices, test_indices) in enumerate(k_fold.split(data)):
    data_train = data.iloc[train_indices]
    data_test = data.iloc[test_indices]
...

And the way you define training and test score are confusing

Consider this:

def score_func(y_true, y_pred, **kwargs):
    return homogeneity_score(y_true, y_pred)

my custom_grid_search_cv logic >
~ For each possible choice of parameters from the parameters grid space, say p:
~~ Apply p to the estimator.
~~ For i=1...K, I've used i-th fold (current test set) of K-folds (in a K-fold splitting) to fit the estimator, then get the labels of the estimator (predict) and finally compute a clustering metric to judge the model prediction strength for the i-th fold.
~~ Average the metrics for all folds yields p score.
~~ If current p score is better than the score of last choice of it, we store current p, say best_params.
~ Apply best_params to the estimator and return that estimator.

@shejmori
Copy link

shejmori commented Mar 14, 2024

Take a look to this automation-machine-learning.ipynb

This is the exact part that may help you:

model.fit(X=X_train)

best_score     = 0
best_params    = {}
best_estimator = None

for params in model.cv_results_['params']:
    current_model = model.estimator.set_params(**params)
    current_model.fit(X_train)
    
    original = current_model.named_steps['preprocessor'].transform(X_train)
    
    if hasattr(model, 'fit_predict'):
        predictions = current_model.fit_predict(X_train)
    elif hasattr(model, 'predict'):
        predictions = current_model.predict(X_train)
    else:
        predictions = current_model.steps[-1][-1].labels_

    try:
        score = silhouette_score(original, predictions)
        if score > best_score:
            best_score     = score
            best_params    = params
            best_estimator = current_model
            extra_msg      = "Labels"
            extra_val      = best_estimator.steps[-1][-1].labels_
    except Exception as e:
        pass

@adrinjalali
Copy link
Member

I think we can close this since there doesn't seem to be a way for us to meaningfully support this, and users can rather easily write a loop to go over and find the best estimator (rather than a train/test split kindof thing which doesn't make sense in the case of many clustering algorithms)

@adrinjalali adrinjalali closed this as not planned Won't fix, can't repro, duplicate, stale Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants