From f4e1438d89ef4c191edc0feff17e2bd114c95b3e Mon Sep 17 00:00:00 2001 From: Jack Montgomery Date: Mon, 26 Oct 2020 12:43:23 -0700 Subject: [PATCH] Add single element tuple output from to_backend/to_glow Summary: Support single element tuples in to_backend Reviewed By: andrewmillspaugh Differential Revision: D24539869 fbshipit-source-id: 67135e92b294cec9a8cbf5b342163f27e956ba46 --- .../to_glow_tuple_output_test.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/torch_glow/tests/functionality/to_glow_tuple_output_test.py b/torch_glow/tests/functionality/to_glow_tuple_output_test.py index 25eef07860..1f15c4355e 100644 --- a/torch_glow/tests/functionality/to_glow_tuple_output_test.py +++ b/torch_glow/tests/functionality/to_glow_tuple_output_test.py @@ -6,20 +6,29 @@ import torch_glow -class Foo(torch.nn.Module): +class TwoTupleModule(torch.nn.Module): def __init__(self): - super(Foo, self).__init__() + super(TwoTupleModule, self).__init__() def forward(self, x): y = 2 * x - return x, y + return (x, y) + + +class OneTupleModule(torch.nn.Module): + def __init__(self): + super(OneTupleModule, self).__init__() + + def forward(self, x): + y = 2 * x + return (y,) class TestToGlowTupleOutput(unittest.TestCase): - def test_to_glow_tuple_output(self): + def tuple_test_helper(self, ModType): input = torch.randn(4) - model = Foo() + model = ModType() spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") @@ -36,10 +45,19 @@ def test_to_glow_tuple_output(self): lowered_model = torch_glow.to_glow(scripted_mod, {"forward": spec}) # Run Glow model - (gx, gy) = lowered_model(input) + g = lowered_model(input) # Run reference model - (tx, ty) = model(input) + t = model(input) + + self.assertEqual(type(g), type(t)) + self.assertEqual(len(g), len(t)) + + for (gi, ti) in zip(g, t): + self.assertTrue(torch.allclose(gi, ti)) + + def test_to_glow_one_tuple_output(self): + self.tuple_test_helper(OneTupleModule) - assert torch.allclose(tx, gx) - assert torch.allclose(ty, gy) + def test_to_glow_two_tuple_output(self): + self.tuple_test_helper(TwoTupleModule)