diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 195df3bb931..ff3fde6e2cc 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -21,3 +21,16 @@ runtime.python_test( "//caffe2:torch", ], ) + +runtime.python_test( + name = "test_preprocess_custom_ops", + srcs = [ + "test_preprocess_custom_ops.py", + ], + preload_deps = [ + ":preprocess_custom_ops_py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/extension/llm/custom_ops/op_tile_crop.cpp b/extension/llm/custom_ops/op_tile_crop.cpp new file mode 100644 index 00000000000..094bb9beec9 --- /dev/null +++ b/extension/llm/custom_ops/op_tile_crop.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& tile_crop_out_impl( + RuntimeContext& ctx, + const Tensor& input, // NOLINT + const int64_t tile_size, // NOLINT + Tensor& out) { + (void)ctx; + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +EXECUTORCH_LIBRARY( + preprocess, + "tile_crop.out", + torch::executor::native::tile_crop_out_impl); diff --git a/extension/llm/custom_ops/op_tile_crop.h b/extension/llm/custom_ops/op_tile_crop.h new file mode 100644 index 00000000000..445e443f9af --- /dev/null +++ b/extension/llm/custom_ops/op_tile_crop.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& tile_crop_out_impl( + RuntimeContext& ctx, + const Tensor& input, + const int64_t tile_size, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/preprocess_custom_ops.py b/extension/llm/custom_ops/preprocess_custom_ops.py new file mode 100644 index 00000000000..aea8c09b0ef --- /dev/null +++ b/extension/llm/custom_ops/preprocess_custom_ops.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from typing import List + +import torch + +from torch.library import impl, Library + +preprocess_op_lib = Library("preprocess", "DEF") + +# Register and define pad and out variant. +# Note: pad doesn't require an explicit meta kernel because +# CompositeExplicitAutograd automatically registers the implementation to meta, +# and meta kernels do not go through functionalization. The implementation +# does not export due to issues during functionalization. +# See: https://github.com/pytorch/pytorch/issues/120288 +preprocess_op_lib.define("pad(Tensor image, SymInt[] padding) -> Tensor") + + +@impl(preprocess_op_lib, "pad", dispatch_key="CompositeExplicitAutograd") +def pad_impl( + image: torch.Tensor, + padding: List[int], +) -> torch.Tensor: + output = torch.empty( + [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], + dtype=image.dtype, + device=image.device, + requires_grad=False, + ) + output = torch.fill(output, 0) + output.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) + return output + + +preprocess_op_lib.define( + "pad.out(Tensor image, SymInt[] padding, *, Tensor(a!) out) -> Tensor(a!)" +) + + +@impl(preprocess_op_lib, "pad.out", dispatch_key="CompositeExplicitAutograd") +def pad_out_impl( + image: torch.Tensor, + padding: List[int], + out: torch.Tensor, +) -> torch.Tensor: + out = torch.empty( + [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], + dtype=image.dtype, + device=image.device, + requires_grad=False, + ) + out = torch.fill(out, 0) + out.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) + return out + + +# Register and define tile_crop and out variant. +preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor") + + +@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd") +def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor: + c = input.shape[0] + h = input.shape[1] + w = input.shape[2] + tiles_height = h // tile_size + tiles_width = w // tile_size + tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size) + transposed = tile_cropped.permute(1, 3, 0, 2, 4) + tiles = transposed.contiguous().view( + tiles_height * tiles_width, c, tile_size, tile_size + ) + return tiles + + +preprocess_op_lib.define( + "tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)" +) + + +@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd") +def tile_crop_out_impl( + input: torch.Tensor, tile_size: int, out: torch.Tensor +) -> torch.Tensor: + out = input.clone() + c = out.shape[0] + h = out.shape[1] + w = out.shape[2] + tiles_height = h // tile_size + tiles_width = w // tile_size + out = out.view(c, tiles_height, tile_size, tiles_width, tile_size) + out = out.permute(1, 3, 0, 2, 4) + out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size) + return out + + +# Register meta kernel to prevent export tracing into the tile_crop impl. +@torch.library.register_fake("preprocess::tile_crop") +def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor: + # Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles. + # We should export with n = max_num_tiles. Set 50 for now. + return torch.empty([50, output.size(0), 224, 224]) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 8c38eb6a0a0..ee277515b1d 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -86,3 +86,36 @@ def define_common_targets(): ":custom_ops", ], ) + + ## For preprocess + runtime.python_library( + name = "preprocess_custom_ops_py", + srcs = [ + "preprocess_custom_ops.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + ], + ) + + runtime.cxx_library( + name = "op_tile_crop", + srcs = ["op_tile_crop.cpp"], + exported_headers = ["op_tile_crop.h"], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/extension/kernel_util:kernel_util", + ], + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + force_static = True, + ) diff --git a/extension/llm/custom_ops/test_preprocess_custom_ops.py b/extension/llm/custom_ops/test_preprocess_custom_ops.py new file mode 100644 index 00000000000..50e149ece16 --- /dev/null +++ b/extension/llm/custom_ops/test_preprocess_custom_ops.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest +from typing import List, Tuple + +import torch + +from .preprocess_custom_ops import preprocess_op_lib # noqa + + +class PreprocessTest(unittest.TestCase): + + def setUp(self): + # pad + self.pad_input = torch.ones(3, 200, 300) + + # tile_crop + self.tile_size = 224 + + def _compare_pad(self, image: torch.Tensor, padding: List[int]) -> None: + output = torch.ops.preprocess.pad.default(image, padding) + output_ref = torch.nn.functional.pad(image, padding) + self.assertTrue(torch.allclose(output_ref, output, 1e-6)) + + def _test_tile_crop(self, image: torch.Tensor, expected_shape: Tuple[int]) -> None: + output = torch.ops.preprocess.tile_crop.default(image, self.tile_size) + self.assertTrue(output.shape == expected_shape) + + def test_op_pad_without_padding(self): + self._compare_pad(self.pad_input, [0, 0, 0, 0]) + + def test_op_pad_with_right_bottom_padding(self): + self._compare_pad(self.pad_input, [0, 124, 0, 148]) + + def test_op_pad_with_right_padding(self): + self._compare_pad(self.pad_input, [0, 124, 0, 0]) + + def test_op_pad_with_bottom_padding(self): + self._compare_pad(self.pad_input, [0, 0, 0, 148]) + + def test_op_tile_crop_2x2(self): + self._test_tile_crop(torch.ones(3, 448, 448), (4, 3, 224, 224)) + + def test_op_tile_crop_1x3(self): + self._test_tile_crop(torch.ones(3, 224, 672), (3, 3, 224, 224)) + + def test_op_tile_crop_4x2(self): + self._test_tile_crop(torch.ones(3, 896, 448), (8, 3, 224, 224))