Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 16 bits png images #4657

Merged
merged 10 commits into from
Oct 21, 2021
Merged

Add support for 16 bits png images #4657

merged 10 commits into from
Oct 21, 2021

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Oct 19, 2021

Closes #4107
Closes #2218

This PR adds support for 16 bits pngs. Since pytorch doesn't support the uint16 dtype, we return int32 tensors instead (we indicate in the doc that we will be returning uint16 tensors in the future, if pytorch start supporting those).

Among other things, this will enable training RAFT on the Kitti dataset, which currently can only be done by relying on openCV.

PIL support for 16 bits png is a bit limited and buggy, especially for grayscale images (python-pillow/Pillow#3011). PIL also automatically converts the 16bits values to uint8, loosing tons of precision. This makes it hard to test. For this reason I only added test for one RGB image and one RGBA image. According to a few ad-hoc tests, grayscale images are decoded properly (unlike for PIL).
Also, for all 200 Kitti-Flow ground-truth flow images, this code returns the exact same values as the cv2 version.

This code takes about the same time as cv2 to decode a 1567 x 1965 RGBA image. PIL is a lot faster but I assume that this is because they downcast everything to uint8:

torchvision
228 ms ± 3.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
PIL
114 µs ± 9.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
CV2
234 ms ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Note: I observe the same relative performance on 8 bits images: torchvision == cv2 >> PIL

@NicolasHug NicolasHug marked this pull request as ready for review October 20, 2021 17:54
@NicolasHug NicolasHug changed the title WIP Add support for 16 bits png images Add support for 16 bits png images Oct 20, 2021
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);

if (bit_depth <= 8) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this if block is unchanged and corresponds to the original code. I just renamed ptr into t_ptr, because the other block uses too many pointers for ptr to be explicit enough

@@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
}
#else

bool is_little_endian() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa I eventually realized that this was a much cleaner and simpler way to handle the endianness. The rest takes care of itself when we cast the uint16 value into a int32_t a few lines below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton for adding support for 16-bit PNGs!

I have one concern about current implementation, otherwise the rest LGTM!

Comment on lines 198 to 199
uint16_t* tmp_buffer =
(uint16_t*)malloc(num_pixels_per_row * sizeof(uint16_t));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to a memory leak in the end of the function.

If you malloc, you need to free after it's used. But you'll need to handle a few corner cases in the freeing size (what if png_read_row fails?).

I think it would be easier to just allocate the buffer via PyTorch torch::empty (or raw data via at::DataPtr via at::getCPUAllocator()->allocate(length);, but I think torch::empty is easier to use, up to you)

@@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 between 0 and 255, except for
16-bits pngs which are int32 tensors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also mention the range for 16-bit pngs, which is from 0-65k?

// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@NicolasHug NicolasHug mentioned this pull request Oct 21, 2021
12 tasks
@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 21, 2021

Thanks for the review!

Test failures are unrelated: #4683

Good to go @fmassa ?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit because it was already like this before: you can just do tmp_buffer_tensor.data_ptr<uint8_t, 1>()

@fmassa fmassa merged commit e32f543 into pytorch:main Oct 21, 2021
facebook-github-bot pushed a commit that referenced this pull request Oct 26, 2021
Summary:
* WIP

* cleaner code

* Add tests

* Add docs

* Assert dtype

* put back check

* Address comments

Reviewed By: NicolasHug

Differential Revision: D31916334

fbshipit-source-id: 8877266f6e533e8c45c5f202e535944a9a939376

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* WIP

* cleaner code

* Add tests

* Add docs

* Assert dtype

* put back check

* Address comments

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support pngs with more than 8 bits Loading 16bit png images
3 participants