Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] add tuple keyword #25474

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/test_jit.py
Expand Up @@ -3311,6 +3311,25 @@ def stuff(x):
a = (torch.rand(3), torch.rand(3))
self.checkScript(stuff, (a,))

def test_tuple_keyword(self):
def bar():
f = tuple((1, 2)) # noqa: C409
return f

self.checkScript(bar, ())

def foo():
return tuple(1, 2)

self.checkScriptRaisesRegex(foo, (), Exception,
"1 argument")

def cant_infer_size():
return tuple([1, 2, 3]) # noqa: C409

with self.assertRaisesRegex(Exception, "cannot statically infer the expected"):
torch.jit.script(cant_infer_size)

def test_tuple_create_return(self):
def stuff2(x):
# type: (int) -> Tuple[Tensor, Tensor]
Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/jit/script/compiler.cpp
Expand Up @@ -348,6 +348,7 @@ struct Environment {
if (!retval) {
static std::unordered_map<std::string, SugaredValuePtr> globals = {
{"print", std::make_shared<PrintValue>()},
{"tuple", std::make_shared<TupleCallValue>()},
{"float",
makeMagic(
"__float__",
Expand Down Expand Up @@ -2173,6 +2174,14 @@ struct to_ir {
auto out = graph->insertNode(graph->createUninitialized(type))
->setSourceRange(loc);
return std::make_shared<SimpleValue>(out->output());
} else if (auto tuple_call = dynamic_cast<TupleCallValue*>(sv.get())) {
checkApplyExpr(apply, loc, /*expected_inputs*/ 1);
auto arg = emitSugaredExpr(apply.inputs()[0], 1);
auto inputs = arg->asTuple(apply.range(), method);
auto inp_values = fmap(
inputs, [&](SugaredValuePtr sv) { return sv->asValue(loc, method); });
return std::make_shared<SimpleValue>(
graph->insertNode(graph->createTuple(inp_values))->output());
} else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
// NOTE: for `isinstance` builtin call in JIT, we only check the static
// types on the inputs to evaluate, and insert the corresponding constant
Expand Down Expand Up @@ -3086,7 +3095,8 @@ CompilationUnit::CompilationUnit(const std::string& source)
define(c10::nullopt, source, nativeResolver(), nullptr);
}

c10::QualifiedName CompilationUnit::mangle(const c10::QualifiedName& name) const {
c10::QualifiedName CompilationUnit::mangle(
const c10::QualifiedName& name) const {
static const std::string manglePrefix = "___torch_mangle_";
std::vector<std::string> atoms = name.atoms();

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/script/sugared_value.h
Expand Up @@ -439,6 +439,14 @@ struct TORCH_API IsInstanceValue : SugaredValue {
}
};

// matched against for special handling of tuple() call
struct TORCH_API TupleCallValue : SugaredValue {
TupleCallValue() = default;
std::string kind() const override {
return "tuple";
}
};

// matched against for special handling of range expressions
struct TORCH_API RangeValue : SugaredValue {
RangeValue(const SourceRange& loc, Function& m, std::vector<Value*> inputs);
Expand Down