Skip to content

Commit

Permalink
Scopes 0.3.1 backport (#5153)
Browse files Browse the repository at this point in the history
* Introduce scopes during tracing (#3016)

* Fix segfault during ONNX export

* Further fix to tracing scope (#4558)

* Set missing temporary scope in callPySymbolicMethod

* Use expected traces in all scope tests

* Fix tracking of tracing scopes during ONNX pass (#4524)

* Fix tracking of tracing scopes during ONNX pass

* Use ResourceGuard to manage setting a temporary current scope in Graph

* Add tests for ONNX pass scopes

* Remove unused num_classes argument

* Expose node scopeName to python (#4200)

* Inherit JIT scopes when cloning only when it's correct

It's correct only when the new graph owns the same scope tree
as the original one. We can end up with dangling pointers otherwise.

* Fixes after cherry-picking, still one test to go

* Fix for last failing test after scope cherry-pick

* Fix linting issue
  • Loading branch information
lantiga authored and soumith committed Feb 9, 2018
1 parent 902d57b commit 2b47480
Show file tree
Hide file tree
Showing 15 changed files with 310 additions and 19 deletions.
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_batchnorm.expect
Expand Up @@ -3,6 +3,6 @@ graph(%1 : Double(2, 2)
%3 : Double(2)
%4 : Double(2)
%5 : Double(2)) {
%7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []];
%7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []], scope: BatchNorm2d;
return (%7);
}
4 changes: 2 additions & 2 deletions test/expect/TestJit.test_conv.expect
@@ -1,6 +1,6 @@
graph(%1 : Double(20, 16, 50, 40)
%2 : Double(13, 16, 3, 3)) {
%4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2];
%5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []];
%4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2], scope: Conv2d;
%5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []], scope: Conv2d;
return (%5);
}
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_dropout.expect
@@ -1,4 +1,4 @@
graph(%1 : Double(2, 2)) {
%3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []];
%3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []], scope: Dropout;
return (%3);
}
8 changes: 8 additions & 0 deletions test/expect/TestJit.test_scopes.expect
@@ -0,0 +1,8 @@
graph(%1 : Double(1)
%2 : Double(1)) {
%3 : Double(1) = add[alpha={1}](%1, %2), uses = [%4.i1];
%4 : Double(1) = mul(%1, %3), uses = [%5.i0], scope: Foo;
%5 : Double(1) = tanh(%4), uses = [%6.i0], scope: Foo/Bar;
%6 : Double(1) = sigmoid(%5), uses = [%0.i0], scope: Foo;
return (%6);
}
9 changes: 9 additions & 0 deletions test/expect/TestJit.test_scopes_identity_node.expect
@@ -0,0 +1,9 @@
graph(%1 : Double(1, 3, 227, 227)
%2 : Double(64, 3, 11, 11)
%3 : Double(64)) {
%5 : UNKNOWN_TYPE = Conv[kernel_shape=[11, 11], strides=[4, 4], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%1, %2), uses = [[%6.i0]], scope: Net/Sequential[features]/Conv2d[0];
%6 : Double(1, 64, 56, 56) = Add[broadcast=1, axis=1](%5, %3), uses = [%7.i0], scope: Net/Sequential[features]/Conv2d[0];
%7 : Double(1, 64, 56, 56) = Relu(%6), uses = [%8.i0], scope: Net/Sequential[features]/ReLU[1];
%8 : Double(1, 64, 27, 27) = MaxPool[kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%7), uses = [%0.i0], scope: Net/Sequential[features]/MaxPool2d[2];
return (%8);
}
5 changes: 5 additions & 0 deletions test/expect/TestJit.test_scopes_intermediate_node.expect
@@ -0,0 +1,5 @@
graph(%1 : Double(2)) {
%2 : Double(2) = Softmax[axis=0](%1), uses = [%3.i0], scope: Net;
%3 : Double(2) = Log(%2), uses = [%0.i0], scope: Net;
return (%3);
}
63 changes: 63 additions & 0 deletions test/test_jit.py
Expand Up @@ -50,6 +50,12 @@ def LSTMCellC(*args, **kwargs):
class TestJit(TestCase):
maxDiff = None

def assertExpectedTrace(self, trace, *args, **kwargs):
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace), *args, **kwargs)

def test_simple(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
Expand All @@ -61,6 +67,63 @@ def f(x, y):
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))

def test_scopes(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

def f(x, y):
out = x + y
with torch.jit.scope('Foo', out):
out = x * out
with torch.jit.scope('Bar', out):
out = torch.tanh(out)
out = torch.sigmoid(out)
return out

trace, z = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))

def test_scopes_intermediate_node(self):

class Net(nn.Module):
def forward(self, x):
return F.log_softmax(x, dim=0)

net = Net()
t = Variable(torch.ones(2), requires_grad=True)
trace, _ = torch.jit.trace(net, (t, ))
torch.onnx._optimize_trace(trace)

self.assertExpectedTrace(trace)

def test_scopes_identity_node(self):

class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)

def forward(self, x):
x = self.features(x)
return x

model = Net()

t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)

with torch.onnx.set_training(model, False):
trace, _ = torch.jit.trace(model, (t, ))

torch.onnx._optimize_trace(trace)

self.assertExpectedTrace(trace)

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/ir.cpp
Expand Up @@ -263,7 +263,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
} else {
emitUses(out,n);
}
out << "];\n";
out << "]";
std::string scopeName = n->scopeName();
if (scopeName.empty()) {
out << ";\n";
}
else {
out << ", ";
out << "scope: " << scopeName << ";\n";
}
return out;
}

Expand Down
130 changes: 120 additions & 10 deletions torch/csrc/jit/ir.h
Expand Up @@ -69,6 +69,64 @@ struct SourceLocation {
std::string python_traceback;
};

// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope {
private:
Scope* parent_;
Symbol name_;
std::vector<std::unique_ptr<Scope> > children_;
public:
Scope() {
name_ = stringToSymbol("");
parent_ = NULL;
}
Scope(Scope* parent, Symbol name) {
name_ = name;
parent_ = parent;
}
Scope* push(Symbol name) {
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
return children_.back().get();
}
Scope* parent() {
if (parent_ == NULL) {
throw std::runtime_error("Cannot get parent from Scope with no parent");
}
return parent_;
}
bool isRoot() {
return parent_ == NULL;
}
Scope* getRoot() {
Scope* current = this;
while (current->parent_) {
current = current->parent_;
}
return current;
}
Symbol name() {
return name_;
}
std::string namesFromRoot(const std::string& separator="/") {
std::string out = std::string(symbolToString(this->name_));
if (this->isRoot()) {
return out;
}
Scope* parent = this->parent_;
while (!parent->isRoot()) {
out = std::string(symbolToString(parent->name_)) + separator + out;
parent = parent->parent_;
}
return out;
}
};

// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
Expand Down Expand Up @@ -123,6 +181,7 @@ struct Node : public Attributes<Node> {
size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
std::string debug_name_;
std::shared_ptr<SourceLocation> source_location_;
Scope* scope_;
protected:
TypePtr type_;
Node(Graph * graph_, NodeKind kind_); //defined after graph
Expand Down Expand Up @@ -188,6 +247,18 @@ struct Node : public Attributes<Node> {
size_t stage() const {
return stage_;
}
Scope* scope() {
return scope_;
}
void setScope(Scope* scope) {
scope_ = scope;
}
std::string scopeName() const {
if (scope_ == NULL) {
return "";
}
return scope_->namesFromRoot();
}
// NB: This returns an ArrayRef; that means that it will
// get invalidated if you resize inputs (e.g., using addInput)
// We can't return a std::vector<Node*>& because there's no
Expand Down Expand Up @@ -528,12 +599,7 @@ struct Node : public Attributes<Node> {
//
// NB: This does NOT clone stages. You're expected to set the stage correctly
// if you are going to preserve it.
virtual void cloneFrom(Node * s) {
if (s->hasType()) setType(s->type());
setDebugName(s->debugName());
setSourceLocation(s->getSourceLocation());
copyAttributes(*s);
}
virtual void cloneFrom(Node * s);
};

struct Graph {
Expand All @@ -551,18 +617,27 @@ friend struct Node;

size_t new_node_stage_;

std::shared_ptr<Scope> scope_root_;
Scope * current_scope_;

// holds outputs in a way that can be reflected
// as a Use object
// also used as the beginning/end of the circular node list to avoid
// having corner cases where the list is empty.
Node * const output_;

public:
Graph()

Graph(std::shared_ptr<Scope> scope_root)
: next_unique_(0)
, new_node_stage_(0)
, scope_root_(scope_root)
, current_scope_(scope_root_.get())
, output_(initOutput(create(kReturn))) {}

Graph()
: Graph( std::make_shared<Scope>()) {}

at::ArrayRef<Node*> inputs() {
return inputs_;
}
Expand Down Expand Up @@ -618,6 +693,29 @@ friend struct Node;
Node * addInput() {
return addInput(create(kParam));
}
void push_scope(const std::string& scope_name) {
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
}
void pop_scope() {
current_scope_ = current_scope_->parent();
}
Scope * current_scope() {
return current_scope_;
}
void set_current_scope(Scope* scope) {
if (scope->getRoot() != scope_root_.get()) {
throw std::runtime_error("trying to set a scope as current that does not belong to the Graph's scope trie");
}
current_scope_ = scope;
}
ResourceGuard set_current_scope_temporary(Scope* scope) {
auto prev_scope = current_scope_;
this->set_current_scope(scope);
return ResourceGuard([prev_scope, this]() { this->current_scope_ = prev_scope; });
}
std::shared_ptr<Scope> scope_root() {
return scope_root_;
}

Node * addInput(Node* n) {
JIT_ASSERT(n->kind() == kParam);
Expand Down Expand Up @@ -694,7 +792,8 @@ friend struct Node;
}
Node * createFusionGroup() {
auto n = create(kFusionGroup);
n->g_(kSubgraph,std::make_shared<Graph>());
auto subgraph = std::make_shared<Graph>(scope_root_);
n->g_(kSubgraph, subgraph);
return n;
}
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, pyobj_list&& scalar_args);
Expand Down Expand Up @@ -764,9 +863,10 @@ inline Node::Node(Graph * graph_, NodeKind kind_) :
graph_(graph_),
unique_(graph_->next_unique_++),
stage_(graph_->new_node_stage_),
scope_(graph_->current_scope_) ,
type_(getInitialType(kind_)) {
graph_->all_nodes.emplace(this);
}
graph_->all_nodes.emplace(this);
}

inline void Node::destroy() {
JIT_ASSERT(inGraphList());
Expand All @@ -788,6 +888,16 @@ inline Node* Node::makeMultireturn() {
return select;
}

inline void Node::cloneFrom(Node * s) {
if (s->hasType()) setType(s->type());
setDebugName(s->debugName());
setSourceLocation(s->getSourceLocation());
if (s->owningGraph()->scope_root_ == owningGraph()->scope_root_) {
scope_ = s->scope_;
}
copyAttributes(*s);
}

// Helper macros for constructing switch statements over Node types
// instead of heavy-weight visitors
// read 'between' these defines to see how they turn into a big switch
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/passes/onnx.cpp
Expand Up @@ -33,7 +33,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
throw std::logic_error("ToONNX: tracing state is expired");
}

auto new_graph = std::make_shared<Graph>();
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
std::unordered_map<void*, Node*> new_buffer_map;

torch::autograd::SymbolicContext ctx;
Expand Down Expand Up @@ -159,6 +159,8 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
py_inputs[input_nr++] = py::cast(envFn(input));
}

auto scope_guard = ctx.graph->set_current_scope_temporary(n->scope());

py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.graph, n, py_inputs);

processSymbolicOutput(symbolToString(n->kind()), n, raw_output);
Expand Down Expand Up @@ -195,6 +197,8 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
py_symbolic_args[input_nr++] = obj;
}

auto scope_guard = ctx.graph->set_current_scope_temporary(op->scope());

// Call the symbolic function
// Use a little trampoline function so we can give good error messages
// upon argument mismatch
Expand All @@ -218,6 +222,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
// Selects are translated by multi-return nodes.
JIT_ASSERT(env.count(value) > 0);
IR_ELSEIFM(CppOp)
auto scope_guard = new_graph->set_current_scope_temporary(node->scope());
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn));
setOutputs(value->name(), node, outputs);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/python_ir.cpp
Expand Up @@ -105,6 +105,7 @@ void initPythonIRBindings(PyObject * module_) {
node->setType(other->typeOption());
return node;
})
.NS(scopeName)
#define AS(name) def(#name,&Attributes<Node> :: name)
// methods from Attributes
.AS(copyAttributes)
Expand Down

0 comments on commit 2b47480

Please sign in to comment.