diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index 30cbae10deeab..e2986e1f1aa43 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -1,6 +1,6 @@ graph(%0 : Double(20, 16, 50, 40) %1 : Double(13, 16, 3, 3)) { - %2 : UNKNOWN_TYPE = Undefined() - %3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2) + %2 : UNKNOWN_TYPE = Undefined(), scope: Conv2d + %3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2), scope: Conv2d return (%3); } diff --git a/test/expect/TestJit.test_dropout.expect b/test/expect/TestJit.test_dropout.expect index 809b7c2e5306a..d6595183363a4 100644 --- a/test/expect/TestJit.test_dropout.expect +++ b/test/expect/TestJit.test_dropout.expect @@ -1,4 +1,4 @@ graph(%0 : Double(2, 2)) { - %1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0) + %1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout return (%1); } diff --git a/test/expect/TestJit.test_scopes.expect b/test/expect/TestJit.test_scopes.expect new file mode 100644 index 0000000000000..1da68f46a1e60 --- /dev/null +++ b/test/expect/TestJit.test_scopes.expect @@ -0,0 +1,8 @@ +graph(%0 : Double(1) + %1 : Double(1)) { + %2 : Double(1) = add[alpha={1}](%0, %1) + %3 : Double(1) = mul(%0, %2), scope: Foo + %4 : Double(1) = tanh(%3), scope: Foo/Bar + %5 : Double(1) = sigmoid(%4), scope: Foo + return (%5); +} diff --git a/test/test_jit.py b/test/test_jit.py index c33a887282a44..31c43fc585c8e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -69,6 +69,23 @@ def f(x, y): torch._C._jit_pass_lint(trace) self.assertExpected(str(trace)) + def test_scopes(self): + x = Variable(torch.Tensor([0.4]), requires_grad=True) + y = Variable(torch.Tensor([0.7]), requires_grad=True) + + def f(x, y): + out = x + y + with torch.jit.scope('Foo', out): + out = x * out + with torch.jit.scope('Bar', out): + out = torch.tanh(out) + out = torch.sigmoid(out) + return out + + trace, z = torch.jit.trace(f, (x, y), nderivs=0) + torch._C._jit_pass_lint(trace) + self.assertExpected(str(trace)) + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_lstm_fusion(self): input = Variable(torch.randn(3, 10).float().cuda()) diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 94fab95928ef3..d3f94a92cac15 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -239,7 +239,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vectorinputs() << ")\n"; + out << "(" << n->inputs() << ")"; + std::string scopeName = n->scopeName(); + if (scopeName.empty()) { + out << "\n"; + } + else { + out << ", "; + out << "scope: " << scopeName << "\n"; + } return out; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 40c4ad30d72bb..78a7d150d4ede 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -70,6 +70,64 @@ struct SourceLocation { std::string python_traceback; }; +// Scope is a node of a trie that represents the tree of nested scopes. +// Individual scopes are pushed and popped from Graph, which holds a +// pointer to the current scope. Each Node in Graph holds a pointer +// to the scope that was current when the node was created. +// The trie never needs to shrink, it only grows until it is disposed +// of when Graph is deallocated. Hence, pointers to scopes held by nodes +// will always be valid as long as Graph is alive. +struct Scope { +private: + Scope* parent_; + Symbol name_; + std::vector > children_; +public: + Scope() { + name_ = stringToSymbol(""); + parent_ = NULL; + } + Scope(Scope* parent, Symbol name) { + name_ = name; + parent_ = parent; + } + Scope* push(Symbol name) { + children_.push_back(std::unique_ptr(new Scope(this, name))); + return children_.back().get(); + } + Scope* parent() { + if (parent_ == NULL) { + throw std::runtime_error("Cannot get parent from Scope with no parent"); + } + return parent_; + } + bool isRoot() { + return parent_ == NULL; + } + Scope* getRoot() { + Scope* current = this; + while (current->parent_) { + current = current->parent_; + } + return current; + } + Symbol name() { + return name_; + } + std::string namesFromRoot(const std::string& separator="/") { + std::string out = std::string(symbolToString(this->name_)); + if (this->isRoot()) { + return out; + } + Scope* parent = this->parent_; + while (!parent->isRoot()) { + out = std::string(symbolToString(parent->name_)) + separator + out; + parent = parent->parent_; + } + return out; + } +}; + // the list types are intentionally simple, but we type-def // them here so if we need to change them, refactoring will be easier using node_list = std::vector; @@ -139,6 +197,9 @@ struct Value { const Node * node() const { return node_; } + Scope* scope(); + void setScope(Scope* scope); + std::string scopeName() const; Graph * owningGraph(); const Graph * owningGraph() const; // TODO: make this more const correct @@ -197,6 +258,7 @@ struct Node : public Attributes { Graph* graph_; std::shared_ptr source_location_; size_t stage_; + Scope* scope_; protected: Node(Graph * graph_, NodeKind kind_); //defined after graph public: @@ -223,6 +285,18 @@ struct Node : public Attributes { stage_ = s; return this; } + Scope* scope() { + return scope_; + } + void setScope(Scope* scope) { + scope_ = scope; + } + std::string scopeName() const { + if (scope_ == NULL) { + return ""; + } + return scope_->namesFromRoot(); + } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no @@ -534,6 +608,7 @@ struct Node : public Attributes { // if you are going to preserve it. virtual void cloneFrom(Node * s) { setSourceLocation(s->getSourceLocation()); + scope_ = s->scope_; copyAttributes(*s); } }; @@ -556,6 +631,9 @@ friend struct Value; size_t new_node_stage_; + std::shared_ptr scope_root_; + Scope * current_scope_; + // holds outputs in a way that can be reflected // as a Use object // also used as the beginning/end of the circular node list to avoid @@ -564,11 +642,17 @@ friend struct Value; Node * const input_; public: - Graph() + + Graph(std::shared_ptr scope_root) : next_unique_(0) , new_node_stage_(0) + , scope_root_(scope_root) + , current_scope_(scope_root_.get()) , output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {} + Graph() + : Graph( std::make_shared()) {} + at::ArrayRef inputs() { return input_->outputs(); } @@ -621,6 +705,18 @@ friend struct Value; const Node * return_node() const { return output_; } + void push_scope(const std::string& scope_name) { + current_scope_ = current_scope_->push(stringToSymbol(scope_name)); + } + void pop_scope() { + current_scope_ = current_scope_->parent(); + } + Scope * current_scope() { + return current_scope_; + } + std::shared_ptr scope_root() { + return scope_root_; + } Value * addInput(std::string name="") { Value * v = input_->addOutput(); if (name != "") v->setUniqueName(name); @@ -676,7 +772,8 @@ friend struct Value; } Node * createFusionGroup() { auto n = create(kFusionGroup, 0); - n->g_(kSubgraph,std::make_shared()); + auto subgraph = std::make_shared(scope_root_); + n->g_(kSubgraph, subgraph); return n; } Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector && var_flags, pyobj_list&& scalar_args); @@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_) node_->graph_->all_values.emplace(this); } +inline Scope* Value::scope() { + return node()->scope(); +} + +inline void Value::setScope(Scope* scope) { + node()->setScope(scope); +} + +inline std::string Value::scopeName() const { + return node()->scopeName(); +} + inline Graph * Value::owningGraph() { return node()->owningGraph(); } @@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) { inline Node::Node(Graph * graph_, NodeKind kind_) : kind_(kind_), graph_(graph_), - stage_(graph_->new_node_stage_) { + stage_(graph_->new_node_stage_), + scope_(graph_->current_scope_) { graph_->all_nodes.emplace(this); } diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 783181664e11f..56e79b67d91e4 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -31,7 +31,7 @@ void ToONNX(std::shared_ptr& state) { throw std::logic_error("ToONNX: tracing state is expired"); } - auto new_graph = std::make_shared(); + auto new_graph = std::make_shared(state->graph->scope_root()); std::unordered_map new_buffer_map; torch::autograd::SymbolicContext ctx; @@ -137,6 +137,10 @@ void ToONNX(std::shared_ptr& state) { throw std::runtime_error(ss.str()); } + for (auto& el: outputs) { + el->setScope(n->scope()); + } + setOutputs(op_name, n, outputs); }; @@ -208,6 +212,9 @@ void ToONNX(std::shared_ptr& state) { IR_IFM(node, CppOp) if (auto fn = std::dynamic_pointer_cast(value->fn)) { auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn), node->getSourceLocation()); + for (auto& el: outputs) { + el->setScope(node->scope()); + } setOutputs(value->name(), node, outputs); } else { cloneNode(node); diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 501e21a1ec012..ce16d202a8661 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -19,7 +19,7 @@ namespace torch { namespace jit { void initPythonTracerBindings(PyObject* module_) { auto m = py::handle(module_).cast(); - py::class_>(m, "TracingState") + py::class_>(m, "TracingState", py::dynamic_attr()) // NB: no constructor; you have to get it from C++ code .def("__repr__", [](const TracingState& s) { std::ostringstream ss; @@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) { ss << *s.graph; return ss.str(); }) + .def("push_scope", [](TracingState& s, const std::string& scope_name) { + ASSERT_UNEXPIRED("push_scope"); + s.push_scope(scope_name); + }) + .def("pop_scope", [](TracingState& s) { + ASSERT_UNEXPIRED("pop_scope"); + s.pop_scope(); + }) .def("export", [](TracingState& s, const std::vector& initializers, int64_t onnx_opset_version) { ASSERT_UNEXPIRED("export"); return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version)); @@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) { m.def("_tracer_exit", [](variable_list var_outputs) { tracer::exit(var_outputs); }); + m.def("_get_tracing_state", [](const variable_list& vars) { + return getTracingState(vars); + }); + m.def("_is_tracing", [](const variable_list& vars) { + return isTracing(vars); + }); } }} // namespace torch::jit diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h index f2a5bb17ffb29..08ea9cbf3a4ca 100644 --- a/torch/csrc/jit/tracer_state.h +++ b/torch/csrc/jit/tracer_state.h @@ -74,6 +74,14 @@ struct TracingState : public std::enable_shared_from_this { bool is_complete() const { return !is_expired() && graph->stage() == num_stages - 1; } + + void push_scope(const std::string& scope_name) { + graph->push_scope(scope_name); + } + + void pop_scope() { + graph->pop_scope(); + } }; struct ValueTracingStateElem { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 6634e4db8ba0c..54a658b100d2e 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -19,6 +19,30 @@ _flatten = torch._C._jit_flatten +# This global variable is set when we are tracing a *forwards* computation. +# It is intended to be a cheap way to test if tracing has occurred, before +# doing the slower path using `get_tracing_state` (below.) +_tracing = False + + +def get_tracing_state(args): + if not torch._C._is_tracing(args): + return None + return torch._C._get_tracing_state(args) + + +@contextlib.contextmanager +def scope(scope_name, *vars): + tracing_state = get_tracing_state(vars) + if tracing_state: + tracing_state.push_scope(scope_name) + try: + yield + finally: + if tracing_state: + tracing_state.pop_scope() + + def compile(arg=None, nderivs=1, optimize=True, enabled=True): """ Decorator which marks a function or module class as eligible for @@ -237,13 +261,16 @@ def __init__(self, inner, nderivs=0): self.nderivs = nderivs def forward(self, *args): + global _tracing in_vars = _flatten(args) # NOTE: use full state, because we need it for BatchNorm export # This differs from the compiler path, which doesn't support it at the moment. module_state = list(self.state_dict(keep_vars=True).values()) trace = torch._C._tracer_enter(in_vars + module_state, self.nderivs) + _tracing = True out = self.inner(*args) out_vars = _flatten(out) + _tracing = False torch._C._tracer_exit(out_vars) return trace, out diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index e42c69b76d1ca..609598298d6d5 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +from collections import OrderedDict, Iterable import functools import torch @@ -319,10 +319,42 @@ def register_forward_hook(self, hook): self._forward_hooks[handle.id] = hook return handle + def _tracing_name(self, tracing_state): + if not tracing_state._traced_module_stack: + return None + module = tracing_state._traced_module_stack[-1] + for name, child in module.named_children(): + if child is self: + return name + return None + + def _slow_forward(self, *input, **kwargs): + input_vars = tuple(torch.autograd.function._iter_variables(input)) + tracing_state = torch.jit.get_tracing_state(input_vars) + if not tracing_state: + return self.forward(*input, **kwargs) + if not hasattr(tracing_state, '_traced_module_stack'): + tracing_state._traced_module_stack = [] + name = self._tracing_name(tracing_state) + if name: + tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name)) + else: + tracing_state.push_scope(self.__class__.__name__) + tracing_state._traced_module_stack.append(self) + try: + result = self.forward(*input, **kwargs) + finally: + tracing_state.pop_scope() + tracing_state._traced_module_stack.pop() + return result + def __call__(self, *input, **kwargs): for hook in self._forward_pre_hooks.values(): hook(self, input) - result = self.forward(*input, **kwargs) + if torch.jit._tracing: + result = self._slow_forward(*input, **kwargs) + else: + result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): hook_result = hook(self, input, result) if hook_result is not None: