diff --git a/functorch/_src/decompositions.py b/functorch/_src/decompositions.py index ad1cd195c..ea3353f5d 100644 --- a/functorch/_src/decompositions.py +++ b/functorch/_src/decompositions.py @@ -204,6 +204,11 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: return grad_output * sigmoid * (1 + self * (1 - sigmoid)) +@register_decomposition(aten.trace) +def trace(x): + return torch.sum(torch.diagonal(x)) + + @register_decomposition(aten.softshrink_backward) def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor: return torch.where((self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output) diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index fab2b8cbd..65577d9a9 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -13,6 +13,8 @@ import torch.autograd.forward_ad as fwAD from .vmap import vmap +from .decompositions import decomposition_table + from functorch._C import ( _wrap_for_grad, @@ -1269,3 +1271,13 @@ def wrapped(*args, **kwargs): finally: _func_decrement_nesting() return wrapped + + +def _register_jit_decomposition(decomp): + assert decomp in decomposition_table, f"could not find {decomp}" + decomp_fn = decomposition_table[decomp] + scripted_decomp_fn = torch.jit.script(decomp_fn) + torch.jit._register_decomposition(decomp, scripted_decomp_fn.graph) + + +_register_jit_decomposition(torch.ops.aten.trace.default) diff --git a/functorch/csrc/BatchRulesHelper.cpp b/functorch/csrc/BatchRulesHelper.cpp index 81a39d243..adf5c4b4a 100644 --- a/functorch/csrc/BatchRulesHelper.cpp +++ b/functorch/csrc/BatchRulesHelper.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include +#include #include namespace at { namespace functorch { @@ -63,7 +64,7 @@ VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArray result.push_back(maybe_wrap_dim(d, rank)+1); } else { result.push_back(maybe_wrap_dim(d, rank)); - } + } } return result; } @@ -132,4 +133,11 @@ void vmapIncompatibleInplaceError(const char* schema_name) { "please file a bug report instead."); } +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + // TODO: templatize based on op and keep static trace_exec + auto * trace_exec = torch::jit::GetDecompositionExecutor(schema); + trace_exec->run((*stack)); +} + }} diff --git a/functorch/csrc/BatchRulesHelper.h b/functorch/csrc/BatchRulesHelper.h index 4a2a93451..1703669cc 100644 --- a/functorch/csrc/BatchRulesHelper.h +++ b/functorch/csrc/BatchRulesHelper.h @@ -192,6 +192,12 @@ inline void handle_variadic_bdims(std::vector>()); +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +#define RUN_JIT_DECOMPOSITION(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>()); + + using UnpackedBatchedTensor = std::tuple>; inline void find_and_unpack_tensors( diff --git a/functorch/csrc/BatchRulesViews.cpp b/functorch/csrc/BatchRulesViews.cpp index dca1d9e18..711afec35 100644 --- a/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/csrc/BatchRulesViews.cpp @@ -151,10 +151,6 @@ std::tuple> _unsafe_view_batch_rule( return std::make_tuple(at::_unsafe_view(self_, view_size), 0); } -Tensor trace_decomp(const Tensor& self) { - return at::sum(at::diagonal(self)); -} - std::tuple> flip_batch_rule(const Tensor& self, optional self_bdim, IntArrayRef dims) { auto self_ = moveBatchDimToFront(self, self_bdim); VmapDimVector new_dims; @@ -511,7 +507,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT(chunk, chunk_batching_rule); m.impl("flatten.using_ints", static_cast(native::flatten)); VMAP_SUPPORT(flip, flip_batch_rule); - m.impl("trace", trace_decomp); + RUN_JIT_DECOMPOSITION(trace) VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril))); VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu))); VMAP_SUPPORT(repeat, repeat_batch_rule);