From f730f2597ecb1be943448e92552f4daa08c2040f Mon Sep 17 00:00:00 2001 From: Cheng Chang Date: Tue, 3 Nov 2020 14:44:33 -0800 Subject: [PATCH 1/5] [NNC] Implement Cond in LLVM codegen (#47256) Summary: Generate LLVM IR for statements such as ``` if (...) { .... } else { .... } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/47256 Test Plan: added unit tests to test_llvm.cpp Reviewed By: nickgg Differential Revision: D24699080 Pulled By: cheng-chang fbshipit-source-id: 83b0cebcd242828263eb6052483f0924b5f091ce --- test/cpp/tensorexpr/test_llvm.cpp | 100 +++++++++++++++++++++ test/cpp/tensorexpr/tests.h | 3 + torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 40 ++++++++- 3 files changed, 142 insertions(+), 1 deletion(-) diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 7f4e1a0afc24..61bdf9d9b974 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -278,6 +278,106 @@ void testLLVMIfThenElseTest() { ASSERT_EQ(b_buffer[0], 42); } +// if (x < 10) x = x + 1 +void testLLVMCondNoFalseBlockTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value); + } + } +} + +// if (x < 10) { +// x = x + 1; +// } else { +// x = x - 1; +// } +void testLLVMCondTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto block = Block::make({ + cond, + x.store({0}, x.load(0) * 2), + }); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(block, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); + } else { + ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); + } + } +} + +// if (x < 10) { +// if (x > 5) { +// x = x + 1; +// } else { +// x = x - 1; +// } +// } else { +// if (x <= 15) { +// x = x + 2; +// } else { +// x = x - 2; +// } +// } +void testLLVMCondNestedTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto true_cmp = + CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); + auto true_cond = Cond::make( + true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto false_cmp = + CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); + auto false_cond = Cond::make( + false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, true_cond, false_cond); + + for (int32_t x_value : {0, 8, 15, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + if (x_value > 5) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value - 1); + } + } else { + if (x_value <= 15) { + ASSERT_EQ(x_buffer[0], x_value + 2); + } else { + ASSERT_EQ(x_buffer[0], x_value - 2); + } + } + } +} + void testLLVMVecLoadStoreTest() { KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 6f7b50790477..72fedbc2a95a 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -506,6 +506,9 @@ namespace jit { _(LLVMEmptyStmt) \ _(LLVMEliminatedStmt) \ _(LLVMIfThenElseTest) \ + _(LLVMCondNoFalseBlockTest) \ + _(LLVMCondTest) \ + _(LLVMCondNestedTest) \ _(LLVMVectorizerLoadStoreTest) \ _(LLVMSimpleReduction) \ _(LLVMRFactorReduction) diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 0400b6f14143..e692237b7c6f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -1689,7 +1689,45 @@ void LLVMCodeGenImpl::visit(const Let* v) { } void LLVMCodeGenImpl::visit(const Cond* v) { - throw unimplemented_lowering(v); + // Even if true_stmt and false_stmt are nullptr, + // in case condition is a function call with side effect, + // we still evaluate it. + v->condition()->accept(this); + + if (!v->true_stmt() && !v->false_stmt()) { + return; + } + assert(v->true_stmt()); + + llvm::Value* condition = value_; + llvm::Value* c = irb_.CreateICmpNE( + condition, llvm::ConstantInt::get(condition->getType(), 0)); + llvm::BasicBlock* then_block = + llvm::BasicBlock::Create(getContext(), "then", fn_); + llvm::BasicBlock* else_block = nullptr; + if (v->false_stmt()) { + else_block = llvm::BasicBlock::Create(getContext(), "else", fn_); + } + llvm::BasicBlock* end_block = + llvm::BasicBlock::Create(getContext(), "end", fn_); + + if (else_block) { + irb_.CreateCondBr(c, then_block, else_block); + } else { + irb_.CreateCondBr(c, then_block, end_block); + } + + irb_.SetInsertPoint(then_block); + v->true_stmt()->accept(this); + irb_.CreateBr(end_block); + + if (else_block) { + irb_.SetInsertPoint(else_block); + v->false_stmt()->accept(this); + irb_.CreateBr(end_block); + } + + irb_.SetInsertPoint(end_block); } void LLVMCodeGenImpl::optimize(llvm::Module& M) { From 2b5433dee615cba787e5bd1a0c6c61b4c8cd6cd0 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 3 Nov 2020 15:17:04 -0800 Subject: [PATCH 2/5] [Pytorch][Annotation] Update inlined callstack with module instance info (#46729) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46729 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D24493220 Pulled By: cccclai fbshipit-source-id: f37834157e6f69bbe87f73a7d3d38a94ece6017d --- torch/csrc/jit/ir/ir.cpp | 20 ++++++++++++++++-- torch/csrc/jit/ir/scope.cpp | 24 +++++++++++++++++++++ torch/csrc/jit/ir/scope.h | 42 +++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index e0b7e15556eb..a85f1ef69b1e 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1897,6 +1897,19 @@ std::vector inlineCallTo( std::unordered_map new_callstack_entries; + c10::optional module_instance_info = c10::nullopt; + if (to_replace->kind() == prim::CallMethod) { + auto class_type_ptr = to_replace->input(0)->type()->cast(); + if (to_replace->input(0)->node()->kind() == prim::GetAttr) { + module_instance_info = c10::make_optional(ModuleInstanceInfo( + class_type_ptr, to_replace->input(0)->node()->s(attr::name))); + } else { + std::string instance_name_unknown("INSTANCE_NAME_UNKNOWN"); + module_instance_info = c10::make_optional( + ModuleInstanceInfo(class_type_ptr, instance_name_unknown)); + } + } + // TODO: We might need to use nodes_map instead of value_map. Otherwise, we // are missing nodes without outputs (e.g. prim::Print). std::unordered_set updated_nodes; @@ -1915,11 +1928,14 @@ std::vector inlineCallTo( if (new_node_cs) { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - *new_node_cs, callee, to_replace->sourceRange()); + *new_node_cs, + callee, + to_replace->sourceRange(), + module_instance_info); } else { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - callee, to_replace->sourceRange()); + callee, to_replace->sourceRange(), module_instance_info); } } new_node->setCallStack(new_callstack_entries.at(raw_callstack_ptr)); diff --git a/torch/csrc/jit/ir/scope.cpp b/torch/csrc/jit/ir/scope.cpp index 900722427225..3901ce1038bf 100644 --- a/torch/csrc/jit/ir/scope.cpp +++ b/torch/csrc/jit/ir/scope.cpp @@ -89,6 +89,14 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() { InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range) : fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, Function* fn, @@ -97,6 +105,16 @@ InlinedCallStack::InlinedCallStack( fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : callee_(std::move(callee)), + fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + c10::optional InlinedCallStack::callee() const { return callee_; } @@ -110,5 +128,11 @@ std::vector InlinedCallStack::vec() { } return r; } + +ModuleInstanceInfo::ModuleInstanceInfo( + c10::ClassTypePtr module_type, + std::string instance_name) + : module_type_(std::move(module_type)), + instance_name_(std::move(instance_name)) {} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index d75f3e060f36..784c2942c263 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -51,6 +52,32 @@ struct TORCH_API Scope : public c10::intrusive_ptr_target { struct Function; struct InlinedCallStack; +/** + * ModuleInstanceInfo is a structure to include the module type and instance + * name. It also provide public methods to get the pointer to module type and + * instance name. + * + * This structure is mainly used as a private member in InlinedCallStack, such + * that one can follow the callstack to find the relevant module hierarchy. + */ +struct ModuleInstanceInfo { + private: + c10::ClassTypePtr module_type_{nullptr}; + std::string instance_name_; + + public: + ModuleInstanceInfo(c10::ClassTypePtr module_type, std::string instance_name); + c10::ClassTypePtr class_type() { + return module_type_; + } + c10::ClassTypePtr class_type() const { + return module_type_; + } + std::string instance_name() const { + return instance_name_; + } +}; + /** * InlinedCallStack is an element in a list representing callstack of functions * that have been inlined. @@ -81,6 +108,8 @@ struct InlinedCallStack; */ using InlinedCallStackPtr = c10::intrusive_ptr; using InlinedCallStackEntry = std::pair; +using InlinedCallStackWithModuleInfo = + std::tuple>; struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { private: @@ -88,17 +117,30 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { Function* fn_; SourceRange source_range_; InlinedCallStackPtr intrusive_from_this(); + c10::optional module_instance_info_; public: // Constructor for a leaf callstack node. InlinedCallStack(Function* fn, SourceRange source_range); + // Constructor for a leaf callstack node. + InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Constructor for an inner callstack node. InlinedCallStack( InlinedCallStackPtr callee, Function* fn, SourceRange source_range); + InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Return next element in the callstack list. c10::optional callee() const; From 63978556fd0127dbf5c0e07628015ae54117898c Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 3 Nov 2020 15:22:30 -0800 Subject: [PATCH 3/5] [numpy] `torch.a{cosh, sinh}` : promote integer inputs to float (#47152) Summary: Reference https://github.com/pytorch/pytorch/issues/42515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47152 Reviewed By: mrshenli Differential Revision: D24681083 Pulled By: mruberry fbshipit-source-id: 246e2272536cf912a2575bfaaa831c3eceec034c --- aten/src/ATen/native/UnaryOps.cpp | 8 ++++---- aten/src/ATen/native/cuda/UnaryGeometricKernels.cu | 4 ++-- .../testing/_internal/common_methods_invocations.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8c3fc182e646..7d57651005d5 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -351,8 +351,8 @@ Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out( Tensor cosh(const Tensor& self) { return unary_op_impl(self, at::cosh_out); } Tensor& cosh_(Tensor& self) { return unary_op_impl_(self, at::cosh_out); } -Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acosh_stub); } -Tensor acosh(const Tensor& self) { return unary_op_impl(self, at::acosh_out); } +Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acosh_stub); } +Tensor acosh(const Tensor& self) { return unary_op_impl_float(self, acosh_stub); } Tensor& acosh_(Tensor& self) { return unary_op_impl_(self, at::acosh_out); } // arccosh, alias for acosh @@ -360,8 +360,8 @@ Tensor& arccosh_out(Tensor& result, const Tensor& self) { return at::acosh_out(r Tensor arccosh(const Tensor& self) { return at::acosh(self); } Tensor& arccosh_(Tensor& self) { return at::acosh_(self); } -Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asinh_stub); } -Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); } +Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asinh_stub); } +Tensor asinh(const Tensor& self) { return unary_op_impl_float(self, asinh_stub); } Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); } // arcsinh, alias for asinh diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 6bf0bdc3ea89..bc64b7eb2268 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -75,7 +75,7 @@ void tanh_kernel_cuda(TensorIterator& iter) { } void acosh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "acosh_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::acosh(a); }); @@ -83,7 +83,7 @@ void acosh_kernel_cuda(TensorIterator& iter) { } void asinh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "asinh_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::asinh(a); }); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b588ded49494..277124466101 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -237,8 +237,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('acosh', ref=np.arccosh, domain=(1, float('inf')), - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False), UnaryUfuncInfo('asin', @@ -255,8 +257,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): # NOTE: derivative for inplace asinh is not implemented UnaryUfuncInfo('asinh', ref=np.arcsinh, - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False), UnaryUfuncInfo('atan', From f41f3e3cd16a51e19441d1d51ddbee53771d4c6c Mon Sep 17 00:00:00 2001 From: pomelyu Date: Tue, 3 Nov 2020 15:27:20 -0800 Subject: [PATCH 4/5] Implement bicubic grid sampler (#44780) Summary: Fix https://github.com/pytorch/pytorch/issues/44601 I added bicubic grid sampler in both cpu and cuda side, but haven't in AVX2 There is a [colab notebook](https://colab.research.google.com/drive/1mIh6TLLj5WWM_NcmKDRvY5Gltbb781oU?usp=sharing) show some test results. The notebook use bilinear for test, since I could only use distributed version of pytorch in it. You could just download it and modify the `mode_torch=bicubic` to show the results. There are some duplicate code about getting and setting values, since the helper function used in bilinear at first clip the coordinate beyond boundary, and then get or set the value. However, in bicubic, there are more points should be consider. I could refactor that part after making sure the overall calculation are correct. Thanks Pull Request resolved: https://github.com/pytorch/pytorch/pull/44780 Reviewed By: mrshenli Differential Revision: D24681114 Pulled By: mruberry fbshipit-source-id: d39c8715e2093a5a5906cb0ef040d62bde578567 --- aten/src/ATen/native/GridSampler.cpp | 109 ++++++++- aten/src/ATen/native/GridSampler.h | 95 +++++++- .../src/ATen/native/cpu/GridSamplerKernel.cpp | 214 ++++++++++++++++++ aten/src/ATen/native/cuda/GridSampler.cu | 100 +++++++- aten/src/ATen/native/cuda/GridSampler.cuh | 88 ++++++- test/test_nn.py | 77 ++++++- .../api/include/torch/nn/functional/vision.h | 4 +- torch/nn/functional.py | 25 +- 8 files changed, 665 insertions(+), 47 deletions(-) diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index 59242e0e6c03..667cbe8f07b3 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -422,11 +423,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, for (int64_t w = 0; w < out_W; ++w) { // get the corresponding input x, y, z co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -483,6 +484,43 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + // grid_sampler_compute_source_index will "clip the value" of idx depends on the padding, + // which would cause calculation to be wrong, + // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1 + // There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t *inp_ptr_NC = inp_ptr_N; + scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { + scalar_t coefficients[4]; + + // Interpolate 4 values in the x directon + for (int64_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + // Interpolate in the y direction + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -547,13 +585,13 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { // get the corresponding input x, y co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; // multipliers for gradients on ix, iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -628,6 +666,55 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW); } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsample_coefficients(x_coeffs, tx); + get_cubic_upsample_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + scalar_t *inp_ptr_NC = inp_ptr_N; + + for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } @@ -640,6 +727,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -682,6 +770,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid, std::tuple grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -757,6 +846,10 @@ Tensor grid_sampler(const Tensor& input, const Tensor& grid, grid.size(-1) == input.dim() - 2, "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " "dimension, but got grid with sizes ", grid.sizes()); + TORCH_CHECK( + !(input.dim() == 5 && static_cast(interpolation_mode) == GridSamplerInterpolation::Bicubic), + "grid_sampler(): bicubic interpolation only supports 4D input" + ); for (int64_t i = 2; i < input.dim(); i++) { TORCH_CHECK(input.size(i) > 0, "grid_sampler(): expected input to have non-empty spatial dimensions, " diff --git a/aten/src/ATen/native/GridSampler.h b/aten/src/ATen/native/GridSampler.h index ebafc9727061..effc322c0d3a 100644 --- a/aten/src/ATen/native/GridSampler.h +++ b/aten/src/ATen/native/GridSampler.h @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -139,14 +139,12 @@ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_l } } -// Computes the pixel source index value for a grid coordinate -template -static inline scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int64_t size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +// Mapping the out-of-boundary points back into boundary +// This would only affect padding_mode=border or reflection +template +static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -163,6 +161,18 @@ static inline scalar_t grid_sampler_compute_source_index( return coord; } +// Computes the pixel source index value for a grid coordinate +template +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + // grid_sampler_compute_source_index_set_grad works similarly to // grid_sampler_compute_source_index except that it also returns the // `d output / d input` via pointer argument `grad_in`. @@ -202,6 +212,30 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static inline scalar_t get_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, int64_t sH, int64_t sW, int64_t H, int64_t W, @@ -221,4 +255,47 @@ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, } } +template +static inline void add_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static inline void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 4dfe644b89a4..ece2d527e899 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -354,6 +354,10 @@ struct ComputeLocation return unnormalize(in); } + inline Vec compute_coordinates(const Vec &in) const { + return in; + } + inline std::pair apply_get_grad(const Vec &in) const { return std::make_pair(unnormalize(in), Vec(scaling_factor)); } @@ -374,6 +378,10 @@ struct ComputeLocation return clip_coordinates(unnormalize(in)); } + inline Vec compute_coordinates(const Vec &in) const { + return clip_coordinates(in); + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_clip; std::tie(res, grad_clip) = clip_coordinates_get_grad(unnormalize(in)); @@ -400,6 +408,12 @@ struct ComputeLocation return res; } + inline Vec compute_coordinates(const Vec &in) const { + auto res = reflect_coordinates(in); + res = clip_coordinates(res); + return res; + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_refl, grad_clip, grad(scaling_factor); std::tie(res, grad_refl) = reflect_coordinates_get_grad(unnormalize(in)); @@ -764,6 +778,202 @@ struct ApplyGridSample +struct ApplyGridSample { + using Vec = Vec256; + using integer_t = int_same_size_t; + using iVec = Vec256; + + const int64_t inp_H; + const int64_t inp_W; + const int64_t inp_sH; + const int64_t inp_sW; + const int64_t C; + const int64_t inp_sC; + const ComputeLocation compute_H; + const ComputeLocation compute_W; + const bool must_in_bound = padding != GridSamplerPadding::Zeros; + + // constant used in cubic convolution + // could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h + const Vec A = Vec(-0.75); + + ApplyGridSample(const TensorAccessor& input) + : inp_H(input.size(2)) + , inp_W(input.size(3)) + , inp_sH(input.stride(2)) + , inp_sW(input.stride(3)) + , C(input.size(1)) + , inp_sC(input.stride(1)) + , compute_H(input.size(2)) + , compute_W(input.size(3)) {} + + // Calculate the cubic convolution coefficient + inline void get_cubic_coefficients(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = tx + Vec(1); // 1 < x = |-1 - tx| < 2 + coeffs[0] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + x = tx; // x = |0 - tx| <= 1 + coeffs[1] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + } + + // Calculate the differential of the cubic convolution, i.e. `d coeff / d x` + inline void get_cubic_coefficients_grad(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = Vec(-1) - tx; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (Vec(-3) * A * x - Vec(10) * A ) * x - Vec(8) * A; + x = Vec(0) - tx; // x = |0 - tx| <= 1 + coeffs[1] = (Vec(-3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = (Vec(3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = (Vec(3) * A * x - Vec(10) * A) * x + Vec(8) * A; + } + + inline Vec get_value_bounded(const scalar_t* data, const Vec& x, const Vec& y) const { + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto offset = iy * iVec(inp_sH) + ix * iVec(inp_sW); + + auto val = mask_gather(Vec(0), data, offset, mask); + return val; + } + + inline void add_value_bounded(scalar_t* data, int64_t len, const Vec& x, const Vec&y, + const Vec& delta) const { + + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto i_gInp_offset = iy * iVec(inp_W) + ix; + integer_t i_gInp_offset_arr[iVec::size()]; + i_gInp_offset.store(i_gInp_offset_arr); + + integer_t mask_arr[iVec::size()]; + mask.store(mask_arr); + + scalar_t gInp_corner_arr[Vec::size()]; + delta.store(gInp_corner_arr); + + mask_scatter_add(gInp_corner_arr, data, i_gInp_offset_arr, mask_arr, len); + } + + inline void forward(TensorAccessor& out_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + auto x = compute_W.unnormalize(grid_x); + auto y = compute_H.unnormalize(grid_y); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + + // Interpolate the 4 values in the x direction + Vec interp_x[4]; + for (int64_t i = 0; i < 4; ++i) { + interp_x[i] = + coeff_x[0] * get_value_bounded(inp_slice_C_ptr, ix - Vec(1), iy + Vec(-1 + i)) + + coeff_x[1] * get_value_bounded(inp_slice_C_ptr, ix + Vec(0), iy + Vec(-1 + i)) + + coeff_x[2] * get_value_bounded(inp_slice_C_ptr, ix + Vec(1), iy + Vec(-1 + i)) + + coeff_x[3] * get_value_bounded(inp_slice_C_ptr, ix + Vec(2), iy + Vec(-1 + i)); + } + + // Interpolate the 4 values in the y direction + auto interpolated = coeff_y[0] * interp_x[0] + coeff_y[1] * interp_x[1] + + coeff_y[2] * interp_x[2] + coeff_y[3] * interp_x[3]; + interpolated.store(out_slice[c].data() + offset, len); + } + } + + inline void backward(TensorAccessor& gInp_slice, + TensorAccessor& gGrid_slice, + const TensorAccessor& gOut_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + Vec x = compute_W.unnormalize(grid_x); + Vec y = compute_H.unnormalize(grid_y); + Vec gx_mult = Vec(compute_W.scaling_factor); + Vec gy_mult = Vec(compute_H.scaling_factor); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + Vec coeff_x_grad[4]; + Vec coeff_y_grad[4]; + get_cubic_coefficients_grad(coeff_x_grad, x - ix); + get_cubic_coefficients_grad(coeff_y_grad, y - iy); + + auto gx = Vec(0), gy = Vec(0); + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + auto gInp_slice_C_ptr = gInp_slice[c].data(); + auto gOut = Vec::loadu(gOut_slice[c].data() + offset, len); + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + auto xx = ix + Vec(-1 + i); + auto yy = iy + Vec(-1 + j); + + add_value_bounded(gInp_slice_C_ptr, len, xx, yy, gOut * coeff_x[i] * coeff_y[j]); + + auto val = get_value_bounded(inp_slice_C_ptr, xx, yy); + gx = gx - val * gOut * coeff_x_grad[i] * coeff_y[j]; + gy = gy - val * gOut * coeff_y_grad[j] * coeff_x[i]; + } + } + } + + gx = gx * gx_mult; + gy = gy * gy_mult; + + constexpr int64_t step = Vec::size(); + auto interleaved_gGrid = interleave2(gx, gy); + auto gGrid_ptr = gGrid_slice.data() + offset * 2; + std::get<0>(interleaved_gGrid).store(gGrid_ptr, + std::min(len * 2, step)); + std::get<1>(interleaved_gGrid).store(gGrid_ptr + step, + std::max(static_cast(0), len * 2 - step)); + } +}; + // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~ // Function to apply a vectorized function on a grid slice tensor (without batch // dimension). @@ -940,11 +1150,13 @@ Tensor grid_sampler_2d_cpu_kernel_impl(const Tensor& input, const Tensor& grid, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); @@ -1014,11 +1226,13 @@ grid_sampler_2d_backward_cpu_kernel_impl(const Tensor& grad_output_, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index 023167109af2..7674e9137238 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -50,11 +51,11 @@ namespace { const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -105,6 +106,38 @@ namespace { *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + scalar_t coefficients[4]; + + for (index_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -300,13 +333,13 @@ namespace { const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; // multipliers for gradients on ix and iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -387,6 +420,57 @@ namespace { scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; gGrid_ptr_NHW[0] = static_cast(0); gGrid_ptr_NHW[1] = static_cast(0); + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsampling_coefficients(x_coeffs, tx); + get_cubic_upsampling_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + for (index_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (index_t i = 0; i < 4; ++i) { + for (index_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, + gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } diff --git a/aten/src/ATen/native/cuda/GridSampler.cuh b/aten/src/ATen/native/cuda/GridSampler.cuh index 4a94a3fda1bb..0c4acd1be41c 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cuh +++ b/aten/src/ATen/native/cuda/GridSampler.cuh @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -153,15 +153,11 @@ scalar_t safe_downgrade_to_int_range(scalar_t x){ return x; } -// Computes the pixel source index value for a grid coordinate -template +template static __forceinline__ __device__ -scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -176,7 +172,20 @@ scalar_t grid_sampler_compute_source_index( coord = clip_coordinates(coord, size); } - coord = safe_downgrade_to_int_range(coord); + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static __forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); return coord; } @@ -224,6 +233,25 @@ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static __forceinline__ __device__ +scalar_t get_value_bounded( + scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static __forceinline__ __device__ void safe_add_2d(scalar_t *data, int h, int w, @@ -244,4 +272,44 @@ void safe_add_3d(scalar_t *data, int d, int h, int w, } } +template +static __forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static __forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + }} // namespace at::native diff --git a/test/test_nn.py b/test/test_nn.py index 15b877391c14..2ce752aa0eb8 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7167,6 +7167,9 @@ def test_grid_sample_error_checking(self): with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"): F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False) + with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"): + F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic') + if TEST_CUDA: with self.assertRaisesRegex(RuntimeError, "expected input and grid to be on same device"): F.grid_sample(input.cuda(), grid, align_corners=False) @@ -7299,8 +7302,8 @@ def get_grid(device='cpu', data=None): self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) out_fallback.backward(gradients.float()) - self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-5, rtol=5e-5) - self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-5, rtol=5e-5) + self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) + self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) if TEST_CUDA: input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_() @@ -7378,7 +7381,7 @@ def get_grid(device='cpu', data=None): W = random.randint(3, IW + 2) test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) - for mode in ('bilinear', 'nearest'): + for mode in ('bilinear', 'nearest', 'bicubic'): for padding_mode in ('zeros', 'border', 'reflection'): for align_corners in (True, False): # test known input on CPU @@ -7446,6 +7449,37 @@ def get_grid(device='cpu', data=None): [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5) else: raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], + [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], + [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], + [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], + [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], + [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], + [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + else: raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode)) output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, @@ -7501,11 +7535,42 @@ def get_grid(device='cpu', data=None): groundtruth = torch.tensor( [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]], + [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]], + [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]] + ).view(1, 2, 4, 2) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]], + [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]], + [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]], + [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]], + [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2) + else: + raise AssertionError("missing gradient groundtruth test for padding mode '{}'".format(padding_mode)) else: raise AssertionError("missing gradient groundtruth test for interpolation mode '{}'".format(mode)) F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth, + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0, msg="gradient groundtruth comparison failed for mode={}, " "padding_mode={}".format(mode, padding_mode)) @@ -7516,7 +7581,7 @@ def get_grid(device='cpu', data=None): F.GRID_SAMPLE_INTERPOLATION_MODES[mode], F.GRID_SAMPLE_PADDING_MODES[padding_mode], align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth) + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0) # do gradcheck N = random.randint(2, 8) @@ -11075,7 +11140,7 @@ def test_grid_sample_large_index_2d(self, device, dtype): sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31, msg="View must use 64-bit indexing") for mode, padding_mode, align_corners in itertools.product( - ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)): + ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)): a = F.grid_sample( small_image, coords, mode=mode, padding_mode=padding_mode, align_corners=align_corners) diff --git a/torch/csrc/api/include/torch/nn/functional/vision.h b/torch/csrc/api/include/torch/nn/functional/vision.h index e1041cb21d8c..1fe084d02c79 100644 --- a/torch/csrc/api/include/torch/nn/functional/vision.h +++ b/torch/csrc/api/include/torch/nn/functional/vision.h @@ -61,8 +61,10 @@ inline Tensor grid_sample( if (c10::get_if(&mode)) { mode_enum = 0; - } else { /// mode == 'nearest' + } else if (c10::get_if(&mode)) { mode_enum = 1; + } else { /// mode == 'bicubic' + mode_enum = 2; } if (c10::get_if(&padding_mode)) { diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 6253f9ddda7b..031d974d4973 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3227,6 +3227,7 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 GRID_SAMPLE_INTERPOLATION_MODES = { 'bilinear': 0, 'nearest': 1, + 'bicubic': 2, } GRID_SAMPLE_PADDING_MODES = { @@ -3293,8 +3294,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values - ``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'`` - Note: When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. padding_mode (str): padding mode for outside grid values @@ -3324,6 +3326,17 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner The default behavior up to version 1.2.0 was ``align_corners = True``. Since then, the default behavior has been changed to ``align_corners = False``, in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func: `torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 """ if not torch.jit.is_scripting(): tens_ops = (input, grid) @@ -3331,9 +3344,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner return handle_torch_function( grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners) - if mode != 'bilinear' and mode != 'nearest': + if mode != 'bilinear' and mode != 'nearest' and mode != 'bicubic': raise ValueError("nn.functional.grid_sample(): expected mode to be " - "'bilinear' or 'nearest', but got: '{}'".format(mode)) + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode)) if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection': raise ValueError("nn.functional.grid_sample(): expected padding_mode " "to be 'zeros', 'border', or 'reflection', " @@ -3341,8 +3354,10 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner if mode == 'bilinear': mode_enum = 0 - else: # mode == 'nearest' + elif mode == 'nearest': mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 if padding_mode == 'zeros': padding_mode_enum = 0 From 31ebac3eb795c163f9e90288ad4e5d50e8adbb74 Mon Sep 17 00:00:00 2001 From: Zafar Date: Tue, 3 Nov 2020 15:33:13 -0800 Subject: [PATCH 5/5] [quant] Quantized flip dispatch (#46235) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46235 Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D24689161 Pulled By: z-a-f fbshipit-source-id: 6833c2639b29ea5f6c81c880b8928c5a1951c7b8 --- .../src/ATen/native/TensorTransformations.cpp | 33 ++++++++++++++----- aten/src/ATen/native/native_functions.yaml | 2 +- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 1b86b3f2d634..fdee519c4bd0 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] { - flip_cpu_kernel( - total_dims, - stride_contiguous_v, - flip_dims_b, - in_tensor, - out_tensor - ); - }); + if (in_tensor.is_quantized()) { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(), + "flip_quantized_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, + in_tensor.scalar_type(), + "flip_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } return out_tensor; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 398aa7474eab..349256477df3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3882,7 +3882,7 @@ use_c10_dispatcher: full variants: function, method dispatch: - CPU: flip_cpu + CPU, QuantizedCPU: flip_cpu CUDA: flip_cuda - func: fliplr(Tensor self) -> Tensor