Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR][Train] Multilabel classification with SklearnTrainer #32732

Closed
srimantacse opened this issue Feb 22, 2023 · 5 comments · Fixed by #42814
Closed

[AIR][Train] Multilabel classification with SklearnTrainer #32732

srimantacse opened this issue Feb 22, 2023 · 5 comments · Fixed by #42814
Assignees
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical train Ray Train Related Issue

Comments

@srimantacse
Copy link

Description

Here in ray 2.x, we have the sklearnTrainer support.
Q1: does it support multilabel classification approach like OneVsRestClassifier?
SklearnTrainer class have one argument in the constructor like label_column. Does it take list of labels there? I tried but it is not working.

Use case

No response

@srimantacse srimantacse added enhancement Request for new feature and/or capability triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 22, 2023
@justinvyu justinvyu changed the title Multilabel classification [AIR][Train] Multilabel classification with SklearnTrainer Feb 22, 2023
@justinvyu
Copy link
Contributor

Yes, it does support multilabel classification. Here's an example you can build off of:

import ray

from ray.train.sklearn import SklearnTrainer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.ensemble import RandomForestClassifier

train_dataset = ray.data.from_items([{"x": x, "y": x % 3} for x in range(32)])
trainer = SklearnTrainer(
    estimator=OneVsRestClassifier(RandomForestClassifier()),
    datasets={"train": train_dataset},
    label_column="y",
    scaling_config=ray.air.config.ScalingConfig(trainer_resources={"CPU": 4}),
)
result = trainer.fit()

By the way, what was the error you were encountering?

@justinvyu justinvyu added question Just a question :) train Ray Train Related Issue air and removed enhancement Request for new feature and/or capability triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 22, 2023
@srimantacse
Copy link
Author

srimantacse commented Feb 23, 2023

@justinvyu thanks for the clarification.
However, I want multilabel, not multiclass.

The dataset will look like below.

================================
feature 1, feature 2, feature 3, feature 4, target 1, target 2, target 3
0,1,1,0,0,0,1
0,1,1,0,1,0,1
=================================

I merged the target column as
[0, 0, 1]
[1, 0, 1]

I used similar approach,
The code is like below

import ray
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from ray.train.sklearn import SklearnTrainer

trainer = SklearnTrainer(
        estimator=OneVsRestClassifier(LinearSVC(), n_jobs=2),
        label_column="target",
        datasets={"train": train_dataset, "valid": valid_dataset},
        cv=5,
        scaling_config=ScalingConfig(trainer_resources={"CPU": 2}),
)
result = trainer.fit()

The below error I am getting

---> 17 result = trainer.fit()

File ~/miniconda3/envs/tmo/lib/python3.8/site-packages/ray/train/base_trainer.py:360, in BaseTrainer.fit(self)
    358     result = result_grid[0]
    359     if result.error:
--> 360         raise result.error
    361 except TuneError as e:
    362     raise TrainingFailedError from e

**RayTaskError(ValueError): ray::_Inner.train() (pid=6739, ip=172.31.68.115, repr=SklearnTrainer)
...
13996   [0, 0, 1]
13997   [1, 0, 1]
Name: target, Length: 14000, dtype: object,)**

@srimantacse
Copy link
Author

Please check the above comment.

@justinvyu
Copy link
Contributor

justinvyu commented Feb 24, 2023

Hi @srimantacse,

Got it, I misunderstood the original question. The problem is that sklearn does not accept a pandas series in the case of multilabel classification.

The problem

Multilabel classification algos provided by sklearn (ex: OneVsRestClassifier) require the labels to be a 2D np array of the shape (batch_size, num_labels).

The SklearnTrainer will convert the labels to a pandas series, where each row is a numpy array, which is not the correct format. This is something that needs to be fixed in the SklearnTrainer.

I was able to get it working by just adding one line here:

y_train = np.vstack(y_train.to_numpy())

I'm not sure if this is the most robust solution -- would you like to open up a PR for this? I can help guide you through it! If not, I can also put this on my backlog.

Workaround

As for a current workaround - instead of using an SklearnTrainer, you can try using a regular Tune training function, that sets up the necessary environment variables inside:

def train_fn(config):
    # see https://scikit-learn.org/stable/computing/parallelism.html
    os.environ["OMP_NUM_THREADS"] = str(num_cpus)
    os.environ["MKL_NUM_THREADS"] = str(num_cpus)
    os.environ["OPENBLAS_NUM_THREADS"] = str(num_cpus)
    os.environ["BLIS_NUM_THREADS"] = str(num_cpus)

    dataset = ...
    estimator = OneVsRestClassifier(...)
    estimator.fit()

tuner = tune.Tuner(train_fn)
results = tuner.fit()

Let me know if that makes sense and works for what you're trying to do.

@srimantacse
Copy link
Author

Thanks a lot @justinvyu
The above suggestion works. I have changed manually that part and it works.

Just one question, to open a PR do I need to add any tag?

@justinvyu justinvyu added bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical and removed question Just a question :) labels Feb 27, 2023
@justinvyu justinvyu self-assigned this Apr 19, 2023
@anyscalesam anyscalesam removed the air label Oct 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical train Ray Train Related Issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants