-
Notifications
You must be signed in to change notification settings - Fork 72
Add Random index-based clip sampler - part 1 #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c090e44
6b52cfd
b107382
63bbcfa
b3bb895
4d5da20
78eebd5
623f329
179de70
310da2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from ._implem import clips_at_random_indices | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| import random | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. random nit: why not call it implementation.py? It's not user-facing, so we can have long descriptive names here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally don't find |
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from torchcodec.decoders import FrameBatch, SimpleVideoDecoder | ||
|
|
||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Below I factored-out some validation logic into functions. This isn't strictly necessary for this PR, but we'll need the exact same logic for the |
||
| def _validate_params( | ||
| *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames | ||
| ): | ||
| if len(decoder) < 1: | ||
| raise ValueError( | ||
| f"Decoder must have at least one frame, found {len(decoder)} frames." | ||
| ) | ||
|
|
||
| if num_clips <= 0: | ||
| raise ValueError(f"num_clips ({num_clips}) must be strictly positive") | ||
| if num_frames_per_clip <= 0: | ||
| raise ValueError( | ||
| f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" | ||
| ) | ||
| if num_indices_between_frames <= 0: | ||
| raise ValueError( | ||
| f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" | ||
| ) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _validate_sampling_range( | ||
| *, sampling_range_start, sampling_range_end, num_frames, clip_span | ||
| ): | ||
| if sampling_range_start < 0: | ||
| sampling_range_start = num_frames + sampling_range_start | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if it's still negative after this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would result in an error because If we don't error, we can either:
Neither sounds like a great intuitive option, so I it might be best to error, similar to what we decided in https://github.com/pytorch/torchcodec/pull/221/files#r1781347650 |
||
|
|
||
| if sampling_range_start >= num_frames: | ||
| raise ValueError( | ||
| f"sampling_range_start ({sampling_range_start}) must be smaller than " | ||
| f"the number of frames ({num_frames})." | ||
| ) | ||
|
|
||
| if sampling_range_end is None: | ||
| sampling_range_end = num_frames - clip_span + 1 | ||
| if sampling_range_start >= sampling_range_end: | ||
| raise ValueError( | ||
| f"We determined that sampling_range_end should be {sampling_range_end}, " | ||
| "but it is smaller than or equal to sampling_range_start " | ||
| f"({sampling_range_start})." | ||
| ) | ||
| else: | ||
| if sampling_range_end < 0: | ||
| # Support negative values so that -1 means last frame. | ||
| sampling_range_end = num_frames + sampling_range_end | ||
| sampling_range_end = min(sampling_range_end, num_frames) | ||
| if sampling_range_start >= sampling_range_end: | ||
| raise ValueError( | ||
| f"sampling_range_start ({sampling_range_start}) must be smaller than " | ||
| f"sampling_range_end ({sampling_range_end})." | ||
| ) | ||
|
|
||
| return sampling_range_start, sampling_range_end | ||
|
|
||
|
|
||
| def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): | ||
| """Return the span of a clip, i.e. the number of frames (or indices) | ||
| between the first and last frame in the clip, both included. | ||
|
|
||
| This isn't the same as the number of frames in a clip! | ||
| Example: f means a frame in the clip, x means a frame excluded from the clip | ||
| num_frames_per_clip = 4 | ||
| num_indices_between_frames = 1, clip = ffff , span = 4 | ||
| num_indices_between_frames = 2, clip = fxfxfxf , span = 7 | ||
| num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 | ||
| """ | ||
| return num_indices_between_frames * (num_frames_per_clip - 1) + 1 | ||
|
|
||
|
|
||
| def clips_at_random_indices( | ||
| decoder: SimpleVideoDecoder, | ||
| *, | ||
| num_clips: int = 1, | ||
| num_frames_per_clip: int = 1, | ||
| num_indices_between_frames: int = 1, | ||
| sampling_range_start: int = 0, | ||
| sampling_range_end: Optional[int] = None, # interval is [start, end). | ||
| ) -> List[FrameBatch]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we agree to return a List here? I thought some users had a need to get a single tensor back? Would those users just stack it manually? Note that stacking could be an expensive operation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we discussed this, we decided to "let the implementation decide".
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Decoding frames to a single tensor doesn't involves a copy -- and that can be done for max speed if we want. Stacking them after the fact is inefficient. |
||
|
|
||
| _validate_params( | ||
| decoder=decoder, | ||
| num_clips=num_clips, | ||
| num_frames_per_clip=num_frames_per_clip, | ||
| num_indices_between_frames=num_indices_between_frames, | ||
| ) | ||
|
|
||
| clip_span = _get_clip_span( | ||
| num_indices_between_frames=num_indices_between_frames, | ||
| num_frames_per_clip=num_frames_per_clip, | ||
| ) | ||
|
|
||
| # TODO: We should probably not error. | ||
| if clip_span > len(decoder): | ||
| raise ValueError( | ||
| f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" | ||
| ) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| sampling_range_start, sampling_range_end = _validate_sampling_range( | ||
| sampling_range_start=sampling_range_start, | ||
| sampling_range_end=sampling_range_end, | ||
| num_frames=len(decoder), | ||
| clip_span=clip_span, | ||
| ) | ||
|
|
||
| clip_start_indices = torch.randint( | ||
| low=sampling_range_start, high=sampling_range_end, size=(num_clips,) | ||
| ) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I'm putting the comment here so that we can still read the logic without a giant textbox in the middle of it.) I'm not sure if we want to compute and then use An alternative approach is that we don't do this, and just set Both options result in some quirky behavior. What is currently implemented makes it less likely (I think) that we'll ever return frames in the range
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think uniformly sampling the clip_starts in the range where they are valid matches my own intuition. Duplicate frames may not be good for training. If @scotts has a strong opinion, maybe we can expose an option but the current behavior seems good to me as the default.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After chatting offline we concluded the best way to handle this was to let the user choose the sampling range, similarly to the "equaly_spaced_in_time" sampler. |
||
| # We want to avoid seeking backwards, so we sort the clip start indices | ||
| # before decoding the frames, and then re-shuffle the clips afterwards. | ||
| # Backward seeks may still happen if there are overlapping clips, i.e. if a | ||
| # clip ends after the next one starts. | ||
| # TODO: We should use a different strategy to avoid backward seeks: | ||
| # - flatten all frames indices, irrespective of their clip | ||
| # - sort the indices and dedup | ||
| # - decode all frames in index order | ||
| # - re-arrange the frames back into their original clips | ||
| clip_start_indices = torch.sort(clip_start_indices).values | ||
| clips = [ | ||
| decoder.get_frames_at( | ||
| start=clip_start_index, | ||
| stop=clip_start_index + clip_span, | ||
| step=num_indices_between_frames, | ||
| ) | ||
| for clip_start_index in clip_start_indices | ||
| ] | ||
|
|
||
| # This an ugly way to shuffle the clips using pytorch RNG *without* | ||
| # affecting the python builtin RNG. | ||
| builtin_random_state = random.getstate() | ||
| random.seed(torch.randint(0, 2**32, (1,)).item()) | ||
| random.shuffle(clips) | ||
| random.setstate(builtin_random_state) | ||
|
Comment on lines
+137
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there is a contextmanager for this? I couldn't find it myself but maybe there is a way using the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's typically the kind of scenario where a CM can be useful indeed. I thought about writing one, but decided against it considering this logic is just 2 lines of code (the CM would be more), and we'll remove it soon when we address the TODO. |
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it necessary that we don't affect the random module's state? I assume we're trying not to affect the random values seen by models during actual training, but why do we want to do that?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users may want to have strict control over the RNG stream of the builtin In general, a library hard-coding a seed for a global RNG stream is a big no no (whether it's the Python RNG, pytorch RNG, numpy, ...). If a library hard-codes a seed, it should be for a local, non-global stream.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I'm on board with not hardcoding a seed. I guess what I'm confused about is why we can't just call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're a pytorch library, so in general we want our RNG stream to come from pytorch, not from other RNGs.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohhhh, got it, so we're porting over torch's RNG state to random. Since your last response was the aha moment for me, let's say something to that effect in a comment. :) |
||
| return clips | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| import random | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def prevent_leaking_rng(): | ||
| # Prevent each test from leaking the rng to all other test when they call | ||
| # torch.manual_seed() or random.seed(). | ||
scotts marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| torch_rng_state = torch.get_rng_state() | ||
| builtin_rng_state = random.getstate() | ||
| if torch.cuda.is_available(): | ||
| cuda_rng_state = torch.cuda.get_rng_state() | ||
|
|
||
| yield | ||
|
|
||
| torch.set_rng_state(torch_rng_state) | ||
| random.setstate(builtin_rng_state) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.set_rng_state(cuda_rng_state) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| import contextlib | ||
| import random | ||
| import re | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torchcodec.decoders import FrameBatch, SimpleVideoDecoder | ||
| from torchcodec.samplers import clips_at_random_indices | ||
|
|
||
| from ..utils import assert_tensor_equal, NASA_VIDEO | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) | ||
| def test_random_sampler(num_indices_between_frames): | ||
| decoder = SimpleVideoDecoder(NASA_VIDEO.path) | ||
| num_clips = 2 | ||
| num_frames_per_clip = 3 | ||
|
|
||
| clips = clips_at_random_indices( | ||
| decoder, | ||
| num_clips=num_clips, | ||
| num_frames_per_clip=num_frames_per_clip, | ||
| num_indices_between_frames=num_indices_between_frames, | ||
| ) | ||
|
|
||
| assert isinstance(clips, list) | ||
| assert len(clips) == num_clips | ||
| assert all(isinstance(clip, FrameBatch) for clip in clips) | ||
| expected_clip_data_shape = ( | ||
| num_frames_per_clip, | ||
| 3, | ||
| NASA_VIDEO.height, | ||
| NASA_VIDEO.width, | ||
| ) | ||
| assert all(clip.data.shape == expected_clip_data_shape for clip in clips) | ||
|
|
||
| # Check the num_indices_between_frames parameter by asserting that the | ||
| # "time" difference between frames in a clip is the same as the "index" | ||
| # distance. | ||
| avg_distance_between_frames_seconds = torch.concat( | ||
| [clip.pts_seconds.diff() for clip in clips] | ||
| ).mean() | ||
| assert avg_distance_between_frames_seconds == pytest.approx( | ||
| num_indices_between_frames / decoder.metadata.average_fps | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "sampling_range_start, sampling_range_end, assert_all_equal", | ||
| ( | ||
| (10, 11, True), | ||
| (10, 12, False), | ||
| ), | ||
| ) | ||
| def test_random_sampler_range( | ||
| sampling_range_start, sampling_range_end, assert_all_equal | ||
| ): | ||
| # Test the sampling_range_start and sampling_range_end parameters by | ||
| # asserting that all clips are equal if the sampling range is of size 1, | ||
| # and that they are not all equal if the sampling range is of size 2. | ||
|
|
||
| # When size=2 there's still a (small) non-zero probability of sampling the | ||
| # same indices for clip starts, so we hard-code a seed that works. | ||
| torch.manual_seed(0) | ||
|
|
||
| decoder = SimpleVideoDecoder(NASA_VIDEO.path) | ||
|
|
||
| clips = clips_at_random_indices( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
| sampling_range_start=sampling_range_start, | ||
| sampling_range_end=sampling_range_end, | ||
| ) | ||
|
|
||
| # This context manager is used to ensure that the call to | ||
| # assert_tensor_equal() below either passes (nullcontext) or fails | ||
| # (pytest.raises) | ||
| cm = ( | ||
| contextlib.nullcontext() | ||
| if assert_all_equal | ||
| else pytest.raises(AssertionError, match="Tensor-likes are not") | ||
| ) | ||
| with cm: | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for clip in clips: | ||
| assert_tensor_equal(clip.data, clips[0].data) | ||
|
|
||
|
|
||
| def test_random_sampler_range_negative(): | ||
| # Test the passing negative values for sampling_range_start and | ||
| # sampling_range_end is the same as passing `len(decoder) - val` | ||
|
|
||
| decoder = SimpleVideoDecoder(NASA_VIDEO.path) | ||
|
|
||
| clips_1 = clips_at_random_indices( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
| sampling_range_start=len(decoder) - 100, | ||
| sampling_range_end=len(decoder) - 99, | ||
| ) | ||
|
|
||
| clips_2 = clips_at_random_indices( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
| sampling_range_start=-100, | ||
| sampling_range_end=-99, | ||
| ) | ||
|
|
||
| # There is only one unique clip in clips_1... | ||
| for clip in clips_1: | ||
| assert_tensor_equal(clip.data, clips_1[0].data) | ||
| # ... and it's the same that's in clips_2 | ||
| for clip in clips_2: | ||
| assert_tensor_equal(clip.data, clips_1[0].data) | ||
|
|
||
|
|
||
| def test_random_sampler_randomness(): | ||
| decoder = SimpleVideoDecoder(NASA_VIDEO.path) | ||
| num_clips = 5 | ||
|
|
||
| builtin_random_state_start = random.getstate() | ||
|
|
||
| torch.manual_seed(0) | ||
| clips_1 = clips_at_random_indices(decoder, num_clips=num_clips) | ||
|
|
||
| # Assert the clip starts aren't sorted, to make sure we haven't messed up | ||
| # the implementation. (This may fail if we're unlucky, but we hard-coded a | ||
| # seed, so it will always pass.) | ||
| clip_starts = [clip.pts_seconds.item() for clip in clips_1] | ||
| assert sorted(clip_starts) != clip_starts | ||
|
|
||
| # Call the same sampler again with the same seed, expect same results | ||
| torch.manual_seed(0) | ||
| clips_2 = clips_at_random_indices(decoder, num_clips=num_clips) | ||
| for clip_1, clip_2 in zip(clips_1, clips_2): | ||
| assert_tensor_equal(clip_1.data, clip_2.data) | ||
| assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) | ||
| assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds) | ||
|
|
||
| # Call with a different seed, expect different results | ||
| torch.manual_seed(1) | ||
| clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) | ||
| with pytest.raises(AssertionError, match="Tensor-likes are not"): | ||
| assert_tensor_equal(clips_1[0].data, clips_3[0].data) | ||
|
|
||
| # Make sure we didn't alter the builtin Python RNG | ||
| builtin_random_state_end = random.getstate() | ||
| assert builtin_random_state_start == builtin_random_state_end | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we can fuzz test this to make sure we don't return errors mid-training. Can pytest help with passing in random values to the parameters (with certain constraints like positive-only values, etc. for things like frames_per_clip, etc.) to make sure we are robust to errors?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is the
?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def test_random_sampler_errors(): | ||
| decoder = SimpleVideoDecoder(NASA_VIDEO.path) | ||
| with pytest.raises( | ||
| ValueError, match=re.escape("num_clips (0) must be strictly positive") | ||
| ): | ||
| clips_at_random_indices(decoder, num_clips=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") | ||
| ): | ||
| clips_at_random_indices(decoder, num_frames_per_clip=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("num_indices_between_frames (0) must be strictly positive"), | ||
| ): | ||
| clips_at_random_indices(decoder, num_indices_between_frames=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("Clip span (1000) is larger than the number of frames"), | ||
| ): | ||
| clips_at_random_indices(decoder, num_frames_per_clip=1000) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("Clip span (1001) is larger than the number of frames"), | ||
| ): | ||
| clips_at_random_indices( | ||
| decoder, num_frames_per_clip=2, num_indices_between_frames=1000 | ||
| ) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") | ||
| ): | ||
| clips_at_random_indices(decoder, sampling_range_start=1000) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (4) must be smaller than") | ||
| ): | ||
| clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (290) must be smaller than") | ||
| ): | ||
| clips_at_random_indices( | ||
| decoder, sampling_range_start=-100, sampling_range_end=-100 | ||
| ) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match="We determined that sampling_range_end should" | ||
| ): | ||
| clips_at_random_indices( | ||
| decoder, | ||
| num_frames_per_clip=10, | ||
| sampling_range_start=len(decoder) - 1, | ||
| sampling_range_end=None, | ||
| ) | ||

Uh oh!
There was an error while loading. Please reload this page.