Skip to content

Conversation

datumbox
Copy link
Contributor

Fixes #3891

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I just have nit comments, feel free to address or not. Overall this doesn't seem to be related to parametrization, this PR mainly changes a bunch of if/else into a dict lookup

# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
_model_params = {
'default_classification': {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: instead of declaring a global dict, should we instead declare one dict in each of _test_classification_model, _test_segmentation_model, etc? This would ease reading the code as relevant info would be where it's used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. This is how I had it originally but thought to move everything together. I'll rollback.

'default_segmentation': {
'num_classes': 10,
'pretrained_backbone': False,
'input_shape': (1, 3, 32, 32),
Copy link
Member

@NicolasHug NicolasHug May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like we don't need to tweak input_shape for segmentation and detection models, so I would leave it out for these. In fact I would even leave it out for the classification ones because we can just move

input_shape = (1, 3, 299, 299) if model_name == 'inception_v3' else (1, 3, 224, 224)

in _test_classification_model, and it's not a model param strictly speaking, unlike the rest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to reduce the size of some object detection models to speed them up (follow up PR), so I thought to introduce it in the dictionary always. The segmentation won't need any overwrites at this point.

Given this info, do you still recommend removing it from the params?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I meant to say is that this:

input_shape = kwargs.pop('input_shape')

is a little awkward so we might want to remove those from the dict. If we do, we wouldn't need to remove e.g. input_shape = (1, 3, 32, 32) line 105.

But again this is a nit, and perhaps I'm missing something so feel free to leave as-is

@datumbox datumbox merged commit 2ab9359 into pytorch:master May 26, 2021
@datumbox datumbox deleted the tests/improve_model_tests branch May 26, 2021 16:35
facebook-github-bot pushed a commit that referenced this pull request Jun 10, 2021
Summary:
* Improve model parameterization on tests.

* Code review changes.

Reviewed By: NicolasHug

Differential Revision: D29027295

fbshipit-source-id: f3f78bc63fb4dfc2f50d931afa3bfa60fca632fb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve unit-test parameterization for models
3 participants