Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/expect/TestJit.test_default_values.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
graph(%x : Dynamic
%a : Dynamic = 0.5
[ Variable[CPUDoubleType]{} ]
%b : Dynamic = 10
[ Variable[CPULongType]{} ]
%c : Dynamic = 20
[ Variable[CPULongType]{} ]
%d : Dynamic = 50
[ Variable[CPULongType]{} ]) {
%5 : int = prim::Constant[value=1]()
%6 : Dynamic = aten::add(%x, %a, %5)
%7 : int = prim::Constant[value=1]()
%8 : Dynamic = aten::add(%6, %b, %7)
%9 : int = prim::Constant[value=1]()
%10 : Dynamic = aten::add(%8, %c, %9)
%11 : int = prim::Constant[value=1]()
%12 : Dynamic = aten::add(%10, %d, %11)
return (%12);
}
13 changes: 13 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,19 @@ def loop_use_test(y):
self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")

def test_default_values(self):
outer_var = 20
outer_var2 = 30

@torch.jit.script
def fn(x, a=0.5, b=10, c=outer_var, d=outer_var + outer_var2):
return x + a + b + c + d

self.assertExpectedGraph(fn.graph)
self.assertEqual(
fn(torch.ones(1)),
torch.ones(1) + 0.5 + 10 + 20 + (20 + 30))


class TestBatched(TestCase):
# generate random examples and create an batchtensor with them
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ std::ostream& operator<<(std::ostream & out, const at::ArrayRef<T> & nodes) {
return out;
}


struct const_value_list_with_types {
const ArrayRef<const Value*> values;
bool use_newlines;
Expand All @@ -68,6 +69,15 @@ std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) {
printValueRef(out, n);
out << " : ";
out << *n->type();

// Print default value if one exists
const ParamValue* pv = dynamic_cast<const ParamValue*>(n);
if (pv != nullptr) {
if (pv->default_value()) {
out << " = " << *pv->default_value();
}
}

}
return out;
}
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ using NodeKind = Symbol;
struct Value {
TH_DISALLOW_COPY_AND_ASSIGN(Value);
Value(Node * node_, size_t offset_);
virtual ~Value() {}

This comment was marked as off-topic.

private:
friend struct Node;
friend struct Graph;
Expand Down Expand Up @@ -256,6 +257,18 @@ struct Value {

};

struct ParamValue : Value {
TH_DISALLOW_COPY_AND_ASSIGN(ParamValue);
ParamValue(Node * node_, size_t offset_, at::optional<IValue> default_value);

const at::optional<IValue>& default_value() const {
return default_value_;
}

private:
at::optional<IValue> default_value_;
};

struct Node : public Attributes<Node> {
TH_DISALLOW_COPY_AND_ASSIGN(Node);
friend struct Graph;
Expand Down Expand Up @@ -510,6 +523,13 @@ struct Node : public Attributes<Node> {
return outputs_.back();
}

ParamValue* addParamOutput(at::optional<IValue> default_value) {
ParamValue* val = new ParamValue(this, outputs_.size(), default_value);
outputs_.push_back(val);
schema_ = nullptr;
return val;
}

Value* insertOutput(size_t i) {
schema_ = nullptr;
outputs_.insert(outputs_.begin() + i, new Value(this, i));
Expand Down Expand Up @@ -796,6 +816,13 @@ struct Block {
v->setUniqueName(name);
return v;
}
ParamValue* addParamInput(
at::optional<IValue> default_value,
std::string name = "") {
ParamValue* v = input_->addParamOutput(default_value);
v->setUniqueName(name);
return v;
}
Value* insertInput(size_t i, std::string name = "") {
Value* v = input_->insertOutput(i);
v->setUniqueName(name);
Expand Down Expand Up @@ -949,6 +976,11 @@ friend struct Block;
Value * addInput(std::string name="") {
return block_->addInput(std::move(name));
}
ParamValue* addParamInput(
at::optional<IValue> default_value,
std::string name = "") {
return block_->addParamInput(default_value, std::move(name));
}
Value* insertInput(size_t i, std::string name = "") {
return block_->insertInput(i, std::move(name));
}
Expand Down Expand Up @@ -1268,6 +1300,12 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
}
}

inline ParamValue::ParamValue(
Node* node_,
size_t offset_,
at::optional<IValue> default_value_)
: Value(node_, offset_), default_value_(default_value_) {}

inline Node::Node(Graph * graph_, NodeKind kind_) :
kind_(kind_),
graph_(graph_),
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/passes/canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
std::unordered_map<Value*, Value*> rn_env;
auto rn_fn = [&](Value* v) { return rn_env.at(v); };
for (auto* input : graph->inputs()) {
auto* r_input = r->addInput();
auto as_param_value = dynamic_cast<const ParamValue*>(input);
Value* r_input = nullptr;
if (as_param_value != nullptr) {
r_input = r->addParamInput(as_param_value->default_value());
} else {
r_input = r->addInput();
}
r_input->copyMetadata(input);
r_input->setStage(input->stage());
rn_env[input] = r_input;
Expand Down
31 changes: 26 additions & 5 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,12 +784,18 @@ struct to_ir {
size_t arg_annotation_idx = 0;
for(;it != end; ++it) {
auto& name = (*it).ident().name();
auto& argument = schema.arguments.at(arg_annotation_idx++);
// Add the input to the graph
Value *new_input = graph->addInput(name);
Value *new_input = nullptr;
if (argument.default_value) {
new_input = graph->addParamInput(*argument.default_value, name);
} else {
new_input = graph->addInput(name);
}
environment_stack->setVar((*it).ident().range(), name, new_input);

// Record the type for the schema and set the Type on the Value*
arguments.push_back(schema.arguments.at(arg_annotation_idx++));
arguments.push_back(argument);
new_input->setType(arguments.back().type);
ensureLegalType((*it).ident().range(), arguments.back().type);
}
Expand Down Expand Up @@ -1931,9 +1937,24 @@ std::vector<Argument> parseArgsFromDecl(Decl decl, bool is_method) {
size_t i = is_method ? 1 : 0;
for (; i < decl.params().size(); ++i) {
auto decl_arg = decl.params()[i];
auto arg = Argument(decl_arg.ident().name(), parseTypeFromExpr(decl_arg.type()),
/*N =*/at::nullopt, /*default_value =*/at::nullopt,
/*kwarg_only =*/false);
at::optional<IValue> default_value = at::nullopt;
if (decl_arg.has_default()) {
// If the Param has a default value, convert it to an IValue
Const default_const = decl_arg.default_value();

This comment was marked as off-topic.

if (default_const.isFloatingPoint()) {
default_value = IValue(autograd::make_variable(
at::scalar_to_tensor(default_const.asFloatingPoint())));
} else {
default_value = IValue(autograd::make_variable(
at::scalar_to_tensor(default_const.asIntegral())));
}
}
auto arg = Argument(
decl_arg.ident().name(),
parseTypeFromExpr(decl_arg.type()),
/*N =*/at::nullopt,
/*default_value =*/default_value,
/*kwarg_only =*/false);
retval.push_back(arg);
}
return retval;
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/script/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ TORCH_API std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolv
TORCH_API Value* packOutputs(Graph& g, at::ArrayRef<Value*> values);
TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
TORCH_API void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what);
TORCH_API void ensureTensors(const SourceRange& range, at::ArrayRef<Value*> values);

// try to match a list if inputs and keyword 'attributes' to this schema,
// if it works return the flat list of positional inputs to the call
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/script/python_tree_views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ void initTreeViewBindings(PyObject *module) {
py::class_<Param, TreeView>(m, "Param")
.def(py::init([](const Expr& type, const Ident& name) {
return Param::create(name.range(), name, type);
}))
.def(py::init([](const Expr& type, const Ident& name, const Const& default_value) {
return Param::create(name.range(), name, type, default_value);
}));
py::class_<Attribute, TreeView>(m, "Attribute")
.def(py::init([](const Ident& name, const Expr& value) {
Expand Down
Loading