1- from absl import logging
21import copy
32import torch
43import torch .distributed as dist
76from torch .nn .parallel import DistributedDataParallel as DDP
87import torch_xla .core .xla_model as xm
98import torch_xla .distributed .xla_backend
10- from torch_xla .experimental import pjrt
9+
10+
11+ # The followings are helpers useful for debugging purpose.
12+ def comp_hook (state : object ,
13+ bucket : dist .GradBucket ) -> torch .futures .Future [torch .Tensor ]:
14+ """
15+ Debug utils. Please refer to DistributedDataParallel.register_comm_hook to learn
16+ how to use it.
17+ """
18+ print ("comp_hook called." )
19+ fut = torch .futures .Future ()
20+ fut .set_result (bucket .buffer ())
21+ return fut
22+
23+
24+ def calculate_model_size (model ):
25+ """
26+ Debug utils. Calculate the given model's size in mb.
27+ """
28+ param_size = 0
29+ for param in model .parameters ():
30+ param_size += param .nelement () * param .element_size ()
31+ buffer_size = 0
32+ for buffer in model .buffers ():
33+ buffer_size += buffer .nelement () * buffer .element_size ()
34+
35+ size_all_mb = (param_size + buffer_size ) / 1024 ** 2
36+ print ('model size: {:.3f}MB' .format (size_all_mb ))
37+
38+
39+ class LargeNet (nn .Module ):
40+
41+ def __init__ (self ):
42+ super (LargeNet , self ).__init__ ()
43+ self .net1 = nn .Linear (10 , 1000 )
44+ self .net2 = nn .Linear (1000 , 1000 )
45+ self .net3 = nn .Linear (1000 , 1000 )
46+ self .relu = nn .ReLU ()
47+ self .net4 = nn .Linear (1000 , 10 )
48+
49+ def forward (self , x ):
50+ output1 = self .relu (self .net1 (x ))
51+ output2 = self .relu (self .net2 (output1 ))
52+ output3 = self .relu (self .net3 (output2 ))
53+ return self .net4 (output3 )
54+
55+
56+ class SmallNet (nn .Module ):
57+
58+ def __init__ (self ):
59+ super (SmallNet , self ).__init__ ()
60+ self .net = nn .Linear (10 , 10 )
61+
62+ def forward (self , x ):
63+ return self .net (x )
1164
1265
1366def init_xla_backend (init_file : str ):
@@ -40,17 +93,32 @@ def train_step(model, inputs, labels, optimizer, loss_fn):
4093 return loss
4194
4295
43- def ddp_correctness (init_file : str ):
96+ def ddp_correctness (init_file : str ,
97+ * ,
98+ use_large_net : bool = False ,
99+ debug : bool = False ):
44100 rank , world_size = init_xla_backend (init_file )
45101
46102 device = xm .xla_device ()
47103
48104 # To make nn.Linear init same parameters across devices.
49105 torch .manual_seed (2022 )
50- cpu_model = nn .Linear (10 , 10 )
106+ # Lower range probably makes sense too. Anyway, stick to 100 as the original PoC.
107+ steps = 100
108+ cpu_model = SmallNet ()
109+ if use_large_net :
110+ steps = 5 # To save test time.
111+ cpu_model = LargeNet ()
112+
51113 # TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option.
114+ # bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
115+ # using models that are too larger (25 mb).
116+ # To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
52117 ddp_model = DDP (
53- copy .deepcopy (cpu_model ).to (device ), gradient_as_bucket_view = True )
118+ copy .deepcopy (cpu_model ).to (device ),
119+ gradient_as_bucket_view = True ,
120+ bucket_cap_mb = 1 )
121+ # ddp_model.register_comm_hook(state=None, hook=comp_hook)
54122
55123 cpu_optimizer = optim .SGD (cpu_model .parameters (), lr = 1e-100 )
56124 ddp_optimizer = optim .SGD (ddp_model .parameters (), lr = 1e-100 )
@@ -59,8 +127,7 @@ def ddp_correctness(init_file: str):
59127 local_batch_size = 2
60128 global_batch_size = local_batch_size * world_size
61129 offset = rank * local_batch_size
62- # Lower range probably makes sense too. Anyway, stick to 100 as the original PoC.
63- for step in range (100 ):
130+ for step in range (steps ):
64131 # To make torch.randn produce same results across devices.
65132 torch .manual_seed (2022 + step )
66133
@@ -82,7 +149,9 @@ def ddp_correctness(init_file: str):
82149 # TODO(@alanwaketan): Investigate why the atol here is this low.
83150 assert torch .allclose (cpu_loss , ddp_loss , atol = 1e-02 )
84151 assert_all_close (cpu_model .parameters (), ddp_model .parameters ())
85- # To display the below messages, set '--verbosity=1'.
86- logging .debug (
87- "iteration %d: cpu_loss = %f, ddp_loss = %f, cpu_model.parameters() ~= ddp_model.parameters()" ,
88- step , cpu_loss , ddp_loss )
152+ # To display the below messages, set '--debug'.
153+ # Here we don't use FLAGS.debug because this function is often ran in different processes than the launcher.
154+ if debug :
155+ print (
156+ f"iteration { step } : cpu_loss = { cpu_loss } , ddp_loss = { ddp_loss } , cpu_model.parameters() ~= ddp_model.parameters()"
157+ )
0 commit comments