Skip to content

Commit f0f9126

Browse files
authored
Handle dynamo function without input (#5565) (#5577)
1 parent 02858b9 commit f0f9126

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

test/dynamo/test_dynamo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ def test_simple_model(self):
123123
res_cpu_3 = self.fn_simple(x + y, y * 3)
124124
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))
125125

126+
def test_fn_without_input(self):
127+
128+
def fn_without_input(device):
129+
constant = 0.835
130+
expanded = torch.full((4, 4), constant, device=device)
131+
arange = torch.arange(16, device=device).reshape(4, 4)
132+
return expanded + arange
133+
134+
device = xm.xla_device()
135+
compiled_fn = torch.compile(fn_without_input, backend='openxla')
136+
res_cpu = fn_without_input('cpu')
137+
res_xla_dynamo = compiled_fn(device)
138+
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
139+
126140
def test_simple_model_with_in_place_ops(self):
127141

128142
class TestModel(nn.Module):

torch_xla/core/dynamo_bridge.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ def optimized_mod(*args):
368368
if len(args_and_out) == 0:
369369
return ()
370370

371-
assert len(args) > 0 # can not handle no args case for now
372371
graph_input = graph_input_matcher(args)
373372
start_ts = time.time()
374373
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)

0 commit comments

Comments
 (0)