diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 10d8704dbb1..be2fab8e0dd 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -2,7 +2,7 @@ import sys import os import torch -import unittest +import pytest from torchvision import io from torchvision.datasets.samplers import ( @@ -38,13 +38,13 @@ def get_list_of_videos(num_videos=5, sizes=None, fps=None): yield names -@unittest.skipIf(not io.video._av_available(), "this test requires av") -class Tester(unittest.TestCase): +@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") +class TestDatasetsSamplers: def test_random_clip_sampler(self): with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) sampler = RandomClipSampler(video_clips, 3) - self.assertEqual(len(sampler), 3 * 3) + assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) videos = torch.div(indices, 5, rounding_mode='floor') v_idxs, count = torch.unique(videos, return_counts=True) @@ -55,10 +55,10 @@ def test_random_clip_sampler_unequal(self): with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) sampler = RandomClipSampler(video_clips, 3) - self.assertEqual(len(sampler), 2 + 3 + 3) + assert len(sampler) == 2 + 3 + 3 indices = list(iter(sampler)) - self.assertIn(0, indices) - self.assertIn(1, indices) + assert 0 in indices + assert 1 in indices # remove elements of the first video, to simplify testing indices.remove(0) indices.remove(1) @@ -72,7 +72,7 @@ def test_uniform_clip_sampler(self): with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) sampler = UniformClipSampler(video_clips, 3) - self.assertEqual(len(sampler), 3 * 3) + assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) videos = torch.div(indices, 5, rounding_mode='floor') v_idxs, count = torch.unique(videos, return_counts=True) @@ -84,7 +84,7 @@ def test_uniform_clip_sampler_insufficient_clips(self): with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5) sampler = UniformClipSampler(video_clips, 3) - self.assertEqual(len(sampler), 3 * 3) + assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])) @@ -100,7 +100,7 @@ def test_distributed_sampler_and_uniform_clip_sampler(self): group_size=3, ) indices = torch.tensor(list(iter(distributed_sampler_rank0))) - self.assertEqual(len(distributed_sampler_rank0), 6) + assert len(distributed_sampler_rank0) == 6 assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14])) distributed_sampler_rank1 = DistributedSampler( @@ -110,9 +110,9 @@ def test_distributed_sampler_and_uniform_clip_sampler(self): group_size=3, ) indices = torch.tensor(list(iter(distributed_sampler_rank1))) - self.assertEqual(len(distributed_sampler_rank1), 6) + assert len(distributed_sampler_rank1) == 6 assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) if __name__ == '__main__': - unittest.main() + pytest.main([__file__])