Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
return grad_output * sigmoid * (1 + self * (1 - sigmoid))


@register_decomposition(aten.trace)
def trace(x):
return torch.sum(torch.diagonal(x))


@register_decomposition(aten.softshrink_backward)
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
return torch.where((self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output)
Expand Down
12 changes: 12 additions & 0 deletions functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.autograd.forward_ad as fwAD

from .vmap import vmap
from .decompositions import decomposition_table


from functorch._C import (
_wrap_for_grad,
Expand Down Expand Up @@ -1269,3 +1271,13 @@ def wrapped(*args, **kwargs):
finally:
_func_decrement_nesting()
return wrapped


def _register_jit_decomposition(decomp):
assert decomp in decomposition_table, f"could not find {decomp}"
decomp_fn = decomposition_table[decomp]
scripted_decomp_fn = torch.jit.script(decomp_fn)
torch.jit._register_decomposition(decomp, scripted_decomp_fn.graph)


_register_jit_decomposition(torch.ops.aten.trace.default)
10 changes: 9 additions & 1 deletion functorch/csrc/BatchRulesHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// LICENSE file in the root directory of this source tree.

#include <functorch/csrc/BatchRulesHelper.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
#include <ATen/WrapDimUtils.h>

namespace at { namespace functorch {
Expand Down Expand Up @@ -63,7 +64,7 @@ VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArray
result.push_back(maybe_wrap_dim(d, rank)+1);
} else {
result.push_back(maybe_wrap_dim(d, rank));
}
}
}
return result;
}
Expand Down Expand Up @@ -132,4 +133,11 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
"please file a bug report instead.");
}

void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
}

}}
6 changes: 6 additions & 0 deletions functorch/csrc/BatchRulesHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
#define VARIADIC_BDIMS_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());

void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);

#define RUN_JIT_DECOMPOSITION(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());


using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;

inline void find_and_unpack_tensors(
Expand Down
6 changes: 1 addition & 5 deletions functorch/csrc/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
return std::make_tuple(at::_unsafe_view(self_, view_size), 0);
}

Tensor trace_decomp(const Tensor& self) {
return at::sum(at::diagonal(self));
}

std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims) {
auto self_ = moveBatchDimToFront(self, self_bdim);
VmapDimVector new_dims;
Expand Down Expand Up @@ -511,7 +507,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(chunk, chunk_batching_rule);
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", trace_decomp);
RUN_JIT_DECOMPOSITION(trace)
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(repeat, repeat_batch_rule);
Expand Down