diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 17414df3024d..e1d475a34425 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -63,16 +64,13 @@ def __init__( # there will be differences between the gradients that are never synchronized. self.rng = np.random.RandomState(random_seed) # Since there is only a single state instance for all the input buckets, - # need to maintain a dictionary that maps each bucket to the local error. - # TODO(wayi): Currently the key is the (hashcode of) input tensor, which may change across steps, - # since the bucket can be rebuilt in the forward pass (to save peak memory usage). - # Need to add an index field to the input bucket of comm hook. + # need to maintain a dictionary that maps each bucket index to the local error. self.error_dict = {} def powerSGD_hook( state: PowerSGDState, - bucket: dist._GradBucket, + bucket, ) -> torch.futures.Future: """ This DDP communication hook implements a simplified PowerSGD gradient compression @@ -127,11 +125,26 @@ def powerSGD_hook( input_tensor[total_length:padded_total_length].fill_(0) # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.get_index() if state.use_error_feedback: - if input_tensor in state.error_dict: - input_tensor.add_(state.error_dict[input_tensor]) + # The buckets can be rebuilt during training. + # In this case, the error tensor shape will not be aligned with the input tensor, + # and the error will be re-initialized as zeros. + if ( + bucket_index in state.error_dict + and state.error_dict[bucket_index].shape[0] == padded_total_length + ): + input_tensor.add_(state.error_dict[bucket_index]) else: - state.error_dict[input_tensor] = torch.zeros(padded_total_length, device=device) + logging.info( + "A zero tensor of length {} that represents local error is created.".format( + padded_total_length + ) + ) + state.error_dict[bucket_index] = torch.zeros( + padded_total_length, device=device + ) + # Keep a copy of the input tensor, # so that we can compute the local error caused by compression later, # by comparing this copy and the input tensor updated after decompression. @@ -181,7 +194,7 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. - state.error_dict[input_tensor] = input_tensor_cp - input_tensor + state.error_dict[bucket_index] = input_tensor_cp - input_tensor ret = input_tensor.resize_(total_length) return [ret]