Skip to content

Commit

Permalink
Add webp decoder (#8527)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 15, 2024
1 parent edb1c33 commit be7cdf1
Show file tree
Hide file tree
Showing 18 changed files with 212 additions and 42 deletions.
3 changes: 3 additions & 0 deletions .github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/sh

export IS_M1_CONDA_BUILD_JOB=1
1 change: 1 addition & 0 deletions .github/scripts/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ conda create \
python="${PYTHON_VERSION}" pip \
ninja cmake \
libpng \
libwebp \
'ffmpeg<4.3'
conda activate ci
conda install --quiet --yes libjpeg-turbo -c pytorch
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-conda-m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
env-var-script: ./.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh
pre-script: ${{ matrix.pre-script }}
post-script: ${{ matrix.post-script }}
package-name: ${{ matrix.package-name }}
Expand Down
17 changes: 17 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_MPS "Enable MPS support" OFF)
option(WITH_PNG "Enable features requiring LibPNG." ON)
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
# Libwebp is disabled by default, which means enabling it from cmake is largely
# untested. Since building from cmake is very low pri anyway, this is OK. If
# you're a user and you need this, please open an issue (and a PR!).
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)

if(WITH_CUDA)
enable_language(CUDA)
Expand All @@ -32,6 +36,11 @@ if (WITH_JPEG)
find_package(JPEG REQUIRED)
endif()

if (WITH_WEBP)
add_definitions(-DWEBP_FOUND)
find_package(WEBP REQUIRED)
endif()

function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
if(NOT "${old_flags}" STREQUAL "")
Expand Down Expand Up @@ -104,6 +113,10 @@ if (WITH_JPEG)
target_link_libraries(${PROJECT_NAME} PRIVATE ${JPEG_LIBRARIES})
endif()

if (WITH_WEBP)
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
endif()

set_target_properties(${PROJECT_NAME} PROPERTIES
EXPORT_NAME TorchVision
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
Expand All @@ -118,6 +131,10 @@ if (WITH_JPEG)
include_directories(${JPEG_INCLUDE_DIRS})
endif()

if (WITH_WEBP)
include_directories(${WEBP_INCLUDE_DIRS})
endif()

set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")

configure_package_config_file(cmake/TorchVisionConfig.cmake.in
Expand Down
6 changes: 6 additions & 0 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ videos.
Images
------

Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG
decoding can also be done on CUDA GPUs.

For encoding, JPEG (cpu and CUDA) and PNG are supported.

.. autosummary::
:toctree: generated/
:template: function.rst
Expand All @@ -20,6 +25,7 @@ Images
decode_jpeg
write_jpeg
decode_gif
decode_webp
encode_png
decode_png
write_png
Expand Down
9 changes: 6 additions & 3 deletions packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash

if [[ "$(uname)" == Darwin ]]; then
# Uninstall Conflicting jpeg brew formulae
jpeg_packages=$(brew list | grep jpeg)
Expand All @@ -12,8 +13,10 @@ if [[ "$(uname)" == Darwin ]]; then
fi

if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
# Install libpng from Anaconda (defaults)
conda install libpng -yq
conda install libpng libwebp -yq
# Installing webp also installs a non-turbo jpeg, so we uninstall jpeg stuff
# before re-installing them
conda uninstall libjpeg-turbo libjpeg -y
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch

# Copy binaries to be included in the wheel distribution
Expand All @@ -29,7 +32,7 @@ else
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly
fi

yum install -y libjpeg-turbo-devel freetype gnutls
yum install -y libjpeg-turbo-devel libwebp-devel freetype gnutls
pip install auditwheel
fi

Expand Down
2 changes: 2 additions & 0 deletions packaging/torchvision/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ requirements:
- {{ compiler('c') }} # [win]
- libpng
- libjpeg-turbo
- libwebp
- ffmpeg >=4.2.2, <5.0.0 # [linux]

host:
Expand All @@ -28,6 +29,7 @@ requirements:
- libpng
- ffmpeg >=4.2.2, <5.0.0 # [linux]
- libjpeg-turbo
- libwebp
- pillow >=5.3.0, !=8.3.*
- pytorch-mutex 1.0 {{ build_variant }} # [not osx ]
{{ environ.get('CONDA_PYTORCH_CONSTRAINT', 'pytorch') }}
Expand Down
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEBUG = os.getenv("DEBUG", "0") == "1"
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
USE_FFMPEG = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
Expand All @@ -41,6 +42,7 @@
print(f"{DEBUG = }")
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_FFMPEG = }")
Expand Down Expand Up @@ -308,6 +310,22 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without JPEG support")

if USE_WEBP:
webp_found, webp_include_dir, webp_library_dir = find_library(header="webp/decode.h")
if webp_found:
print("Building torchvision with WEBP support")
print(f"{webp_include_dir = }")
print(f"{webp_library_dir = }")
if webp_include_dir is not None and webp_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(webp_include_dir)
library_dirs.append(webp_library_dir)
webp_library = "libwebp" if sys.platform == "win32" else "webp"
libraries.append(webp_library)
define_macros += [("WEBP_FOUND", 1)]
else:
warnings.warn("Building torchvision without WEBP support")

if USE_NVJPEG and torch.cuda.is_available():
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand Down
Binary file added test/assets/fakedata/logos/rgb_pytorch.webp
Binary file not shown.
22 changes: 16 additions & 6 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Run smoke tests"""

import os
import sys
from pathlib import Path

import torch
import torchvision
from torchvision.io import decode_jpeg, read_file, read_image
from torchvision.io import decode_jpeg, decode_webp, read_file, read_image
from torchvision.models import resnet50, ResNet50_Weights


SCRIPT_DIR = Path(__file__).parent


Expand All @@ -25,6 +27,9 @@ def smoke_test_torchvision_read_decode() -> None:
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
if img_webp.shape != (3, 100, 100):
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")


def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
Expand Down Expand Up @@ -77,11 +82,16 @@ def main() -> None:
print(f"torchvision: {torchvision.__version__}")
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")

# Turn 1.11.0aHASH into 1.11 (major.minor only)
version = ".".join(torchvision.__version__.split(".")[:2])
if version >= "0.16":
print(f"{torch.ops.image._jpeg_version() = }")
assert torch.ops.image._is_compiled_against_turbo()
print(f"{torch.ops.image._jpeg_version() = }")
if not torch.ops.image._is_compiled_against_turbo():
msg = "Torchvision wasn't compiled against libjpeg-turbo"
if os.getenv("IS_M1_CONDA_BUILD_JOB") == "1":
# When building the conda package on M1, it's difficult to enforce
# that we build against turbo due to interactions with the libwebp
# package. So we just accept it, instead of raising an error.
print(msg)
else:
raise ValueError(msg)

smoke_test_torchvision()
smoke_test_torchvision_read_decode()
Expand Down
29 changes: 23 additions & 6 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
decode_image,
decode_jpeg,
decode_png,
decode_webp,
encode_jpeg,
encode_png,
ImageReadMode,
Expand Down Expand Up @@ -861,16 +862,32 @@ def test_decode_gif(tmpdir, name, scripted):
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


def test_decode_gif_errors():
@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp))
def test_decode_gif_webp_errors(decode_fun):
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
decode_gif(encoded_data[None])
decode_fun(encoded_data[None])
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
decode_gif(encoded_data.float())
decode_fun(encoded_data.float())
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
decode_gif(encoded_data[::2])
with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")):
decode_gif(encoded_data)
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
else:
expected_match = "WebPDecodeRGB failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)


@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_webp(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)


if __name__ == "__main__":
Expand Down
39 changes: 27 additions & 12 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "decode_gif.h"
#include "decode_jpeg.h"
#include "decode_png.h"
#include "decode_webp.h"

namespace vision {
namespace image {
Expand All @@ -20,29 +21,43 @@ torch::Tensor decode_image(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

auto err_msg =
"Unsupported image file. Only jpeg, png and gif are currently supported.";

auto datap = data.data_ptr<uint8_t>();

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
TORCH_CHECK(data.numel() >= 3, err_msg);
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
}

const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
TORCH_CHECK(data.numel() >= 4, err_msg);
if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(data, mode, apply_exif_orientation);
}

const uint8_t gif_signature_1[6] = {
0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a"
const uint8_t gif_signature_2[6] = {
0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(data, mode, apply_exif_orientation);
} else if (
memcmp(gif_signature_1, datap, 6) == 0 ||
TORCH_CHECK(data.numel() >= 6, err_msg);
if (memcmp(gif_signature_1, datap, 6) == 0 ||
memcmp(gif_signature_2, datap, 6) == 0) {
return decode_gif(data);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg, png and gif ",
"are currently supported.");
}

const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
const uint8_t webp_signature_end[7] = {
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"
TORCH_CHECK(data.numel() >= 15, err_msg);
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
return decode_webp(data);
}

TORCH_CHECK(false, err_msg);
}

} // namespace image
Expand Down
40 changes: 40 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "decode_webp.h"

#if WEBP_FOUND
#include "webp/decode.h"
#endif // WEBP_FOUND

namespace vision {
namespace image {

#if !WEBP_FOUND
torch::Tensor decode_webp(const torch::Tensor& data) {
TORCH_CHECK(
false, "decode_webp: torchvision not compiled with libwebp support");
}
#else

torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

int width = 0;
int height = 0;
auto decoded_data = WebPDecodeRGB(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
return out.permute({2, 0, 1}); // return CHW, channels-last
}
#endif // WEBP_FOUND

} // namespace image
} // namespace vision
11 changes: 11 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_webp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);

} // namespace image
} // namespace vision
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ static auto registry =
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "cpu/decode_image.h"
#include "cpu/decode_jpeg.h"
#include "cpu/decode_png.h"
#include "cpu/decode_webp.h"
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
Expand Down
Loading

0 comments on commit be7cdf1

Please sign in to comment.