Skip to content

Commit d01880c

Browse files
authored
[SPMD] Add FSDP sharding for test_train_spmd_linear_model.py (#5299)
Summary: This diff adds FSDP sharding for test_train_spmd_linear_model.py. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_train_spmd_linear_model.py --sharding fsdp
1 parent df680da commit d01880c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

test/spmd/test_train_spmd_linear_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
MODEL_OPTS = {
1717
'--sharding': {
18-
'choices': ['batch', 'megatron-lm'],
18+
'choices': ['batch', 'megatron-lm', 'fsdp'],
1919
'nargs': '+',
2020
'default': [],
2121
},
@@ -58,7 +58,6 @@ def forward(self, x):
5858

5959
def train():
6060
print('===> Preparing data..')
61-
num_epochs = 18
6261
lr = 0.1
6362
train_loader = xu.SampleGenerator(
6463
data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim),
@@ -78,6 +77,14 @@ def train():
7877
train_loader = pl.MpDeviceLoader(
7978
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
8079

80+
if 'fsdp' in FLAGS.sharding:
81+
train_loader = pl.MpDeviceLoader(
82+
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
83+
print('Sharding model weights')
84+
# Shard the weights according to their 0th dim
85+
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
86+
xs.mark_sharding(model.fc2.weight, mesh, (0, 1))
87+
8188
if 'megatron-lm' in FLAGS.sharding:
8289
print('Sharding model weights')
8390
# Shard the first layer's weights row-wise

0 commit comments

Comments
 (0)