In [7]:
def print_1f1b_timeline(num_stages=4, num_microbatches=8):
    timeline = []
    max_time = 2 * num_microbatches + num_stages - 2
    
    # Initialize empty timeline
    for _ in range(max_time):
        timeline.append(['-'] * num_stages)
    
    # Fill in all forward passes first
    for mb in range(1, num_microbatches + 1):
        start_time = mb - 1
        for stage in range(num_stages):
            if start_time + stage < max_time:
                timeline[start_time + stage][stage] = f'F({mb})'
    
    # Fill in backward passes
    warmup_steps = num_stages - 1  # Time needed for first microbatch to complete forward
    for mb in range(1, num_microbatches + 1):
        start_time = warmup_steps + mb
        for stage in range(num_stages - 1, -1, -1):
            if start_time + (num_stages - 1 - stage) < max_time:
                timeline[start_time + (num_stages - 1 - stage)][stage] = f'B({mb})'
    
    # Print timeline
    print(f"\nPipeline Timeline ({num_stages} stages, {num_microbatches} microbatches):")
    print("Time", end="\t")
    for i in range(num_stages):
        print(f"Stage {i}", end="\t")
    print()
    print("-" * 50)
    
    for t in range(max_time):
        print(f"{t:2d}", end="\t")
        for s in range(num_stages):
            print(f"{timeline[t][s]:6}", end="\t")
        print()

print_1f1b_timeline()


Pipeline Timeline (4 stages, 8 microbatches):
Time	Stage 0	Stage 1	Stage 2	Stage 3	
--------------------------------------------------
 0	F(1)  	-     	-     	-     	
 1	F(2)  	F(1)  	-     	-     	
 2	F(3)  	F(2)  	F(1)  	-     	
 3	F(4)  	F(3)  	F(2)  	F(1)  	
 4	F(5)  	F(4)  	F(3)  	B(1)  	
 5	F(6)  	F(5)  	B(1)  	B(2)  	
 6	F(7)  	B(1)  	B(2)  	B(3)  	
 7	B(1)  	B(2)  	B(3)  	B(4)  	
 8	B(2)  	B(3)  	B(4)  	B(5)  	
 9	B(3)  	B(4)  	B(5)  	B(6)  	
10	B(4)  	B(5)  	B(6)  	B(7)  	
11	B(5)  	B(6)  	B(7)  	B(8)  	
12	B(6)  	B(7)  	B(8)  	-     	
13	B(7)  	B(8)  	-     	-     	
14	B(8)  	-     	-     	-     	
15	-     	-     	-     	-     	
16	-     	-     	-     	-     	
17	-     	-     	-     	-     	


In [9]:
def generate_1f1b_pipeline_schedule(num_stages, num_microbatches):
    """
    Generate and print a 1F1B pipeline schedule with both forward and backward passes.
    
    Args:
        num_stages (int): Number of pipeline stages.
        num_microbatches (int): Number of microbatches.
    """
    S = num_stages
    M = num_microbatches
    
    # Total time steps:
    # Forward fill: (S - 1)
    # Forward steady: M
    # Backward steady: M
    # Backward drain: (S - 1)
    # Total = 2 * (M + S - 1)
    total_steps = 2 * (M + S - 1)
    
    # Create a schedule matrix: schedule[time_step][stage]
    # Each cell will be a string: "F(m)" for forward microbatch m, "B(m)" for backward microbatch m, or "-"
    schedule = []
    for t in range(total_steps):
        row = []
        for i in range(S):
            # Determine if this time-step and stage is doing forward or backward or idle.
            
            # Check forward pass:
            # forward microbatch m_f at stage i at time t => t = (m_f - 1) + i
            # => m_f = t - i + 1
            m_f = t - i + 1
            # Forward is valid if 1 <= m_f <= M and also occurs before backward starts dominating the timeline.
            # Actually forward can continue until it has fed all microbatches. The last forward microbatch starts at stage 0 at time M-1.
            # The largest time for forward at stage i is (M - 1) + i.
            forward_valid = (1 <= m_f <= M) and (t <= (M - 1) + i)
            
            # Check backward pass:
            # backward microbatch m_b at stage i:
            # t = (S - 1 + M) + (m_b - 1) + ((S - 1) - i)
            # => m_b = t - ((S - 1 + M) + ((S - 1) - i)) + 1
            # Let's solve this cleanly:
            # offset = (S - 1 + M) + (S - 1 - i) = 2*S - 2 - i + M
            offset = (S - 1 + M) + ((S - 1) - i)
            m_b = t - offset + 1
            backward_valid = (1 <= m_b <= M) and (t >= (S - 1 + M))
            
            # In a proper 1F1B schedule, a stage at a given time either does forward or backward, not both.
            # By construction, there should be no overlap in a correct 1F1B schedule.
            # The known pattern ensures that forward tasks for a given stage will be done before backward tasks for that same microbatch reach it.
            
            if forward_valid and not backward_valid:
                row.append(f"F({m_f})")
            elif backward_valid and not forward_valid:
                row.append(f"B({m_b})")
            else:
                row.append("-")
        
        schedule.append(row)
    
    # Print the schedule in a nice tabular format
    header = ["Time"] + [f"Stage {i}" for i in range(S)]
    col_widths = [max(len(str(x)) for x in col) for col in zip(*([header] + schedule))]
    
    header_str = "  ".join(h.ljust(w) for h, w in zip(header, col_widths))
    print(header_str)
    print("-" * len(header_str))
    
    for t, row in enumerate(schedule):
        row_str = "  ".join(str(cell).ljust(w) for cell, w in zip([t] + row, col_widths))
        print(row_str)

# Example usage:
if __name__ == "__main__":
    # Example: 4 stages, 8 microbatches
    # This will print a 1F1B schedule showing both forward (F(m)) and backward (B(m)) passes.
    generate_1f1b_pipeline_schedule(num_stages=4, num_microbatches=8)


Time  Stage 0  Stage 1  Stage 2
-------------------------------
0     F(1)     -        -      
1     F(2)     F(1)     -      
2     F(3)     F(2)     F(1)   
3     F(4)     F(3)     F(2)   
4     F(5)     F(4)     F(3)   
5     F(6)     F(5)     F(4)   
6     F(7)     F(6)     F(5)   
7     F(8)     F(7)     F(6)   
8     -        F(8)     F(7)   
9     -        -        F(8)   
10    -        -        -      
11    -        -        -      
12    -        -        B(1)   
13    -        B(1)     B(2)   
14    B(1)     B(2)     B(3)   
15    B(2)     B(3)     B(4)   
16    B(3)     B(4)     B(5)   
17    B(4)     B(5)     B(6)   
18    B(5)     B(6)     B(7)   
19    B(6)     B(7)     B(8)   
20    B(7)     B(8)     -      
21    B(8)     -        -      
