diff --git a/test/test_image.py b/test/test_image.py index e7e5b8b197d..2e427af26af 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -326,7 +326,8 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): @pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) def test_decode_jpeg_cuda_device_param(cuda_device): """Make sure we can pass a string or a torch.device as device param""" - data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) + path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if 'cmyk' not in path) + data = read_file(path) decode_jpeg(data, device=cuda_device)