Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions torch_glow/tests/functionality/to_glow_tuple_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@
import torch_glow


class Foo(torch.nn.Module):
class TwoTupleModule(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
super(TwoTupleModule, self).__init__()

def forward(self, x):
y = 2 * x
return x, y
return (x, y)


class OneTupleModule(torch.nn.Module):
def __init__(self):
super(OneTupleModule, self).__init__()

def forward(self, x):
y = 2 * x
return (y,)


class TestToGlowTupleOutput(unittest.TestCase):
def test_to_glow_tuple_output(self):
def tuple_test_helper(self, ModType):
input = torch.randn(4)

model = Foo()
model = ModType()

spec = torch_glow.CompilationSpec()
spec.get_settings().set_glow_backend("Interpreter")
Expand All @@ -36,10 +45,19 @@ def test_to_glow_tuple_output(self):
lowered_model = torch_glow.to_glow(scripted_mod, {"forward": spec})

# Run Glow model
(gx, gy) = lowered_model(input)
g = lowered_model(input)

# Run reference model
(tx, ty) = model(input)
t = model(input)

self.assertEqual(type(g), type(t))
self.assertEqual(len(g), len(t))

for (gi, ti) in zip(g, t):
self.assertTrue(torch.allclose(gi, ti))

def test_to_glow_one_tuple_output(self):
self.tuple_test_helper(OneTupleModule)

assert torch.allclose(tx, gx)
assert torch.allclose(ty, gy)
def test_to_glow_two_tuple_output(self):
self.tuple_test_helper(TwoTupleModule)