diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 0c7fba859ea35..16baadd3c1c33 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -180,36 +180,34 @@ def powerSGD_hook( Once gradient tensors are aggregated across all workers, this hook applies compression as follows: - 1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases). + 1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: - 2. Divides all the tensors into two groups: + 1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. - 2.1 High-rank tensors that can have enough saving in bandwidth after the compression should be compressed before allreduce. + 1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). - 2.2 Rest of the tensors will be directly allreduced without compression (this group is referred to as rank-1 tensors below). + 2. Handles uncompressed tensors: - 3. Handles rank-1 tensors by allreducing them without compression: + 2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; - 3.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression; + 2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. - 3.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor. + 3. Handles the tensors that should be compressed by PowerSGD compression: - 4. Handles high-rank tensors by PowerSGD compression: - - 4.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, + 3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; - 4.2. Computes each P in Ps, which is equal to MQ; + 3.2. Computes each P in Ps, which is equal to MQ; - 4.3. Allreduces Ps as a batch; + 3.3. Allreduces Ps as a batch; - 4.4. Orthogonalizes each P in Ps; + 3.4. Orthogonalizes each P in Ps; - 4.5. Computes each Q in Qs, which is approximately equal to M^TP; + 3.5. Computes each Q in Qs, which is approximately equal to M^TP; - 4.6. Allreduces Qs as a batch; + 3.6. Allreduces Qs as a batch; - 4.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. + 3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, @@ -274,43 +272,32 @@ def powerSGD_hook( # Step I: Divide all the tensors into two groups, # one will be compressed before allreduce and the other will be directly allreduced without compression. - rank1_tensors, high_rank_tensors, high_rank_tensors_to_compress = [], [], [] - for tensor in tensors: - if tensor.ndimension() <= 1: - rank1_tensors.append(tensor) - else: - high_rank_tensors.append(tensor.view(tensor.shape[0], -1)) - + tensors_to_compress, uncompressed_tensors = [], [] total_Ps_size = 0 total_Qs_size = 0 - - # Treat high-rank tensors that do not gain compression benefit as rank-1 tensors - - while len(high_rank_tensors): - tensor = high_rank_tensors.pop() - n, m = tensor.shape + for tensor in tensors: + matrix = tensor.view(tensor.shape[0], -1) + n, m = matrix.shape matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) - if _should_compress( n, m, matrix_approximation_rank, state.min_compression_rate ): - high_rank_tensors_to_compress.append(tensor) + tensors_to_compress.append(matrix) total_Ps_size += n * matrix_approximation_rank total_Qs_size += m * matrix_approximation_rank else: - rank1_tensors.append(tensor.view(-1)) + uncompressed_tensors.append(tensor) - # Step II: Handle rank-1 tensors (including the high-rank tensors that not worth compression). - # Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently. - rank1_tensors_memory = ( - torch.cat([tensor.view(-1) for tensor in rank1_tensors]) - if rank1_tensors + # Step II: Handle uncompressed tensors. + # Allocate contiguous memory for these tensors to allreduce efficiently. + uncompressed_tensors_memory = ( + torch.cat([tensor.view(-1) for tensor in uncompressed_tensors]) + if uncompressed_tensors else torch.tensor([], device=device, dtype=dtype) ) - # Step III: Handle high-rank tensors that should be compressed. - # Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently. - + # Step III: Handle the tensors that should be compressed. + # Allocate contiguous memory for Ps and Qs to allreduce efficiently. # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. need_randomize_qs = False @@ -336,7 +323,7 @@ def powerSGD_hook( qs = [] p_idx = 0 q_idx = 0 - for tensor in high_rank_tensors_to_compress: + for tensor in tensors_to_compress: n, m = tensor.shape matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) ps.append( @@ -376,22 +363,22 @@ def powerSGD_hook( _orthogonalize(q) # Compute Ps. - for tensor, q, p in zip(high_rank_tensors_to_compress, qs, ps): + for tensor, q, p in zip(tensors_to_compress, qs, ps): torch.matmul(tensor, q, out=p) - # This allreduce is only applied to rank-1 tensors, - # so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs. + # This allreduce is only applied to uncompressed tensors, + # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs. # However, this somehow requires a separate future chain at this time. - allreduce_contiguous_rank1_tensors_fut = dist.all_reduce( - rank1_tensors_memory, group=group_to_use, async_op=True + allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce( + uncompressed_tensors_memory, group=group_to_use, async_op=True ).get_future() - def unpack_rank1_tensors_and_allreduce_ps(fut): - rank1_tensors_memory = fut.value()[0].div_(world_size) + def unpack_uncompressed_tensors_and_allreduce_ps(fut): + uncompressed_tensors_memory = fut.value()[0].div_(world_size) idx = 0 - for tensor in rank1_tensors: - tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]]) - idx += tensor.shape[0] + for tensor in uncompressed_tensors: + tensor.copy_(uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)) + idx += tensor.numel() # Since these Ps will be orthogonalized later, no need to divide them by world size. return [ @@ -408,7 +395,7 @@ def compute_qs(fut): _orthogonalize(p) # Compute Qs. - for tensor, p, q in zip(high_rank_tensors_to_compress, ps, qs): + for tensor, p, q in zip(tensors_to_compress, ps, qs): torch.matmul(tensor.t(), p, out=q) # TODO: The above procedure does two matmul+allreduce steps per iteration -- @@ -427,7 +414,7 @@ def compute_qs(fut): def decompress(fut): state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size) - for p, q, tensor in zip(ps, qs, high_rank_tensors_to_compress): + for p, q, tensor in zip(ps, qs, tensors_to_compress): torch.matmul(p, q.t(), out=tensor) if torch.cuda.is_available(): torch.cuda.synchronize(device) @@ -444,8 +431,8 @@ def decompress(fut): return [input_tensor] return ( - allreduce_contiguous_rank1_tensors_fut.then( - unpack_rank1_tensors_and_allreduce_ps + allreduce_contiguous_uncompressed_tensors_fut.then( + unpack_uncompressed_tensors_and_allreduce_ps ) .then(compute_qs) .then(decompress)