diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index b602542b3ed3..e0b7e15556eb 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -846,22 +846,40 @@ void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) { uses_.end()); } -size_t findArgument(const FunctionSchema& the_schema, Symbol name) { - auto name_str = name.toUnqualString(); +size_t findArgument( + const FunctionSchema& the_schema, + const std::string& unqualName) { for (size_t i = 0; i < the_schema.arguments().size(); ++i) { const Argument* arg = &the_schema.arguments()[i]; - if (arg->name() == name_str) { + if (arg->name() == unqualName) { return i; } } throw std::runtime_error( - std::string("Couldn't find an argument called ") + name.toQualString()); + std::string("Couldn't find an argument called ") + unqualName); +} + +size_t findArgument(const FunctionSchema& the_schema, Symbol name) { + const auto unqualName = name.toUnqualString(); + return findArgument(the_schema, unqualName); } c10::optional Node::get(Symbol name) const { return toIValue(namedInput(name)); } +bool Node::hasNamedInput(const std::string& name) const { + for (const auto& argument : schema().arguments()) { + if (argument.name() == name) { + return true; + } + } + return false; +} + +Value* Node::namedInput(const std::string& unqualName) const { + return input(findArgument(schema(), unqualName)); +} Value* Node::namedInput(Symbol name) const { return input(findArgument(schema(), name)); } diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 665bd9797b26..dbd9fb5ca755 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -414,6 +414,8 @@ struct TORCH_API Node { return inputs_.at(i); } + bool hasNamedInput(const std::string& unqualName) const; + Value* namedInput(const std::string& unqualName) const; Value* namedInput(Symbol name) const; c10::optional get(Symbol name) const;