Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import decoders # noqa
from . import decoders, samplers # noqa # noqa

__version__ = "0.0.2.dev"
1 change: 1 addition & 0 deletions src/torchcodec/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._implem import clips_at_random_indices
142 changes: 142 additions & 0 deletions src/torchcodec/samplers/_implem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import random
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random nit: why not call it implementation.py? It's not user-facing, so we can have long descriptive names here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't find implementation substantially more descriptive than implem.
I'm OK to address, but only at the very end before merging, because the renaming would cause GitHub to mark ongoing comments as "outdated".

from typing import List, Optional

import torch

from torchcodec.decoders import FrameBatch, SimpleVideoDecoder


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below I factored-out some validation logic into functions. This isn't strictly necessary for this PR, but we'll need the exact same logic for the clips_at_regular_indices() sampler. So I prefer extracting the logic now, so that we can minimize the diff changes of the clips_at_regular_indices() PR.

def _validate_params(
*, decoder, num_clips, num_frames_per_clip, num_indices_between_frames
):
if len(decoder) < 1:
raise ValueError(
f"Decoder must have at least one frame, found {len(decoder)} frames."
)

if num_clips <= 0:
raise ValueError(f"num_clips ({num_clips}) must be strictly positive")
if num_frames_per_clip <= 0:
raise ValueError(
f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive"
)
if num_indices_between_frames <= 0:
raise ValueError(
f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive"
)


def _validate_sampling_range(
*, sampling_range_start, sampling_range_end, num_frames, clip_span
):
if sampling_range_start < 0:
sampling_range_start = num_frames + sampling_range_start
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it's still negative after this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would result in an error because sampling_range_end <= sampling_range_start .

If we don't error, we can either:

  • decide that it's equivalent to 1 (so that it can work when sampling_range_start is 0)
  • keep wrapping-around (modulo).

Neither sounds like a great intuitive option, so I it might be best to error, similar to what we decided in https://github.com/pytorch/torchcodec/pull/221/files#r1781347650


if sampling_range_start >= num_frames:
raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be smaller than "
f"the number of frames ({num_frames})."
)

if sampling_range_end is None:
sampling_range_end = num_frames - clip_span + 1
if sampling_range_start >= sampling_range_end:
raise ValueError(
f"We determined that sampling_range_end should be {sampling_range_end}, "
"but it is smaller than or equal to sampling_range_start "
f"({sampling_range_start})."
)
else:
if sampling_range_end < 0:
# Support negative values so that -1 means last frame.
sampling_range_end = num_frames + sampling_range_end
sampling_range_end = min(sampling_range_end, num_frames)
if sampling_range_start >= sampling_range_end:
raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be smaller than "
f"sampling_range_end ({sampling_range_end})."
)

return sampling_range_start, sampling_range_end


def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip):
"""Return the span of a clip, i.e. the number of frames (or indices)
between the first and last frame in the clip, both included.

This isn't the same as the number of frames in a clip!
Example: f means a frame in the clip, x means a frame excluded from the clip
num_frames_per_clip = 4
num_indices_between_frames = 1, clip = ffff , span = 4
num_indices_between_frames = 2, clip = fxfxfxf , span = 7
num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10
"""
return num_indices_between_frames * (num_frames_per_clip - 1) + 1


def clips_at_random_indices(
decoder: SimpleVideoDecoder,
*,
num_clips: int = 1,
num_frames_per_clip: int = 1,
num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
) -> List[FrameBatch]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we agree to return a List here? I thought some users had a need to get a single tensor back? Would those users just stack it manually? Note that stacking could be an expensive operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we discussed this, we decided to "let the implementation decide".
I do agree that stacked tensors make more sense in general, but we also don't want to stack on behalf of the user as it may introduce an unnecessary copy.
The current implementation leads to a list, but it's possible by implementing the "better frame shuffling" strategy (left as a TODO at the bottom of this function), we'll end up with a stack. This is still TBD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoding frames to a single tensor doesn't involves a copy -- and that can be done for max speed if we want.

Stacking them after the fact is inefficient.


_validate_params(
decoder=decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
)

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

# TODO: We should probably not error.
if clip_span > len(decoder):
raise ValueError(
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})"
)

sampling_range_start, sampling_range_end = _validate_sampling_range(
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
num_frames=len(decoder),
clip_span=clip_span,
)

clip_start_indices = torch.randint(
low=sampling_range_start, high=sampling_range_end, size=(num_clips,)
)

Copy link
Contributor

@scotts scotts Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm putting the comment here so that we can still read the logic without a giant textbox in the middle of it.)

I'm not sure if we want to compute and then use last_clip_start_index. From a probability perspective, I think we're actually making it much less likely that the frames in the range [last_clip_start_index, len(decoder)) are ever returned to the user. For all other frames, consider that they could be the start of a clip, or included in a clip. For last set of frames, they're only included if we happen to select last_clip_start_index as a clip start.

An alternative approach is that we don't do this, and just set low=0, high=len(decoder). Then we let our clip-too-small policy from above handle the case when the clip start is is in the range [last_clip_start_index, len(decoder)).

Both options result in some quirky behavior. What is currently implemented makes it less likely (I think) that we'll ever return frames in the range [last_clip_start_index, len(decoder)). My suggestion above means that selecting those frames is just as likely as the frames in the rest of the video, but we'll always have some duplicates - and maybe biasing the earlier part of the video is better than having duplicates. I'm not sure, sampling is actually a hard problem. ¯\_(ツ)_/¯

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think uniformly sampling the clip_starts in the range where they are valid matches my own intuition.

Duplicate frames may not be good for training.

If @scotts has a strong opinion, maybe we can expose an option but the current behavior seems good to me as the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After chatting offline we concluded the best way to handle this was to let the user choose the sampling range, similarly to the "equaly_spaced_in_time" sampler.

# We want to avoid seeking backwards, so we sort the clip start indices
# before decoding the frames, and then re-shuffle the clips afterwards.
# Backward seeks may still happen if there are overlapping clips, i.e. if a
# clip ends after the next one starts.
# TODO: We should use a different strategy to avoid backward seeks:
# - flatten all frames indices, irrespective of their clip
# - sort the indices and dedup
# - decode all frames in index order
# - re-arrange the frames back into their original clips
clip_start_indices = torch.sort(clip_start_indices).values
clips = [
decoder.get_frames_at(
start=clip_start_index,
stop=clip_start_index + clip_span,
step=num_indices_between_frames,
)
for clip_start_index in clip_start_indices
]

# This an ugly way to shuffle the clips using pytorch RNG *without*
# affecting the python builtin RNG.
builtin_random_state = random.getstate()
random.seed(torch.randint(0, 2**32, (1,)).item())
random.shuffle(clips)
random.setstate(builtin_random_state)
Comment on lines +137 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a contextmanager for this? I couldn't find it myself but maybe there is a way using the with statement here that auto saves/restores?

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's typically the kind of scenario where a CM can be useful indeed. I thought about writing one, but decided against it considering this logic is just 2 lines of code (the CM would be more), and we'll remove it soon when we address the TODO.


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary that we don't affect the random module's state? I assume we're trying not to affect the random values seen by models during actual training, but why do we want to do that?

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users may want to have strict control over the RNG stream of the builtin random module, and we don't want to alter that in any way. Think of a user running different experiments with random.seed(1) and random.seed(2). We'd be altering the RNG streams of these executions to something completely different, rendering their experiment invalid without them knowing.

In general, a library hard-coding a seed for a global RNG stream is a big no no (whether it's the Python RNG, pytorch RNG, numpy, ...). If a library hard-codes a seed, it should be for a local, non-global stream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I'm on board with not hardcoding a seed. I guess what I'm confused about is why we can't just call random.shuffle(clips) on its own. Why do we need to set a seed before shuffling? (Without that need, I assume it's okay to consume values in the RNG stream.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're a pytorch library, so in general we want our RNG stream to come from pytorch, not from other RNGs.
In other words we want torch.manual_seed(123) to affect the RNG of the sampler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, got it, so we're porting over torch's RNG state to random. Since your last response was the aha moment for me, let's say something to that effect in a comment. :)

return clips
22 changes: 22 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import random

import pytest
import torch


@pytest.fixture(autouse=True)
def prevent_leaking_rng():
# Prevent each test from leaking the rng to all other test when they call
# torch.manual_seed() or random.seed().

torch_rng_state = torch.get_rng_state()
builtin_rng_state = random.getstate()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()

yield

torch.set_rng_state(torch_rng_state)
random.setstate(builtin_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
210 changes: 210 additions & 0 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import contextlib
import random
import re

import pytest
import torch
from torchcodec.decoders import FrameBatch, SimpleVideoDecoder
from torchcodec.samplers import clips_at_random_indices

from ..utils import assert_tensor_equal, NASA_VIDEO


@pytest.mark.parametrize("num_indices_between_frames", [1, 5])
def test_random_sampler(num_indices_between_frames):
decoder = SimpleVideoDecoder(NASA_VIDEO.path)
num_clips = 2
num_frames_per_clip = 3

clips = clips_at_random_indices(
decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
)

assert isinstance(clips, list)
assert len(clips) == num_clips
assert all(isinstance(clip, FrameBatch) for clip in clips)
expected_clip_data_shape = (
num_frames_per_clip,
3,
NASA_VIDEO.height,
NASA_VIDEO.width,
)
assert all(clip.data.shape == expected_clip_data_shape for clip in clips)

# Check the num_indices_between_frames parameter by asserting that the
# "time" difference between frames in a clip is the same as the "index"
# distance.
avg_distance_between_frames_seconds = torch.concat(
[clip.pts_seconds.diff() for clip in clips]
).mean()
assert avg_distance_between_frames_seconds == pytest.approx(
num_indices_between_frames / decoder.metadata.average_fps
)


@pytest.mark.parametrize(
"sampling_range_start, sampling_range_end, assert_all_equal",
(
(10, 11, True),
(10, 12, False),
),
)
def test_random_sampler_range(
sampling_range_start, sampling_range_end, assert_all_equal
):
# Test the sampling_range_start and sampling_range_end parameters by
# asserting that all clips are equal if the sampling range is of size 1,
# and that they are not all equal if the sampling range is of size 2.

# When size=2 there's still a (small) non-zero probability of sampling the
# same indices for clip starts, so we hard-code a seed that works.
torch.manual_seed(0)

decoder = SimpleVideoDecoder(NASA_VIDEO.path)

clips = clips_at_random_indices(
decoder,
num_clips=10,
num_frames_per_clip=2,
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
)

# This context manager is used to ensure that the call to
# assert_tensor_equal() below either passes (nullcontext) or fails
# (pytest.raises)
cm = (
contextlib.nullcontext()
if assert_all_equal
else pytest.raises(AssertionError, match="Tensor-likes are not")
)
with cm:
for clip in clips:
assert_tensor_equal(clip.data, clips[0].data)


def test_random_sampler_range_negative():
# Test the passing negative values for sampling_range_start and
# sampling_range_end is the same as passing `len(decoder) - val`

decoder = SimpleVideoDecoder(NASA_VIDEO.path)

clips_1 = clips_at_random_indices(
decoder,
num_clips=10,
num_frames_per_clip=2,
sampling_range_start=len(decoder) - 100,
sampling_range_end=len(decoder) - 99,
)

clips_2 = clips_at_random_indices(
decoder,
num_clips=10,
num_frames_per_clip=2,
sampling_range_start=-100,
sampling_range_end=-99,
)

# There is only one unique clip in clips_1...
for clip in clips_1:
assert_tensor_equal(clip.data, clips_1[0].data)
# ... and it's the same that's in clips_2
for clip in clips_2:
assert_tensor_equal(clip.data, clips_1[0].data)


def test_random_sampler_randomness():
decoder = SimpleVideoDecoder(NASA_VIDEO.path)
num_clips = 5

builtin_random_state_start = random.getstate()

torch.manual_seed(0)
clips_1 = clips_at_random_indices(decoder, num_clips=num_clips)

# Assert the clip starts aren't sorted, to make sure we haven't messed up
# the implementation. (This may fail if we're unlucky, but we hard-coded a
# seed, so it will always pass.)
clip_starts = [clip.pts_seconds.item() for clip in clips_1]
assert sorted(clip_starts) != clip_starts

# Call the same sampler again with the same seed, expect same results
torch.manual_seed(0)
clips_2 = clips_at_random_indices(decoder, num_clips=num_clips)
for clip_1, clip_2 in zip(clips_1, clips_2):
assert_tensor_equal(clip_1.data, clip_2.data)
assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds)
assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds)

# Call with a different seed, expect different results
torch.manual_seed(1)
clips_3 = clips_at_random_indices(decoder, num_clips=num_clips)
with pytest.raises(AssertionError, match="Tensor-likes are not"):
assert_tensor_equal(clips_1[0].data, clips_3[0].data)

# Make sure we didn't alter the builtin Python RNG
builtin_random_state_end = random.getstate()
assert builtin_random_state_start == builtin_random_state_end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can fuzz test this to make sure we don't return errors mid-training.

Can pytest help with passing in random values to the parameters (with certain constraints like positive-only values, etc. for things like frames_per_clip, etc.) to make sure we are robust to errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the hypothesis package that I think does what you're suggesting, but I'm not sure how useful it would be in this specific instance? What do you have in mind with

make sure we don't return errors mid-training.

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was requesting fuzz testing so we can be robust to errors so we don't return errors in the middle of training (which is what your original philosophy was).

image

Now if we are switching to errors mindset, you do seem to have tests for exact errors, so you can resolve this comment.


def test_random_sampler_errors():
decoder = SimpleVideoDecoder(NASA_VIDEO.path)
with pytest.raises(
ValueError, match=re.escape("num_clips (0) must be strictly positive")
):
clips_at_random_indices(decoder, num_clips=0)

with pytest.raises(
ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive")
):
clips_at_random_indices(decoder, num_frames_per_clip=0)

with pytest.raises(
ValueError,
match=re.escape("num_indices_between_frames (0) must be strictly positive"),
):
clips_at_random_indices(decoder, num_indices_between_frames=0)

with pytest.raises(
ValueError,
match=re.escape("Clip span (1000) is larger than the number of frames"),
):
clips_at_random_indices(decoder, num_frames_per_clip=1000)

with pytest.raises(
ValueError,
match=re.escape("Clip span (1001) is larger than the number of frames"),
):
clips_at_random_indices(
decoder, num_frames_per_clip=2, num_indices_between_frames=1000
)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (1000) must be smaller than")
):
clips_at_random_indices(decoder, sampling_range_start=1000)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (4) must be smaller than")
):
clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (290) must be smaller than")
):
clips_at_random_indices(
decoder, sampling_range_start=-100, sampling_range_end=-100
)

with pytest.raises(
ValueError, match="We determined that sampling_range_end should"
):
clips_at_random_indices(
decoder,
num_frames_per_clip=10,
sampling_range_start=len(decoder) - 1,
sampling_range_end=None,
)