From 598e34f65390423be83daee6111388d6be822d7a Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Wed, 7 Aug 2024 17:57:51 -0700 Subject: [PATCH] Introduce preprocess custom ops (#4548) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4548 For preprocess, we have two custom ops, pad and reshape. These are used to hide the implementation and any data dependent issues from export, allowing us to export the model. The model definition is in D60051533, though we're trying to upstream it to torchtune: https://github.com/pytorch/torchtune/pull/1269 This diff adds: - Python impl + tests - C++ function signatures + buck setup for reshape using [EXECUTORCH_LIBRARY](https://pytorch.org/executorch/stable/kernel-library-custom-aten-kernel.html#custom-ops-c-api) Differential Revision: D60491675 --- extension/llm/custom_ops/TARGETS | 13 +++ extension/llm/custom_ops/op_tile_crop.cpp | 33 ++++++ extension/llm/custom_ops/op_tile_crop.h | 26 +++++ .../llm/custom_ops/preprocess_custom_ops.py | 110 ++++++++++++++++++ extension/llm/custom_ops/targets.bzl | 33 ++++++ .../custom_ops/test_preprocess_custom_ops.py | 54 +++++++++ 6 files changed, 269 insertions(+) create mode 100644 extension/llm/custom_ops/op_tile_crop.cpp create mode 100644 extension/llm/custom_ops/op_tile_crop.h create mode 100644 extension/llm/custom_ops/preprocess_custom_ops.py create mode 100644 extension/llm/custom_ops/test_preprocess_custom_ops.py diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 195df3bb931..ff3fde6e2cc 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -21,3 +21,16 @@ runtime.python_test( "//caffe2:torch", ], ) + +runtime.python_test( + name = "test_preprocess_custom_ops", + srcs = [ + "test_preprocess_custom_ops.py", + ], + preload_deps = [ + ":preprocess_custom_ops_py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/extension/llm/custom_ops/op_tile_crop.cpp b/extension/llm/custom_ops/op_tile_crop.cpp new file mode 100644 index 00000000000..094bb9beec9 --- /dev/null +++ b/extension/llm/custom_ops/op_tile_crop.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& tile_crop_out_impl( + RuntimeContext& ctx, + const Tensor& input, // NOLINT + const int64_t tile_size, // NOLINT + Tensor& out) { + (void)ctx; + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +EXECUTORCH_LIBRARY( + preprocess, + "tile_crop.out", + torch::executor::native::tile_crop_out_impl); diff --git a/extension/llm/custom_ops/op_tile_crop.h b/extension/llm/custom_ops/op_tile_crop.h new file mode 100644 index 00000000000..445e443f9af --- /dev/null +++ b/extension/llm/custom_ops/op_tile_crop.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& tile_crop_out_impl( + RuntimeContext& ctx, + const Tensor& input, + const int64_t tile_size, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/preprocess_custom_ops.py b/extension/llm/custom_ops/preprocess_custom_ops.py new file mode 100644 index 00000000000..aea8c09b0ef --- /dev/null +++ b/extension/llm/custom_ops/preprocess_custom_ops.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from typing import List + +import torch + +from torch.library import impl, Library + +preprocess_op_lib = Library("preprocess", "DEF") + +# Register and define pad and out variant. +# Note: pad doesn't require an explicit meta kernel because +# CompositeExplicitAutograd automatically registers the implementation to meta, +# and meta kernels do not go through functionalization. The implementation +# does not export due to issues during functionalization. +# See: https://github.com/pytorch/pytorch/issues/120288 +preprocess_op_lib.define("pad(Tensor image, SymInt[] padding) -> Tensor") + + +@impl(preprocess_op_lib, "pad", dispatch_key="CompositeExplicitAutograd") +def pad_impl( + image: torch.Tensor, + padding: List[int], +) -> torch.Tensor: + output = torch.empty( + [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], + dtype=image.dtype, + device=image.device, + requires_grad=False, + ) + output = torch.fill(output, 0) + output.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) + return output + + +preprocess_op_lib.define( + "pad.out(Tensor image, SymInt[] padding, *, Tensor(a!) out) -> Tensor(a!)" +) + + +@impl(preprocess_op_lib, "pad.out", dispatch_key="CompositeExplicitAutograd") +def pad_out_impl( + image: torch.Tensor, + padding: List[int], + out: torch.Tensor, +) -> torch.Tensor: + out = torch.empty( + [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], + dtype=image.dtype, + device=image.device, + requires_grad=False, + ) + out = torch.fill(out, 0) + out.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) + return out + + +# Register and define tile_crop and out variant. +preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor") + + +@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd") +def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor: + c = input.shape[0] + h = input.shape[1] + w = input.shape[2] + tiles_height = h // tile_size + tiles_width = w // tile_size + tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size) + transposed = tile_cropped.permute(1, 3, 0, 2, 4) + tiles = transposed.contiguous().view( + tiles_height * tiles_width, c, tile_size, tile_size + ) + return tiles + + +preprocess_op_lib.define( + "tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)" +) + + +@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd") +def tile_crop_out_impl( + input: torch.Tensor, tile_size: int, out: torch.Tensor +) -> torch.Tensor: + out = input.clone() + c = out.shape[0] + h = out.shape[1] + w = out.shape[2] + tiles_height = h // tile_size + tiles_width = w // tile_size + out = out.view(c, tiles_height, tile_size, tiles_width, tile_size) + out = out.permute(1, 3, 0, 2, 4) + out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size) + return out + + +# Register meta kernel to prevent export tracing into the tile_crop impl. +@torch.library.register_fake("preprocess::tile_crop") +def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor: + # Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles. + # We should export with n = max_num_tiles. Set 50 for now. + return torch.empty([50, output.size(0), 224, 224]) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 8c38eb6a0a0..ee277515b1d 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -86,3 +86,36 @@ def define_common_targets(): ":custom_ops", ], ) + + ## For preprocess + runtime.python_library( + name = "preprocess_custom_ops_py", + srcs = [ + "preprocess_custom_ops.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + ], + ) + + runtime.cxx_library( + name = "op_tile_crop", + srcs = ["op_tile_crop.cpp"], + exported_headers = ["op_tile_crop.h"], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/extension/kernel_util:kernel_util", + ], + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + force_static = True, + ) diff --git a/extension/llm/custom_ops/test_preprocess_custom_ops.py b/extension/llm/custom_ops/test_preprocess_custom_ops.py new file mode 100644 index 00000000000..50e149ece16 --- /dev/null +++ b/extension/llm/custom_ops/test_preprocess_custom_ops.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest +from typing import List, Tuple + +import torch + +from .preprocess_custom_ops import preprocess_op_lib # noqa + + +class PreprocessTest(unittest.TestCase): + + def setUp(self): + # pad + self.pad_input = torch.ones(3, 200, 300) + + # tile_crop + self.tile_size = 224 + + def _compare_pad(self, image: torch.Tensor, padding: List[int]) -> None: + output = torch.ops.preprocess.pad.default(image, padding) + output_ref = torch.nn.functional.pad(image, padding) + self.assertTrue(torch.allclose(output_ref, output, 1e-6)) + + def _test_tile_crop(self, image: torch.Tensor, expected_shape: Tuple[int]) -> None: + output = torch.ops.preprocess.tile_crop.default(image, self.tile_size) + self.assertTrue(output.shape == expected_shape) + + def test_op_pad_without_padding(self): + self._compare_pad(self.pad_input, [0, 0, 0, 0]) + + def test_op_pad_with_right_bottom_padding(self): + self._compare_pad(self.pad_input, [0, 124, 0, 148]) + + def test_op_pad_with_right_padding(self): + self._compare_pad(self.pad_input, [0, 124, 0, 0]) + + def test_op_pad_with_bottom_padding(self): + self._compare_pad(self.pad_input, [0, 0, 0, 148]) + + def test_op_tile_crop_2x2(self): + self._test_tile_crop(torch.ones(3, 448, 448), (4, 3, 224, 224)) + + def test_op_tile_crop_1x3(self): + self._test_tile_crop(torch.ones(3, 224, 672), (3, 3, 224, 224)) + + def test_op_tile_crop_4x2(self): + self._test_tile_crop(torch.ones(3, 896, 448), (8, 3, 224, 224))