Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions torch_patches/16851.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
commit 4962f54ee5ee02edb5da45ea14b83a1beb25d30d
Author: Alex Şuhan <asuhan@google.com>
Date: Thu Feb 7 12:26:16 2019 -0800

Fix autodiff of nll_loss

diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 8122aeafc..c43669d8d 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -129,7 +129,7 @@ bool isDifferentiable(Node* n) {
if (n->matches(
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
// TODO(asuhan): support weight
- return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
+ return n->namedInput(attr::weight)->node()->kind() == prim::None;
}

// linear blocks may appear as inputs to graph executors, but they are removed
@@ -717,7 +717,7 @@ class GradientHelper {
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
auto graph = node->owningGraph();
auto total_weight = graph->insertNode(graph->createUndefined());
- auto weight = graph->insertNode(graph->createUndefined());
+ auto weight = graph->insertNode(graph->createNone(TensorType::get()));
auto backward_value = graph->insert(
aten::nll_loss_backward,
{grads.at(0).value(),
1 change: 1 addition & 0 deletions torch_xla/csrc/translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ CreateTranslationHandlers() {
(*t)[at::aten::size] = TranslateSize;
(*t)[at::prim::Constant] = TranslateConstant;
(*t)[at::prim::Undefined] = TranslateUndefined;
(*t)[at::prim::None] = TranslateUndefined;
(*t)[at::aten::_grad_sum_to_size] = TranslateGradSumToSize;
(*t)[at::prim::ListConstruct] = TranslateNop;
return t;
Expand Down