Skip to content

Commit

Permalink
Allow tracing with fork/wait (#15184)
Browse files Browse the repository at this point in the history
Summary:
There is still limitation on this: if a script module is somewhere
in the trace, the inputs/outputs can only be tensors or tuples of
tensors.

resolves #15052
Pull Request resolved: #15184

Differential Revision: D13457691

Pulled By: highker

fbshipit-source-id: 8fe46afc41357a0eb8eadd83f687b31d074deb0e
  • Loading branch information
highker authored and facebook-github-bot committed Dec 18, 2018
1 parent bd958cd commit e37a221
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 9 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -532,6 +532,9 @@ struct CAFFE2_API FutureType : public SingleElementType<TypeKind::FutureType, Fu
ss << "Future[" << getElementType()->python_str() << "]";
return ss.str();
}
TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
return create(contained_types.at(0));
}
private:
FutureType(TypePtr elem) : SingleElementType(elem) {}
};
Expand Down
53 changes: 53 additions & 0 deletions test/test_jit.py
Expand Up @@ -11222,6 +11222,59 @@ def wait_script(x1, x2, x3):
self.assertEqual(y2, foo2(x1, x2))
self.assertEqual(y3, foo3(x1, x2, x3))

def test_async_script_trace(self):
class Traced(nn.Module):
def __init__(self):
super(Traced, self).__init__()

def forward(self, x):
return tuple([torch.neg(x), x])

class Module(torch.jit.ScriptModule):
def __init__(self):
super(Module, self).__init__(False)
x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)

@torch.jit.script_method
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
future1 = torch.jit._fork(self.traced, x)
future2 = torch.jit._fork(torch.neg, x)

tensor_tuple = torch.jit._wait(future1)
tensor_single = torch.jit._wait(future2)

tensor_list = []
tensor_list.append(tensor_tuple[0])
tensor_list.append(tensor_single)

# return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1])

class Tuple(nn.Module):
def __init__(self):
super(Tuple, self).__init__()
self.module = Module()

def forward(self, x):
z = torch.neg(x)
y = self.module(x)
list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
return tuple(list)

x = torch.rand(3, 3)
module = torch.jit.trace(Tuple(), (x), _force_outplace=True)

# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)

y = torch.neg(x)
self.assertEqual(module(x), tuple([y, y, y, y, x, x]))


for test in autograd_method_tests:
add_autograd_test(*test)
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/jit/graph_executor.cpp
Expand Up @@ -513,7 +513,11 @@ struct GraphExecutorImpl {
// NB: we could just run the fallback in here and call it a day, but that would loose all
// the control flow information we have in the graph. Thus, we run the fallback to
// get the correct output values, but we will override the tracing states later.
getOrCompileFallback().run(stack);
{
// No need to trace a script module.
ResourceGuard guard(tracer::pauseTracing());
getOrCompileFallback().run(stack);
}

// Traces always have types propagated through them, so we make sure to
// also propagate types through the graph we are inserting here.
Expand All @@ -527,10 +531,7 @@ struct GraphExecutorImpl {

auto outputs = last(stack, num_outputs);
for (size_t i = 0; i < outputs.size(); ++i) {
// We can't attach tracing states to scalars, so we have to skip them here
// TODO: Should we reinterpret them as scalar tensors instead?
if (!outputs[i].isTensor()) continue;
tracer::setValueTrace(outputs[i].toTensor(), output_values[i]);
tracer::setValueTrace(outputs[i], output_values[i]);
}
}

Expand Down
27 changes: 27 additions & 0 deletions torch/csrc/jit/tracer.cpp
Expand Up @@ -37,6 +37,33 @@ thread_local std::shared_ptr<TracingState> tracing_state;

} // namespace detail

void setValueTrace(const IValue &v, Value *value) {
if (v.isTensor()) {
auto var = v.toTensor();
JIT_ASSERT(var.defined());
getTracingState()->value_map[var] = value;
} else if (v.isTensorList()) {
auto& outputs = v.toTensorList()->elements();
auto graph = getTracingState()->graph;
Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
}
} else if (v.isTuple()) {
auto& outputs = v.toTuple()->elements();
auto graph = getTracingState()->graph;
Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
}
} else {
std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
<< "Supported types are tensor, tensor list, and tuple of tensors.";
throw std::runtime_error(os.str());
}
}

void addInputs(Node *n, const char * name, int64_t value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
Expand Down
14 changes: 10 additions & 4 deletions torch/csrc/jit/tracer.h
Expand Up @@ -32,16 +32,22 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
// this node with an output variable, so that further operations involving this
// variable know which node in the IR to reference.
inline void setValueTrace(const Variable& var, Value *value) {
JIT_ASSERT(var.defined());
getTracingState()->value_map[var] = value;
}
TORCH_API void setValueTrace(const IValue& v, Value* value);

inline void delValueTrace(const Variable& var) {
JIT_ASSERT(var.defined());
getTracingState()->value_map.erase(var);
}

inline std::function<void()> pauseTracing() {
std::shared_ptr<tracer::TracingState> state = getTracingState();
tracer::setTracingState(nullptr);

return [state]() {
tracer::setTracingState(state);
};
}

// Given a variable 'var', return the 'node' which represents the instruction
// which computes the value of this variable in the IR.
// Here, we interpret untraced variables as constants that are just embedded
Expand Down

0 comments on commit e37a221

Please sign in to comment.