Skip to content

Commit

Permalink
[ONNX] Fix param names (#50764)
Browse files Browse the repository at this point in the history
Preserve name of parameters for ONNX.

Looks like output->copyMetadata(input) API is giving the same debugName to the output. So the name of the original input is changed. This update avoid the name change by just copying the type.
  • Loading branch information
neginraoof committed Jan 22, 2021
1 parent 9145d90 commit 891c0e7
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,29 +571,27 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
<< "Warning: ONNX Preprocess - Removing mutation on block inputs. "
<< "This changes graph semantics." << std::endl;

Node* newNode = nullptr;
if (input->type()->kind() == TypeKind::ListType) {
// Create an aten::list to clone the list in graph inputs
auto newNode = node->owningGraph()->create(aten::list, 1);
newNode->output()->copyMetadata(input);
newNode = node->owningGraph()->create(aten::list, 1);
newNode->output()->setType(input->type());
newNode->addInput(input);
newNode->insertBefore(node);
node->replaceInput(index, newNode->output());
input->replaceAllUsesAfterNodeWith(node, newNode->output());
b->prependNode(newNode);
} else {
// Create an aten::clone to clone the tensor in graph inputs
auto newNode = node->owningGraph()->create(aten::clone, 1);
newNode->output()->copyMetadata(input);
newNode = node->owningGraph()->create(aten::clone, 1);
newNode->output()->setType(input->type());
newNode->addInput(input);

auto* noneNode = node->owningGraph()->create(prim::Constant);
noneNode->output()->setType(NoneType::get());
newNode->addInput(noneNode->output());

newNode->insertBefore(node);
b->prependNode(newNode);
noneNode->insertBefore(newNode);
node->replaceInput(index, newNode->output());
input->replaceAllUsesAfterNodeWith(node, newNode->output());
}
node->replaceInput(index, newNode->output());
input->replaceAllUsesAfterNodeWith(node, newNode->output());
}
}
}
Expand Down

0 comments on commit 891c0e7

Please sign in to comment.