Skip to content

Commit

Permalink
Specialize optional tensor inputs to graphs in the JIT (#18360)
Browse files Browse the repository at this point in the history
Summary:
This specializes optional tensor inputs to either a DimensionedTensorType or, when None is passed,
UndefinedTensor (aka AutogradZeroTensorType).
This works because we already have different specs and thus separate plans for the two cases.
It enhances the shape analysis - because now unwrapped optional tensors will have DimensionedTensorType with appropriate shape and required grad etc.
Also, when combined with "if-pruning" (which I understand #18259 works towards), we actually get much nicer concrete graphs, too.
Pull Request resolved: #18360

Differential Revision: D14590577

Pulled By: soumith

fbshipit-source-id: cac204a506d1d38b15703cbcc67a6b75fd4979f4
  • Loading branch information
t-vi authored and facebook-github-bot committed Mar 24, 2019
1 parent 32d0e7e commit 7cc7ed1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
17 changes: 17 additions & 0 deletions test/test_jit.py
Expand Up @@ -4986,6 +4986,23 @@ def named_var_and(x, y):
if y is not None and x_none:
print(x + y) # noqa: T484

def test_optional_tensor(self):
@torch.jit.script
def fn(x):
# type: (Optional[Tensor]) -> int
if x is None:
return 1
else:
return 0

fn(None)
g = fn.graph_for(None)
self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor')
t = torch.ones(1)
fn(t)
g = fn.graph_for(t)
self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType')

def test_while_write_outer_then_read(self):
def func(a, b):
while bool(a < 10):
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/argument_spec.h
Expand Up @@ -153,7 +153,8 @@ struct ArgumentSpec {

private:
TypePtr fillType(TypePtr original, size_t& offset) const {
if (original->isSubtypeOf(TensorType::get())) {
if (original->isSubtypeOf(TensorType::get())
|| original->isSubtypeOf(OptionalType::ofTensor())) {
auto& arg = args.at(offset++);
if (!arg.defined())
return AutogradZeroTensorType::get();
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/passes/shape_analysis.cpp
Expand Up @@ -513,6 +513,16 @@ class ShapePropagator {
}
return;
}
case prim::unchecked_unwrap_optional: {
// we know we cannot have None as input, so we can always pass
// on the type.
if(auto ot = node->input()->type()->cast<OptionalType>()) {
node->output()->setType(ot->getElementType());
} else {
node->output()->setType(node->input()->type());
}
return;
}
case prim::ConstantChunk: {
Value* tensor = node->input();
if (auto type = tensor->type()->cast<DimensionedTensorType>()) {
Expand All @@ -529,10 +539,17 @@ class ShapePropagator {
return;
}
case aten::_unwrap_optional: {
// if we have None as input, we need to leave the output alone
auto input_ivalue = toIValue(node->input());
if (input_ivalue && input_ivalue->isNone()) {
return;
}
if(auto ot = node->input()->type()->cast<OptionalType>()) {
node->output()->setType(ot->getElementType());
} else {
node->output()->setType(node->input()->type());
}
return;
}
default:
break; // fall-through
Expand Down

0 comments on commit 7cc7ed1

Please sign in to comment.