Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RandomHorizontalFlip torchscriptable #2282

Merged
merged 3 commits into from Jun 4, 2020
Merged

Make RandomHorizontalFlip torchscriptable #2282

merged 3 commits into from Jun 4, 2020

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Jun 2, 2020

This is the first PR in a series towards making the torchvision transforms torchscriptable.
The majority of the work has already been done in the functional_tensor.py file, which for now hasn't been properly exposed to the user.

The objective of this PR is twofold:

  • make functional and transforms seamlessly support Tensor arguments
  • make functional and transforms be torchscriptable

In order to minimize code changes, I decided for now to move (one at a time) the Pillow transforms to a separate file, and in functional.py dispatch to either functional_tensor.py or functional_pil.py.

Another thing that might be worth discussion: the torchvision transforms now inherit from nn.Module. I think this makes things simpler wrt torchscript, and as such might be preferred over keeping them as before. In particular, it will make it possible to save the transforms into a standalone file, and having custom __str__ implementations.

cc @eellison for feedback on replacing standard classes with classes inheriting from nn.Module for a more seamless support for torchscript.

This is a re-do of #2278, but this time from a branch from pytorch/vision, so that I can try stacking PRs together

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM, how common is it the case that the transform would need to be saved in their own file ? Also, what do you mean mean exactly with the custom __str__ implementations?

I guess generally i would say they should only be nn.Modules if they would be nn.Modules without considering TorchScript. not sure how that applies here exactly.

@fmassa
Copy link
Member Author

fmassa commented Jun 3, 2020

how common is it the case that the transform would need to be saved in their own file ?

This is a good question. @vreis do you have any insights here?

I would expect that most of the models in the future might have a self.transform method, which could be a standard object type or a nn.Module. For example, in Faster R-CNN I implemented it as a nn.Module instance

class GeneralizedRCNNTransform(nn.Module):

in which case they wouldn't be needed to be saved in a standalone file.

That being said, it seemed easier to add support to torchscript by converting the transform to inherit from nn.Module. Without it, I would face errors such as

Traceback (most recent call last):
  File "test/test_transforms_tensor.py", line 33, in test_random_horizontal_flip
    scripted_fn = torch.jit.script(f)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1306, in script
    qualified_name = _qualified_name(obj)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/_jit_internal.py", line 687, in _qualified_name
    name = obj.__name__
AttributeError: 'RandomHorizontalFlip' object has no attribute '__name__'

Note that we do have __call__ methods implemented in the transforms, and at least in the past they were not supported in torchscript (but maybe this has been fixed?), while converting to a nn.Module would still keep backwards-compatibility while just needing to implement a forward method.

@vreis
Copy link

vreis commented Jun 3, 2020

This looks good to me.

how common is it the case that the transform would need to be saved in their own file ?

This is a good question. @vreis do you have any insights here?

We don't have that as a strict requirement, but right now the transforms are tied to the dataset abstraction rather than the model so I think it makes sense to generate a separate file.

What would be nice is that the final torchscript model took pure jpegs and did everything needed to it (decoding, transforms, inference). But having separate files doesn't preclude us from doing that.

* Make RandomVerticalFlip torchscriptable

* Fix lint
@fmassa fmassa merged commit 11a39aa into master Jun 4, 2020
@fmassa fmassa deleted the hflip-script branch June 4, 2020 15:54
@fmassa fmassa mentioned this pull request Jun 5, 2020
16 tasks
@fmassa fmassa mentioned this pull request Jun 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants