diff --git a/test/expect/TestJit.test_default_values.expect b/test/expect/TestJit.test_default_values.expect new file mode 100644 index 000000000000..0976528f0613 --- /dev/null +++ b/test/expect/TestJit.test_default_values.expect @@ -0,0 +1,19 @@ +graph(%x : Dynamic + %a : Dynamic = 0.5 +[ Variable[CPUDoubleType]{} ] + %b : Dynamic = 10 +[ Variable[CPULongType]{} ] + %c : Dynamic = 20 +[ Variable[CPULongType]{} ] + %d : Dynamic = 50 +[ Variable[CPULongType]{} ]) { + %5 : int = prim::Constant[value=1]() + %6 : Dynamic = aten::add(%x, %a, %5) + %7 : int = prim::Constant[value=1]() + %8 : Dynamic = aten::add(%6, %b, %7) + %9 : int = prim::Constant[value=1]() + %10 : Dynamic = aten::add(%8, %c, %9) + %11 : int = prim::Constant[value=1]() + %12 : Dynamic = aten::add(%10, %d, %11) + return (%12); +} diff --git a/test/test_jit.py b/test/test_jit.py index e89174cb8d6c..6b021c0f28c8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1961,6 +1961,19 @@ def loop_use_test(y): self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test") self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test") + def test_default_values(self): + outer_var = 20 + outer_var2 = 30 + + @torch.jit.script + def fn(x, a=0.5, b=10, c=outer_var, d=outer_var + outer_var2): + return x + a + b + c + d + + self.assertExpectedGraph(fn.graph) + self.assertEqual( + fn(torch.ones(1)), + torch.ones(1) + 0.5 + 10 + 20 + (20 + 30)) + class TestBatched(TestCase): # generate random examples and create an batchtensor with them diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 82b14fa0b683..fc3188539978 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -43,6 +43,7 @@ std::ostream& operator<<(std::ostream & out, const at::ArrayRef & nodes) { return out; } + struct const_value_list_with_types { const ArrayRef values; bool use_newlines; @@ -68,6 +69,15 @@ std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) { printValueRef(out, n); out << " : "; out << *n->type(); + + // Print default value if one exists + const ParamValue* pv = dynamic_cast(n); + if (pv != nullptr) { + if (pv->default_value()) { + out << " = " << *pv->default_value(); + } + } + } return out; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 062d0422c2be..2b62c2b2e7fc 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -171,6 +171,7 @@ using NodeKind = Symbol; struct Value { TH_DISALLOW_COPY_AND_ASSIGN(Value); Value(Node * node_, size_t offset_); + virtual ~Value() {} private: friend struct Node; friend struct Graph; @@ -256,6 +257,18 @@ struct Value { }; +struct ParamValue : Value { + TH_DISALLOW_COPY_AND_ASSIGN(ParamValue); + ParamValue(Node * node_, size_t offset_, at::optional default_value); + + const at::optional& default_value() const { + return default_value_; + } + +private: + at::optional default_value_; +}; + struct Node : public Attributes { TH_DISALLOW_COPY_AND_ASSIGN(Node); friend struct Graph; @@ -510,6 +523,13 @@ struct Node : public Attributes { return outputs_.back(); } + ParamValue* addParamOutput(at::optional default_value) { + ParamValue* val = new ParamValue(this, outputs_.size(), default_value); + outputs_.push_back(val); + schema_ = nullptr; + return val; + } + Value* insertOutput(size_t i) { schema_ = nullptr; outputs_.insert(outputs_.begin() + i, new Value(this, i)); @@ -796,6 +816,13 @@ struct Block { v->setUniqueName(name); return v; } + ParamValue* addParamInput( + at::optional default_value, + std::string name = "") { + ParamValue* v = input_->addParamOutput(default_value); + v->setUniqueName(name); + return v; + } Value* insertInput(size_t i, std::string name = "") { Value* v = input_->insertOutput(i); v->setUniqueName(name); @@ -949,6 +976,11 @@ friend struct Block; Value * addInput(std::string name="") { return block_->addInput(std::move(name)); } + ParamValue* addParamInput( + at::optional default_value, + std::string name = "") { + return block_->addParamInput(default_value, std::move(name)); + } Value* insertInput(size_t i, std::string name = "") { return block_->insertInput(i, std::move(name)); } @@ -1268,6 +1300,12 @@ inline void Value::replaceAllUsesWith(Value * newValue) { } } +inline ParamValue::ParamValue( + Node* node_, + size_t offset_, + at::optional default_value_) + : Value(node_, offset_), default_value_(default_value_) {} + inline Node::Node(Graph * graph_, NodeKind kind_) : kind_(kind_), graph_(graph_), diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index e5dda3ec4c58..e91bcc337ea0 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -9,7 +9,13 @@ std::shared_ptr Canonicalize(const std::shared_ptr& graph) { std::unordered_map rn_env; auto rn_fn = [&](Value* v) { return rn_env.at(v); }; for (auto* input : graph->inputs()) { - auto* r_input = r->addInput(); + auto as_param_value = dynamic_cast(input); + Value* r_input = nullptr; + if (as_param_value != nullptr) { + r_input = r->addParamInput(as_param_value->default_value()); + } else { + r_input = r->addInput(); + } r_input->copyMetadata(input); r_input->setStage(input->stage()); rn_env[input] = r_input; diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index b66b96dd5eb6..b78da7529e4e 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -784,12 +784,18 @@ struct to_ir { size_t arg_annotation_idx = 0; for(;it != end; ++it) { auto& name = (*it).ident().name(); + auto& argument = schema.arguments.at(arg_annotation_idx++); // Add the input to the graph - Value *new_input = graph->addInput(name); + Value *new_input = nullptr; + if (argument.default_value) { + new_input = graph->addParamInput(*argument.default_value, name); + } else { + new_input = graph->addInput(name); + } environment_stack->setVar((*it).ident().range(), name, new_input); // Record the type for the schema and set the Type on the Value* - arguments.push_back(schema.arguments.at(arg_annotation_idx++)); + arguments.push_back(argument); new_input->setType(arguments.back().type); ensureLegalType((*it).ident().range(), arguments.back().type); } @@ -1931,9 +1937,24 @@ std::vector parseArgsFromDecl(Decl decl, bool is_method) { size_t i = is_method ? 1 : 0; for (; i < decl.params().size(); ++i) { auto decl_arg = decl.params()[i]; - auto arg = Argument(decl_arg.ident().name(), parseTypeFromExpr(decl_arg.type()), - /*N =*/at::nullopt, /*default_value =*/at::nullopt, - /*kwarg_only =*/false); + at::optional default_value = at::nullopt; + if (decl_arg.has_default()) { + // If the Param has a default value, convert it to an IValue + Const default_const = decl_arg.default_value(); + if (default_const.isFloatingPoint()) { + default_value = IValue(autograd::make_variable( + at::scalar_to_tensor(default_const.asFloatingPoint()))); + } else { + default_value = IValue(autograd::make_variable( + at::scalar_to_tensor(default_const.asIntegral()))); + } + } + auto arg = Argument( + decl_arg.ident().name(), + parseTypeFromExpr(decl_arg.type()), + /*N =*/at::nullopt, + /*default_value =*/default_value, + /*kwarg_only =*/false); retval.push_back(arg); } return retval; diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index deef6a5c2ca8..9d59ec3d6c14 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -157,7 +157,6 @@ TORCH_API std::shared_ptr compileFunction(Def def, const Resolver& resolv TORCH_API Value* packOutputs(Graph& g, at::ArrayRef values); TORCH_API std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs); TORCH_API void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what); -TORCH_API void ensureTensors(const SourceRange& range, at::ArrayRef values); // try to match a list if inputs and keyword 'attributes' to this schema, // if it works return the flat list of positional inputs to the call diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index be67a262a3fa..15fb49037712 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -86,6 +86,9 @@ void initTreeViewBindings(PyObject *module) { py::class_(m, "Param") .def(py::init([](const Expr& type, const Ident& name) { return Param::create(name.range(), name, type); + })) + .def(py::init([](const Expr& type, const Ident& name, const Const& default_value) { + return Param::create(name.range(), name, type, default_value); })); py::class_(m, "Attribute") .def(py::init([](const Ident& name, const Expr& value) { diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 162c33e68386..71f5f1596d2b 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -257,87 +257,6 @@ struct Expr : public TreeView { } }; -//////////////////////////////////////////////////////////////////////////////// -// Helper nodes (mostly for function arguments) -//////////////////////////////////////////////////////////////////////////////// - -struct Attribute : public TreeView { - explicit Attribute(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_ATTRIBUTE); - } - Ident name() const { - return Ident(subtree(0)); - } - Expr value() const { - return Expr(subtree(1)); - } - static Attribute create(const SourceRange& range, const Ident& name, const TreeRef& value) { - return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value})); - } -}; - - -struct Param : public TreeView { - explicit Param(const TreeRef& tree) : TreeView(tree) { - tree_->match(TK_PARAM); - } - static Param create(const SourceRange& range, const Ident& ident, const Expr& type) { - return Param(Compound::create(TK_PARAM, range, {ident, type})); - } - Ident ident() const { - return Ident(subtree(0)); - } - Expr type() const { - return Expr(subtree(1)); - } - template - T typeExpect() const { - return T(type()); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// Top level definitions -//////////////////////////////////////////////////////////////////////////////// - -struct Decl : public TreeView { - explicit Decl(const TreeRef& tree) : TreeView(tree) { - tree->match(TK_DECL); - } - List params() const { - return List(subtree(0)); - } - Maybe return_type() const { - return Maybe(subtree(1)); - } - static Decl create(const SourceRange& range, const List& params, Maybe return_type) { - return Decl(Compound::create(TK_DECL, range, {params, return_type})); - } -}; - -struct Def : public TreeView { - explicit Def(const TreeRef& tree) : TreeView(tree) { - tree->match(TK_DEF); - } - Ident name() const { - return Ident(subtree(0)); - } - Decl decl() const { - return Decl(subtree(1)); - } - List statements() const { - return List(subtree(2)); - } - static Def create( - const SourceRange& range, - const Ident& name, - const Decl& decl, - const List& stmts) { - return Def(Compound::create( - TK_DEF, range, {name, decl, stmts})); - } -}; - //////////////////////////////////////////////////////////////////////////////// // Statements @@ -558,6 +477,105 @@ struct Const : public Expr { } }; +//////////////////////////////////////////////////////////////////////////////// +// Helper nodes (mostly for function arguments) +//////////////////////////////////////////////////////////////////////////////// + +struct Attribute : public TreeView { + explicit Attribute(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_ATTRIBUTE); + } + Ident name() const { + return Ident(subtree(0)); + } + Expr value() const { + return Expr(subtree(1)); + } + static Attribute create(const SourceRange& range, const Ident& name, const TreeRef& value) { + return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value})); + } +}; + + +struct Param : public TreeView { + explicit Param(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_PARAM); + } + static Param create( + const SourceRange& range, + const Ident& ident, + const Expr& type) { + return Param( + Compound::create(TK_PARAM, range, {ident, type})); + } + static Param create( + const SourceRange& range, + const Ident& ident, + const Expr& type, + const Const& default_value) { + return Param( + Compound::create(TK_PARAM, range, {ident, type, default_value})); + } + Ident ident() const { + return Ident(subtree(0)); + } + Expr type() const { + return Expr(subtree(1)); + } + Const default_value() const { + return Const(subtree(2)); + } + bool has_default() const { + return tree()->trees().size() == 3; + } + template + T typeExpect() const { + return T(type()); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Top level definitions +//////////////////////////////////////////////////////////////////////////////// + +struct Decl : public TreeView { + explicit Decl(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_DECL); + } + List params() const { + return List(subtree(0)); + } + Maybe return_type() const { + return Maybe(subtree(1)); + } + static Decl create(const SourceRange& range, const List& params, Maybe return_type) { + return Decl(Compound::create(TK_DECL, range, {params, return_type})); + } +}; + +struct Def : public TreeView { + explicit Def(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_DEF); + } + Ident name() const { + return Ident(subtree(0)); + } + Decl decl() const { + return Decl(subtree(1)); + } + List statements() const { + return List(subtree(2)); + } + static Def create( + const SourceRange& range, + const Ident& name, + const Decl& decl, + const List& stmts) { + return Def(Compound::create( + TK_DEF, range, {name, decl, stmts})); + } +}; + struct StringLiteral : public Expr { explicit StringLiteral(const TreeRef& tree) : Expr(tree) { tree_->matchNumSubtrees(TK_STRINGLITERAL, 1); diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index f7cea00e6292..3a2bbab1bc29 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -585,7 +585,7 @@ def script(fn, optimize=True, _frames_up=0): if not _enabled: return fn rcb = createResolutionCallback(_frames_up + 1) - ast = get_jit_ast(fn, is_method=False) + ast = get_jit_ast(fn, is_method=False, frames_up=_frames_up + 1) graph = _jit_script_compile(ast, rcb) mod = ScriptModule() mod._create_method_from_graph('forward', graph) @@ -618,7 +618,7 @@ def script_method(fn): # createResolutionCallback internally adds 1 to get us to the scope of this # function (the calling function). Adding 2 gets us to the proper surrounding scope. rcb = createResolutionCallback(frames_up=2) - ast = get_jit_ast(fn, is_method=True) + ast = get_jit_ast(fn, is_method=True, frames_up=2) return ScriptMethodStub(rcb, ast, fn) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 99d767797e1b..f444854ebe24 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -137,22 +137,23 @@ def _uses_true_division(fn): '_uses_true_division: expected function or method, got {}'.format(type(fn))) -def get_jit_ast(fn, is_method): +def get_jit_ast(fn, is_method, frames_up=0): source = dedent(inspect.getsource(fn)) py_ast = ast.parse(source) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError("expected a single top-level function") type_line = torch.jit.annotations.get_type_line(source) - ctx = SourceContext(source, _uses_true_division(fn)) + ctx = SourceContext(source, inspect.stack()[frames_up + 1][0], _uses_true_division(fn)) return build_def(ctx, py_ast.body[0], type_line, is_method) # Thin wrapper around SourceRangeFactory to store extra metadata # about the function-to-be-compiled. class SourceContext(SourceRangeFactory): - def __init__(self, source, uses_true_division=True): + def __init__(self, source, frame, uses_true_division=True): super(SourceContext, self).__init__(source) self.uses_true_division = uses_true_division + self.frame = frame class Builder(object): @@ -182,19 +183,36 @@ def build_def(ctx, py_def, type_line, is_method): build_stmts(ctx, body)) -_vararg_kwarg_err = ("Compiled functions can't take variable number of arguments, " - "have default values for arguments, nor keyword-only arguments") +_vararg_kwarg_err = ("Compiled functions can't take variable number of" + " arguments or have keyword-only arguments") def build_param_list(ctx, py_args): - if py_args.vararg is not None or py_args.kwarg is not None or py_args.defaults: + num_no_default = len(py_args.args) - len(py_args.defaults) + + def get_default_at(i): + return py_args.defaults[i - num_no_default] if i >= num_no_default else None + + if py_args.vararg is not None or py_args.kwarg is not None: raise ValueError(_vararg_kwarg_err) - if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs): + if not PY2 and py_args.kwonlyargs: raise ValueError(_vararg_kwarg_err) - return [build_param(ctx, arg) for arg in py_args.args] + return [build_param(ctx, arg, get_default_at(i)) for i, arg in enumerate(py_args.args)] + + +def eval_default_arg(ctx, default, _frames_up=1 if False else 0): + if isinstance(default, ast.Num): + value = str(default.n) + else: + expr = compile(ast.Expression(default), '', 'eval') + value = str(eval(expr, ctx.frame.f_locals, ctx.frame.f_globals)) + + r_default = ctx.make_range( + default.lineno, default.col_offset, default.col_offset + len(value)) + return (r_default, value) -def build_param(ctx, py_arg): +def build_param(ctx, py_arg, default): # NB: In Python3 py_arg is a pair of (str arg, expr? annotation) # In Python2 py_arg is a Name (Expr subclass) name = py_arg.id if PY2 else py_arg.arg @@ -203,7 +221,12 @@ def build_param(ctx, py_arg): annotation_expr = build_expr(ctx, py_arg.annotation) else: annotation_expr = Var(Ident(r, 'Tensor')) - return Param(annotation_expr, Ident(r, name)) + + if default: + r_default, value = eval_default_arg(ctx, default) + return Param(annotation_expr, Ident(r, name), Const(r_default, value)) + else: + return Param(annotation_expr, Ident(r, name)) class StmtBuilder(Builder):