diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 0c85b5c2f193..2e30961f76eb 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -228,7 +228,7 @@ def train_imagenet(): mesh_shape = (1, 1, num_devices // 2, 2) input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H')) print( - f'Sharding input images on spatial dimensions with mesh {mesh.get_logical_mesh()}' + f'Sharding input images on spatial dimensions with mesh {input_mesh.get_logical_mesh()}' ) writer = None