Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Implement the original layerwise PowerSGD #49417

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ class DDPCommHookType(Enum):
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
)
# Batching can lead to a faster training at the cost of accuracy.
BATCHED_POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=1,
)
BATCHED_POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=2,
)


def register_ddp_comm_hook(
Expand Down
187 changes: 185 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,193 @@ def powerSGD_hook(
bucket,
) -> torch.futures.Future:
"""
This DDP communication hook implements a simplified PowerSGD gradient compression
This DDP communication hook implements the original PowerSGD gradient compression
algorithm described in https://arxiv.org/abs/1905.13727.
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
high-rank tensors and vector-like rank-1 tensors (for biases).
2) Handles rank-1 tensors by allreducing them without compression:
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
2.1) Allocate contiguous memory for those rank-1 tensors,
and allreduces all the rank-1 tensors as a batch, without compression;
2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor.
3) Handles high-rank tensors by PowerSGD compression:
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
3.2) Computes each P in Ps, which is equal to MQ;
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
3.3) Allreduces Ps as a batch;
3.4) Orthogonizes each P in Ps;
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
3.6) Allreduces Qs as a batch;
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.

TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm start, can take one such step at a time, and alternate between them.

Arguments:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.

Returns:
Future handler of the communication, which updates the gradients in place.

Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
process_group.size() if process_group is not None else dist.get_world_size()
)

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
device = input_tensor.device
dtype = input_tensor.dtype
# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
tensors = [
input_tensor[offset : offset + length].view(sizes)
for offset, length, sizes in zip(
bucket.get_offsets(), bucket.get_lengths(), bucket.get_sizes_list()
)
]

# Step I: Handle rank-1 tensors.
# Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently.
rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1]
rank1_tensors_memory = (
torch.cat([tensor.view(-1) for tensor in rank1_tensors])
if rank1_tensors
else torch.tensor([], device=device)
)

# Step II: Handle high-rank tensors.
# Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently.
high_rank_tensors = [
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
tensor.view(tensor.shape[0], -1)
for tensor in tensors
if tensor.ndimension() > 1
]
total_Ps_size = 0
ps_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
total_Qs_size = 0
qs_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
ps_memory = torch.empty(total_Ps_size, device=device, dtype=dtype)
qs_memory = torch.empty(total_Qs_size, device=device, dtype=dtype)

# Create Ps and Qs that point to the allocated memory.
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
ps.append(
ps_memory[p_idx : p_idx + n * matrix_approximation_rank].view(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
n, matrix_approximation_rank
)
)
qs.append(
qs_memory[q_idx : q_idx + m * matrix_approximation_rank].view(
m, matrix_approximation_rank
)
)
p_idx += n * matrix_approximation_rank
q_idx += m * matrix_approximation_rank

# Initialize and then orthogonalize Qs.
with torch.random.fork_rng(devices=[]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be reasonable to dedupe other use cases of this forking in grad compression to a helper function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had the same feeling. However, in the first implementation, it has a loop in it: for q in qs:, between setting the manual seed and filling random values. It can be a bit tricky. Let me try to do it in a separate refactoring PR.

# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(state.rng.randint(1_000_000_000))
for q in qs:
q.data = torch.randn(
*q.shape,
device="cpu",
dtype=dtype,
).to(device)
_orthogonalize(q)

# Compute Ps.
for tensor, q, p in zip(high_rank_tensors, qs, ps):
torch.matmul(tensor, q, out=p)

# This allreduce is only applied to rank-1 tensors,
# so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs.
# However, this somehow requires a separate future chain at this time.
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
allreduce_contiguous_rank1_tensors_fut = dist.all_reduce(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
rank1_tensors_memory, group=group_to_use, async_op=True
).get_future()

def unpack_rank1_tensors_and_allreduce_ps(fut):
rank1_tensors_memory = fut.value()[0].div_(world_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this isn't fut.wait()? Other calls seem to use fut.wait().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question!

wait() is already called once in the precursor callback, in the return statement, and it should only be called once.

value() means just reading the value without blocking, and it's user's responsibility to ensure proper wait before retrieving the value.

There is a recent PR that changed the semantics of value and wait. See: https://github.com/pytorch/pytorch/pull/48505/files#r532577873

I think I should change more wait() into value() in the original PowerSGD implementation. Previously, these two functions are kind of equivalent on NCCL backend.

idx = 0
for tensor in rank1_tensors:
tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]])
idx += tensor.shape[0]

# Since these Ps will be orthogonized later, no need to divide them by world size.
return [
dist.all_reduce(ps_memory, group=group_to_use, async_op=True)
.get_future()
.wait()[0]
]

def compute_qs(fut):
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
ps_memory = fut.wait()[0]
for p in ps:
_orthogonalize(p)

# Compute Qs.
for tensor, p, q in zip(high_rank_tensors, ps, qs):
torch.matmul(tensor.t(), p, out=q)

# Allreduce Qs.
return [
dist.all_reduce(qs_memory, group=group_to_use, async_op=True)
.get_future()
.wait()[0]
]

def decompress(fut):
qs_memory = fut.wait()[0].div_(world_size)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
assert not torch.any(torch.isnan(tensor))
return [input_tensor]
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

return (
allreduce_contiguous_rank1_tensors_fut.then(
unpack_rank1_tensors_and_allreduce_ps
)
.then(compute_qs)
.then(decompress)
)


def batched_powerSGD_hook(
state: PowerSGDState,
bucket,
) -> torch.futures.Future:
"""
This DDP communication hook implements a simplified PowerSGD gradient compression
algorithm described in https://arxiv.org/abs/1905.13727.
Once gradient tensors are aggregated across all workers, this hook applies
compression to the flattened input tensor that batches per-parameter tensors as follows:
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2) Creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
Expand All @@ -105,7 +288,7 @@ def powerSGD_hook(

Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
Expand Down
47 changes: 47 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
import torch.nn as nn
Expand Down Expand Up @@ -2819,6 +2820,52 @@ def test_DistributedDataParallel_non_default_stream(self):
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
)

@unittest.skipIf(
BACKEND != "nccl",
"Only NCCL backend support DistributedDataParallel",
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_if_rocm
def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self):
stream = torch.cuda.Stream(self.rank)
rank = self.rank
with torch.cuda.stream(stream):
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
net = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(1, 5).to(rank), device_ids=[rank]
)
process_group = torch.distributed.new_group([0, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test with the process_group being the entire world as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed some examples like test_ddp_uneven_inputs_replicated_error in the same file, and I didn't see it's necessary to testing both cases here.

I can change torch.distributed.new_group([0, 1]) to list(range(0, dist.get_world_size())), or just use group, _, rank = self._init_global_test().

I plan to rewrite the unit test in a separate PR, as mentioned in some other comments.

state = powerSGD.PowerSGDState(
process_group=process_group, matrix_approximation_rank=1
)
net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook)
# NOTE: batched_powerSGD_hook cannot pass the following test, because it has a lower accuracy.
for i in range(1000):
# Clear gradients manually.
grad = net.module.weight.grad
if grad is not None:
grad.requires_grad_(False)
grad.zero_()
# Forward + BW
batch = torch.tensor([rank]).float().cuda(rank)
loss = net(batch).sum()
loss.backward()
# For each worker, the gradient on the weight should be worker_rank.
grad = net.module.weight.grad
avg = grad.clone()
# All-reducing the gradient averages should give us the gradient
# average. If not, then one of the workers has not correctly
# written back the averaged gradient before this all-reduce call.
dist.all_reduce(avg)
world_size = int(os.environ["WORLD_SIZE"])
avg.div_(world_size)
expected_grad = sum(i for i in range(world_size)) / world_size
self.assertEqual(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
avg[0, 0],
expected_grad,
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
)


@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_gpu
Expand Down