diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 44dc9ad163ca..cfa4da9f383e 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -102,12 +102,14 @@ void Node::ReplaceAllUsesWith(NodePtr node, size_t index) { } XlaOpVector Node::ReturnOp(xla::XlaOp op, LoweringContext* loctx) const { + XLA_CHECK_EQ(num_outputs(), 1); loctx->AssignOutputOp(Output(this), op); return XlaOpVector({std::move(op)}); } XlaOpVector Node::ReturnOps(tensorflow::gtl::ArraySlice ops, LoweringContext* loctx) const { + XLA_CHECK_EQ(num_outputs(), ops.size()); XlaOpVector result; for (size_t i = 0; i < ops.size(); ++i) { loctx->AssignOutputOp(Output(this, i), ops[i]); diff --git a/torch_xla/csrc/ops/conv2d_backward.cpp b/torch_xla/csrc/ops/conv2d_backward.cpp index a289c675804c..56b11c0732dc 100644 --- a/torch_xla/csrc/ops/conv2d_backward.cpp +++ b/torch_xla/csrc/ops/conv2d_backward.cpp @@ -49,7 +49,7 @@ Conv2dBackward::Conv2dBackward( : Node(ir::OpKind(at::aten::thnn_conv2d_backward), {grad_output, input, weight}, NodeOutputShape(grad_output, input, weight, stride, padding), - /*num_outputs=*/1, xla::util::MHash(stride, padding)), + /*num_outputs=*/3, xla::util::MHash(stride, padding)), stride_(stride.begin(), stride.end()), padding_(padding.begin(), padding.end()), precision_(MakePrecisionConfig(use_full_conv_precision)) {}