|
1 | 1 | #include "readpng_cpu.h"
|
2 | 2 |
|
3 |
| -// Comment |
4 | 3 | #include <ATen/ATen.h>
|
5 |
| -#include <string> |
6 | 4 |
|
7 | 5 | #if !PNG_FOUND
|
8 |
| -torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { |
| 6 | +torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { |
9 | 7 | TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
|
10 | 8 | }
|
11 | 9 | #else
|
12 | 10 | #include <png.h>
|
13 | 11 | #include <setjmp.h>
|
14 | 12 |
|
15 |
| -torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { |
| 13 | +torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { |
16 | 14 | // Check that the input tensor dtype is uint8
|
17 | 15 | TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
|
18 | 16 | // Check that the input tensor is 1-dimensional
|
19 | 17 | TORCH_CHECK(
|
20 | 18 | data.dim() == 1 && data.numel() > 0,
|
21 | 19 | "Expected a non empty 1-dimensional tensor");
|
22 |
| - TORCH_CHECK( |
23 |
| - channels >= 0 && channels <= 4, "Number of channels not supported"); |
24 | 20 |
|
25 | 21 | auto png_ptr =
|
26 | 22 | png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
|
@@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
|
74 | 70 | TORCH_CHECK(retval == 1, "Could read image metadata from content.")
|
75 | 71 | }
|
76 | 72 |
|
77 |
| - int current_channels = png_get_channels(png_ptr, info_ptr); |
| 73 | + int channels = png_get_channels(png_ptr, info_ptr); |
78 | 74 |
|
79 |
| - if (channels > 0) { |
| 75 | + if (mode != IMAGE_READ_MODE_UNCHANGED) { |
80 | 76 | // TODO: consider supporting PNG_INFO_tRNS
|
81 | 77 | bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
|
82 | 78 | bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
|
83 | 79 | bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
|
84 | 80 |
|
85 |
| - switch (channels) { |
86 |
| - case 1: // Gray |
87 |
| - if (is_palette) { |
88 |
| - png_set_palette_to_rgb(png_ptr); |
89 |
| - has_alpha = true; |
90 |
| - } |
91 |
| - |
92 |
| - if (has_alpha) { |
93 |
| - png_set_strip_alpha(png_ptr); |
94 |
| - } |
95 |
| - |
96 |
| - if (has_color) { |
97 |
| - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); |
| 81 | + switch (mode) { |
| 82 | + case IMAGE_READ_MODE_GRAY: |
| 83 | + if (color_type != PNG_COLOR_TYPE_GRAY) { |
| 84 | + if (is_palette) { |
| 85 | + png_set_palette_to_rgb(png_ptr); |
| 86 | + has_alpha = true; |
| 87 | + } |
| 88 | + |
| 89 | + if (has_alpha) { |
| 90 | + png_set_strip_alpha(png_ptr); |
| 91 | + } |
| 92 | + |
| 93 | + if (has_color) { |
| 94 | + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); |
| 95 | + } |
| 96 | + channels = 1; |
98 | 97 | }
|
99 | 98 | break;
|
100 |
| - case 2: // Gray + Alpha |
101 |
| - if (is_palette) { |
102 |
| - png_set_palette_to_rgb(png_ptr); |
103 |
| - has_alpha = true; |
104 |
| - } |
105 |
| - |
106 |
| - if (!has_alpha) { |
107 |
| - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); |
108 |
| - } |
109 |
| - |
110 |
| - if (has_color) { |
111 |
| - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); |
| 99 | + case IMAGE_READ_MODE_GRAY_ALPHA: |
| 100 | + if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { |
| 101 | + if (is_palette) { |
| 102 | + png_set_palette_to_rgb(png_ptr); |
| 103 | + has_alpha = true; |
| 104 | + } |
| 105 | + |
| 106 | + if (!has_alpha) { |
| 107 | + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); |
| 108 | + } |
| 109 | + |
| 110 | + if (has_color) { |
| 111 | + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); |
| 112 | + } |
| 113 | + channels = 2; |
112 | 114 | }
|
113 | 115 | break;
|
114 |
| - case 3: |
115 |
| - if (is_palette) { |
116 |
| - png_set_palette_to_rgb(png_ptr); |
117 |
| - has_alpha = true; |
118 |
| - } else if (!has_color) { |
119 |
| - png_set_gray_to_rgb(png_ptr); |
120 |
| - } |
121 |
| - |
122 |
| - if (has_alpha) { |
123 |
| - png_set_strip_alpha(png_ptr); |
| 116 | + case IMAGE_READ_MODE_RGB: |
| 117 | + if (color_type != PNG_COLOR_TYPE_RGB) { |
| 118 | + if (is_palette) { |
| 119 | + png_set_palette_to_rgb(png_ptr); |
| 120 | + has_alpha = true; |
| 121 | + } else if (!has_color) { |
| 122 | + png_set_gray_to_rgb(png_ptr); |
| 123 | + } |
| 124 | + |
| 125 | + if (has_alpha) { |
| 126 | + png_set_strip_alpha(png_ptr); |
| 127 | + } |
| 128 | + channels = 3; |
124 | 129 | }
|
125 | 130 | break;
|
126 |
| - case 4: |
127 |
| - if (is_palette) { |
128 |
| - png_set_palette_to_rgb(png_ptr); |
129 |
| - has_alpha = true; |
130 |
| - } else if (!has_color) { |
131 |
| - png_set_gray_to_rgb(png_ptr); |
132 |
| - } |
133 |
| - |
134 |
| - if (!has_alpha) { |
135 |
| - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); |
| 131 | + case IMAGE_READ_MODE_RGB_ALPHA: |
| 132 | + if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { |
| 133 | + if (is_palette) { |
| 134 | + png_set_palette_to_rgb(png_ptr); |
| 135 | + has_alpha = true; |
| 136 | + } else if (!has_color) { |
| 137 | + png_set_gray_to_rgb(png_ptr); |
| 138 | + } |
| 139 | + |
| 140 | + if (!has_alpha) { |
| 141 | + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); |
| 142 | + } |
| 143 | + channels = 4; |
136 | 144 | }
|
137 | 145 | break;
|
138 | 146 | default:
|
139 | 147 | png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
|
140 |
| - TORCH_CHECK(false, "Invalid number of output channels."); |
| 148 | + TORCH_CHECK(false, "Provided mode not supported"); |
141 | 149 | }
|
142 | 150 |
|
143 | 151 | png_read_update_info(png_ptr, info_ptr);
|
144 |
| - } else { |
145 |
| - channels = current_channels; |
146 | 152 | }
|
147 | 153 |
|
148 | 154 | auto tensor =
|
|
0 commit comments