@@ -836,6 +836,10 @@ def test_mark_sharding_ir(self):
836836
837837 self .assertTrue (torch .allclose (expected , actual .cpu ()))
838838
839+ def _check_sharding_annotation (self , tensor , sharding_annotation ):
840+ hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([tensor ])
841+ self .assertIn (sharding_annotation , hlo )
842+
839843 @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
840844 "Multiple devices required for autograd sharding test" )
841845 def test_mark_sharding_autograd (self ):
@@ -849,9 +853,56 @@ def test_mark_sharding_autograd(self):
849853 t = y .sum ()
850854 # Backward pass
851855 t .backward ()
852- hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([z .grad ])
853- sharding_annotation = 'sharding={devices=[1,%d]' % self .n_devices
854- self .assertIn (sharding_annotation , hlo )
856+ self ._check_sharding_annotation (z .grad ,
857+ 'sharding={devices=[1,%d]' % self .n_devices )
858+
859+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
860+ "Multiple devices required for autograd sharding test" )
861+ def test_mark_sharding_aot_compile (self ):
862+ mesh = self ._get_mesh ((self .n_devices ,))
863+
864+ def my_fn (x ):
865+ z = torch .sin (x )
866+ y = MarkShardingFunction .apply (z , mesh , (0 ,))
867+ return y + 42
868+
869+ from functorch .compile import aot_function , make_boxed_func # type: ignore
870+
871+ x = torch .randn (8 )
872+ x = x .to ('xla' ).requires_grad_ (True )
873+
874+ graphs = []
875+
876+ def get_graph (gm : torch .fx .GraphModule , _ ):
877+ graphs .append (gm )
878+ return make_boxed_func (gm )
879+
880+ y = aot_function (my_fn , get_graph )(x )
881+ t = y .sum ()
882+ t .backward ()
883+ torch_xla .sync ()
884+
885+ sharding_spec = '{devices=[%d]' % self .n_devices
886+
887+ # Check that the output has sharding.
888+ self .assertIn (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (y ))
889+
890+ # Check that the gradient has sharding.
891+ self .assertIsNotNone (x .grad )
892+ self .assertIn (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (x .grad ))
893+
894+ # Check that the AOTAutograd captured graphs also each contains a mark_sharding.
895+ fwd , bwd = graphs
896+
897+ inp = torch .randn (8 ).to ('xla' ).requires_grad_ (False )
898+ out , * residuals = fwd (inp )
899+ self ._check_sharding_annotation (out ,
900+ 'sharding={devices=[%d]' % self .n_devices )
901+
902+ tangents = torch .randn (8 ).to ('xla' ).requires_grad_ (False )
903+ out , = bwd (* residuals , tangents )
904+ self ._check_sharding_annotation (out ,
905+ 'sharding={devices=[%d]' % self .n_devices )
855906
856907 def test_sharded_tensor_aliasing (self ):
857908 met .clear_all ()
0 commit comments