diff --git a/kernels/portable/cpu/op_grid_sampler_2d.cpp b/kernels/portable/cpu/op_grid_sampler_2d.cpp new file mode 100644 index 00000000000..57155b3c01b --- /dev/null +++ b/kernels/portable/cpu/op_grid_sampler_2d.cpp @@ -0,0 +1,480 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using executorch::aten::ArrayRef; +using executorch::aten::SizesType; +using std::optional; + +namespace { +template +void grid_sample_2d_bilinear_kernel_impl_nchw( + const Tensor& in, + const Tensor& grid, + GridSamplerPadding padding_mode, + bool align_corners, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + // Grid has shape [N, H_out, W_out, 2] + // Last dimension contains (x, y) normalized coordinates in [-1, 1] + const auto grid_data = grid.const_data_ptr(); + + const int64_t N = in.size(0); + const int64_t C = in.size(1); + const int64_t inp_H = in.size(2); + const int64_t inp_W = in.size(3); + + const int64_t out_H = out.size(2); + const int64_t out_W = out.size(3); + + // Process each batch + for (const auto n : c10::irange(N)) { + const auto grid_offset = n * grid.strides()[0]; + const auto in_batch_offset = n * in.strides()[0]; + const auto out_batch_offset = n * out.strides()[0]; + + // Process each channel + for (const auto c : c10::irange(C)) { + const auto in_channel_offset = in_batch_offset + c * in.strides()[1]; + const auto out_channel_offset = out_batch_offset + c * out.strides()[1]; + + // Process each output pixel + for (const auto h : c10::irange(out_H)) { + for (const auto w : c10::irange(out_W)) { + // Get grid coordinates for this output position + // grid[n, h, w] contains (x, y) + const int64_t grid_idx = + grid_offset + h * grid.strides()[1] + w * grid.strides()[2]; + const CTYPE x = grid_data[grid_idx]; + const CTYPE y = grid_data[grid_idx + grid.strides()[3]]; + + // Compute source coordinates in pixel space + const CTYPE ix = grid_sampler_compute_source_index( + x, inp_W, padding_mode, align_corners); + const CTYPE iy = grid_sampler_compute_source_index( + y, inp_H, padding_mode, align_corners); + + // Get corner pixel coordinates + const int64_t ix_nw = static_cast(std::floor(ix)); + const int64_t iy_nw = static_cast(std::floor(iy)); + const int64_t ix_ne = ix_nw + 1; + const int64_t iy_ne = iy_nw; + const int64_t ix_sw = ix_nw; + const int64_t iy_sw = iy_nw + 1; + const int64_t ix_se = ix_nw + 1; + const int64_t iy_se = iy_nw + 1; + + // Get interpolation weights + const CTYPE nw_weight = (ix_se - ix) * (iy_se - iy); + const CTYPE ne_weight = (ix - ix_sw) * (iy_sw - iy); + const CTYPE sw_weight = (ix_ne - ix) * (iy - iy_ne); + const CTYPE se_weight = (ix - ix_nw) * (iy - iy_nw); + + // Compute output value for this channel + CTYPE out_val = 0; + + // Add contribution from each corner if within bounds + if (padding_mode == GridSamplerPadding::Zeros) { + // For zeros padding, only sample if within bounds + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + out_val += in_data + [in_channel_offset + iy_nw * in.strides()[2] + + ix_nw * in.strides()[3]] * + nw_weight; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + out_val += in_data + [in_channel_offset + iy_ne * in.strides()[2] + + ix_ne * in.strides()[3]] * + ne_weight; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + out_val += in_data + [in_channel_offset + iy_sw * in.strides()[2] + + ix_sw * in.strides()[3]] * + sw_weight; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + out_val += in_data + [in_channel_offset + iy_se * in.strides()[2] + + ix_se * in.strides()[3]] * + se_weight; + } + } else { + // For border/reflection padding, clip corner indices to valid range + // Even though source coordinates are clipped, adding 1 can push + // corners out of bounds + const int64_t ix_nw_safe = clip_coordinates(ix_nw, inp_W); + const int64_t iy_nw_safe = clip_coordinates(iy_nw, inp_H); + const int64_t ix_ne_safe = clip_coordinates(ix_ne, inp_W); + const int64_t iy_ne_safe = clip_coordinates(iy_ne, inp_H); + const int64_t ix_sw_safe = clip_coordinates(ix_sw, inp_W); + const int64_t iy_sw_safe = clip_coordinates(iy_sw, inp_H); + const int64_t ix_se_safe = clip_coordinates(ix_se, inp_W); + const int64_t iy_se_safe = clip_coordinates(iy_se, inp_H); + out_val = in_data + [in_channel_offset + iy_nw_safe * in.strides()[2] + + ix_nw_safe * in.strides()[3]] * + nw_weight + + in_data + [in_channel_offset + iy_ne_safe * in.strides()[2] + + ix_ne_safe * in.strides()[3]] * + ne_weight + + in_data + [in_channel_offset + iy_sw_safe * in.strides()[2] + + ix_sw_safe * in.strides()[3]] * + sw_weight + + in_data + [in_channel_offset + iy_se_safe * in.strides()[2] + + ix_se_safe * in.strides()[3]] * + se_weight; + } + + // Write output in NCHW order + const int64_t out_idx = + out_channel_offset + h * out.strides()[2] + w * out.strides()[3]; + out_data[out_idx] = out_val; + } + } + } + } +} + +template +void grid_sample_2d_nearest_kernel_impl_nchw( + const Tensor& in, + const Tensor& grid, + GridSamplerPadding padding_mode, + bool align_corners, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + // Grid has shape [N, H_out, W_out, 2] + // Last dimension contains (x, y) normalized coordinates in [-1, 1] + const auto grid_data = grid.const_data_ptr(); + + const int64_t N = in.size(0); + const int64_t C = in.size(1); + const int64_t inp_H = in.size(2); + const int64_t inp_W = in.size(3); + + const int64_t out_H = out.size(2); + const int64_t out_W = out.size(3); + + // Process each batch + for (const auto n : c10::irange(N)) { + const auto grid_offset = n * grid.strides()[0]; + const auto in_batch_offset = n * in.strides()[0]; + const auto out_batch_offset = n * out.strides()[0]; + + // Process each channel + for (const auto c : c10::irange(C)) { + const auto in_channel_offset = in_batch_offset + c * in.strides()[1]; + const auto out_channel_offset = out_batch_offset + c * out.strides()[1]; + + // Process each output pixel + for (const auto h : c10::irange(out_H)) { + for (const auto w : c10::irange(out_W)) { + // Get grid coordinates for this output position + // grid[n, h, w] contains (x, y) + const int64_t grid_idx = + grid_offset + h * grid.strides()[1] + w * grid.strides()[2]; + const CTYPE x = grid_data[grid_idx]; + const CTYPE y = grid_data[grid_idx + grid.strides()[3]]; + + // Compute source coordinates in pixel space + const CTYPE ix = grid_sampler_compute_source_index( + x, inp_W, padding_mode, align_corners); + const CTYPE iy = grid_sampler_compute_source_index( + y, inp_H, padding_mode, align_corners); + + // Get nearest pixel coordinates + // Use nearbyint (not round) to match ATen's rounding behavior. + // nearbyint uses the current rounding mode (typically round-to-even), + // which matches PyTorch's (ATen's) behavior. In contrast, round may + // not always respect the rounding mode. See: + // aten/src/ATen/native/GridSampler.cpp + int64_t ix_nearest = static_cast(std::nearbyint(ix)); + int64_t iy_nearest = static_cast(std::nearbyint(iy)); + + // Compute output value for this channel + CTYPE out_val = 0; + + // Check bounds and sample + if (padding_mode == GridSamplerPadding::Zeros) { + // For zeros padding, only sample if within bounds + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { + out_val = in_data + [in_channel_offset + iy_nearest * in.strides()[2] + + ix_nearest * in.strides()[3]]; + } + } else { + // For border/reflection padding, clip coordinates after rounding + // Rounding can push coordinates out of bounds even after + // grid_sampler_compute_source_index + int64_t ix_clipped = clip_coordinates(ix_nearest, inp_W); + int64_t iy_clipped = clip_coordinates(iy_nearest, inp_H); + out_val = in_data + [in_channel_offset + iy_clipped * in.strides()[2] + + ix_clipped * in.strides()[3]]; + } + + // Write output in NCHW order + const int64_t out_idx = + out_channel_offset + h * out.strides()[2] + w * out.strides()[3]; + out_data[out_idx] = out_val; + } + } + } + } +} + +template +void grid_sample_2d_bicubic_kernel_impl_nchw( + const Tensor& in, + const Tensor& grid, + GridSamplerPadding padding_mode, + bool align_corners, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + // Grid has shape [N, H_out, W_out, 2] + // Last dimension contains (x, y) normalized coordinates in [-1, 1] + const auto grid_data = grid.const_data_ptr(); + + const int64_t N = in.size(0); + const int64_t C = in.size(1); + const int64_t inp_H = in.size(2); + const int64_t inp_W = in.size(3); + + const int64_t out_H = out.size(2); + const int64_t out_W = out.size(3); + + // Process each batch + for (const auto n : c10::irange(N)) { + const auto grid_offset = n * grid.strides()[0]; + const auto in_batch_offset = n * in.strides()[0]; + const auto out_batch_offset = n * out.strides()[0]; + + // Process each channel + for (const auto c : c10::irange(C)) { + const auto in_channel_offset = in_batch_offset + c * in.strides()[1]; + const auto out_channel_offset = out_batch_offset + c * out.strides()[1]; + + // Process each output pixel + for (const auto h : c10::irange(out_H)) { + for (const auto w : c10::irange(out_W)) { + // Get grid coordinates for this output position + // grid[n, h, w] contains (x, y) + const int64_t grid_idx = + grid_offset + h * grid.strides()[1] + w * grid.strides()[2]; + const CTYPE x = grid_data[grid_idx]; + const CTYPE y = grid_data[grid_idx + grid.strides()[3]]; + + // Compute source coordinates in pixel space + // For bicubic, we need raw unnormalized coordinates without padding + // applied Padding is applied later when fetching individual pixels + // from the 4x4 neighborhood + CTYPE ix = grid_sampler_unnormalize(x, inp_W, align_corners); + CTYPE iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + // Get the integer part and fractional part + int64_t ix_0 = static_cast(std::floor(ix)); + int64_t iy_0 = static_cast(std::floor(iy)); + CTYPE tx = ix - ix_0; + CTYPE ty = iy - iy_0; + + // Bicubic interpolation uses a 4x4 grid of pixels + // Get the 16 pixel coordinates + int64_t ix_m1 = ix_0 - 1; + int64_t ix_p1 = ix_0 + 1; + int64_t ix_p2 = ix_0 + 2; + + int64_t iy_m1 = iy_0 - 1; + int64_t iy_p1 = iy_0 + 1; + int64_t iy_p2 = iy_0 + 2; + + // Helper lambda to safely get pixel value with bounds checking + auto get_value_bounded = [&](int64_t iy, int64_t ix) -> CTYPE { + if (padding_mode == GridSamplerPadding::Zeros) { + if (within_bounds_2d(iy, ix, inp_H, inp_W)) { + return in_data + [in_channel_offset + iy * in.strides()[2] + + ix * in.strides()[3]]; + } + return static_cast(0); + } else if (padding_mode == GridSamplerPadding::Border) { + // For border padding, clip coordinates to valid range + int64_t iy_safe = + std::max(static_cast(0), std::min(iy, inp_H - 1)); + int64_t ix_safe = + std::max(static_cast(0), std::min(ix, inp_W - 1)); + return in_data + [in_channel_offset + iy_safe * in.strides()[2] + + ix_safe * in.strides()[3]]; + } else { + // For reflection padding, reflect coordinates at boundaries + CTYPE iy_reflected = static_cast(iy); + CTYPE ix_reflected = static_cast(ix); + + if (align_corners) { + iy_reflected = + reflect_coordinates(iy_reflected, 0, 2 * (inp_H - 1)); + ix_reflected = + reflect_coordinates(ix_reflected, 0, 2 * (inp_W - 1)); + } else { + iy_reflected = + reflect_coordinates(iy_reflected, -1, 2 * inp_H - 1); + ix_reflected = + reflect_coordinates(ix_reflected, -1, 2 * inp_W - 1); + } + + // Clip to ensure we're in bounds (reflection + clip for safety) + int64_t iy_safe = + static_cast(clip_coordinates(iy_reflected, inp_H)); + int64_t ix_safe = + static_cast(clip_coordinates(ix_reflected, inp_W)); + + return in_data + [in_channel_offset + iy_safe * in.strides()[2] + + ix_safe * in.strides()[3]]; + } + }; + + // Get the 4x4 grid of pixels + // For each row, interpolate in x-direction + CTYPE coefficients[4]; + + // Row -1 + CTYPE p0 = get_value_bounded(iy_m1, ix_m1); + CTYPE p1 = get_value_bounded(iy_m1, ix_0); + CTYPE p2 = get_value_bounded(iy_m1, ix_p1); + CTYPE p3 = get_value_bounded(iy_m1, ix_p2); + coefficients[0] = cubic_interp1d(p0, p1, p2, p3, tx); + + // Row 0 + p0 = get_value_bounded(iy_0, ix_m1); + p1 = get_value_bounded(iy_0, ix_0); + p2 = get_value_bounded(iy_0, ix_p1); + p3 = get_value_bounded(iy_0, ix_p2); + coefficients[1] = cubic_interp1d(p0, p1, p2, p3, tx); + + // Row +1 + p0 = get_value_bounded(iy_p1, ix_m1); + p1 = get_value_bounded(iy_p1, ix_0); + p2 = get_value_bounded(iy_p1, ix_p1); + p3 = get_value_bounded(iy_p1, ix_p2); + coefficients[2] = cubic_interp1d(p0, p1, p2, p3, tx); + + // Row +2 + p0 = get_value_bounded(iy_p2, ix_m1); + p1 = get_value_bounded(iy_p2, ix_0); + p2 = get_value_bounded(iy_p2, ix_p1); + p3 = get_value_bounded(iy_p2, ix_p2); + coefficients[3] = cubic_interp1d(p0, p1, p2, p3, tx); + + // Interpolate in y-direction + CTYPE out_val = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + + // Write output in NCHW order + const int64_t out_idx = + out_channel_offset + h * out.strides()[2] + w * out.strides()[3]; + out_data[out_idx] = out_val; + } + } + } + } +} + +} // namespace + +// Signatures are auto-generated, so disable pass-by-value lint. +// NOLINTBEGIN(facebook-hte-ConstantArgumentPassByValue, +// facebook-hte-ParameterMightThrowOnCopy) +Tensor& grid_sampler_2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out) { + // Check arguments and resize output tensor + ET_KERNEL_CHECK_MSG( + ctx, + check_grid_sampler_2d_args_and_resize_out(input, grid, out) == Error::Ok, + InvalidArgument, + out, + "Failed to validate arguments and resize output tensor"); + + // Convert integer mode parameters to enums + GridSamplerInterpolation mode = + static_cast(interpolation_mode); + GridSamplerPadding padding = static_cast(padding_mode); + + // Validate mode and padding values + ET_KERNEL_CHECK( + ctx, + mode == GridSamplerInterpolation::Bilinear || + mode == GridSamplerInterpolation::Nearest || + mode == GridSamplerInterpolation::Bicubic, + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + padding == GridSamplerPadding::Zeros || + padding == GridSamplerPadding::Border || + padding == GridSamplerPadding::Reflection, + InvalidArgument, + out); + + // Dispatch to appropriate implementation based on dtype + ET_SWITCH_REALHBF16_TYPES( + input.scalar_type(), ctx, "grid_sampler_2d.out", CTYPE, [&]() { + // Dispatch to appropriate interpolation mode + switch (mode) { + case GridSamplerInterpolation::Bilinear: + grid_sample_2d_bilinear_kernel_impl_nchw( + input, grid, padding, align_corners, out); + break; + case GridSamplerInterpolation::Nearest: + grid_sample_2d_nearest_kernel_impl_nchw( + input, grid, padding, align_corners, out); + break; + case GridSamplerInterpolation::Bicubic: + grid_sample_2d_bicubic_kernel_impl_nchw( + input, grid, padding, align_corners, out); + break; + } + }); + + return out; +} +// NOLINTEND(facebook-hte-ConstantArgumentPassByValue, +// facebook-hte-ParameterMightThrowOnCopy) + +} // namespace native +} // namespace executor +} // namespace torch \ No newline at end of file diff --git a/kernels/portable/cpu/util/grid_sampler_2d_util.cpp b/kernels/portable/cpu/util/grid_sampler_2d_util.cpp new file mode 100644 index 00000000000..d8856094e55 --- /dev/null +++ b/kernels/portable/cpu/util/grid_sampler_2d_util.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { + +Error check_grid_sampler_2d_args_and_resize_out( + const Tensor& input, + const Tensor& grid, + Tensor& out) { + // Input must be 4D (N, C, H, W) + ET_CHECK_OR_RETURN_ERROR( + input.dim() == 4, + InvalidArgument, + "Input must be 4D, got %zu dimensions", + static_cast(input.dim())); + + ET_CHECK_OR_RETURN_ERROR( + tensor_is_default_dim_order(input), + InvalidArgument, + "Input must be in NCHW format"); + + // Grid must be 4D (N, H_out, W_out, 2) + ET_CHECK_OR_RETURN_ERROR( + grid.dim() == 4, + InvalidArgument, + "Grid must be 4D, got %zu dimensions", + static_cast(grid.dim())); + + ET_CHECK_OR_RETURN_ERROR( + grid.size(3) == 2, + InvalidArgument, + "Grid last dimension must be 2, got %ld", + static_cast(grid.size(3))); + + // Batch sizes must match + ET_CHECK_OR_RETURN_ERROR( + input.size(0) == grid.size(0), + InvalidArgument, + "Input and grid batch sizes must match, got input=%ld, grid=%ld", + static_cast(input.size(0)), + static_cast(grid.size(0))); + + // Input and grid must have same dtype + ET_CHECK_OR_RETURN_ERROR( + tensors_have_same_dtype(input, grid), + InvalidArgument, + "Input and grid must have same dtype"); + + // Input and output must have the same dtype + ET_CHECK_OR_RETURN_ERROR( + tensors_have_same_dtype(input, out), + InvalidArgument, + "Input and output must have the same dtype"); + + // Resize output tensor to [N, C, H_out, W_out] + std::array out_sizes = { + static_cast(input.size(0)), + static_cast(input.size(1)), + static_cast(grid.size(1)), + static_cast(grid.size(2))}; + + Error err = resize_tensor(out, {out_sizes.data(), 4}); + ET_CHECK_OR_RETURN_ERROR( + err == Error::Ok, InvalidArgument, "Failed to resize output tensor"); + + return Error::Ok; +} + +} // namespace executor +} // namespace torch \ No newline at end of file diff --git a/kernels/portable/cpu/util/grid_sampler_2d_util.h b/kernels/portable/cpu/util/grid_sampler_2d_util.h new file mode 100644 index 00000000000..bff8923cb52 --- /dev/null +++ b/kernels/portable/cpu/util/grid_sampler_2d_util.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { + +// Ported from aten/src/ATen/native/GridSampler.h +// note that these need to be in the SAME ORDER as the enum in GridSampler.h +// as they are mapped to integer values (0, 1, 2) in this order +enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic }; +enum class GridSamplerPadding { Zeros, Border, Reflection }; + +// Ported from aten/src/ATen/native/GridSampler.h +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +inline scalar_t +grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1) * size - 1) / 2; + } +} + +// Ported from aten/src/ATen/native/GridSampler.h +// Clips coordinates to between 0 and clip_limit - 1 +template +inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { + return std::min( + static_cast(clip_limit - 1), + std::max(in, static_cast(0))); +} + +// Ported from aten/src/ATen/native/GridSampler.h +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +inline scalar_t +reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// Ported from aten/src/ATen/native/GridSampler.h +// Computes the pixel source index value for a grid coordinate +template +inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2 * size - 1); + } + coord = clip_coordinates(coord, size); + } + return coord; +} + +// Ported from aten/src/ATen/native/GridSampler.h +// Check if coordinates are within bounds [0, limit-1] +template +inline bool within_bounds_2d(scalar_t h, scalar_t w, int64_t H, int64_t W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +// Ported from aten/src/ATen/native/UpSample.h +// Cubic convolution function 1 (for points within 1 unit of the point) +template +inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +// Ported from aten/src/ATen/native/UpSample.h +// Cubic convolution function 2 (for points between 1 and 2 units from the +// point) +template +inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +// Ported from aten/src/ATen/native/UpSample.h +// Computes the 4 cubic interpolation coefficients for a given position t in [0, +// 1] +template +inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) { + // Standard bicubic interpolation uses alpha = -0.75 + scalar_t A = static_cast(-0.75); + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + static_cast(1.0), A); + coeffs[1] = cubic_convolution1(x1, A); + + scalar_t x2 = static_cast(1.0) - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + static_cast(1.0), A); +} + +// Ported from aten/src/ATen/native/UpSample.h +// Performs 1D cubic interpolation given 4 points and a position t in [0, 1] +template +inline scalar_t +cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +// Argument checking and output tensor resizing for grid_sampler_2d +Error check_grid_sampler_2d_args_and_resize_out( + const Tensor& input, + const Tensor& grid, + Tensor& out); + +} // namespace executor +} // namespace torch \ No newline at end of file diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 84d0712c033..4db9bbafa69 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -36,6 +36,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:upsample_util", "//executorch/kernels/portable/cpu/util:vectorized_math", + "//executorch/kernels/portable/cpu/util:grid_sampler_2d_util", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -342,6 +343,16 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "grid_sampler_2d_util", + srcs = ["grid_sampler_2d_util.cpp"], + exported_headers = ["grid_sampler_2d_util.h"], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/..."], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in get_aten_mode_options(): suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index cea8a115e1b..07ec35059da 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -427,6 +427,11 @@ - arg_meta: null kernel_name: torch::executor::glu_out +- op: grid_sampler_2d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::grid_sampler_2d_out + - op: gt.Scalar_out kernels: - arg_meta: null diff --git a/kernels/portable/test/TARGETS b/kernels/portable/test/TARGETS index c42f54075b9..b659d6c093b 100644 --- a/kernels/portable/test/TARGETS +++ b/kernels/portable/test/TARGETS @@ -19,6 +19,7 @@ runtime.cxx_library( ], deps = [ "//executorch/extension/aten_util:aten_bridge", + "//executorch/kernels/portable/cpu:op_grid_sampler_2d", "//executorch/kernels/portable/cpu:op_upsample_bilinear2d", "//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa", "//executorch/kernels/portable/cpu:op_upsample_nearest2d", diff --git a/kernels/portable/test/op_grid_sampler_2d_test.py b/kernels/portable/test/op_grid_sampler_2d_test.py new file mode 100644 index 00000000000..4791da57d58 --- /dev/null +++ b/kernels/portable/test/op_grid_sampler_2d_test.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import itertools +import unittest + +import torch + + +class GridSampler2dTest(unittest.TestCase): + def run_grid_sampler_test( + self, + inp: torch.Tensor, + grid: torch.Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: bool = False, + atol: float = 1e-5, + ) -> None: + """Test grid_sampler_2d against PyTorch's reference implementation.""" + # PyTorch reference + aten_result = torch.nn.functional.grid_sample( + inp, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + # Convert mode strings to integers for et_test + mode_map = {"bilinear": 0, "nearest": 1, "bicubic": 2} + padding_map = {"zeros": 0, "border": 1, "reflection": 2} + + # ExecuTorch implementation + et_result = torch.zeros_like(aten_result) + et_result = torch.ops.et_test.grid_sampler_2d( + inp, + grid, + interpolation_mode=mode_map[mode], + padding_mode=padding_map[padding_mode], + align_corners=align_corners, + out=et_result, + ) + + self.assertTrue( + torch.allclose(et_result, aten_result, atol=atol, rtol=1e-5), + msg=f"Mode: {mode}, Padding: {padding_mode}, Align: {align_corners}\n" + f"ET: {et_result}\n" + f"ATen: {aten_result}\n" + f"Error: {(et_result.to(torch.float) - aten_result.to(torch.float)).abs().max()}", + ) + + def test_grid_sampler_2d_all_modes_f32(self): + """Test all combinations of interpolation, padding, and align_corners.""" + N = [1, 2] + C = [1, 3] + H_IN = [4, 8] + W_IN = [4, 8] + H_OUT = [3, 6] + W_OUT = [3, 6] + MODES = ["bilinear", "nearest", "bicubic"] + PADDING_MODES = ["zeros", "border", "reflection"] + ALIGN_CORNERS = [True, False] + + for ( + n, + c, + h_in, + w_in, + h_out, + w_out, + mode, + padding_mode, + align_corners, + ) in itertools.product( + N, C, H_IN, W_IN, H_OUT, W_OUT, MODES, PADDING_MODES, ALIGN_CORNERS + ): + # Create input tensor + input_tensor = torch.randn(n, c, h_in, w_in, dtype=torch.float32) + + # Create grid with coordinates in [-1, 1] + grid = torch.randn(n, h_out, w_out, 2, dtype=torch.float32) + # Normalize grid to [-1, 1] range + grid = torch.clamp(grid, -2, 2) + + self.run_grid_sampler_test( + input_tensor, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + atol=1e-4, # Slightly relaxed tolerance for bicubic + ) + + def test_grid_sampler_2d_bilinear_specific_cases(self): + """Test bilinear mode with specific edge cases.""" + # Test with identity grid (should return same as input) + input_tensor = torch.randn(1, 3, 4, 4) + y = torch.linspace(-1, 1, 4) + x = torch.linspace(-1, 1, 4) + grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + + self.run_grid_sampler_test( + input_tensor, + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + + def test_grid_sampler_2d_nearest_specific_cases(self): + """Test nearest mode with specific patterns.""" + # Create a checkerboard pattern + input_tensor = torch.zeros(1, 1, 4, 4) + input_tensor[0, 0, ::2, ::2] = 1.0 + input_tensor[0, 0, 1::2, 1::2] = 1.0 + + # Sample at grid points + y = torch.linspace(-1, 1, 6) + x = torch.linspace(-1, 1, 6) + grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + + self.run_grid_sampler_test( + input_tensor, + grid, + mode="nearest", + padding_mode="zeros", + align_corners=False, + ) + + def test_grid_sampler_2d_padding_modes(self): + """Test different padding modes with out-of-bounds coordinates.""" + input_tensor = torch.randn(1, 2, 5, 5) + + # Create grid with some out-of-bounds coordinates + grid = torch.tensor( + [ + [ + [[-1.5, -1.5], [-0.5, -0.5], [0.5, 0.5], [1.5, 1.5]], + [[-1.0, 1.5], [0.0, 0.0], [1.0, -1.5], [2.0, 2.0]], + ] + ], + dtype=torch.float32, + ) + + for padding_mode in ["zeros", "border", "reflection"]: + for align_corners in [True, False]: + self.run_grid_sampler_test( + input_tensor, + grid, + mode="bilinear", + padding_mode=padding_mode, + align_corners=align_corners, + ) + + def test_grid_sampler_2d_bicubic_smoothness(self): + """Test bicubic interpolation for smooth gradients.""" + # Create a smooth gradient + input_tensor = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4) + + # Create a fine grid for upsampling + y = torch.linspace(-1, 1, 7) + x = torch.linspace(-1, 1, 7) + grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + + self.run_grid_sampler_test( + input_tensor, + grid, + mode="bicubic", + padding_mode="zeros", + align_corners=True, + atol=1e-4, + ) + + def test_grid_sampler_2d_align_corners_comparison(self): + """Compare align_corners=True vs False.""" + input_tensor = torch.randn(1, 1, 8, 8) + + # Create grid at corner positions + grid = torch.tensor( + [ + [ + [[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]], + ] + ], + dtype=torch.float32, + ) + + for mode in ["bilinear", "nearest", "bicubic"]: + # Test with align_corners=True + self.run_grid_sampler_test( + input_tensor, + grid, + mode=mode, + padding_mode="zeros", + align_corners=True, + ) + + # Test with align_corners=False + self.run_grid_sampler_test( + input_tensor, + grid, + mode=mode, + padding_mode="zeros", + align_corners=False, + ) + + def test_grid_sampler_2d_batch_processing(self): + """Test with multiple batches.""" + batch_sizes = [1, 2, 4] + for batch_size in batch_sizes: + input_tensor = torch.randn(batch_size, 3, 6, 6) + grid = torch.randn(batch_size, 4, 4, 2) + grid = torch.clamp(grid, -1.5, 1.5) + + self.run_grid_sampler_test( + input_tensor, + grid, + mode="bilinear", + padding_mode="border", + align_corners=False, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/kernels/portable/test/register_ops_aot_for_test.cpp b/kernels/portable/test/register_ops_aot_for_test.cpp index d13fe9d56ed..cee7a56f460 100644 --- a/kernels/portable/test/register_ops_aot_for_test.cpp +++ b/kernels/portable/test/register_ops_aot_for_test.cpp @@ -101,6 +101,35 @@ Tensor& _upsample_bilinear2d_aa_out_no_context( return ret; } + +Tensor& grid_sampler_2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out); + +Tensor& grid_sampler_2d_out_no_context( + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out) { + KernelRuntimeContext ctx; + auto& ret = grid_sampler_2d_out( + ctx, input, grid, interpolation_mode, padding_mode, align_corners, out); + + if (ctx.failure_state() != Error::Ok) { + throw std::runtime_error( + std::string("Kernel failed with error: ") + + std::to_string((int)ctx.failure_state())); + } + + return ret; +} // NOLINTEND(facebook-hte-ConstantArgumentPassByValue, // facebook-hte-ParameterMightThrowOnCopy) @@ -114,6 +143,9 @@ TORCH_LIBRARY(et_test, m) { m.def( "_upsample_bilinear2d_aa.out(Tensor input, SymInt[] output_size, bool align_corners, float? scale_h, float? scale_w, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(_upsample_bilinear2d_aa_out_no_context, 5)); + m.def( + "grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", + WRAP_TO_ATEN(grid_sampler_2d_out_no_context, 5)); } } // namespace native diff --git a/kernels/portable/test/targets.bzl b/kernels/portable/test/targets.bzl index 918d2b29fef..c52b59b7b4f 100644 --- a/kernels/portable/test/targets.bzl +++ b/kernels/portable/test/targets.bzl @@ -52,6 +52,19 @@ def define_common_targets(): ], ) + python_unittest( + name = "op_grid_sampler_2d_test", + srcs = [ + "op_grid_sampler_2d_test.py", + ], + preload_deps = [ + ":aot_ops_test_lib", + ], + deps = [ + "//caffe2:torch", + ], + ) + op_test(name = "op_allclose_test") op_test(name = "op_div_test") op_test(name = "op_gelu_test") diff --git a/kernels/portable/test/test_grid_sampler_2d_executorch.py b/kernels/portable/test/test_grid_sampler_2d_executorch.py new file mode 100644 index 00000000000..4b635e14842 --- /dev/null +++ b/kernels/portable/test/test_grid_sampler_2d_executorch.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test grid_sampler_2d by exporting to ExecuTorch and comparing with PyTorch. +""" + +import itertools +import sys +import unittest + +import torch +import torch.nn as nn +from executorch.exir import to_edge +from executorch.runtime import Runtime +from torch.export import export + + +class GridSampleModule(nn.Module): + """Wrapper module for grid_sample operation.""" + + def __init__( + self, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: bool = False, + ): + super().__init__() + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + + def forward(self, input: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.grid_sample( + input, + grid, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, + ) + + +class GridSampler2DExecutorchTest(unittest.TestCase): + """Test ExecuTorch grid_sampler_2d implementation.""" + + def run_executorch_test( + self, + input_tensor: torch.Tensor, + grid: torch.Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + ) -> None: + """Export to ExecuTorch and compare with PyTorch reference.""" + + # Create module + model = GridSampleModule(mode, padding_mode, align_corners) + model.eval() + + # PyTorch reference + with torch.no_grad(): + pytorch_output = model(input_tensor, grid) + + try: + # Export to ExecuTorch + example_inputs = (input_tensor, grid) + + # Export the model + exported_program = export(model, example_inputs) + + # Convert to edge IR + edge_program = to_edge(exported_program) + + # Get ExecuTorch program + executorch_program = edge_program.to_executorch() + + # Run through ExecuTorch + runtime = Runtime.get() + fwd_method = runtime.load_program(executorch_program.buffer).load_method( + "forward" + ) + if fwd_method is None: + self.fail("Failed to load forward method") + executorch_output = fwd_method.execute((input_tensor, grid))[0] + + # Compare results + self.assertTrue( + executorch_output.shape == pytorch_output.shape, + msg=f"Shape mismatch: ET={executorch_output.shape} vs PT={pytorch_output.shape}", + ) + + if not torch.allclose( + executorch_output, pytorch_output, atol=atol, rtol=rtol + ): + max_diff = (executorch_output - pytorch_output).abs().max().item() + mean_diff = (executorch_output - pytorch_output).abs().mean().item() + self.fail( + f"\nMode: {mode}, Padding: {padding_mode}, Align: {align_corners}\n" + f"Max difference: {max_diff:.6e}\n" + f"Mean difference: {mean_diff:.6e}\n" + f"Tolerance (atol): {atol:.6e}\n" + f"ExecuTorch output:\n{executorch_output}\n" + f"PyTorch output:\n{pytorch_output}\n" + ) + + except Exception as e: + self.fail( + f"Failed to export or run model:\n" + f"Mode: {mode}, Padding: {padding_mode}, Align: {align_corners}\n" + f"Error: {str(e)}" + ) + + def test_all_mode_combinations(self): + """Test all combinations of interpolation modes, padding modes, and align_corners.""" + print("\n" + "=" * 70) + print("Testing all mode combinations") + print("=" * 70) + + modes = ["bilinear", "nearest", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners_options = [True, False] + + # Test parameters + batch_size = 2 + channels = 3 + height_in = 5 + width_in = 5 + height_out = 4 + width_out = 4 + + test_count = 0 + for mode, padding, align in itertools.product( + modes, padding_modes, align_corners_options + ): + with self.subTest(mode=mode, padding=padding, align=align): + # Create random input + input_tensor = torch.randn( + batch_size, channels, height_in, width_in, dtype=torch.float32 + ) + + # Create grid with some values in [-1, 1] range + grid = torch.randn( + batch_size, height_out, width_out, 2, dtype=torch.float32 + ) + grid = torch.clamp(grid, -1.2, 1.2) # Include some out-of-bounds + + # Bicubic may have slightly larger numerical errors + atol = 1e-4 if mode == "bicubic" else 1e-5 + + self.run_executorch_test( + input_tensor, grid, mode, padding, align, atol=atol + ) + test_count += 1 + print(f" ✓ {mode}/{padding}/align={align}") + + print(f"✓ Passed {test_count} mode combination tests") + + def test_batch_sizes(self): + """Test various batch sizes.""" + print("\n" + "=" * 70) + print("Testing various batch sizes") + print("=" * 70) + + batch_sizes = [1, 2, 4] + + for batch_size in batch_sizes: + with self.subTest(batch_size=batch_size): + input_tensor = torch.randn(batch_size, 3, 6, 6, dtype=torch.float32) + grid = torch.randn(batch_size, 4, 4, 2, dtype=torch.float32) + grid = torch.clamp(grid, -1, 1) + + self.run_executorch_test(input_tensor, grid, "bilinear", "zeros", False) + print(f" ✓ batch_size={batch_size}") + + print(f"✓ Passed {len(batch_sizes)} batch size tests") + + def test_input_sizes(self): + """Test various input and output sizes.""" + print("\n" + "=" * 70) + print("Testing various input/output sizes") + print("=" * 70) + + test_cases = [ + # (H_in, W_in, H_out, W_out) + (4, 4, 4, 4), # Same size + (8, 8, 4, 4), # Downsampling + (4, 4, 8, 8), # Upsampling + (10, 5, 7, 3), # Non-square, different aspect ratios + ] + + for h_in, w_in, h_out, w_out in test_cases: + with self.subTest(h_in=h_in, w_in=w_in, h_out=h_out, w_out=w_out): + input_tensor = torch.randn(1, 2, h_in, w_in, dtype=torch.float32) + grid = torch.randn(1, h_out, w_out, 2, dtype=torch.float32) + grid = torch.clamp(grid, -1, 1) + + self.run_executorch_test(input_tensor, grid, "bilinear", "zeros", False) + print(f" ✓ {h_in}x{w_in} -> {h_out}x{w_out}") + + print(f"✓ Passed {len(test_cases)} size variation tests") + + def test_identity_grid(self): + """Test with identity grid (should return approximately same as input).""" + print("\n" + "=" * 70) + print("Testing identity grid") + print("=" * 70) + + input_tensor = torch.randn(1, 3, 4, 4, dtype=torch.float32) + + # Create identity grid + y = torch.linspace(-1, 1, 4) + x = torch.linspace(-1, 1, 4) + grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + + for mode in ["bilinear", "nearest", "bicubic"]: + with self.subTest(mode=mode): + atol = 1e-4 if mode == "bicubic" else 1e-5 + self.run_executorch_test( + input_tensor, grid, mode, "zeros", True, atol=atol + ) + print(f" ✓ {mode}") + + print("✓ Passed identity grid tests") + + def test_corner_coordinates(self): + """Test sampling at corner coordinates with different align_corners settings.""" + print("\n" + "=" * 70) + print("Testing corner coordinates") + print("=" * 70) + + input_tensor = torch.randn(1, 1, 8, 8, dtype=torch.float32) + + # Grid sampling at corners + grid = torch.tensor( + [ + [ + [[-1.0, -1.0], [-1.0, 1.0]], + [[1.0, -1.0], [1.0, 1.0]], + ] + ], + dtype=torch.float32, + ) + + for align_corners in [True, False]: + for mode in ["bilinear", "nearest"]: + with self.subTest(align_corners=align_corners, mode=mode): + self.run_executorch_test( + input_tensor, grid, mode, "zeros", align_corners + ) + print(f" ✓ {mode}/align={align_corners}") + + print("✓ Passed corner coordinate tests") + + def test_out_of_bounds(self): + """Test behavior with out-of-bounds coordinates for different padding modes.""" + print("\n" + "=" * 70) + print("Testing out-of-bounds coordinates") + print("=" * 70) + + input_tensor = torch.randn(1, 2, 5, 5, dtype=torch.float32) + + # Grid with out-of-bounds coordinates + grid = torch.tensor( + [ + [ + [[-1.5, -1.5], [-0.5, -0.5], [0.5, 0.5], [1.5, 1.5]], + [[-1.0, 1.5], [0.0, 0.0], [1.0, -1.5], [2.0, 2.0]], + ] + ], + dtype=torch.float32, + ) + + for padding_mode in ["zeros", "border", "reflection"]: + for mode in ["bilinear", "nearest"]: + with self.subTest(padding_mode=padding_mode, mode=mode): + self.run_executorch_test( + input_tensor, grid, mode, padding_mode, False + ) + print(f" ✓ {mode}/{padding_mode}") + + print("✓ Passed out-of-bounds tests") + + def test_single_channel(self): + """Test with single channel input.""" + print("\n" + "=" * 70) + print("Testing single channel input") + print("=" * 70) + + input_tensor = torch.randn(1, 1, 6, 6, dtype=torch.float32) + grid = torch.randn(1, 4, 4, 2, dtype=torch.float32) + grid = torch.clamp(grid, -1, 1) + + for mode in ["bilinear", "nearest", "bicubic"]: + with self.subTest(mode=mode): + atol = 1e-4 if mode == "bicubic" else 1e-5 + self.run_executorch_test( + input_tensor, grid, mode, "zeros", False, atol=atol + ) + print(f" ✓ {mode}") + + print("✓ Passed single channel tests") + + def test_different_dtypes(self): + """Test with different data types (float16, bfloat16).""" + print("\n" + "=" * 70) + print("Testing different dtypes") + print("=" * 70) + + dtypes = [torch.float16, torch.bfloat16] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + input_tensor = torch.randn(1, 2, 4, 4, dtype=dtype) + grid = torch.randn(1, 3, 3, 2, dtype=dtype) + grid = torch.clamp(grid, -1, 1) + + # Use larger tolerance for float16/bfloat16 + atol = 1e-2 if dtype == torch.bfloat16 else 5e-3 + + self.run_executorch_test( + input_tensor, grid, "bilinear", "zeros", False, atol=atol + ) + print(f" ✓ {dtype}") + + print("✓ Passed dtype tests") + + def test_very_small_inputs(self): + """Test with very small input sizes.""" + print("\n" + "=" * 70) + print("Testing very small inputs") + print("=" * 70) + + test_cases = [ + # (H_in, W_in, H_out, W_out, description) + (1, 1, 1, 1, "1x1 input, 1x1 output"), + (1, 1, 2, 2, "1x1 input, 2x2 output"), + (2, 2, 1, 1, "2x2 input, 1x1 output"), + (2, 2, 2, 2, "2x2 input, 2x2 output"), + (2, 2, 3, 3, "2x2 input, 3x3 output"), + (3, 3, 1, 1, "3x3 input, single pixel output"), + ] + + for h_in, w_in, h_out, w_out, desc in test_cases: + with self.subTest(desc=desc): + input_tensor = torch.randn(1, 2, h_in, w_in, dtype=torch.float32) + grid = torch.randn(1, h_out, w_out, 2, dtype=torch.float32) + grid = torch.clamp(grid, -1, 1) + + self.run_executorch_test(input_tensor, grid, "bilinear", "zeros", False) + print(f" ✓ {desc}") + + print("✓ Passed very small input tests") + + def test_exact_boundary_coordinates(self): + """Test with grid coordinates exactly at boundaries.""" + print("\n" + "=" * 70) + print("Testing exact boundary coordinates") + print("=" * 70) + + input_tensor = torch.randn(1, 2, 5, 5, dtype=torch.float32) + + # Test grid with exact boundary values + grids = [ + # All corners + torch.tensor( + [[[[-1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [1.0, 1.0]]]], + dtype=torch.float32, + ), + # Center + torch.tensor([[[[0.0, 0.0]]]], dtype=torch.float32), + # Edges + torch.tensor( + [[[[-1.0, 0.0], [1.0, 0.0], [0.0, -1.0], [0.0, 1.0]]]], + dtype=torch.float32, + ), + ] + + for i, grid in enumerate(grids): + for mode in ["bilinear", "nearest", "bicubic"]: + for align_corners in [True, False]: + with self.subTest(grid=i, mode=mode, align_corners=align_corners): + atol = 1e-4 if mode == "bicubic" else 1e-5 + self.run_executorch_test( + input_tensor, grid, mode, "zeros", align_corners, atol=atol + ) + print(f" ✓ grid {i}/{mode}/align={align_corners}") + + print("✓ Passed exact boundary coordinate tests") + + def test_out_of_bounds_values_in_grid(self): + """Test with out of bounds values in grid.""" + print("\n" + "=" * 70) + print("Testing special values in grid") + print("=" * 70) + + input_tensor = torch.randn(1, 2, 4, 4, dtype=torch.float32) + + test_cases = [ + # (grid, description) + ( + torch.tensor([[[[10.0, 10.0], [-10.0, -10.0]]]], dtype=torch.float32), + "Very large coordinates (far out of bounds)", + ), + ( + torch.tensor( + [[[[2.0, 0.0], [0.0, 2.0], [-2.0, 0.0], [0.0, -2.0]]]], + dtype=torch.float32, + ), + "Moderately out of bounds coordinates", + ), + ] + + for grid, desc in test_cases: + with self.subTest(desc=desc): + # Test with zeros padding (most common for out-of-bounds) + self.run_executorch_test(input_tensor, grid, "bilinear", "zeros", False) + print(f" ✓ {desc}") + + print("✓ Passed special value tests") + + +def main(): + """Run the tests.""" + print("\n" + "=" * 70) + print("ExecuTorch grid_sampler_2d Test Suite") + print("Testing via model export and ExecuTorch runtime") + print("=" * 70) + + # Run tests with verbose output + suite = unittest.TestLoader().loadTestsFromTestCase(GridSampler2DExecutorchTest) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Print summary + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 2e488b109c1..6e0590e5c62 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -185,6 +185,7 @@ set(all_test_sources "op_ge_test.cpp" "op_gelu_test.cpp" "op_glu_test.cpp" + "op_grid_sampler_2d_test.cpp" "op_gt_test.cpp" "op_hardtanh_test.cpp" "op_index_put_test.cpp" diff --git a/kernels/test/op_grid_sampler_2d_test.cpp b/kernels/test/op_grid_sampler_2d_test.cpp new file mode 100644 index 00000000000..9b444ddd917 --- /dev/null +++ b/kernels/test/op_grid_sampler_2d_test.cpp @@ -0,0 +1,517 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + +class OpGridSampler2dTest : public OperatorTest { + protected: + Tensor& op_grid_sampler_2d_out( + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out) { + return torch::executor::aten::grid_sampler_2d_outf( + context_, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + out); + } + + template + std::enable_if_t, void> + test_grid_sampler_2d_dtype() { + TensorFactory tf; + + // Simple test: 2x2 input, identity-like grid + const auto input = tf.make({1, 1, 2, 2}, {1, 2, 3, 4}); + const auto grid = tf.make( + {1, 2, 2, 2}, + { + -0.5, + -0.5, // Top-left quadrant + 0.5, + -0.5, // Top-right quadrant + -0.5, + 0.5, // Bottom-left quadrant + 0.5, + 0.5 // Bottom-right quadrant + }); + auto out = tf.zeros({1, 1, 2, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + // Output should be close to input for this nearly-identity grid + EXPECT_TENSOR_CLOSE(out, input); + } + + template + std::enable_if_t, void> + test_grid_sampler_2d_dtype() { + // not supported + return; + } +}; + +// +// Bilinear interpolation tests +// + +TEST_F(OpGridSampler2dTest, BilinearSimple) { + TensorFactory tf; + + // 2x2 input, sample at exact pixel centers + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // Grid: sample at pixel centers in normalized coords [-1, 1] + // For 2x2 with align_corners=false: + // pixel (0,0) is at normalized (-0.5, -0.5) + // pixel (1,1) is at normalized (0.5, 0.5) + const auto grid = tf.make( + {1, 2, 2, 2}, + { + -0.5, + -0.5, // Sample pixel (0,0) -> 1.0 + 0.5, + -0.5, // Sample pixel (1,0) -> 2.0 + -0.5, + 0.5, // Sample pixel (0,1) -> 3.0 + 0.5, + 0.5 // Sample pixel (1,1) -> 4.0 + }); + auto out = tf.zeros({1, 1, 2, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + const auto expected = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpGridSampler2dTest, BilinearInterpolation) { + TensorFactory tf; + + // 2x2 input + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // Sample at center of image (should be average of all pixels) + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + // Center should be close to 2.5 (average of 1,2,3,4) + const auto expected = tf.make({1, 1, 1, 1}, {2.5}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpGridSampler2dTest, BilinearAlignCorners) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // With align_corners=true, corners map exactly to pixel centers + const auto grid = tf.make( + {1, 2, 2, 2}, + { + -1.0, + -1.0, // Top-left corner -> pixel (0,0) -> 1.0 + 1.0, + -1.0, // Top-right corner -> pixel (1,0) -> 2.0 + -1.0, + 1.0, // Bottom-left corner -> pixel (0,1) -> 3.0 + 1.0, + 1.0 // Bottom-right corner -> pixel (1,1) -> 4.0 + }); + auto out = tf.zeros({1, 1, 2, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + true, // align_corners + out); + + const auto expected = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + EXPECT_TENSOR_EQ(out, expected); +} + +// +// Nearest neighbor tests +// + +TEST_F(OpGridSampler2dTest, NearestSimple) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // Sample near pixel centers (should snap to nearest pixel) + const auto grid = tf.make( + {1, 2, 2, 2}, + { + -0.6, + -0.6, // Near (0,0) -> 1.0 + 0.4, + -0.4, // Near (1,0) -> 2.0 + -0.3, + 0.3, // Near (0,1) -> 3.0 + 0.6, + 0.6 // Near (1,1) -> 4.0 + }); + auto out = tf.zeros({1, 1, 2, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 1, // nearest + 0, // zeros padding + false, + out); + + const auto expected = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + EXPECT_TENSOR_EQ(out, expected); +} + +// +// Bicubic interpolation tests +// + +TEST_F(OpGridSampler2dTest, BicubicSimple) { + TensorFactory tf; + + // Larger input for bicubic (needs 4x4 neighborhood) + const auto input = tf.make( + {1, 1, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + // Sample at center + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + op_grid_sampler_2d_out( + input, + grid, + 2, // bicubic + 0, // zeros padding + false, + out); + + // Bicubic at center should be close to 8.5 (average of middle pixels) + // Note: The tolerance of 0.5 is intentionally large because the expected + // value (8.5) is a rough estimate (average of the middle pixels), not the + // exact bicubic interpolation result. Bicubic interpolation can produce + // values that differ from this average due to its mathematical properties. + const auto expected = tf.make({1, 1, 1, 1}, {8.5}); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 0, 0.5); +} + +// +// Padding mode tests +// + +TEST_F(OpGridSampler2dTest, ZerosPaddingOutOfBounds) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // Sample way outside the image bounds + const auto grid = tf.make( + {1, 2, 2, 2}, + { + -2.0, + -2.0, // Far outside + 2.0, + 2.0, // Far outside + -0.5, + -0.5, // Inside + 0.5, + 0.5 // Inside + }); + auto out = tf.zeros({1, 1, 2, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + // Out-of-bounds samples should be 0, in-bounds samples should match + const auto expected = tf.make({1, 1, 2, 2}, {0.0, 0.0, 1.0, 4.0}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpGridSampler2dTest, BorderPaddingOutOfBounds) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1.0, 2.0, 3.0, 4.0}); + + // Sample outside bounds + const auto grid = tf.make( + {1, 1, 2, 2}, + { + -2.0, + -2.0, // Should clamp to top-left pixel -> 1.0 + 2.0, + 2.0 // Should clamp to bottom-right pixel -> 4.0 + }); + auto out = tf.zeros({1, 1, 1, 2}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 1, // border padding + false, + out); + + const auto expected = tf.make({1, 1, 1, 2}, {1.0, 4.0}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpGridSampler2dTest, ReflectionPadding) { + TensorFactory tf; + + const auto input = tf.make({1, 1, 3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // Sample with reflection padding + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 2, // reflection padding + false, + out); + + // Center pixel should be 5 + const auto expected = tf.make({1, 1, 1, 1}, {5.0}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +// +// Multi-channel and batch tests +// + +TEST_F(OpGridSampler2dTest, MultiChannel) { + TensorFactory tf; + + // 2 channels + const auto input = tf.make( + {1, 2, 2, 2}, + {1, + 2, // Channel 0 + 3, + 4, + 5, + 6, // Channel 1 + 7, + 8}); + + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 2, 1, 1}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + // Each channel should average its 4 pixels + const auto expected = tf.make({1, 2, 1, 1}, {2.5, 6.5}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpGridSampler2dTest, MultiBatch) { + TensorFactory tf; + + // 2 batches + const auto input = tf.make( + {2, 1, 2, 2}, + {1, + 2, // Batch 0 + 3, + 4, + 5, + 6, // Batch 1 + 7, + 8}); + + const auto grid = tf.make( + {2, 1, 1, 2}, + { + 0.0, + 0.0, // Batch 0 samples center + 0.0, + 0.0 // Batch 1 samples center + }); + auto out = tf.zeros({2, 1, 1, 1}); + + op_grid_sampler_2d_out( + input, + grid, + 0, // bilinear + 0, // zeros padding + false, + out); + + // Each batch averages its 4 pixels + const auto expected = tf.make({2, 1, 1, 1}, {2.5, 6.5}); + EXPECT_TENSOR_CLOSE(out, expected); +} + +// +// Dtype tests +// + +TEST_F(OpGridSampler2dTest, DType) { +#define TEST_ENTRY(ctype, dtype) \ + test_grid_sampler_2d_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +// +// Error case tests +// + +TEST_F(OpGridSampler2dTest, InvalidInputRankDies) { + TensorFactory tf; + + // Input must be 4D + const auto input = tf.ones({1, 2, 2}); + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, InvalidGridRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2, 2}); + // Grid must be 4D + const auto grid = tf.make({1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, GridLastDimMustBe2Dies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2, 2}); + // Grid's last dimension must be 2 (x, y coordinates) + const auto grid = tf.ones({1, 1, 1, 3}); + auto out = tf.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, BatchSizeMismatchDies) { + TensorFactory tf; + + // Batch size must match between input and grid + const auto input = tf.ones({1, 1, 2, 2}); + const auto grid = tf.make({2, 1, 1, 2}, {0.0, 0.0, 0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, MismatchedDTypeDies) { + TensorFactory tf; + TensorFactory tf_long; + + const auto input = tf.ones({1, 1, 2, 2}); + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + // Output dtype must match input dtype + auto out = tf_long.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, GridDTypeMismatchDies) { + TensorFactory tf; + TensorFactory tf_double; + + const auto input = tf.ones({1, 1, 2, 2}); + // Grid dtype must match input dtype + const auto grid = tf_double.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, InvalidInterpolationModeDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2, 2}); + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + // Invalid interpolation mode (valid: 0=bilinear, 1=nearest, 2=bicubic) + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 3, 0, false, out)); +} + +TEST_F(OpGridSampler2dTest, InvalidPaddingModeDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2, 2}); + const auto grid = tf.make({1, 1, 1, 2}, {0.0, 0.0}); + auto out = tf.zeros({1, 1, 1, 1}); + + // Invalid padding mode (valid: 0=zeros, 1=border, 2=reflection) + ET_EXPECT_KERNEL_FAILURE( + context_, op_grid_sampler_2d_out(input, grid, 0, 3, false, out)); +} \ No newline at end of file diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index c42be80010b..03c3329a0bf 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -628,6 +628,13 @@ ATEN_OPS = ( "//executorch/runtime/core/exec_aten/util:tensor_util", ], ), + op_target( + name = "op_grid_sampler_2d", + deps = [ + "//executorch/kernels/portable/cpu/util:grid_sampler_2d_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], + ), op_target( name = "op_gt", deps = [