diff --git a/test/test_image.py b/test/test_image.py index ea055b82715..26f72c87588 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -94,6 +94,8 @@ def test_decode_jpeg(img_path, pil_mode, mode): data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) + assert img_ljpeg.is_contiguous() + # Permit a small variation on pixel values to account for implementation # differences between Pillow and LibJPEG. abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() @@ -173,6 +175,8 @@ def test_decode_png(img_path, pil_mode, mode): data = read_file(img_path) img_lpng = decode_image(data, mode=mode) + assert img_lpng.is_contiguous() + tol = 0 if pil_mode is None else 1 if PILLOW_VERSION >= (8, 3) and pil_mode == "LA": diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index c6e971c3b12..666a3e9e0a1 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -148,7 +148,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); - return tensor.permute({2, 0, 1}); + return tensor.permute({2, 0, 1}).contiguous(); } #endif diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 0df55daed68..ba78ae6374f 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -224,7 +224,7 @@ torch::Tensor decode_png( } } png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - return tensor.permute({2, 0, 1}); + return tensor.permute({2, 0, 1}).contiguous(); } #endif diff --git a/torchvision/io/image.py b/torchvision/io/image.py index f835565016c..21dc7ecab6c 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -59,7 +59,7 @@ def write_file(filename: str, data: torch.Tensor) -> None: def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ - Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. + Decodes a PNG image into a 3 dimensional RGB or grayscale contiguous Tensor. Optionally converts the image to the desired format. The values of the output tensor are uint8 in [0, 255]. @@ -117,7 +117,7 @@ def decode_jpeg( input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu" ) -> torch.Tensor: """ - Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor. + Decodes a JPEG image into a 3 dimensional RGB or grayscale contiguous Tensor. Optionally converts the image to the desired format. The values of the output tensor are uint8 between 0 and 255. @@ -185,7 +185,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ Detects whether an image is a JPEG or PNG and performs the appropriate - operation to decode the image into a 3 dimensional RGB or grayscale Tensor. + operation to decode the image into a 3 dimensional RGB or grayscale contiguous Tensor. Optionally converts the image to the desired format. The values of the output tensor are uint8 in [0, 255]. @@ -207,7 +207,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ - Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. + Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale contiguous Tensor. Optionally converts the image to the desired format. The values of the output tensor are uint8 in [0, 255].