diff --git a/functorch/csrc/ATenDecompositions.h b/functorch/csrc/ATenDecompositions.h new file mode 100644 index 000000000..9c121c796 --- /dev/null +++ b/functorch/csrc/ATenDecompositions.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +namespace at { namespace functorch { + +// TODO: Figure out how to delete all of these and replace with +// with the "official" decompositions that are written in Python. + +inline Tensor nll_loss_backward_decomp( + const at::Tensor & grad_output, const at::Tensor & self, + const at::Tensor & target, const c10::optional & weight, + int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + + int64_t channel_dim = 1; + if (self.dim() < 2) { + channel_dim = 0; + } + auto self_ = self; + auto target_ = target.unsqueeze(channel_dim); + + auto grad_output_ = grad_output; + if (reduction == Reduction::Mean) { + grad_output_ = grad_output_ / total_weight; + } + + auto grad_input = at::zeros_like(self); + grad_input = at::scatter(grad_input, channel_dim, target_, -1.0); + + if (grad_output_.dim() < grad_input.dim() && grad_output_.dim() > 0) { + grad_output_ = grad_output_.unsqueeze(channel_dim); + } + + Tensor weight_; + if (weight && weight->defined()) { + auto shape = weight->sizes(); + std::vector new_shape(self_.dim(), 1); + new_shape[channel_dim] = shape[0]; + weight_ = weight->reshape(new_shape); + grad_output_ = grad_output_ * weight_; + } + + bool has_ignore_index = ignore_index >= 0; + Tensor ignore_index_mask; + if (has_ignore_index) { + ignore_index_mask = target_ != ignore_index; + grad_output_ = grad_output_ * ignore_index_mask; + } + + return grad_input * grad_output_; +} + +}} // namespace at::functorch diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 6d96d52aa..75b4656bc 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace at { namespace functorch { @@ -159,12 +160,6 @@ struct SaveLocalDispatchKeySet { SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete; }; -static c10::impl::ForceDispatchKeyGuard -restoreLocalDispatchKeySetRAII(const DynamicLayer& layer) { - auto tmp = layer.interpreter().getSavedLocalDispatchKeySet(); - return c10::impl::ForceDispatchKeyGuard(tmp); -} - const std::vector& getDynamicLayerStack() { return dynamicLayerStackAccessor(); } @@ -330,38 +325,6 @@ std::ostream& operator<< (std::ostream& os, const std::vector& dls return os; } -static bool allTensors( - ArrayRef args, - std::function pred) { - for (const auto& ivalue : args) { - // Tensor?[] translates to a c10::List so we need to peek inside List - if (ivalue.isList()) { - for (const auto& elt : ivalue.toListRef()) { - if (elt.isTensor() && !pred(elt.toTensor())) { - return false; - } - } - continue; - } - if (ivalue.isTensorList()) { - for (const auto& elt : ivalue.toTensorList()) { - if (!pred(elt)) { - return false; - } - } - continue; - } - TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict"); - if (!ivalue.isTensor()) { - continue; - } - if (!pred(ivalue.toTensor())) { - return false; - } - } - return true; -} - bool isInplaceOp(const FunctionSchema& schema) { if (!schema.is_mutable() || schema.returns().size() != 1) { return false; @@ -403,8 +366,72 @@ WithoutTop::~WithoutTop() { pushDynamicLayer(std::move(layer_)); } -void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { +// NOTE: [forward-mode AD decompositions hack] +// +// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the +// jvp transform, AND we have a decomposition for the operation, then run +// the decomposition. +// +// Let's break that down. There are a douple of moving pieces. +// +// 0. How do we know what transform we're dispatching on? +// Easy, check the top of the DynamicLayerStack and read the transform. +// +// 1. Next, we must identify when an operation (e.g. nll_loss_backward) +// gets dispatched to. The slow way to do this is to str check +// OperatorHandle::schema::name. We do something a little faster, which is: +// - register a special kernel to the DynamicLayerFrontMode key +// (see FALLBACK_WITH_ID) +// - that special kernel invokes dynamicLayerFrontFallbackOperator with +// a special enum value (ATenOpId) that identifies the operation. +// +// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp. +// The decompositions are written in C++ right now, but we really want to just +// reuse the decompositions that we have in Python (because those are actually +// tested). + +// Ideally c10::OperatorHandle would have a field like this +// to identify the operator. +// The stuff here should map 1:1 with the operator name. +// aten::nll_loss_backward -> nll_loss_backward +// aten::add.Tensor -> add_Tensor +enum class ATenOpId { + nll_loss_backward, + nll_loss2d_backward, +}; + +static void call_decomposition_for_jvp( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + ATenOpId op_id) { + switch (op_id) { + case ATenOpId::nll_loss2d_backward: + case ATenOpId::nll_loss_backward: { + ArrayRef args = torch::jit::last(stack, 7); + auto result = nll_loss_backward_decomp( + args[0].toTensor(), + args[1].toTensor(), + args[2].toTensor(), + args[3].toTensor(), + args[4].toInt(), + args[5].toInt(), + args[6].toTensor() + ); + torch::jit::pop(*stack, 7); + torch::jit::push(stack, result); + return; + } + default: + TORCH_INTERNAL_ASSERT(false); + } +} + +static void dynamicLayerFrontFallbackOperator( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + optional maybe_op_id) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE if (c10::show_dispatch_trace_enabled()) { std::cout << dynamicLayerStack << std::endl; @@ -412,11 +439,16 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* } #endif + // Hack: if jvp and we have a decomposition registered, then do the decomposition + if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp && + maybe_op_id.has_value()) { + return call_decomposition_for_jvp(op, stack, *maybe_op_id); + } + // Save the current LocalDispatchKeySet (to the current DynamicLayer). // Upon exiting the current scope, that LocalDispatchKeySet gets restored. // When the current DynamicLayer dispatches to the next (inner) DynamicLayer, // it will also temporarily restore the saved LocalDispatchKeySet. - TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); SaveLocalDispatchKeySet guard; // Unwrap escaped GradWrappers @@ -432,6 +464,17 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) { return c10::impl::ForceDispatchKeyGuard(key_set); } +void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, nullopt); +} + +template +void dynamicLayerFrontFallbackWithOpId( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, op_id); +} + void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& layer = dynamicLayerStackAccessor().back(); auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet()); @@ -448,11 +491,17 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); } -// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) { -// m.impl("_unwrap_for_grad", native::_unwrap_for_grad); -// m.impl("dump_tensor", native::dump_tensor); -// m.impl("dlevel", native::dlevel); -// } +#define FALLBACK_WITH_ID(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallbackWithOpId>()); + +#define FALLBACK_WITH_ID2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallbackWithOpId>()); + +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + FALLBACK_WITH_ID(nll_loss_backward); + FALLBACK_WITH_ID(nll_loss2d_backward); +} + } } // namespace at diff --git a/test/test_ops.py b/test/test_ops.py index 9308b6b90..f76b59bea 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1162,7 +1162,6 @@ def test_vjpvmap(self, device, dtype, op): xfail('nn.functional.layer_norm', ''), xfail('nn.functional.logsigmoid', ''), xfail('nn.functional.mse_loss', ''), - xfail('nn.functional.nll_loss', ''), xfail('nn.functional.pad', 'circular'), xfail('nn.functional.prelu', ''), xfail('nn.functional.softmin', ''), @@ -1254,7 +1253,31 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)) return expected - expected = reference(primals, cotangents, primals_tangents, cotangents_tangents) + def double_backward_trick_reference(primals, cotangents, primals_tangents, cotangents_tangents): + def f(*all_inputs): + p = all_inputs[:len(primals)] + c = all_inputs[len(primals):] + _, vjp_fn = ref_vjp(fn, *p) + return vjp_fn(c) + + flat_primals, _ = tree_flatten((primals, cotangents)) + flat_tangents, _ = tree_flatten((primals_tangents, cotangents_tangents)) + flat_primals = tuple(flat_primals) + flat_tangents = tuple(flat_tangents) + + # doesn't actually invoke forward-mode AD, it does the + # "double backward trick" + result = torch.autograd.functional.jvp(f, flat_primals, v=flat_tangents) + return result + + # HACK: obviously pytorch should also have the same coverage + FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = { + 'nn.functional.nll_loss', + } + if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH: + expected = double_backward_trick_reference(primals, cotangents, primals_tangents, cotangents_tangents) + else: + expected = reference(primals, cotangents, primals_tangents, cotangents_tangents) self.assertEqual(result, expected)