Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ClassifierChain to multi-output problems #9245

Open
jnothman opened this issue Jun 29, 2017 · 11 comments · May be fixed by #21942
Open

Extend ClassifierChain to multi-output problems #9245

jnothman opened this issue Jun 29, 2017 · 11 comments · May be fixed by #21942
Labels
Enhancement Moderate Anything that requires some knowledge of conventions and best practices module:multioutput

Comments

@jnothman
Copy link
Member

ClassifierChain currently supports multilabel classification. It should be straightforward to extend it to multi-output (as long as it only chains on predict) except for implementing ClassifierChain.{predict_proba,decision_function} which will take some care.

@jnothman jnothman added Enhancement Moderate Anything that requires some knowledge of conventions and best practices Need Contributor labels Jun 29, 2017
@siebenHeaven
Copy link

@jnothman I am new to contributing this project. Would like to start here. The way i understand this is that currently, ClassifierChain predicts that a given instance belongs to a class, and goes on (by passing this prediction to next estimator in the chain) to check for other classes.So would extending it to support multi-output include implementing a new parameter which, if set, tells it to not check for further classes?

@jnothman
Copy link
Member Author

jnothman commented Jul 2, 2017 via email

@siebenHeaven
Copy link

Ok! So right now, the base estimator that is passed to the ClassifierChain is a binary classifier. So inorder to make multi-class classifiers work, what changes will be needed? Would it need a separate class or changes would be needed in the current ClassifierChain class?
Also, what dataset can this multi-outptut multiclass be tested on?
(sorry for late reply,had some work at university :) )

@jnothman
Copy link
Member Author

No, use the same class. Basically, you just need to ensure that the predict_proba and decision_function output conform to what you get from a multi-output DecisionTreeClassifier. It doesn't look like we have any standard datasets here. You could take a look at sklearn/tree/tests/test_tree.py:test_multioutput.

@Johayon
Copy link
Contributor

Johayon commented Feb 9, 2018

I could take it up, if no one is currently working on it.

@jnothman
Copy link
Member Author

jnothman commented Feb 20, 2018 via email

@agamemnonc
Copy link
Contributor

agamemnonc commented Oct 31, 2018

@jnothman just to clarify, by multi-output you mean multi-class, right? Because I feel that the convention that is currently used is that multi-output == multi-label.

Anyway, I confirm that this is currently causing issues, especially given that the outputs of MultioutputClassifier and ClassifierChain predict_proba are not compatible; the former returns a list of length n_outputs where each element has shape = (n_samples, n_classes), whereas the latter returns an array of shape = (n_samples, n_outputs).

For multi-output binary problems, this is OK, as it is assumed that the method returns the probability of the positive class for each output. However, in the multioutput-multiclass case, what predict_proba returns makes no sense, since there is no way to know which class the probabilities correspond to. And there are no warning messages to let the user know that the results may not make sense. Therefore, I would personally classify this issue as a bug.

@Johayon if you have decided not to work on this any more, pls let me know and I would be happy to take it up.

@jnothman
Copy link
Member Author

jnothman commented Nov 3, 2018 via email

@henrif94
Copy link

I am still getting a No Loop Matching Error for .predict_proba().
Any Updates on this issue?

@agamemnonc
Copy link
Contributor

There is a PR under-way to address this issue (#14654).

It is currently in stall mode due to conflicts with the master etc, but I hope to be able to address this in the next few weeks.

@lucyleeow
Copy link
Member

lucyleeow commented Mar 28, 2024

Looking deeper into this, Y is currently of shape (n_samples, n_classes) (multi-label binarized). Other estimators that support multi-label, multi-output have Y of shape (n_samples, n_outputs). I could not find an example where Y is of shape list of (n_samples, n_classes).

If we want to implement this, do we ask Y to be:

  1. (n_samples, n_outputs) (not backwards compatible)
  2. List of (n_samples, n_classes) - ? complex?
  3. Support both to be backwards compatible - difficult to maintain?

cc @ogrisel since you reviewed #21942

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Enhancement Moderate Anything that requires some knowledge of conventions and best practices module:multioutput
Projects
None yet
9 participants