Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air - preprocessor] Add BatchMapper. #23700

Merged
merged 14 commits into from
Apr 14, 2022
Merged

Conversation

xwjiang2010
Copy link
Contributor

@xwjiang2010 xwjiang2010 commented Apr 4, 2022

Why are these changes needed?

Add BatchMapper preprocessor.
Update the semantics of preprocessor.fit() to allow for multiple fit. This is to follow scikitlearn example.
Introduce FitStatus to explicitly incorporate Chain case.

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

raise PreprocessorNotFittedException(
"`fit` must be called before `transform_batch`."
)
return self._transform_batch(df)

def should_fit(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds more like can_fit(self) or fittable(self) to me.
btw why is it check_is_fitted() and not is_fitted() ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, got it. We can decide on the naming. But semantics are basically:

  1. fittable is an inherent attribute of the "type" of the preprocessors. It also implies whether a fit method in meaningful at all throughout the entire lifetime of this preprocessor.
  2. should_fit/can_fit depends on the state a preprocessor is currently in (assuming it's fittable).

Exposing check_is_fitted alone is not enough, as you can see in trainer.py - it only checks for check_is_fitted in current impl, which leads to crash in the case of non-fittable preprocessors. That's why the proposal is to add should_fit.

check_is_fitted v.s. is_fitted or can v.s. should - I don't have much preference. @clarkzinzow @matthewdeng maybe as original author of the API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with the change. I am just nit-picking the naming.
hope to get things named consistently. thanks :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for should_fit over can_fit since it's not indicating an optional operation for fittable preprocessors, it's a necessary operation: if a fittable preprocessor is not fit before calling .transform(), it will fail. An argument could even be made for needs_fit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WE can change check_is_fitted to is_fitted.

should_fit functionally is a bit strange to me, at least as a public API. In particular, I want to avoid the case where the user does something like:

if (preprocessor.should_fit()):
    preprocessor.fit()

It's not clear how to differentiate the case where the preprocessor is fitted from the case where the preprocessor was already fitted before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdeng hmm, mind elaborating a bit more?
so 3 cases:

  1. not fittable
  2. fittable and fitted
  3. fittable and not fitted yet

should_fit == case 3
if_fitted == 2 conditioned on (2 + 3)

Another way that I can see to work is just to enforce "at most once" fitting semantics internally - and caller doesn't have to call should_fit before fit. Which one do you prefer? Or are you proposing an alternative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I lack the context why you want to bundle these 2 things together in the first place.
but in my mind, the most intuitive way is to:

if calling fit() or fit_transform(), and not fittable, throw exception.
if calling fit(), and already fitted, print warning msg, and no-op.
if calling fit_transoform(), and already fitted, print warning msg, then proceed to transform().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise an exception in the fit()/fit_transfomr() when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).

Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, the biggest remaining things are:

  1. We need to modify Chain to work correctly with the new should_fit() API.
  2. I think that BatchMapper.fit() should be a no-op in order for Chain to be able to naively call .fit() and .fit_transform() on all of its preprocessors, which should be cleaner.

def _fit(self, ds: Dataset) -> Preprocessor:
for preprocessor in self.preprocessors[:-1]:
ds = preprocessor.fit_transform(ds)
self.preprocessors[-1].fit(ds)
return self
def fit_transform(self, ds: Dataset) -> Dataset:
for preprocessor in self.preprocessors:
ds = preprocessor.fit_transform(ds)
return ds
def _transform(self, ds: Dataset) -> Dataset:
for preprocessor in self.preprocessors:
ds = preprocessor.transform(ds)
return ds
def _transform_batch(self, df: DataBatchType) -> DataBatchType:
for preprocessor in self.preprocessors:
df = preprocessor.transform_batch(df)
return df
def check_is_fitted(self) -> bool:
return all(p.check_is_fitted() for p in self.preprocessors)
def __repr__(self):
return f"<Chain preprocessors={self.preprocessors}>"

python/ray/ml/preprocessors/batch_mapper.py Show resolved Hide resolved
python/ray/ml/preprocessors/batch_mapper.py Outdated Show resolved Hide resolved
Args:
dataset: Input dataset.

Returns:
Preprocessor: The fitted Preprocessor with state attributes.
"""
assert self._is_fittable, "One is expected to call `should_fit` before `fit`."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also make this a no-op when not self._is_fittable, which would be more friendly to e.g. chain preprocessors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

Looking good, the biggest remaining things are:

  1. We need to modify Chain to work correctly with the new should_fit() API.
  2. I think that BatchMapper.fit() should be a no-op in order for Chain to be able to naively call .fit() and .fit_transform() on all of its preprocessors, which should be cleaner.

def _fit(self, ds: Dataset) -> Preprocessor:
for preprocessor in self.preprocessors[:-1]:
ds = preprocessor.fit_transform(ds)
self.preprocessors[-1].fit(ds)
return self
def fit_transform(self, ds: Dataset) -> Dataset:
for preprocessor in self.preprocessors:
ds = preprocessor.fit_transform(ds)
return ds
def _transform(self, ds: Dataset) -> Dataset:
for preprocessor in self.preprocessors:
ds = preprocessor.transform(ds)
return ds
def _transform_batch(self, df: DataBatchType) -> DataBatchType:
for preprocessor in self.preprocessors:
df = preprocessor.transform_batch(df)
return df
def check_is_fitted(self) -> bool:
return all(p.check_is_fitted() for p in self.preprocessors)
def __repr__(self):
return f"<Chain preprocessors={self.preprocessors}>"

@clarkzinzow I see.
Looking at Chain preprocessor, _is_fittable is set to False. Are users supposed to overwrite this when constructing their Chain preprocessor?

@@ -60,6 +64,8 @@ def fit_transform(self, dataset: Dataset) -> Dataset:
Returns:
ray.data.Dataset: The transformed Dataset.
"""
assert self._is_fittable, "One is expected to call `should_fit` before `fit`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this error message looks weird. why don't you check:
assert self.should_fit() here as well?

python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
raise PreprocessorNotFittedException(
"`fit` must be called before `transform_batch`."
)
return self._transform_batch(df)

def should_fit(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I lack the context why you want to bundle these 2 things together in the first place.
but in my mind, the most intuitive way is to:

if calling fit() or fit_transform(), and not fittable, throw exception.
if calling fit(), and already fitted, print warning msg, and no-op.
if calling fit_transoform(), and already fitted, print warning msg, then proceed to transform().

@xwjiang2010
Copy link
Contributor Author

xwjiang2010 commented Apr 6, 2022

@gjoliver @matthewdeng @clarkzinzow
A few updates:

  • introduces a FitStatus and fit_status() to incorporate some of the nuances for chained preprocessors.
  • throws explicit exceptions
  • check_is_fitted is now private

python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
Comment on lines +33 to +34
elif fitted_count > 0:
return Preprocessor.FitStatus.PARTIALLY_FITTED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a valid state, and when would this happen? Is this just when a chain is created that contains some fitted and some unfitted preprocessors? Is that even a valid use case that we should allow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. I don't think this is necessarily a valid state to be in. But one may construct a chain preprocessor incorrectly ending up in this mixed state.
Trying to be defensive and explicit here.
I am also open to have another error to warn explicitly about this mixed state, which should not happen..

@@ -192,7 +192,7 @@ def preprocess_datasets(self) -> None:

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset and not self.preprocessor.check_is_fitted():
if train_dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there valid use cases in which an already-fitted preprocessor may be passed and we'd rather no-op than error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See @matthewdeng's preference about wanting explicit exception. :)
let's make a decision and stick to it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do that. It's just @matthewdeng has this concern to not silently no-op (even with a warning msg):

Can we raise an exception in the fit()/fit_transfomr() when already fitted instead? Logging is better than no logging, but I worry the behavior here isn't clear for users (I can see users thinking it should re-fit).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print a info or warning msg sounds good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think that Preprocessor itself should error if .fit() is called on an already fitted preprocessor, but I was less sure about whether Train as a user of Preprocessor should let these exceptions happen. I think that @matthewdeng is right, we should error here to ensure that the user doesn't think that an overwriting or incremental fit is happening.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about partially fitted chain? what's a user's options here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline.
@matthewdeng @gjoliver @clarkzinzow PTAL.

python/ray/ml/preprocessors/batch_mapper.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessors/chain.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessors/chain.py Outdated Show resolved Hide resolved
@@ -192,7 +192,7 @@ def preprocess_datasets(self) -> None:

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset and not self.preprocessor.check_is_fitted():
if train_dataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow fitted dataset, and basically no-op here.
why do we want to require unfitted dataset? what if the entire dataset is not_fitable?

Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only nits! IMO good to merge after one other ML team reviewer approval.

python/ray/ml/preprocessor.py Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
@richardliaw richardliaw added this to the Ray AIR milestone Apr 8, 2022
Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New functionality looks great, thanks for iterating on this!

python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessors/chain.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessors/batch_mapper.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/preprocessor.py Outdated Show resolved Hide resolved
python/ray/ml/tests/test_preprocessors.py Show resolved Hide resolved
Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - can you update the PR summary and add a description (including the fit status changes)?

@amogkam amogkam merged commit 06a57b2 into ray-project:master Apr 14, 2022
@xwjiang2010 xwjiang2010 deleted the add_column branch July 26, 2023 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants