@@ -418,6 +418,32 @@ def count_number_of_sines(partition_fn):
418418 count_number_of_sines (min_cut_rematerialization_partition ), 10 )
419419 self .assertEqual (count_number_of_sines (default_partition ), 0 )
420420
421+ def test_scan_different_dtypes (self ):
422+ """Test that the combine function can output different dtypes."""
423+
424+ def fn (carry , x ):
425+ bf16_value , f32_value = x
426+ y = (torch .sin (bf16_value ), torch .sin (f32_value ))
427+ return torch .sin (carry ), y
428+
429+ init = torch .tensor ([0.0 , 0.0 ],
430+ requires_grad = True ,
431+ device = self .device ,
432+ dtype = torch .float16 )
433+ bf16_xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
434+ requires_grad = True ,
435+ device = self .device ,
436+ dtype = torch .bfloat16 )
437+ f32_xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
438+ requires_grad = True ,
439+ device = self .device ,
440+ dtype = torch .float32 )
441+ final_carry , ys = self .run_test (fn , init , (bf16_xs , f32_xs ))
442+ bf16_ys , f32_ys = ys
443+ self .assertEqual (final_carry .dtype , torch .float16 )
444+ self .assertEqual (bf16_ys .dtype , torch .bfloat16 )
445+ self .assertEqual (f32_ys .dtype , torch .float32 )
446+
421447
422448class PyTreeTest (TestBase ):
423449
0 commit comments