From d45e7e9bd336c774427549b2e9c2f34794d3e4d1 Mon Sep 17 00:00:00 2001 From: Will Price Date: Thu, 10 Jan 2019 14:04:11 +0000 Subject: [PATCH] Bugfix: Ensure self.transform is called on frames within VideoFolderDataset --- src/torchvideo/datasets/__init__.py | 2 ++ tests/unit/test_dataset.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/torchvideo/datasets/__init__.py b/src/torchvideo/datasets/__init__.py index c316114..05fa189 100644 --- a/src/torchvideo/datasets/__init__.py +++ b/src/torchvideo/datasets/__init__.py @@ -254,6 +254,8 @@ def __getitem__( video_length = self.video_lengths[index] frames_idx = self.sampler.sample(video_length) frames = self._load_frames(frames_idx, video_file) + if self.transform is not None: + frames = self.transform(frames) if self.labels is not None: return frames, self.labels[index] diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index cf6ec6e..3cbef01 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -1,5 +1,7 @@ +import numpy import os from pathlib import Path +from unittest.mock import Mock import pytest from pyfakefs.fake_filesystem import FakeFilesystem @@ -11,6 +13,7 @@ DummyLabelSet, LambdaLabelSet, ) +from torchvideo.internal.utils import frame_idx_to_list @pytest.fixture @@ -107,6 +110,25 @@ def test_labels_are_accessible(self, dataset_dir, fs, mock_frame_count): assert len(dataset.labels) == video_count assert all([label == i for i, label in enumerate(dataset.labels)]) + def test_transform_is_called_if_provided(self, dataset_dir, fs, monkeypatch): + def _load_mock_frames(self, frames_idx, video_file): + frames_count = len(frame_idx_to_list(frames_idx)) + return numpy.zeros((frames_count, 10, 20, 3)) + + monkeypatch.setattr( + torchvideo.datasets.VideoFolderDataset, "_load_frames", _load_mock_frames + ) + video_count = 10 + self.make_video_files(dataset_dir, fs, video_count) + mock_transform = Mock(side_effect=lambda frames: frames) + dataset = VideoFolderDataset( + dataset_dir, transform=mock_transform, frame_counter=lambda p: 20 + ) + + frames = dataset[0] + + mock_transform.assert_called_once_with(frames) + @staticmethod def make_video_files(dataset_dir, fs, video_count): for i in range(0, video_count):