Skip to content

Commit

Permalink
Bugfix: Ensure self.transform is called on frames within VideoFolderD…
Browse files Browse the repository at this point in the history
…ataset
  • Loading branch information
willprice committed Jan 10, 2019
1 parent d489b87 commit d45e7e9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/torchvideo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def __getitem__(
video_length = self.video_lengths[index]
frames_idx = self.sampler.sample(video_length)
frames = self._load_frames(frames_idx, video_file)
if self.transform is not None:
frames = self.transform(frames)

if self.labels is not None:
return frames, self.labels[index]
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy
import os
from pathlib import Path
from unittest.mock import Mock

import pytest
from pyfakefs.fake_filesystem import FakeFilesystem
Expand All @@ -11,6 +13,7 @@
DummyLabelSet,
LambdaLabelSet,
)
from torchvideo.internal.utils import frame_idx_to_list


@pytest.fixture
Expand Down Expand Up @@ -107,6 +110,25 @@ def test_labels_are_accessible(self, dataset_dir, fs, mock_frame_count):
assert len(dataset.labels) == video_count
assert all([label == i for i, label in enumerate(dataset.labels)])

def test_transform_is_called_if_provided(self, dataset_dir, fs, monkeypatch):
def _load_mock_frames(self, frames_idx, video_file):
frames_count = len(frame_idx_to_list(frames_idx))
return numpy.zeros((frames_count, 10, 20, 3))

monkeypatch.setattr(
torchvideo.datasets.VideoFolderDataset, "_load_frames", _load_mock_frames
)
video_count = 10
self.make_video_files(dataset_dir, fs, video_count)
mock_transform = Mock(side_effect=lambda frames: frames)
dataset = VideoFolderDataset(
dataset_dir, transform=mock_transform, frame_counter=lambda p: 20
)

frames = dataset[0]

mock_transform.assert_called_once_with(frames)

@staticmethod
def make_video_files(dataset_dir, fs, video_count):
for i in range(0, video_count):
Expand Down

0 comments on commit d45e7e9

Please sign in to comment.