Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MLP.make_baseline() return type #40

Closed
wants to merge 3 commits into from

Conversation

jpgard
Copy link

@jpgard jpgard commented Nov 16, 2022

Return object of type cls, not MLP, in MLP.make_baseline(). Otherwise, child classes inheriting from MLP constructed using the .make_baseline() method always have type MLP (instead of the type of the child class).

Return object of type cls, not MLP, in MLP.make_baseline(). Otherwise, child classes inheriting from MLP constructed using the .make_baseline() method always have type MLP (instead of the type of the child class).
Fix MLP.make_baseline() return type
Return cls, not FTTransformer.
@jpgard
Copy link
Author

jpgard commented Nov 16, 2022

Note: also makes a similar change for FTTransformer.

@Yura52
Copy link
Collaborator

Yura52 commented Nov 16, 2022

Thank you for the proposal! In fact, this limitation was introduced intentionally, because I was not sure if it was a good thing to allow using make_baseline with inherited models. I wonder if you faced the limitation in practice or is it just something that you noticed while browsing the codebase?

@jpgard
Copy link
Author

jpgard commented Nov 16, 2022

This is something I encountered in practice, yes. For our research use cases when benchmarking tabular data algorithms, it is common for us to need a scikit-learn-style interface (i.e. methods such as fit, predict_proba, etc) for all models, which is why I needed to subclass these.

What made it particularly confusing is that the ResNet class does not have this limitation, only MLP and FTTransformer, so the behavior didn't seem consistent . Of course, if there is an intentional design choice behind this, feel free to ignore :), but I would be curious to know.

@Yura52
Copy link
Collaborator

Yura52 commented Nov 17, 2022

In your particular case I would create a separate class like Trainer that would implement the scikit-learn API and take an instance of torch.nn.Module as an argument in the constructor (i.e. in the spirit of skorch). As for ResNet, this is just a "bug" and was overlooked, thanks for noticing :)

@Yura52
Copy link
Collaborator

Yura52 commented Nov 17, 2022

Closing the PR then?

@jpgard
Copy link
Author

jpgard commented Nov 17, 2022

I suppose so, thanks for the clarification :)

@jpgard jpgard closed this Nov 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants