diff --git a/torch_glow/tests/functionality/to_glow_tuple_output_test.py b/torch_glow/tests/functionality/to_glow_tuple_output_test.py index 25eef07860..1f15c4355e 100644 --- a/torch_glow/tests/functionality/to_glow_tuple_output_test.py +++ b/torch_glow/tests/functionality/to_glow_tuple_output_test.py @@ -6,20 +6,29 @@ import torch_glow -class Foo(torch.nn.Module): +class TwoTupleModule(torch.nn.Module): def __init__(self): - super(Foo, self).__init__() + super(TwoTupleModule, self).__init__() def forward(self, x): y = 2 * x - return x, y + return (x, y) + + +class OneTupleModule(torch.nn.Module): + def __init__(self): + super(OneTupleModule, self).__init__() + + def forward(self, x): + y = 2 * x + return (y,) class TestToGlowTupleOutput(unittest.TestCase): - def test_to_glow_tuple_output(self): + def tuple_test_helper(self, ModType): input = torch.randn(4) - model = Foo() + model = ModType() spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") @@ -36,10 +45,19 @@ def test_to_glow_tuple_output(self): lowered_model = torch_glow.to_glow(scripted_mod, {"forward": spec}) # Run Glow model - (gx, gy) = lowered_model(input) + g = lowered_model(input) # Run reference model - (tx, ty) = model(input) + t = model(input) + + self.assertEqual(type(g), type(t)) + self.assertEqual(len(g), len(t)) + + for (gi, ti) in zip(g, t): + self.assertTrue(torch.allclose(gi, ti)) + + def test_to_glow_one_tuple_output(self): + self.tuple_test_helper(OneTupleModule) - assert torch.allclose(tx, gx) - assert torch.allclose(ty, gy) + def test_to_glow_two_tuple_output(self): + self.tuple_test_helper(TwoTupleModule)