diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index f83752ddc499..fc29f162790e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -122,6 +122,33 @@ def powerSGD_hook( input_tensor = bucket.get_tensors()[0] device = input_tensor.device dtype = input_tensor.dtype + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.get_index() + input_tensor_cp = None + total_length = input_tensor.shape[0] + if state.use_error_feedback: + # The buckets can be rebuilt during training. + # In this case, the error tensor shape will not be aligned with the input tensor, + # and the error will be re-initialized as zeros. + if ( + bucket_index in state.error_dict + and state.error_dict[bucket_index].shape[0] == total_length + ): + input_tensor.add_(state.error_dict[bucket_index]) + else: + logging.info( + "A zero tensor of length {} that represents local error is created.".format( + total_length + ) + ) + state.error_dict[bucket_index] = torch.zeros(total_length, device=device) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = torch.clone(input_tensor).detach() + # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. tensors = [ input_tensor[offset : offset + length].view(sizes) @@ -242,6 +269,10 @@ def decompress(fut): for p, q, tensor in zip(ps, qs, high_rank_tensors): torch.matmul(p, q.t(), out=tensor) assert not torch.any(torch.isnan(tensor)) + + if state.use_error_feedback: + # Memorize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor return [input_tensor] return (