-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
only flatten a pytree once #6767
Conversation
@@ -437,7 +437,7 @@ def test__get_params(self, fill, side_range, mocker): | |||
image = mocker.MagicMock(spec=features.Image) | |||
h, w = image.spatial_size = (24, 32) | |||
|
|||
params = transform._get_params(image) | |||
params = transform._get_params([image]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All changes in this file stem from the fact that _get_params
previously handled pytree objects, but now only handles flattened ones.
@@ -1197,6 +1197,7 @@ def test_assertions(self, transform_cls): | |||
[ | |||
[transforms.Pad(2), transforms.RandomCrop(28)], | |||
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)], | |||
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've accidentally included this while playing with the container flattening. It is useful nevertheless to have the "foreign" transform also in the middle of the pipeline. Let me know if I should revert.
@@ -639,7 +639,7 @@ def test_random_apply(self, p): | |||
prototype_transform = prototype_transforms.RandomApply( | |||
[ | |||
prototype_transforms.Resize(256), | |||
legacy_transforms.CenterCrop(224), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just plain wrong. It never triggered because the input is a single image. Found this while playing with the container flattening.
@@ -31,16 +31,17 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: | |||
key = keys[int(torch.randint(len(keys), ()))] | |||
return key, dct[key] | |||
|
|||
def _extract_image_or_video( | |||
def _flatten_and_extract_image_or_video( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've opted to group the flattening and extraction as well as the unflattening and insertion since that yielded the cleanest results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @pmeier. I like the approach. It keeps this detail internal and minimizes the overhead. Let's see how this improved our speed in real world applications. I'll do some training later to confirm.
Reviewed By: NicolasHug Differential Revision: D40427480 fbshipit-source-id: 0552b32b56e5292a64060fcddde46feca4137b6a
Closes #6760 minus the container handling as stated in #6760 (comment).