diff --git a/references/video_classification/sampler.py b/references/video_classification/sampler.py index c5a879ffa1a..b92dad013c6 100644 --- a/references/video_classification/sampler.py +++ b/references/video_classification/sampler.py @@ -87,3 +87,38 @@ def __iter__(self): def __len__(self): return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) + + +class RandomClipSampler(torch.utils.data.Sampler): + """ + Samples at most `max_video_clips_per_video` clips for each video randomly + + Arguments: + video_clips (VideoClips): video clips to sample from + max_clips_per_video (int): maximum number of clips to be sampled per video + """ + def __init__(self, video_clips, max_clips_per_video): + if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips): + raise TypeError("Expected video_clips to be an instance of VideoClips, " + "got {}".format(type(video_clips))) + self.video_clips = video_clips + self.max_clips_per_video = max_clips_per_video + + def __iter__(self): + idxs = [] + s = 0 + # select at most max_clips_per_video for each video, randomly + for c in self.video_clips.clips: + length = len(c) + size = min(length, self.max_clips_per_video) + sampled = torch.randperm(length)[:size] + s + s += length + idxs.append(sampled) + idxs = torch.cat(idxs) + # shuffle all clips randomly + perm = torch.randperm(len(idxs)) + idxs = idxs[perm].tolist() + return iter(idxs) + + def __len__(self): + return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 0f04475eade..f532f121e70 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -13,7 +13,7 @@ from torchvision import transforms import utils -from sampler import DistributedSampler, UniformClipSampler +from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler from scheduler import WarmupMultiStepLR import transforms as T @@ -184,7 +184,7 @@ def main(args): dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) print("Creating data loaders") - train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video) + train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video) test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 1a0fb3c6681..da8b7ef55bb 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -4,7 +4,7 @@ import unittest from torchvision import io -from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler +from torchvision.datasets.video_utils import VideoClips, unfold from common_utils import get_tmp_dir @@ -80,10 +80,11 @@ def test_video_clips(self): self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) + @unittest.skip("Moved to reference scripts for now") def test_video_sampler(self): with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) + sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 3 * 3) indices = torch.tensor(list(iter(sampler))) videos = indices // 5 @@ -91,10 +92,11 @@ def test_video_sampler(self): self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) + @unittest.skip("Moved to reference scripts for now") def test_video_sampler_unequal(self): with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) + sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 2 + 3 + 3) indices = list(iter(sampler)) self.assertIn(0, indices) diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 1ebec7df1e9..4b81f64ec77 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -1,7 +1,6 @@ import bisect import math import torch -import torch.utils.data from torchvision.io import read_video_timestamps, read_video from .utils import tqdm @@ -214,38 +213,3 @@ def get_clip(self, idx): info["video_fps"] = self.frame_rate assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) return video, audio, info, video_idx - - -class RandomClipSampler(torch.utils.data.Sampler): - """ - Samples at most `max_video_clips_per_video` clips for each video randomly - - Arguments: - video_clips (VideoClips): video clips to sample from - max_clips_per_video (int): maximum number of clips to be sampled per video - """ - def __init__(self, video_clips, max_clips_per_video): - if not isinstance(video_clips, VideoClips): - raise TypeError("Expected video_clips to be an instance of VideoClips, " - "got {}".format(type(video_clips))) - self.video_clips = video_clips - self.max_clips_per_video = max_clips_per_video - - def __iter__(self): - idxs = [] - s = 0 - # select at most max_clips_per_video for each video, randomly - for c in self.video_clips.clips: - length = len(c) - size = min(length, self.max_clips_per_video) - sampled = torch.randperm(length)[:size] + s - s += length - idxs.append(sampled) - idxs = torch.cat(idxs) - # shuffle all clips randomly - perm = torch.randperm(len(idxs)) - idxs = idxs[perm].tolist() - return iter(idxs) - - def __len__(self): - return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)