Skip to content

Commit

Permalink
[jit][tracer] allow traced modules to return dicts with tuple values …
Browse files Browse the repository at this point in the history
…when strict=False (#49568)

Summary:
Pull Request resolved: #49568

We have some inference use cases where the expected output of a module is of the form `{"key": (t1, t1)}` and are currently jit tracing the modules until we can reach jit script compatibility.

Test Plan: buck test mode/dev caffe2/test:jit -- 'test_trace_returning_complex_dict'

Reviewed By: houseroad

Differential Revision: D25624152

fbshipit-source-id: 5adef0e3c9d54cd31ad5fece4ac6530d541fd673
  • Loading branch information
bradleyhd authored and facebook-github-bot committed Dec 21, 2020
1 parent 46c9a0e commit 5b163e2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
43 changes: 43 additions & 0 deletions test/jit/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,3 +2279,46 @@ def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tenso
traced_module = torch.jit.trace(eager_module, input1)
self.assertEqual(traced_module(input1), eager_module(input1))
self.assertEqual(traced_module(input2), eager_module(input2))

def test_trace_returning_dict_with_tensor_tuples(self):
"""Tracing over a module returning a dictionary whose values are tuples of tensors
should work.
"""
class ReturnsDict(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
x = 2 * k
y = 3 * v
result = {
"imakey": (x, y)
}
return result

class ReturnsBadDict(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, float]]:
x = 2 * k
result = {
"imakey": (x, 1)
}
return result

mod = ReturnsDict()
traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False)
out = traced_module(torch.ones(1), torch.ones(1))
expected = {
"imakey": (torch.tensor([2.]), torch.tensor([3.]))
}
self.assertEqual(out, expected)

with self.assertRaisesRegex(RuntimeError, "cannot be understood by the tracer, only outputs matching"):
mod = ReturnsBadDict()
traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False)
16 changes: 14 additions & 2 deletions torch/csrc/jit/frontend/tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,23 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
key_type->isSubtypeOf(TensorType::get());
bool value_type_valid = value_type->isSubtypeOf(TensorType::get());

// Support tuple values that contain only tensors
if (value_type->isSubtypeOf(AnyTupleType::get())) {
value_type_valid = true;
for (const auto& type : value_type->containedTypes()) {
if (!type->isSubtypeOf(TensorType::get())) {
value_type_valid = false;
break;
}
}
}

if (!key_type_valid || !value_type_valid) {
std::ostringstream os;
os << "output " << i << " (" << dict << ") of traced region "
<< "cannot be understood by the tracer, only dict[str, Tensor] "
<< "or dict[Tensor, Tensor] can be a dictionary output of a traced function";
<< "cannot be understood by the tracer, only outputs matching"
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
<< "can be a dictionary output of a traced function";
throw std::runtime_error(os.str());
}
std::vector<Value*> keys;
Expand Down

0 comments on commit 5b163e2

Please sign in to comment.