1515
1616MODEL_OPTS = {
1717 '--sharding' : {
18- 'choices' : ['batch' , 'megatron-lm' ],
18+ 'choices' : ['batch' , 'megatron-lm' , 'fsdp' ],
1919 'nargs' : '+' ,
2020 'default' : [],
2121 },
@@ -58,7 +58,6 @@ def forward(self, x):
5858
5959def train ():
6060 print ('===> Preparing data..' )
61- num_epochs = 18
6261 lr = 0.1
6362 train_loader = xu .SampleGenerator (
6463 data = (torch .zeros (FLAGS .batch_size , FLAGS .input_dim ),
@@ -78,6 +77,14 @@ def train():
7877 train_loader = pl .MpDeviceLoader (
7978 train_loader , device , input_sharding = xs .ShardingSpec (mesh , (0 , 1 )))
8079
80+ if 'fsdp' in FLAGS .sharding :
81+ train_loader = pl .MpDeviceLoader (
82+ train_loader , device , input_sharding = xs .ShardingSpec (mesh , (0 , 1 )))
83+ print ('Sharding model weights' )
84+ # Shard the weights according to their 0th dim
85+ xs .mark_sharding (model .fc1 .weight , mesh , (0 , 1 ))
86+ xs .mark_sharding (model .fc2 .weight , mesh , (0 , 1 ))
87+
8188 if 'megatron-lm' in FLAGS .sharding :
8289 print ('Sharding model weights' )
8390 # Shard the first layer's weights row-wise
0 commit comments