-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
introduce _check method for type checks on prototype transforms #6503
Conversation
@@ -32,6 +35,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |||
def forward(self, *inputs: Any) -> Any: | |||
sample = inputs if len(inputs) > 1 else inputs[0] | |||
|
|||
self._check(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we flattening before this call? Also in this approach we can avoid the multiple flattening calls within has_any
, has_all
and the rest of the utility methods. Do you plan to do this on a follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. My plan was to introduce _check
here and go for #6760 in a follow-up PR.
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
…nto transforms-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
…rms (#6503) Summary: * introduce _check method for type checks on prototype transforms * cleanup * Update torchvision/prototype/transforms/_geometry.py * introduce _check on new transforms * _check -> _check_inputs * always check inputs in _RandomApplyTransform Reviewed By: NicolasHug Differential Revision: D40427467 fbshipit-source-id: eec7c1ac207955df9310212443769f8f7b146c6a Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Most of our prototype transforms that override
forward
do this only to perform some input checks before passing them on to_get_params
and_transform
. This PR adds a_check
method that makes this regular behavior, whereas overridingforward
means some "special" behavior the API can't handle.with
_check
in place, there are only three* transforms that still need to overrideforward
:vision/torchvision/prototype/transforms/_augment.py
Line 188 in 7cc2c95
vision/torchvision/prototype/transforms/_auto_augment.py
Line 139 in 7cc2c95
vision/torchvision/prototype/transforms/_container.py
Line 10 in 7cc2c95
These are objectively outlier and thus it is justified for them to override
forward
.Using
_check
will also get rid of the boilerplateidiom.