Skip to content

Commit d3d00ac

Browse files
authored
[scan] Support different output dtypes (#8530)
1 parent a2d3d4d commit d3d00ac

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

test/scan/test_scan.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ def count_number_of_sines(partition_fn):
418418
count_number_of_sines(min_cut_rematerialization_partition), 10)
419419
self.assertEqual(count_number_of_sines(default_partition), 0)
420420

421+
def test_scan_different_dtypes(self):
422+
"""Test that the combine function can output different dtypes."""
423+
424+
def fn(carry, x):
425+
bf16_value, f32_value = x
426+
y = (torch.sin(bf16_value), torch.sin(f32_value))
427+
return torch.sin(carry), y
428+
429+
init = torch.tensor([0.0, 0.0],
430+
requires_grad=True,
431+
device=self.device,
432+
dtype=torch.float16)
433+
bf16_xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
434+
requires_grad=True,
435+
device=self.device,
436+
dtype=torch.bfloat16)
437+
f32_xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
438+
requires_grad=True,
439+
device=self.device,
440+
dtype=torch.float32)
441+
final_carry, ys = self.run_test(fn, init, (bf16_xs, f32_xs))
442+
bf16_ys, f32_ys = ys
443+
self.assertEqual(final_carry.dtype, torch.float16)
444+
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
445+
self.assertEqual(f32_ys.dtype, torch.float32)
446+
421447

422448
class PyTreeTest(TestBase):
423449

torch_xla/experimental/scan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,16 +501,15 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
501501
fn_carry_out, fn_y_out = split(fn_outputs, carry_len)
502502
assert carry_len + y_len == len(fn_outputs)
503503
fn_carry_shapes = [v.shape for v in fn_carry_out]
504-
fn_y_shapes = [v.shape for v in fn_y_out]
505504
for fn_carry_shape, init_leaf in zip(fn_carry_shapes, init):
506505
assert fn_carry_shape == init_leaf.shape, f"`fn` must keep the `carry` shape unchanged. \
507506
Got {fn_carry_shape} but expected {init_leaf.shape}"
508507

509508
builder = Builder('scan')
510509
num_iters = next(iter(tree_iter(xs))).size(0)
511510
ys = [
512-
torch.zeros((num_iters, *fn_y_shape), device=device)
513-
for fn_y_shape in fn_y_shapes
511+
torch.zeros((num_iters, *v.shape), device=device, dtype=v.dtype)
512+
for v in fn_y_out
514513
]
515514
# Start the `curr_iter` loop variable at zero.
516515
zero = torch.tensor(0, device=device)

0 commit comments

Comments
 (0)