diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect index a0325e85167b67e..d1377cd85ca8b0b 100644 --- a/test/expect/TestJit.test_batchnorm.expect +++ b/test/expect/TestJit.test_batchnorm.expect @@ -3,6 +3,6 @@ graph(%1 : Double(2, 2) %3 : Double(2) %4 : Double(2) %5 : Double(2)) { - %7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []]; + %7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []], scope: BatchNorm2d; return (%7); } diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index a8cfc1daea6492b..ca2eb44443acffe 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -1,6 +1,6 @@ graph(%1 : Double(20, 16, 50, 40) %2 : Double(13, 16, 3, 3)) { - %4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2]; - %5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []]; + %4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2], scope: Conv2d; + %5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []], scope: Conv2d; return (%5); } diff --git a/test/expect/TestJit.test_dropout.expect b/test/expect/TestJit.test_dropout.expect index e1c1c4921c94aa0..efab5a4560063aa 100644 --- a/test/expect/TestJit.test_dropout.expect +++ b/test/expect/TestJit.test_dropout.expect @@ -1,4 +1,4 @@ graph(%1 : Double(2, 2)) { - %3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []]; + %3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []], scope: Dropout; return (%3); } diff --git a/test/expect/TestJit.test_scopes.expect b/test/expect/TestJit.test_scopes.expect new file mode 100644 index 000000000000000..24e433218cac75e --- /dev/null +++ b/test/expect/TestJit.test_scopes.expect @@ -0,0 +1,8 @@ +graph(%1 : Double(1) + %2 : Double(1)) { + %3 : Double(1) = add[alpha={1}](%1, %2), uses = [%4.i1]; + %4 : Double(1) = mul(%1, %3), uses = [%5.i0], scope: Foo; + %5 : Double(1) = tanh(%4), uses = [%6.i0], scope: Foo/Bar; + %6 : Double(1) = sigmoid(%5), uses = [%0.i0], scope: Foo; + return (%6); +} diff --git a/test/expect/TestJit.test_scopes_identity_node.expect b/test/expect/TestJit.test_scopes_identity_node.expect new file mode 100644 index 000000000000000..6434d0d603634f9 --- /dev/null +++ b/test/expect/TestJit.test_scopes_identity_node.expect @@ -0,0 +1,9 @@ +graph(%1 : Double(1, 3, 227, 227) + %2 : Double(64, 3, 11, 11) + %3 : Double(64)) { + %5 : UNKNOWN_TYPE = Conv[kernel_shape=[11, 11], strides=[4, 4], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%1, %2), uses = [[%6.i0]], scope: Net/Sequential[features]/Conv2d[0]; + %6 : Double(1, 64, 56, 56) = Add[broadcast=1, axis=1](%5, %3), uses = [%7.i0], scope: Net/Sequential[features]/Conv2d[0]; + %7 : Double(1, 64, 56, 56) = Relu(%6), uses = [%8.i0], scope: Net/Sequential[features]/ReLU[1]; + %8 : Double(1, 64, 27, 27) = MaxPool[kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%7), uses = [%0.i0], scope: Net/Sequential[features]/MaxPool2d[2]; + return (%8); +} diff --git a/test/expect/TestJit.test_scopes_intermediate_node.expect b/test/expect/TestJit.test_scopes_intermediate_node.expect new file mode 100644 index 000000000000000..99c86c769ccb70e --- /dev/null +++ b/test/expect/TestJit.test_scopes_intermediate_node.expect @@ -0,0 +1,5 @@ +graph(%1 : Double(2)) { + %2 : Double(2) = Softmax[axis=0](%1), uses = [%3.i0], scope: Net; + %3 : Double(2) = Log(%2), uses = [%0.i0], scope: Net; + return (%3); +} diff --git a/test/test_jit.py b/test/test_jit.py index b2a8d6335176cdf..c57d223e3cff3e8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -50,6 +50,12 @@ def LSTMCellC(*args, **kwargs): class TestJit(TestCase): maxDiff = None + def assertExpectedTrace(self, trace, *args, **kwargs): + torch._C._jit_pass_lint(trace) + torch._C._jit_pass_dce(trace) + torch._C._jit_pass_lint(trace) + self.assertExpected(str(trace), *args, **kwargs) + def test_simple(self): x = Variable(torch.Tensor([0.4]), requires_grad=True) y = Variable(torch.Tensor([0.7]), requires_grad=True) @@ -61,6 +67,63 @@ def f(x, y): torch._C._jit_pass_lint(trace) self.assertExpected(str(trace)) + def test_scopes(self): + x = Variable(torch.Tensor([0.4]), requires_grad=True) + y = Variable(torch.Tensor([0.7]), requires_grad=True) + + def f(x, y): + out = x + y + with torch.jit.scope('Foo', out): + out = x * out + with torch.jit.scope('Bar', out): + out = torch.tanh(out) + out = torch.sigmoid(out) + return out + + trace, z = torch.jit.trace(f, (x, y), nderivs=0) + torch._C._jit_pass_lint(trace) + self.assertExpected(str(trace)) + + def test_scopes_intermediate_node(self): + + class Net(nn.Module): + def forward(self, x): + return F.log_softmax(x, dim=0) + + net = Net() + t = Variable(torch.ones(2), requires_grad=True) + trace, _ = torch.jit.trace(net, (t, )) + torch.onnx._optimize_trace(trace) + + self.assertExpectedTrace(trace) + + def test_scopes_identity_node(self): + + class Net(nn.Module): + + def __init__(self): + super(Net, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + + def forward(self, x): + x = self.features(x) + return x + + model = Net() + + t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True) + + with torch.onnx.set_training(model, False): + trace, _ = torch.jit.trace(model, (t, )) + + torch.onnx._optimize_trace(trace) + + self.assertExpectedTrace(trace) + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_lstm_fusion(self): input = Variable(torch.randn(3, 10).cuda()) diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 5a7e269a86ba1c3..de73a693341ff9b 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -263,7 +263,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vectorscopeName(); + if (scopeName.empty()) { + out << ";\n"; + } + else { + out << ", "; + out << "scope: " << scopeName << ";\n"; + } return out; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 6940ed77206c4c2..9bf0ae44d8d04a9 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -69,6 +69,64 @@ struct SourceLocation { std::string python_traceback; }; +// Scope is a node of a trie that represents the tree of nested scopes. +// Individual scopes are pushed and popped from Graph, which holds a +// pointer to the current scope. Each Node in Graph holds a pointer +// to the scope that was current when the node was created. +// The trie never needs to shrink, it only grows until it is disposed +// of when Graph is deallocated. Hence, pointers to scopes held by nodes +// will always be valid as long as Graph is alive. +struct Scope { +private: + Scope* parent_; + Symbol name_; + std::vector > children_; +public: + Scope() { + name_ = stringToSymbol(""); + parent_ = NULL; + } + Scope(Scope* parent, Symbol name) { + name_ = name; + parent_ = parent; + } + Scope* push(Symbol name) { + children_.push_back(std::unique_ptr(new Scope(this, name))); + return children_.back().get(); + } + Scope* parent() { + if (parent_ == NULL) { + throw std::runtime_error("Cannot get parent from Scope with no parent"); + } + return parent_; + } + bool isRoot() { + return parent_ == NULL; + } + Scope* getRoot() { + Scope* current = this; + while (current->parent_) { + current = current->parent_; + } + return current; + } + Symbol name() { + return name_; + } + std::string namesFromRoot(const std::string& separator="/") { + std::string out = std::string(symbolToString(this->name_)); + if (this->isRoot()) { + return out; + } + Scope* parent = this->parent_; + while (!parent->isRoot()) { + out = std::string(symbolToString(parent->name_)) + separator + out; + parent = parent->parent_; + } + return out; + } +}; + // the list types are intentionally simple, but we type-def // them here so if we need to change them, refactoring will be easier using node_list = std::vector; @@ -123,6 +181,7 @@ struct Node : public Attributes { size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... std::string debug_name_; std::shared_ptr source_location_; + Scope* scope_; protected: TypePtr type_; Node(Graph * graph_, NodeKind kind_); //defined after graph @@ -188,6 +247,18 @@ struct Node : public Attributes { size_t stage() const { return stage_; } + Scope* scope() { + return scope_; + } + void setScope(Scope* scope) { + scope_ = scope; + } + std::string scopeName() const { + if (scope_ == NULL) { + return ""; + } + return scope_->namesFromRoot(); + } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no @@ -528,12 +599,7 @@ struct Node : public Attributes { // // NB: This does NOT clone stages. You're expected to set the stage correctly // if you are going to preserve it. - virtual void cloneFrom(Node * s) { - if (s->hasType()) setType(s->type()); - setDebugName(s->debugName()); - setSourceLocation(s->getSourceLocation()); - copyAttributes(*s); - } + virtual void cloneFrom(Node * s); }; struct Graph { @@ -551,6 +617,9 @@ friend struct Node; size_t new_node_stage_; + std::shared_ptr scope_root_; + Scope * current_scope_; + // holds outputs in a way that can be reflected // as a Use object // also used as the beginning/end of the circular node list to avoid @@ -558,11 +627,17 @@ friend struct Node; Node * const output_; public: - Graph() + + Graph(std::shared_ptr scope_root) : next_unique_(0) , new_node_stage_(0) + , scope_root_(scope_root) + , current_scope_(scope_root_.get()) , output_(initOutput(create(kReturn))) {} + Graph() + : Graph( std::make_shared()) {} + at::ArrayRef inputs() { return inputs_; } @@ -618,6 +693,29 @@ friend struct Node; Node * addInput() { return addInput(create(kParam)); } + void push_scope(const std::string& scope_name) { + current_scope_ = current_scope_->push(stringToSymbol(scope_name)); + } + void pop_scope() { + current_scope_ = current_scope_->parent(); + } + Scope * current_scope() { + return current_scope_; + } + void set_current_scope(Scope* scope) { + if (scope->getRoot() != scope_root_.get()) { + throw std::runtime_error("trying to set a scope as current that does not belong to the Graph's scope trie"); + } + current_scope_ = scope; + } + ResourceGuard set_current_scope_temporary(Scope* scope) { + auto prev_scope = current_scope_; + this->set_current_scope(scope); + return ResourceGuard([prev_scope, this]() { this->current_scope_ = prev_scope; }); + } + std::shared_ptr scope_root() { + return scope_root_; + } Node * addInput(Node* n) { JIT_ASSERT(n->kind() == kParam); @@ -694,7 +792,8 @@ friend struct Node; } Node * createFusionGroup() { auto n = create(kFusionGroup); - n->g_(kSubgraph,std::make_shared()); + auto subgraph = std::make_shared(scope_root_); + n->g_(kSubgraph, subgraph); return n; } Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, pyobj_list&& scalar_args); @@ -764,9 +863,10 @@ inline Node::Node(Graph * graph_, NodeKind kind_) : graph_(graph_), unique_(graph_->next_unique_++), stage_(graph_->new_node_stage_), + scope_(graph_->current_scope_) , type_(getInitialType(kind_)) { - graph_->all_nodes.emplace(this); -} + graph_->all_nodes.emplace(this); + } inline void Node::destroy() { JIT_ASSERT(inGraphList()); @@ -788,6 +888,16 @@ inline Node* Node::makeMultireturn() { return select; } +inline void Node::cloneFrom(Node * s) { + if (s->hasType()) setType(s->type()); + setDebugName(s->debugName()); + setSourceLocation(s->getSourceLocation()); + if (s->owningGraph()->scope_root_ == owningGraph()->scope_root_) { + scope_ = s->scope_; + } + copyAttributes(*s); +} + // Helper macros for constructing switch statements over Node types // instead of heavy-weight visitors // read 'between' these defines to see how they turn into a big switch diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 1615942dd5e036c..ab7f79faafc6763 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -33,7 +33,7 @@ void ToONNX(std::shared_ptr& state) { throw std::logic_error("ToONNX: tracing state is expired"); } - auto new_graph = std::make_shared(); + auto new_graph = std::make_shared(state->graph->scope_root()); std::unordered_map new_buffer_map; torch::autograd::SymbolicContext ctx; @@ -159,6 +159,8 @@ void ToONNX(std::shared_ptr& state) { py_inputs[input_nr++] = py::cast(envFn(input)); } + auto scope_guard = ctx.graph->set_current_scope_temporary(n->scope()); + py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.graph, n, py_inputs); processSymbolicOutput(symbolToString(n->kind()), n, raw_output); @@ -195,6 +197,8 @@ void ToONNX(std::shared_ptr& state) { py_symbolic_args[input_nr++] = obj; } + auto scope_guard = ctx.graph->set_current_scope_temporary(op->scope()); + // Call the symbolic function // Use a little trampoline function so we can give good error messages // upon argument mismatch @@ -218,6 +222,7 @@ void ToONNX(std::shared_ptr& state) { // Selects are translated by multi-return nodes. JIT_ASSERT(env.count(value) > 0); IR_ELSEIFM(CppOp) + auto scope_guard = new_graph->set_current_scope_temporary(node->scope()); if (auto fn = std::dynamic_pointer_cast(value->fn)) { auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn)); setOutputs(value->name(), node, outputs); diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 8ceefb6ce04cc7f..ff34a51b940106d 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -105,6 +105,7 @@ void initPythonIRBindings(PyObject * module_) { node->setType(other->typeOption()); return node; }) + .NS(scopeName) #define AS(name) def(#name,&Attributes :: name) // methods from Attributes .AS(copyAttributes) diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 501e21a1ec0127f..ce16d202a866177 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -19,7 +19,7 @@ namespace torch { namespace jit { void initPythonTracerBindings(PyObject* module_) { auto m = py::handle(module_).cast(); - py::class_>(m, "TracingState") + py::class_>(m, "TracingState", py::dynamic_attr()) // NB: no constructor; you have to get it from C++ code .def("__repr__", [](const TracingState& s) { std::ostringstream ss; @@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) { ss << *s.graph; return ss.str(); }) + .def("push_scope", [](TracingState& s, const std::string& scope_name) { + ASSERT_UNEXPIRED("push_scope"); + s.push_scope(scope_name); + }) + .def("pop_scope", [](TracingState& s) { + ASSERT_UNEXPIRED("pop_scope"); + s.pop_scope(); + }) .def("export", [](TracingState& s, const std::vector& initializers, int64_t onnx_opset_version) { ASSERT_UNEXPIRED("export"); return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version)); @@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) { m.def("_tracer_exit", [](variable_list var_outputs) { tracer::exit(var_outputs); }); + m.def("_get_tracing_state", [](const variable_list& vars) { + return getTracingState(vars); + }); + m.def("_is_tracing", [](const variable_list& vars) { + return isTracing(vars); + }); } }} // namespace torch::jit diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h index 75438b3437791da..aed17306edb7e2d 100644 --- a/torch/csrc/jit/tracer_state.h +++ b/torch/csrc/jit/tracer_state.h @@ -80,6 +80,14 @@ struct TracingState : public std::enable_shared_from_this { bool is_complete() const { return !is_expired() && graph->stage() == num_stages - 1; } + + void push_scope(const std::string& scope_name) { + graph->push_scope(scope_name); + } + + void pop_scope() { + graph->pop_scope(); + } }; struct ValueTracingStateElem { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index bee2ac4a9005a3c..b566ea639cbe3d7 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -31,6 +31,30 @@ def __repr__(self): VOLATILE = Placeholder("VOLATILE") +# This global variable is set when we are tracing a *forwards* computation. +# It is intended to be a cheap way to test if tracing has occurred, before +# doing the slower path using `get_tracing_state` (below.) +_tracing = False + + +def get_tracing_state(args): + if not torch._C._is_tracing(args): + return None + return torch._C._get_tracing_state(args) + + +@contextlib.contextmanager +def scope(scope_name, *vars): + tracing_state = get_tracing_state(vars) + if tracing_state: + tracing_state.push_scope(scope_name) + try: + yield + finally: + if tracing_state: + tracing_state.pop_scope() + + def compile(arg=None, **kwargs): """ Decorator which marks a function or module class as eligible for @@ -227,6 +251,8 @@ def __init__(self, inner, nderivs=0): self.nderivs = nderivs def forward(self, *args, **kwargs): + global _tracing + # TODO: Possible optimization: use the unflattened # output so we don't unflatten it when we get out # NB: Not a method because _raw_trace can't deal @@ -238,7 +264,9 @@ def traced_inner(in_vars, in_struct): kw_items = list(kwargs.items()) kw_items.sort() in_vars, in_struct = _flatten((args, tuple(kw_items)), self.state_dict(keep_vars=True).values()) + _tracing = True trace, (out_vars, out_struct) = traced_inner(in_vars, in_struct) + _tracing = False out, unmatched = _unflatten(out_vars, out_struct) assert len(unmatched) == 0 return trace, out diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 9b2309f2ec1593f..e3ce3c77ed28935 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +from collections import OrderedDict, Iterable import functools import torch @@ -319,10 +319,42 @@ def register_forward_hook(self, hook): self._forward_hooks[handle.id] = hook return handle + def _tracing_name(self, tracing_state): + if not tracing_state._traced_module_stack: + return None + module = tracing_state._traced_module_stack[-1] + for name, child in module.named_children(): + if child is self: + return name + return None + + def _slow_forward(self, *input, **kwargs): + input_vars = tuple(torch.autograd.function._iter_variables(input)) + tracing_state = torch.jit.get_tracing_state(input_vars) + if not tracing_state: + return self.forward(*input, **kwargs) + if not hasattr(tracing_state, '_traced_module_stack'): + tracing_state._traced_module_stack = [] + name = self._tracing_name(tracing_state) + if name: + tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name)) + else: + tracing_state.push_scope(self.__class__.__name__) + tracing_state._traced_module_stack.append(self) + try: + result = self.forward(*input, **kwargs) + finally: + tracing_state.pop_scope() + tracing_state._traced_module_stack.pop() + return result + def __call__(self, *input, **kwargs): for hook in self._forward_pre_hooks.values(): hook(self, input) - result = self.forward(*input, **kwargs) + if torch.jit._tracing: + result = self._slow_forward(*input, **kwargs) + else: + result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): hook_result = hook(self, input, result) if hook_result is not None: