11
11
12
12
13
13
def ddp_setup ():
14
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
15
- init_process_group (backend = "nccl" )
14
+ rank = int (os .environ ["LOCAL_RANK" ])
15
+ if torch .accelerator .is_available ():
16
+ device = torch .device (f"{ torch .accelerator .current_accelerator ()} :{ rank } " )
17
+ torch .accelerator .set_device_index (rank )
18
+ print (f"Running on rank { rank } on device { device } " )
19
+ else :
20
+ print (f"Multi-GPU environment not detected" )
21
+
22
+ backend = torch .distributed .get_default_backend_for_device (rank )
23
+ torch .distributed .init_process_group (backend = backend , rank = rank , device_id = rank )
24
+
16
25
17
26
class Trainer :
18
27
def __init__ (
@@ -38,7 +47,7 @@ def __init__(
38
47
self .model = DDP (self .model , device_ids = [self .local_rank ])
39
48
40
49
def _load_snapshot (self , snapshot_path ):
41
- loc = f"cuda: { self . local_rank } "
50
+ loc = str ( torch . accelerator . current_accelerator ())
42
51
snapshot = torch .load (snapshot_path , map_location = loc )
43
52
self .model .load_state_dict (snapshot ["MODEL_STATE" ])
44
53
self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -104,8 +113,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
104
113
if __name__ == "__main__" :
105
114
import argparse
106
115
parser = argparse .ArgumentParser (description = 'simple distributed training job' )
107
- parser .add_argument ('total_epochs' , type = int , help = 'Total epochs to train the model' )
108
- parser .add_argument ('save_every' , type = int , help = 'How often to save a snapshot' )
116
+ parser .add_argument ('total_epochs' , default = 50 , type = int , help = 'Total epochs to train the model' )
117
+ parser .add_argument ('save_every' , default = 5 , type = int , help = 'How often to save a snapshot' )
109
118
parser .add_argument ('--batch_size' , default = 32 , type = int , help = 'Input batch size on each device (default: 32)' )
110
119
args = parser .parse_args ()
111
120
0 commit comments