In [None]:
def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
    num_warmup_microbatches = min(process_group_manager.pp_world_size - process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches)
    num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
    logging_loss, input_tensors, output_tensors  = 0.0, [], []
    
    def _forward_step(input_tensor):
        batch = next(iter(data_loader))
        batch["hidden_states"] = input_tensor
        output_tensor = model.forward(batch, device)
        if process_group_manager.is_pipeline_last_stage:
            output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
            nonlocal logging_loss
            logging_loss += output_tensor.item()
        return output_tensor

    for _ in range(num_warmup_microbatches): # Warmup forward passes
        input_tensor = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_forward')
        output_tensor = _forward_step(input_tensor)
        communicate(tensor=output_tensor, operation='send_forward')
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    if num_microbatches_remaining > 0:
        input_tensor = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_forward')
    
    for i in range(num_microbatches_remaining):  # 1F1B steady state
        output_tensor = _forward_step(input_tensor)
        output_tensor_grad = bidirectional_communicate('send_fwd_recv_bwd', output_tensor, tensor_shapes, torch.float32, device)
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)
        input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
        input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
        if i == num_microbatches_remaining - 1: # last iteration
            input_tensor = None
            communicate(tensor=input_tensor_grad, operation='send_backward')
        else:
            input_tensor = bidirectional_communicate('send_bwd_recv_fwd', input_tensor_grad, tensor_shapes, torch.float32, device)

    for _ in range(num_warmup_microbatches): # Cooldown backward passes
        input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
        output_tensor_grad = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_backward')
        input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
        communicate(tensor=input_tensor_grad, operation='send_backward')
    return logging_loss

In [2]:
def print_pipeline_timeline(S, M):
    # Total steps for 1F1B schedule
    T = 2 * M + 2 * (S - 1)
    
    # Prepare a matrix to hold the schedule
    # schedule[time][stage] = string describing what's happening
    schedule = [["-" for _ in range(S)] for _ in range(T)]
    
    # Fill in forward (F) and backward (B) operations
    for t in range(T):
        for s in range(S):
            # Check for forward pass:
            # f = t - s + 1
            f = t - s + 1
            # Forward is valid if 1 <= f <= M and t <= M + S - 2
            if 1 <= f <= M and t <= M + S - 2:
                schedule[t][s] = f"F({f})"
            else:
                # Check for backward pass:
                # Backward microbatch at stage s:
                # Derived formula: t = 2S + b - s - 2  => b = t + s + 2 - 2S
                b = t + s + 2 - 2 * S
                # Backward starts after pipeline is initially filled, so t >= S
                # Also need b in range [1..M]
                if 1 <= b <= M and t >= S:
                    schedule[t][s] = f"B({b})"
                # else it remains "-"
    
    # Print the timeline
    print(f"Pipeline Timeline ({S} stages, {M} microbatches):")
    # Header
    header = "Time\t" + "\t".join([f"Stage {s}" for s in range(S)])
    print(header)
    print("-" * (len(header) + 10))
    
    for t in range(T):
        row_str = f"{t}\t" + "\t".join(schedule[t])
        print(row_str)

# Example usage:
# S = number of stages, M = number of microbatches
# The example in the prompt mentioned 4 stages and 8 microbatches.
print_pipeline_timeline(S=4, M=8)


Pipeline Timeline (4 stages, 8 microbatches):
Time	Stage 0	Stage 1	Stage 2	Stage 3
----------------------------------------------
0	F(1)	-	-	-
1	F(2)	F(1)	-	-
2	F(3)	F(2)	F(1)	-
3	F(4)	F(3)	F(2)	F(1)
4	F(5)	F(4)	F(3)	F(2)
5	F(6)	F(5)	F(4)	F(3)
6	F(7)	F(6)	F(5)	F(4)
7	F(8)	F(7)	F(6)	F(5)
8	B(2)	F(8)	F(7)	F(6)
9	B(3)	B(4)	F(8)	F(7)
10	B(4)	B(5)	B(6)	F(8)
11	B(5)	B(6)	B(7)	B(8)
12	B(6)	B(7)	B(8)	-
13	B(7)	B(8)	-	-
14	B(8)	-	-	-
15	-	-	-	-
16	-	-	-	-
17	-	-	-	-
18	-	-	-	-
19	-	-	-	-
20	-	-	-	-
21	-	-	-	-


In [None]:
state = PipelineTrainBatchState()

        outputs = []
        batch = iter(batch)

        current_pp_rank = dist.get_rank(pg)

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # Init
            for _ in range(pg.size() - current_pp_rank - 1):
                micro_batch = next(batch)
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)

                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Send tensors
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

                # One backward
                context = self._get_bwd_context(
                    model=model,
                    nb_backwards=state.nb_backwards,
                    grad_accumulator=grad_accumulator,
                )
                self.backward(context=context, state=state, grad_accumulator=grad_accumulator)

            # Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left
            assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1
            # No more activation to send/recv
            assert (
                len(state.microbatches_activations_to_send) == 0
            ), f"There are activations left for me to send still: {len(state.microbatches_activations_to_send)}"
            assert (
                len(state.microbatches_activations_to_recv) == 0
            ), f"There are activations left for me to recv still: {len(state.microbatches_activations_to_recv)}"

            # Close: compute backward for the rest
            # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
            for _ in range(len(state.microbatches_grads_to_send)):
                send_grads = state.microbatches_grads_to_send.popleft()
                # Execute
                send_grads()
            for _ in range(len(state.microbatches_activations_requiring_backward)):
                context = self._get_bwd_context(
                    model=model,
                    nb_backwards=state.nb_backwards,
                    grad_accumulator=grad_accumulator,
                )
                self.backward(context=context, state=state, grad_accumulator=grad_accumulator)

                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_grads_to_send)):
                    send_grads = state.microbatches_grads_to_send.popleft()
                    # Execute
                    send_grads()

            # Make sure that micro batches are all fully consumed
            state.check_buffers_empty()