Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce _check method for type checks on prototype transforms #6503

Merged
merged 10 commits into from
Oct 13, 2022

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Aug 26, 2022

Most of our prototype transforms that override forward do this only to perform some input checks before passing them on to _get_params and _transform. This PR adds a _check method that makes this regular behavior, whereas overriding forward means some "special" behavior the API can't handle.

with _check in place, there are only three* transforms that still need to override forward:

These are objectively outlier and thus it is justified for them to override forward.

Using _check will also get rid of the boilerplate

def forward(self, *inputs):
    sample = inputs if len(inputs) > 1 else inputs[0]

    ...

    return super().forward(sample)

idiom.

torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
@@ -32,6 +35,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

self._check(sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we flattening before this call? Also in this approach we can avoid the multiple flattening calls within has_any, has_all and the rest of the utility methods. Do you plan to do this on a follow up?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. My plan was to introduce _check here and go for #6760 in a follow-up PR.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

@pmeier pmeier merged commit e1b21f9 into pytorch:main Oct 13, 2022
@pmeier pmeier deleted the transforms-check branch October 13, 2022 14:46
facebook-github-bot pushed a commit that referenced this pull request Oct 17, 2022
…rms (#6503)

Summary:
* introduce _check method for type checks on prototype transforms

* cleanup

* Update torchvision/prototype/transforms/_geometry.py

* introduce _check on new transforms

* _check -> _check_inputs

* always check inputs in _RandomApplyTransform

Reviewed By: NicolasHug

Differential Revision: D40427467

fbshipit-source-id: eec7c1ac207955df9310212443769f8f7b146c6a

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants