-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Uniform temporal Subsampling for Video #6812
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with having custom tests for this now, but they only cover correctness at this point. Unless this is super urgent, we should really implement a KernelInfo
and DispatcherInfo
for it to also cover stuff like JIT. @datumbox if you don't have time to add that LMK and I will push a commit.
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: | ||
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 | ||
t_max = video.size(temporal_dim) - 1 | ||
indices = torch.linspace(0, t_max, num_samples, device=video.device).clamp_(0, t_max).long() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is just for parity, but is there a reason to create the linspace in int32
to later convert it to int64
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plus, maybe this is just personal preference, but I would appreciate us to be "explicit" about the types
indices = torch.linspace(0, t_max, num_samples, device=video.device).clamp_(0, t_max).long() | |
indices = torch.linspace(0, t_max, num_samples, device=video.device).clamp_(0, t_max).to(torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plusplus, the docs for torch.linspace
state:
Creates a one-dimensional tensor of size
steps
whose values are evenly spaced fromstart
toend
, inclusive.
What is the clamp for if the function by definition does not return values outside this range?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be careful on how we are porting this. Here is why casting later is not the same:
t_max = 9
num_samples = 8
indices = torch.linspace(0, t_max, num_samples, dtype=torch.int64)
indices2 = torch.linspace(0, t_max, num_samples).clamp(0, t_max).long()
assert indices.equal(indices2), f"({t_max}, {num_samples})\n{indices}\n{indices2}"
Result:
assert indices.equal(indices2), f"({t_max}, {num_samples})\n{indices}\n{indices2}"
AssertionError: (9, 8)
tensor([0, 1, 2, 3, 5, 6, 7, 8])
tensor([0, 1, 2, 3, 5, 6, 7, 9])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this looks like a bug in torch linspace as its docs states that the last element is end
(=t_max)
@pmeier Thanks for the review. I've addressed the comments to unblock. We should review on what degree these nits can be covered by the linter or adopt simpler conventions on the future. Regardless I've made the necessary changes to cut down back-and-forth as we need this ASAP to assist the migration of a couple of internal teams. Concerning the tests, I was hoping for more pointers from you on where I'm supposed to add them (the current structure has a lot of abstraction and it's not obvious). Concerning JIT testing, I have a test but it's probably on the wrong place. You are welcome to push a commit on this branch to improve on the tests as you are more familiar with the new test infra but if that is going to take longer, we should do it on a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox I've added the tests I've wanted and over-explained them in my comments below. To run these tests, you can use
pytest test/test_prototype_transforms_functional.py -k "uniform_temporal_subsample"
This will run ~200 tests. LMK if any of my explanations is unclear or I missed something that you want to know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is one relevant test failure:
_____________ TestUniformTemporalSubsample.test__transform[inpt2] ______________
Traceback (most recent call last):
File "/home/runner/work/vision/vision/test/test_prototype_transforms.py", line 1917, in test__transform
output = transform(inpt)
File "/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1363, in _call_impl
return forward_call(*input, **kwargs)
File "/home/runner/work/vision/vision/torchvision/prototype/transforms/_transform.py", line 40, in forward
for inpt in flat_inputs
File "/home/runner/work/vision/vision/torchvision/prototype/transforms/_transform.py", line 40, in <listcomp>
for inpt in flat_inputs
File "/home/runner/work/vision/vision/torchvision/prototype/transforms/_temporal.py", line 16, in _transform
return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim)
File "/home/runner/work/vision/vision/torchvision/prototype/transforms/functional/_temporal.py", line 20, in uniform_temporal_subsample
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
ValueError: Video inputs must have temporal_dim equivalent to -4
Otherwise LGTM if CI is green. Thanks Vasilis!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the bug mentioned above for you. PR is now ready to be merged from my side unless you have any objections about the stuff I added.
Summary: * Adding temporal sampling kernel and dispatcher. * Adding the UniformTemporalSubsample class. * Add it on init * Adding tests. * Addressing comments. * Reverting proposal as it led to different results. * add more tests for uniform_temporal_subsample * cleanup * fix logic * fix logic * make test more strict * lint * Update torchvision/prototype/transforms/functional/_temporal.py * remove pytorchvideo again per request Reviewed By: YosuaMichael Differential Revision: D40722910 fbshipit-source-id: 68af13821890d1784f47ddb7cfbfea409b6ee6a0 Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Resolves #6768
Adding temporal sampling Transforms similar to the ones on PyTorch Video:
cc @vfdev-5 @bjuncek @pmeier