Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ runtime.python_test(
"//caffe2:torch",
],
)

runtime.python_test(
name = "test_preprocess_custom_ops",
srcs = [
"test_preprocess_custom_ops.py",
],
preload_deps = [
":preprocess_custom_ops_py",
],
deps = [
"//caffe2:torch",
],
)
33 changes: 33 additions & 0 deletions extension/llm/custom_ops/op_tile_crop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_tile_crop.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

Tensor& tile_crop_out_impl(
RuntimeContext& ctx,
const Tensor& input, // NOLINT
const int64_t tile_size, // NOLINT
Tensor& out) {
(void)ctx;
return out;
}

} // namespace native
} // namespace executor
} // namespace torch

EXECUTORCH_LIBRARY(
preprocess,
"tile_crop.out",
torch::executor::native::tile_crop_out_impl);
26 changes: 26 additions & 0 deletions extension/llm/custom_ops/op_tile_crop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

namespace native {

Tensor& tile_crop_out_impl(
RuntimeContext& ctx,
const Tensor& input,
const int64_t tile_size,
Tensor& out);

} // namespace native
} // namespace executor
} // namespace torch
110 changes: 110 additions & 0 deletions extension/llm/custom_ops/preprocess_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import List

import torch

from torch.library import impl, Library

preprocess_op_lib = Library("preprocess", "DEF")

# Register and define pad and out variant.
# Note: pad doesn't require an explicit meta kernel because
# CompositeExplicitAutograd automatically registers the implementation to meta,
# and meta kernels do not go through functionalization. The implementation
# does not export due to issues during functionalization.
# See: https://github.com/pytorch/pytorch/issues/120288
preprocess_op_lib.define("pad(Tensor image, SymInt[] padding) -> Tensor")


@impl(preprocess_op_lib, "pad", dispatch_key="CompositeExplicitAutograd")
def pad_impl(
image: torch.Tensor,
padding: List[int],
) -> torch.Tensor:
output = torch.empty(
[image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]],
dtype=image.dtype,
device=image.device,
requires_grad=False,
)
output = torch.fill(output, 0)
output.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image)
return output


preprocess_op_lib.define(
"pad.out(Tensor image, SymInt[] padding, *, Tensor(a!) out) -> Tensor(a!)"
)


@impl(preprocess_op_lib, "pad.out", dispatch_key="CompositeExplicitAutograd")
def pad_out_impl(
image: torch.Tensor,
padding: List[int],
out: torch.Tensor,
) -> torch.Tensor:
out = torch.empty(
[image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]],
dtype=image.dtype,
device=image.device,
requires_grad=False,
)
out = torch.fill(out, 0)
out.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image)
return out


# Register and define tile_crop and out variant.
preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")


@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd")
def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor:
c = input.shape[0]
h = input.shape[1]
w = input.shape[2]
tiles_height = h // tile_size
tiles_width = w // tile_size
tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size)
transposed = tile_cropped.permute(1, 3, 0, 2, 4)
tiles = transposed.contiguous().view(
tiles_height * tiles_width, c, tile_size, tile_size
)
return tiles


preprocess_op_lib.define(
"tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)"
)


@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd")
def tile_crop_out_impl(
input: torch.Tensor, tile_size: int, out: torch.Tensor
) -> torch.Tensor:
out = input.clone()
c = out.shape[0]
h = out.shape[1]
w = out.shape[2]
tiles_height = h // tile_size
tiles_width = w // tile_size
out = out.view(c, tiles_height, tile_size, tiles_width, tile_size)
out = out.permute(1, 3, 0, 2, 4)
out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size)
return out


# Register meta kernel to prevent export tracing into the tile_crop impl.
@torch.library.register_fake("preprocess::tile_crop")
def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
# Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles.
# We should export with n = max_num_tiles. Set 50 for now.
return torch.empty([50, output.size(0), 224, 224])
33 changes: 33 additions & 0 deletions extension/llm/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,36 @@ def define_common_targets():
":custom_ops",
],
)

## For preprocess
runtime.python_library(
name = "preprocess_custom_ops_py",
srcs = [
"preprocess_custom_ops.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
],
)

runtime.cxx_library(
name = "op_tile_crop",
srcs = ["op_tile_crop.cpp"],
exported_headers = ["op_tile_crop.h"],
exported_deps = [
"//executorch/runtime/kernel:kernel_includes",
"//executorch/extension/kernel_util:kernel_util",
],
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
# @lint-ignore BUCKLINT link_whole
link_whole = True,
force_static = True,
)
54 changes: 54 additions & 0 deletions extension/llm/custom_ops/test_preprocess_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest
from typing import List, Tuple

import torch

from .preprocess_custom_ops import preprocess_op_lib # noqa


class PreprocessTest(unittest.TestCase):

def setUp(self):
# pad
self.pad_input = torch.ones(3, 200, 300)

# tile_crop
self.tile_size = 224

def _compare_pad(self, image: torch.Tensor, padding: List[int]) -> None:
output = torch.ops.preprocess.pad.default(image, padding)
output_ref = torch.nn.functional.pad(image, padding)
self.assertTrue(torch.allclose(output_ref, output, 1e-6))

def _test_tile_crop(self, image: torch.Tensor, expected_shape: Tuple[int]) -> None:
output = torch.ops.preprocess.tile_crop.default(image, self.tile_size)
self.assertTrue(output.shape == expected_shape)

def test_op_pad_without_padding(self):
self._compare_pad(self.pad_input, [0, 0, 0, 0])

def test_op_pad_with_right_bottom_padding(self):
self._compare_pad(self.pad_input, [0, 124, 0, 148])

def test_op_pad_with_right_padding(self):
self._compare_pad(self.pad_input, [0, 124, 0, 0])

def test_op_pad_with_bottom_padding(self):
self._compare_pad(self.pad_input, [0, 0, 0, 148])

def test_op_tile_crop_2x2(self):
self._test_tile_crop(torch.ones(3, 448, 448), (4, 3, 224, 224))

def test_op_tile_crop_1x3(self):
self._test_tile_crop(torch.ones(3, 224, 672), (3, 3, 224, 224))

def test_op_tile_crop_4x2(self):
self._test_tile_crop(torch.ones(3, 896, 448), (8, 3, 224, 224))