diff --git a/torch/distributed/pipelining/_PipelineStage.py b/torch/distributed/pipelining/_PipelineStage.py index b30d99366caf5..db0340677b172 100644 --- a/torch/distributed/pipelining/_PipelineStage.py +++ b/torch/distributed/pipelining/_PipelineStage.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist import torch.fx as fx +import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensor from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.fx.node import map_aggregate @@ -55,11 +56,11 @@ def __repr__(self): def _make_tensor_from_meta( - example: FakeTensor, + example: Union[torch.Tensor, FakeTensor], device: torch.device, ) -> torch.Tensor: """ - Create a real tensor from a fake tensor. + Create a real tensor from a tensor. """ return torch.empty( example.size(), @@ -142,7 +143,7 @@ def __init__( self.log_prefix = f"[Stage {self.stage_index}]" # Forward infra - self.args_recv_info: Dict[int, Tuple[InputInfo]] = {} + self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {} self.set_requires_grad: Dict[int, bool] = {} self.act_send_info: Dict[int, List] = {} @@ -211,7 +212,7 @@ def _create_grad_recv_info( def _get_recv_ops( self, - recv_infos: Tuple[InputInfo], + recv_infos: Tuple[InputInfo, ...], ) -> List[dist.P2POp]: """ Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. @@ -239,7 +240,7 @@ def get_fwd_recv_ops(self) -> List[dist.P2POp]: Returns a list of ops that are needed to receive the input arguments for this stage. """ - recv_infos: Tuple[InputInfo] = self.args_recv_info[self.fwd_chunk_id] + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[self.fwd_chunk_id] # In case there is backward pass, set requires_grad for receive buffers # before first forward @@ -360,7 +361,7 @@ def clear_runtime_states(self) -> None: def _map_tensor_from_recv_info( self, - recv_infos: Tuple[InputInfo], + recv_infos: Tuple[InputInfo, ...], ): """ Map tensors from recv infos to a list. @@ -819,3 +820,399 @@ def __init__( # Get my pipe info pipe_info = pipe.info() super().__init__(stage_module, stage_index, pipe_info, device, group) + + +# Manual PipelineStage functions and definition + +METADATA_TENSOR_LEN = 100 +PLACEHOLDER_VAL = -1 + + +def create_empty_tensors( + tensor: Union[torch.Tensor, List[torch.Tensor]], device: torch.device +) -> List[torch.Tensor]: + """ + Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), + and places them on the specified device. + Args: + tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s). + device (torch.device): The device where the new tensors will be placed. + Returns: + List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s). + """ + if isinstance(tensor, torch.Tensor): + return [torch.empty_like(tensor, device=device)] + elif isinstance(tensor, (list, tuple)): + return [torch.empty_like(t, device=device) for t in tensor] + raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors") + + +def create_metadata_tensor( + tensors: Optional[List[torch.Tensor]] = None, + device: Optional[torch.device] = torch.device("cpu"), +) -> torch.Tensor: + """ + Create a metadata tensor that can be sent over the wire. + This tensor contains the number of dimensions and the shape of each tensor being sent. + + The data is of format [num_dims, dim1, dim2, ...]. + If the tensor is None, a tensor of only placeholder values will be returned. + + Inputs: + tensors: A list of tensors, the tensors will converted into its shape dimensions and + these dimensions will be concatenated. + device: The device where the metadata tensor will be created. + If the tensor is None, then this tensor will contain PLACEHOLDER_VALs. + + """ + metadata_tensor = torch.full( + (METADATA_TENSOR_LEN,), + PLACEHOLDER_VAL, + dtype=torch.int32, + device=device, + ) + if tensors: + # Create a list of tensors containing the number of dimensions and the shape of each tensor + data = [ + # data is of format [num_dims, dim1, dim2, ...] + torch.tensor( + [len(tensor.shape)] + list(tensor.shape), + dtype=torch.int32, + device=device, + ) + for tensor in tensors + ] + # Concatenate the data into a single tensor + data_tensor = torch.cat(data) + dt_shape = data_tensor.shape[0] + if dt_shape > METADATA_TENSOR_LEN: + raise ValueError( + f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})." + ) + metadata_tensor[:dt_shape] = data_tensor + return metadata_tensor + + +def extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]: + """ + Extract the number of dimensions and the shape of each tensor from a metadata tensor. + """ + metadata: List[torch.Size] = [] + i = 0 + while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL: + num_dims = int(tensor[i].item()) + shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist()) + metadata.append(shape) + i += num_dims + 1 + return metadata + + +def get_stage_shapes( + stage_modules: List[nn.Module], + stage_ids: List[int], + num_stages: int, + rank: int, + world_size: int, + device: torch.device, + microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, +): + """ + Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of + virtual pipelining) and returns the shape of the inputs and outputs of the module. + Only the first stage must pass in a microbatch. + + Each rank must call get_stage_shapes or the program will hang. + + Args: + stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any + non-interleaved schedules and >1 for any interleaved schedules. + stage_ids: The id of the stages assigned to this rank. + num_stages: Total number of stages. + rank: Rank of the current process. + world_size: Number of processes participating in the pipeline. + device: Device where the tensors are allocated. + + Returns a dictionary containing the following keys: + "inputs": Shape of the inputs to the module + "outputs": Shape of the outputs of the module + """ + + stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {} + for stage_id, model in zip(stage_ids, stage_modules): + input_shape_metadata_tensor = create_metadata_tensor(device=device) + # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1 + prev_rank = (rank - 1) % world_size + next_rank = (rank + 1) % world_size + shapes = {} + + # first stage doesn't receive anything and uses a microbatch + if stage_id == 0: + if microbatch is None: + raise RuntimeError("Microbatch is required for first stage") + example_fwd_inputs = microbatch + if isinstance(example_fwd_inputs, torch.Tensor): + example_fwd_inputs = [example_fwd_inputs] + else: + # other stages must receive shape information + # TODO: send/recv should take a group, rather than use the default group + dist.recv(input_shape_metadata_tensor, prev_rank) + metadata = extract_metadata_from_tensor(input_shape_metadata_tensor) + example_fwd_inputs = [ + torch.empty(shape_list, device=device) for shape_list in metadata + ] + shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs] + + # perform forward + # TODO: if forward fails raise a more descriptive error explaining which stage failed + fwd_outputs = model(*example_fwd_inputs) + fwd_outputs = create_empty_tensors(fwd_outputs, device) + shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs] + + # send shape dims + if stage_id != num_stages - 1: + output_shape_metadata_tensor = create_metadata_tensor( + fwd_outputs, device=device + ) + dist.send(output_shape_metadata_tensor, next_rank) + stage_id_to_shapes[stage_id] = shapes + logger.info(stage_id_to_shapes) + return stage_id_to_shapes + + +class ManualPipelineStage(PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + This class is created manually by providing a example input (and optionally output) + as opposed to the PipelineStage class that is outputed from pipeline(). + This class extends the `PipelineStageBase` class and can similarly be used + in `PipelineScheule`. + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (torch.device): The device where this stage is located. + num_microbatches (int): The number of microbatches to use. + input_args (Union[torch.Tensor, List[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, List[torch.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + num_microbatches: int, + input_args: Union[torch.Tensor, List[torch.Tensor]], + output_args: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + group: Optional[dist.ProcessGroup] = None, + ): + super().__init__( + submodule, stage_index, num_stages, device, num_microbatches, group + ) + self.submod.to(self.device) + # When we materialize the model partition on cuda, we call reset_parameters() if it is available + # logger.info(f"input args {input_args=}") + self.inputs: List[torch.Tensor] = [] + self.outputs: List[torch.Tensor] = [] + + self.inputs = create_empty_tensors(input_args, device) + + if output_args is None: + logger.info("output_args not provided, performing forward using input_args") + self.outputs = self.submod(*self.inputs) + # create buffers for the output so that the data is in the correct + # shape in order to use in p2p op (send) + self.outputs = create_empty_tensors(self.outputs, device) + else: + self.outputs = create_empty_tensors(output_args, device) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: List[torch.Tensor] = [] + + def stage_global_rank(peer_rank): + return ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + + self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) + self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) + + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(self.chunks): + self.set_requires_grad[chunk_id] = False + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + [ + RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs + ] + ) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + [RootArgPlaceholder() for _ in self.inputs] + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: Dict[int, List] = {} + for idx in range(len(self.outputs)): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + logger.debug( + f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + f"inputs: {[inp.shape for inp in self.inputs]}, " + f"output: {[output.shape for output in self.outputs]}" + ) + + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[RecvInfo, ...]: + grad_recv_info: Tuple[RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + [ + RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta(self.outputs[idx], self.device), + ) + for idx, dst_list in act_send_info.items() + ] + ) + return grad_recv_info + + def init_p2p_neighbors(self): + """ + Set up p2p communitors between previous and next stages + by sending a dummy tensor. + + If this is used, must be called for all pipeline stages. + """ + ops = [] + recv_tensor = torch.zeros(1, device="cuda") + send_tensor = torch.ones(1, device="cuda") + # forward + if not self.is_first: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group)) + + # backward + if not self.is_first: + ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group)) + + return True + + +def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]): + """ + Check that the buffer shapes match between stages was expected by performing an all_gather between + all stages. + """ + if len(pipeline_stages) == 0: + raise ValueError("No pipeline stages provided.") + + virtual_pipeline_size = len(pipeline_stages) + all_inputs = [] + all_outputs = [] + world_size = pipeline_stages[0].group_size + num_stages = pipeline_stages[0].num_stages + + # perform all gathers between all stages + for virtual_id, stage in enumerate(pipeline_stages): + world_size = stage.group_size + stage_id: int = stage.stage_index + rank = stage.group_rank + # check that world_size and num_stages are consistent across all stages + if stage.group_size != world_size: + raise ValueError( + f"Stage id {stage_id} has world size ({stage.group_size}) \ + which does not match world size ({world_size}) of other stages." + ) + if stage.num_stages != num_stages: + raise ValueError( + f"Stage id {stage_id} has num stages ({stage.num_stages}) \ + which does not match num stages ({num_stages}) of other stages." + ) + + pg_rank = dist.get_rank(stage.group) + if rank != pg_rank: + raise ValueError( + f"Rank {rank} is not equal to process group rank {pg_rank}" + ) + + if (num_stages := stage.num_stages) % world_size != 0: + raise ValueError( + f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})" + ) + + # all gather each ranks inputs + tensor_list = [ + create_metadata_tensor(device=stage.device) for _ in range(stage.group_size) + ] + expected_inputs = stage.inputs + stage_input = create_metadata_tensor(expected_inputs, device=stage.device) + dist.all_gather(tensor_list, stage_input) + stage_input_shapes = [ + extract_metadata_from_tensor(tensor) for tensor in tensor_list + ] + + # all gather each ranks outputs + tensor_list = [ + create_metadata_tensor(device=stage.device) for _ in range(stage.group_size) + ] + expected_outputs = stage.outputs + stage_output = create_metadata_tensor(expected_outputs, device=stage.device) + dist.all_gather(tensor_list, stage_output) + stage_output_shapes = [ + extract_metadata_from_tensor(tensor) for tensor in tensor_list + ] + + logger.debug( + f"Rank: {pg_rank}" # noqa: G004 + f"Stage id: {stage_id}" + f"Stage num stages: {stage.num_stages}" + f"Stage rank: {rank}" + f"Stage world size: {world_size}" + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003 + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003 + ) + + all_inputs.extend(stage_input_shapes) + all_outputs.extend(stage_output_shapes) + + # log only rank 0's view, they will all be equivalent + if pg_rank == 0: + logger.info( + f"all stage inputs: {all_inputs}" # noqa: G004 + f"all stage outputs: {all_outputs}" + ) + + # Check if the output for stage 0 matches the input at stage 1, and so forth + for i in range(virtual_pipeline_size * world_size - 1): + if (out := all_outputs[i]) != (inp := all_inputs[i + 1]): + raise ValueError( + f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}." + )