Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions functorch/csrc/ATenDecompositions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <ATen/ATen.h>

namespace at { namespace functorch {

// TODO: Figure out how to delete all of these and replace with
// with the "official" decompositions that are written in Python.

inline Tensor nll_loss_backward_decomp(
const at::Tensor & grad_output, const at::Tensor & self,
const at::Tensor & target, const c10::optional<at::Tensor> & weight,
int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {

int64_t channel_dim = 1;
if (self.dim() < 2) {
channel_dim = 0;
}
auto self_ = self;
auto target_ = target.unsqueeze(channel_dim);

auto grad_output_ = grad_output;
if (reduction == Reduction::Mean) {
grad_output_ = grad_output_ / total_weight;
}

auto grad_input = at::zeros_like(self);
grad_input = at::scatter(grad_input, channel_dim, target_, -1.0);

if (grad_output_.dim() < grad_input.dim() && grad_output_.dim() > 0) {
grad_output_ = grad_output_.unsqueeze(channel_dim);
}

Tensor weight_;
if (weight && weight->defined()) {
auto shape = weight->sizes();
std::vector<int64_t> new_shape(self_.dim(), 1);
new_shape[channel_dim] = shape[0];
weight_ = weight->reshape(new_shape);
grad_output_ = grad_output_ * weight_;
}

bool has_ignore_index = ignore_index >= 0;
Tensor ignore_index_mask;
if (has_ignore_index) {
ignore_index_mask = target_ != ignore_index;
grad_output_ = grad_output_ * ignore_index_mask;
}

return grad_input * grad_output_;
}

}} // namespace at::functorch
139 changes: 94 additions & 45 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <torch/csrc/autograd/variable.h>
#include <c10/util/irange.h>
#include <ATen/FuncTorchTLS.h>
#include <functorch/csrc/ATenDecompositions.h>

namespace at {
namespace functorch {
Expand Down Expand Up @@ -159,12 +160,6 @@ struct SaveLocalDispatchKeySet {
SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete;
};

static c10::impl::ForceDispatchKeyGuard
restoreLocalDispatchKeySetRAII(const DynamicLayer& layer) {
auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
return c10::impl::ForceDispatchKeyGuard(tmp);
}

const std::vector<DynamicLayer>& getDynamicLayerStack() {
return dynamicLayerStackAccessor();
}
Expand Down Expand Up @@ -330,38 +325,6 @@ std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls
return os;
}

static bool allTensors(
ArrayRef<IValue> args,
std::function<bool(const Tensor&)> pred) {
for (const auto& ivalue : args) {
// Tensor?[] translates to a c10::List<IValue> so we need to peek inside List
if (ivalue.isList()) {
for (const auto& elt : ivalue.toListRef()) {
if (elt.isTensor() && !pred(elt.toTensor())) {
return false;
}
}
continue;
}
if (ivalue.isTensorList()) {
for (const auto& elt : ivalue.toTensorList()) {
if (!pred(elt)) {
return false;
}
}
continue;
}
TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
if (!ivalue.isTensor()) {
continue;
}
if (!pred(ivalue.toTensor())) {
return false;
}
}
return true;
}

bool isInplaceOp(const FunctionSchema& schema) {
if (!schema.is_mutable() || schema.returns().size() != 1) {
return false;
Expand Down Expand Up @@ -403,20 +366,89 @@ WithoutTop::~WithoutTop() {
pushDynamicLayer(std::move(layer_));
}

void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// NOTE: [forward-mode AD decompositions hack]
//
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
// jvp transform, AND we have a decomposition for the operation, then run
// the decomposition.
//
// Let's break that down. There are a douple of moving pieces.
//
// 0. How do we know what transform we're dispatching on?
// Easy, check the top of the DynamicLayerStack and read the transform.
//
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
// gets dispatched to. The slow way to do this is to str check
// OperatorHandle::schema::name. We do something a little faster, which is:
// - register a special kernel to the DynamicLayerFrontMode key
// (see FALLBACK_WITH_ID)
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
// a special enum value (ATenOpId) that identifies the operation.
//
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
// The decompositions are written in C++ right now, but we really want to just
// reuse the decompositions that we have in Python (because those are actually
// tested).

// Ideally c10::OperatorHandle would have a field like this
// to identify the operator.
// The stuff here should map 1:1 with the operator name.
// aten::nll_loss_backward -> nll_loss_backward
// aten::add.Tensor -> add_Tensor
enum class ATenOpId {
nll_loss_backward,
nll_loss2d_backward,
};

static void call_decomposition_for_jvp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
ATenOpId op_id) {
switch (op_id) {
case ATenOpId::nll_loss2d_backward:
case ATenOpId::nll_loss_backward: {
ArrayRef<IValue> args = torch::jit::last(stack, 7);
auto result = nll_loss_backward_decomp(
args[0].toTensor(),
args[1].toTensor(),
args[2].toTensor(),
args[3].toTensor(),
args[4].toInt(),
args[5].toInt(),
args[6].toTensor()
);
torch::jit::pop(*stack, 7);
torch::jit::push(stack, result);
return;
}
default:
TORCH_INTERNAL_ASSERT(false);
}
}

static void dynamicLayerFrontFallbackOperator(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
optional<ATenOpId> maybe_op_id) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
std::cout << dynamicLayerStack << std::endl;
dump_local_tls();
}
#endif

// Hack: if jvp and we have a decomposition registered, then do the decomposition
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
maybe_op_id.has_value()) {
return call_decomposition_for_jvp(op, stack, *maybe_op_id);
}

// Save the current LocalDispatchKeySet (to the current DynamicLayer).
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
// it will also temporarily restore the saved LocalDispatchKeySet.
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
SaveLocalDispatchKeySet guard;

// Unwrap escaped GradWrappers
Expand All @@ -432,6 +464,17 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
return c10::impl::ForceDispatchKeyGuard(key_set);
}

void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, nullopt);
}

template <ATenOpId op_id>
void dynamicLayerFrontFallbackWithOpId(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, op_id);
}

void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& layer = dynamicLayerStackAccessor().back();
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
Expand All @@ -448,11 +491,17 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}

// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
// m.impl("_unwrap_for_grad", native::_unwrap_for_grad);
// m.impl("dump_tensor", native::dump_tensor);
// m.impl("dlevel", native::dlevel);
// }
#define FALLBACK_WITH_ID(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallbackWithOpId<ATenOpId::op>>());

#define FALLBACK_WITH_ID2(op, overload) \
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallbackWithOpId<ATenOpId::op ## _ ## overload>>());

TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
FALLBACK_WITH_ID(nll_loss_backward);
FALLBACK_WITH_ID(nll_loss2d_backward);
}


}
} // namespace at
27 changes: 25 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,6 @@ def test_vjpvmap(self, device, dtype, op):
xfail('nn.functional.layer_norm', ''),
xfail('nn.functional.logsigmoid', ''),
xfail('nn.functional.mse_loss', ''),
xfail('nn.functional.nll_loss', ''),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.prelu', ''),
xfail('nn.functional.softmin', ''),
Expand Down Expand Up @@ -1254,7 +1253,31 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
return expected

expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
def double_backward_trick_reference(primals, cotangents, primals_tangents, cotangents_tangents):
def f(*all_inputs):
p = all_inputs[:len(primals)]
c = all_inputs[len(primals):]
_, vjp_fn = ref_vjp(fn, *p)
return vjp_fn(c)

flat_primals, _ = tree_flatten((primals, cotangents))
flat_tangents, _ = tree_flatten((primals_tangents, cotangents_tangents))
flat_primals = tuple(flat_primals)
flat_tangents = tuple(flat_tangents)

# doesn't actually invoke forward-mode AD, it does the
# "double backward trick"
result = torch.autograd.functional.jvp(f, flat_primals, v=flat_tangents)
return result

# HACK: obviously pytorch should also have the same coverage
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
'nn.functional.nll_loss',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
expected = double_backward_trick_reference(primals, cotangents, primals_tangents, cotangents_tangents)
else:
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)


Expand Down