Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Warm-start of PowerSGD #49451

Closed
wants to merge 8 commits into from
25 changes: 16 additions & 9 deletions test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _create_store(self):
return store

def test_address_already_in_use(self):
if sys.platform == 'win32':
if sys.platform == "win32":
err_msg_reg = "Only one usage of each socket address*"
else:
err_msg_reg = "^Address already in use$"
Expand Down Expand Up @@ -339,6 +339,7 @@ def _test_numkeys_delkeys(self, fs):
def test_numkeys_delkeys(self):
self._test_numkeys_delkeys(self._create_store())


class PrefixTCPStoreTest(TestCase, StoreTestBase):
def setUp(self):
super(PrefixTCPStoreTest, self).setUp()
Expand Down Expand Up @@ -3803,15 +3804,21 @@ def _test_powerSGD_ddp_comm_hook_nccl(self, gradient_as_bucket_view=False):
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# Get GPU model with the hook registered.
state = powerSGD.PowerSGDState(
process_group=process_group, matrix_approximation_rank=1
)
gpu_model = self._gpu_model_with_ddp_comm_hook(
process_group, powerSGD.powerSGD_hook, gradient_as_bucket_view, state
)
# Test the hook with different algorithmic configs.
for use_error_feedback, warm_start in product([True, False], [True, False]):
state = powerSGD.PowerSGDState(
process_group=process_group,
matrix_approximation_rank=1,
use_error_feedback=use_error_feedback,
warm_start=warm_start,
)
for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]:
gpu_model = self._gpu_model_with_ddp_comm_hook(
process_group, hook, gradient_as_bucket_view, state
)

# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
"""
Expand Down
171 changes: 126 additions & 45 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,19 @@ class PowerSGDState(object):
"process_group",
"matrix_approximation_rank",
"use_error_feedback",
"warm_start",
"rng",
"error_dict",
"p_memory_dict",
"q_memory_dict",
]

def __init__(
self,
process_group,
matrix_approximation_rank=1,
use_error_feedback=True,
warm_start=True,
random_seed=0,
):
self.process_group = process_group
Expand All @@ -59,6 +63,12 @@ def __init__(
# sometimes it is possible to converge to the optima without error feedback.
# See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
self.use_error_feedback = use_error_feedback
# Warm-start reuses P(s) and Q(s) from the previous iteration.
# This can improve the approximation quality and hence improve the accuracy.
# Additionally, by avoiding the initialization of these low-rank tensors at every step,
# this can also accelerate training.
# However, this is at the cost of extra memory.
self.warm_start = warm_start
# The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
# but in the same order for all the DDP replicas.
# Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
Expand All @@ -68,6 +78,8 @@ def __init__(
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict = {}
self.p_memory_dict = {}
self.q_memory_dict = {}


def powerSGD_hook(
Expand Down Expand Up @@ -174,16 +186,35 @@ def powerSGD_hook(
if tensor.ndimension() > 1
]
total_Ps_size = 0
ps_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
total_Qs_size = 0
qs_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
ps_memory = torch.empty(total_Ps_size, device=device, dtype=dtype)
qs_memory = torch.empty(total_Qs_size, device=device, dtype=dtype)
# Reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size
or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size
):
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
logging.info(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
"Allocating contiguous memory of length {} for Ps, and of length {} for Qs, respectively.".format(
total_Ps_size, total_Qs_size
)
)
state.p_memory_dict[bucket_index] = torch.empty(
total_Ps_size, device=device, dtype=dtype
)
state.q_memory_dict[bucket_index] = torch.empty(
total_Qs_size, device=device, dtype=dtype
)

# Create Ps and Qs that point to the allocated memory.
ps = []
Expand All @@ -194,14 +225,14 @@ def powerSGD_hook(
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
ps.append(
ps_memory[p_idx : p_idx + n * matrix_approximation_rank].view(
n, matrix_approximation_rank
)
state.p_memory_dict[bucket_index][
p_idx : p_idx + n * matrix_approximation_rank
].view(n, matrix_approximation_rank)
)
qs.append(
qs_memory[q_idx : q_idx + m * matrix_approximation_rank].view(
m, matrix_approximation_rank
)
state.q_memory_dict[bucket_index][
q_idx : q_idx + m * matrix_approximation_rank
].view(m, matrix_approximation_rank)
)
p_idx += n * matrix_approximation_rank
q_idx += m * matrix_approximation_rank
Expand Down Expand Up @@ -242,13 +273,15 @@ def unpack_rank1_tensors_and_allreduce_ps(fut):

# Since these Ps will be orthogonized later, no need to divide them by world size.
return [
dist.all_reduce(ps_memory, group=group_to_use, async_op=True)
dist.all_reduce(
state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
]

def compute_qs(fut):
ps_memory = fut.wait()[0]
state.p_memory_dict[bucket_index] = fut.wait()[0]
for p in ps:
_orthogonalize(p)

Expand All @@ -258,13 +291,15 @@ def compute_qs(fut):

# Allreduce Qs.
return [
dist.all_reduce(qs_memory, group=group_to_use, async_op=True)
dist.all_reduce(
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
]

def decompress(fut):
qs_memory = fut.wait()[0].div_(world_size)
state.q_memory_dict[bucket_index] = fut.wait()[0].div_(world_size)

for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
Expand All @@ -274,6 +309,10 @@ def decompress(fut):
if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if not state.warm_start:
state.p_memory_dict.clear()
state.q_memory_dict.clear()

return [input_tensor]

return (
Expand Down Expand Up @@ -367,56 +406,98 @@ def batched_powerSGD_hook(
input_tensor_cp = torch.clone(input_tensor).detach()
matrix = input_tensor.view(square_side_length, square_side_length)

def create_low_rank_tensor(fill_random_values, rng):
"Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
if fill_random_values:
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(rng.randint(1_000_000_000))
return torch.randn(
# Reuse P and Q from the previous iteration if possible.
# The memory spaces of P and Q need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape
!= (square_side_length, state.matrix_approximation_rank)
):
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
logging.info(
"Initializing low-rank tensors P and Q, each of which has a shape of {} x {}.".format(
square_side_length, state.matrix_approximation_rank
)
)

def create_low_rank_tensor(fill_random_values, rng):
"Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
if fill_random_values:
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling
# anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(rng.randint(1_000_000_000))
return torch.randn(
square_side_length,
state.matrix_approximation_rank,
device="cpu",
dtype=input_tensor.dtype,
).to(device)
else:
return torch.empty(
square_side_length,
state.matrix_approximation_rank,
device="cpu",
device=device,
dtype=input_tensor.dtype,
).to(device)
else:
return torch.empty(
square_side_length,
state.matrix_approximation_rank,
device=device,
dtype=input_tensor.dtype,
)
)

p = create_low_rank_tensor(fill_random_values=False, rng=state.rng)
q = create_low_rank_tensor(fill_random_values=True, rng=state.rng)
_orthogonalize(q, 0)
state.p_memory_dict[bucket_index] = create_low_rank_tensor(
fill_random_values=False, rng=state.rng
)
state.q_memory_dict[bucket_index] = create_low_rank_tensor(
fill_random_values=True, rng=state.rng
)
_orthogonalize(state.q_memory_dict[bucket_index], 0)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

torch.matmul(matrix, q, out=p)
allreduce_p_fut = dist.all_reduce(p, group=group_to_use, async_op=True).get_future()
torch.matmul(
matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
)
allreduce_p_fut = dist.all_reduce(
state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
).get_future()

def compute_q(fut):
p = fut.value()[0]
_orthogonalize(p, 0)
state.p_memory_dict[bucket_index] = fut.value()[0]
_orthogonalize(state.p_memory_dict[bucket_index], 0)

torch.matmul(matrix.t(), p, out=q)
torch.matmul(
matrix.t(),
state.p_memory_dict[bucket_index],
out=state.q_memory_dict[bucket_index],
)

return [
dist.all_reduce(q, group=group_to_use, async_op=True).get_future().wait()[0]
dist.all_reduce(
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
]

def decompress(fut):
q = fut.value()[0].div_(world_size)
torch.matmul(p, q.t(), out=matrix)
state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size)
torch.matmul(
state.p_memory_dict[bucket_index],
state.q_memory_dict[bucket_index].t(),
out=matrix,
)

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if torch.cuda.is_available():
torch.cuda.synchronize()
if not state.warm_start:
state.p_memory_dict.clear()
state.q_memory_dict.clear()
ret = input_tensor.resize_(total_length)
return [ret]

Expand Down