Skip to content

[FSDP2] The evil record_stream in c10d causes FSDP2 to over-allocate GPU memory #147168

@leonardo0lyj

Description

@leonardo0lyj

Hey Andrew @awgu,

As a big fan of FSDP2, I find an potential bug 😄

Demand:

  • No inter-stream memory fragmentation (incurred by copy in streams)
  • Explicit Prefetch
  • CPU runs a head of GPU by a lot

_set_unshard_async_op(True)

To satisfy these demands, FSDP2 has to turn on _set_unshard_async_op(True) with explicit prefetch set_modules_to_forward_prefetch and set_modules_to_backward_prefetch.

Memory Over-Allocation

Then memory over-allocation happens like this:

Image

with memory traces:

Image

Image

Root Cause

As known to all, these memory over-allocations are caused by the evil tensor.record_stream(ncclStream). Although FSDP2 tried to avoid this evil originated from FSDP1, such record_stream still is embedded in all c10d collectives (when async_op=True). Therefore, FSDP2 still suffers over-allocation from this evil in c10d.

Candidate Solution

Not sure how can we avoid the record_stream even when async_op=True?

IMO, candidate solutions are below:

  1. Make TORCH_NCCL_AVOID_RECORD_STREAMS=True as an default value, getting rid of the record_stream in c10d. (Safety should be good without record_stream, as collective with async_op=True usually starts from allocation stream and ends at allocation stream, or users indeed know how to manually sync streams.)

  2. Make TORCH_NCCL_AVOID_RECORD_STREAMS=True an advanced option to each collective, such as dist.all_gather(..., _avoid_record_stream=True). This limits the scope of environmental TORCH_NCCL_AVOID_RECORD_STREAMS to each specific collective.

  3. Use only dist.all_gather(async_op=False) in FSDP2, but changes the current_stream to the all_gather_stream such that all gather still allocates/frees in current_stream while runs in all_gather_stream and overlaps with current_stream, just like async_op=True.

def get_all_gather_streams(
        self, async_op: bool, training_state: TrainingState
    ) -> tuple[torch.Stream, torch.Stream]:
        if not async_op and training_state in (
            TrainingState.FORWARD,
            TrainingState.PRE_BACKWARD,
        ):
            # Use separate streams for implicit prefetching
            return self.all_gather_copy_in_stream, self.all_gather_stream
        
        # Use separate streams for explicit prefetching!
        current_stream = self.device_handle.current_stream()
        return current_stream, self.all_gather_stream # Change this!

How do you prefer?

(Let us make FSDP great again 😄)

Code

P.S. the code to reproduce over-allocation:

class MLP(nn.Module):
    def __init__(self, hidden_dim: int, bias: bool = False):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


class MultiMLP(nn.Module):
    def __init__(self, hidden_dim: int, bias: bool = False, layers: int = 4):
        super().__init__()
        self.pre_norm = nn.LayerNorm(hidden_dim, bias=bias)
        self.mlps = nn.ModuleList([MLP(hidden_dim, bias) for _ in range(layers)])
        self.post_norm = nn.LayerNorm(hidden_dim, bias=bias)

    def forward(self, x):
        x = self.pre_norm(x)
        for mlp in self.mlps:
            x = x + mlp(x)
        x = self.post_norm(x)
        return x

class TestMemory(DTensorTestBase):
    @with_comms
    def test_over_allocation(self):
        mesh = init_device_mesh("cuda", (self.world_size,))
        device = torch.device("cuda")
        hidden_dim = 10240
        total_bsz = 16

        # ----- init model --------
        torch.manual_seed(0)
        model = MultiMLP(hidden_dim=hidden_dim).to(device).to(torch.float32)

        # --------  fsdp2 wrap --------
        fully_shard_fn = functools.partial(
            fully_shard,
            mesh=mesh,
            reshard_after_forward=True,
        )

        last_fsdp_module = None
        for module in model.modules():
            if isinstance(module, MLP):
                fully_shard_fn(module)
                if last_fsdp_module is not None:
                    last_fsdp_module.set_modules_to_forward_prefetch([module])
                    module.set_modules_to_backward_prefetch([last_fsdp_module])
                last_fsdp_module = module
        fsdp_model = fully_shard_fn(model)
        fsdp_model._set_unshard_async_op(True)

        optim = torch.optim.Adam(fsdp_model.parameters())

        # ----- init data -----
        torch.manual_seed(self.rank)
        bsz = total_bsz // self.world_size

        # --------  training loop --------
        torch.distributed.barrier()
        torch.cuda.synchronize(self.rank)
        
        train_iter = 4
        for iter in range(train_iter):
            # torch.distributed.barrier()
            # torch.cuda.synchronize(self.rank)

            if self.rank == 0 and iter == train_iter - 1:
                torch.cuda.memory._record_memory_history(max_entries=int(1E6))

            with record_function("## zero grad ##"):
                optim.zero_grad()

            input = torch.randn((bsz, hidden_dim), device="cuda")

            with record_function(f"## forward ##"):
                output = fsdp_model(input)
                loss = output.mean()

            with record_function(f"## backward ##"):
                loss.backward()

            with record_function("## optimizer step ##"):
                optim.step()

            if self.rank == 0 and iter == train_iter - 1:
                timestamp = datetime.now().strftime("%b_%d_%H_%M_%S")
                file_name = f"mem_{timestamp}"
                torch.cuda.memory._dump_snapshot(f"{file_name}.pickle")
                torch.cuda.memory._record_memory_history(enabled=None)

        torch.distributed.barrier()
        torch.cuda.synchronize(self.rank)

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang

Metadata

Metadata

Assignees

Labels

module: c10dIssues/PRs related to collective communications and process groupsmodule: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions