Skip to content

Commit

Permalink
Add string versions of argument funcs in jit Node (#45464)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45464

Usage of Symbols to find arguments requires one to generate a nonsense symbol for inputs which don't already have one. The intention of symbols appears to be something of an internalized string, but the namespace component doesn't apply to an argument. In order to access the arguments by name without adding new symbols, versions of those functions with std::string input was added. These can be proved valid based on the existing codepath. Additionally, a hasNamedInput convenience function was added to remove the necessity of a try/catch block in user code.

The primary motivation is to be able to easily handle the variable number of arguments in glow, so that the arange op may be implemented.

Reviewed By: eellison

Differential Revision: D23972315

fbshipit-source-id: 3e0b41910cf07e916186f1506281fb221725a91b
  • Loading branch information
spaugh authored and facebook-github-bot committed Oct 2, 2020
1 parent b234acd commit cdf93b0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
26 changes: 22 additions & 4 deletions torch/csrc/jit/ir/ir.cpp
Expand Up @@ -846,22 +846,40 @@ void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
uses_.end());
}

size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
auto name_str = name.toUnqualString();
size_t findArgument(
const FunctionSchema& the_schema,
const std::string& unqualName) {
for (size_t i = 0; i < the_schema.arguments().size(); ++i) {
const Argument* arg = &the_schema.arguments()[i];
if (arg->name() == name_str) {
if (arg->name() == unqualName) {
return i;
}
}
throw std::runtime_error(
std::string("Couldn't find an argument called ") + name.toQualString());
std::string("Couldn't find an argument called ") + unqualName);
}

size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
const auto unqualName = name.toUnqualString();
return findArgument(the_schema, unqualName);
}

c10::optional<IValue> Node::get(Symbol name) const {
return toIValue(namedInput(name));
}

bool Node::hasNamedInput(const std::string& name) const {
for (const auto& argument : schema().arguments()) {
if (argument.name() == name) {
return true;
}
}
return false;
}

Value* Node::namedInput(const std::string& unqualName) const {
return input(findArgument(schema(), unqualName));
}
Value* Node::namedInput(Symbol name) const {
return input(findArgument(schema(), name));
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/ir/ir.h
Expand Up @@ -414,6 +414,8 @@ struct TORCH_API Node {
return inputs_.at(i);
}

bool hasNamedInput(const std::string& unqualName) const;
Value* namedInput(const std::string& unqualName) const;
Value* namedInput(Symbol name) const;

c10::optional<IValue> get(Symbol name) const;
Expand Down

0 comments on commit cdf93b0

Please sign in to comment.