Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More helpful error message when predicting with unfitted net #488

Merged
merged 2 commits into from Aug 12, 2019

Conversation

@BenjaminBossan
Copy link
Collaborator

commented Jun 23, 2019

Fixes #487

  • Use check_is_fitted similar to what is used in sklearn, except that
    a skorch.exceptions.NotInitializedError is raised.
  • Only assumes the presence of 'module_' so that nets remain hackable,
    except where a different attribute is specifically required.
  • Re-wrote existing checks to now use check_is_fitted.
BenjaminBossan
More helpful error message when predicting with unfitted net
* Use check_is_fitted similar to what is used in sklearn, except that
  a skorch.exceptions.NotInitializedError is raised.
* Only assumes the presence of 'module_' so that nets remain hackable,
  except where a different attribute is specifically required.
* Re-wrote existing checks to now use check_is_fitted.

@BenjaminBossan BenjaminBossan requested review from thomasjpfan and ottonemo Jun 23, 2019

@BenjaminBossan BenjaminBossan self-assigned this Jun 23, 2019

@@ -483,6 +484,31 @@ def get_map_location(target_device, fallback_device='cpu'):
return map_location


def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jun 24, 2019

Member

We can also use the sklearn version directly:

from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted

def check_is_fitted(...):
    try:
        sk_check_is_fitted(...)
    except NotFittedError as e:
        raise NotInitializedError(str(e))

This comment has been minimized.

Copy link
@BenjaminBossan

BenjaminBossan Jun 24, 2019

Author Collaborator

This doesn't quite work because the error message now says "...is not initialized yet...". Or do you believe this is not enough to justify not re-using sklearn's check_is_fitted?

Scratch that, I changed it since it's possible to pass the error message as an argument.

@@ -852,6 +854,27 @@ def fit(self, X, y=None, **fit_params):
self.partial_fit(X, y, **fit_params)
return self

def check_is_fitted(self, attributes=None, *args, **kwargs):

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jun 24, 2019

Member

Is check_is_fitted public to signal that it can be overridden by subclasses?

This comment has been minimized.

Copy link
@BenjaminBossan

BenjaminBossan Jun 24, 2019

Author Collaborator

Yes, the idea was to make it easy for a user to change the attributes that they consider to be necessary for a net to count as fitted.

BenjaminBossan
skorch chech_is_fitted not calls sklearn check_is_fitted
Only changes error message and exception type.

@ottonemo ottonemo merged commit 7d275df into master Aug 12, 2019

2 checks passed

Travis CI - Branch Build Passed
Details
Travis CI - Pull Request Build Passed
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.