Skip to content

Commit

Permalink
PR: Make JPEG/PNG reading ops return images in CHW format (#2680)
Browse files Browse the repository at this point in the history
* Make JPEG/PNG return images in CHW format

* Use int array
  • Loading branch information
andfoy committed Sep 18, 2020
1 parent c4dcfb0 commit 5e4a9f6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions test/test_image.py
Expand Up @@ -30,12 +30,14 @@ class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = read_jpeg(img_path)
self.assertTrue(img_ljpeg.equal(img_pil))

def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_ljpeg.equal(img_pil))
Expand Down Expand Up @@ -68,12 +70,14 @@ def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil))

def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_lpng.equal(img_pil))
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/image/readjpeg_cpu.cpp
Expand Up @@ -137,7 +137,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {

jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor;
return tensor.permute({2, 0, 1});
}

#endif // JPEG_FOUND
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/image/readpng_cpu.cpp
Expand Up @@ -79,6 +79,6 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
ptr += bytes;
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor;
return tensor.permute({2, 0, 1});
}
#endif // PNG_FOUND

0 comments on commit 5e4a9f6

Please sign in to comment.