Pipeline parallelism is a parallelization strategy that distributes the execution of a neural network's forward and backward passes across multiple devices. It does this by dividing the network into smaller subnetworks or partitions and assigning each partition to a different device. This way, the workload is evenly distributed and allows for improved training efficiency.

In pipeline parallelism, the forward and backward passes are decomposed into smaller tasks based on the micro-batches and partitions of the network. Let's break down how the forward and backward passes are decomposed.

Forward Pass Decomposition:

Divide the input batch into micro-batches, e.g., $x_1, \cdots, x_m$.
Sequentially execute the partitions $f^j$ on each micro-batch $x_i$. This results in tasks $F_{i, j}$, where $x_i^0 = x_i$ and $x_i^j = f^j(x_i^{j-1})$ for $i = 1, \cdots, m$ and $j = 1, \cdots, n$.
Compute the output $f(x)$ by aggregating the results from each device, $x_i^n = f(x_i)$.
Backward Pass Decomposition:

Compute the gradient of the loss with respect to each output, $dx_i^n$.
Sequentially execute the backward pass through the partitions $f^j$ on each gradient $dx_i^j$. This results in tasks $B_{i, j}$, where $dx_i^{j-1} = \partial_x f^j(dx_i^j)$ and $g_i^j = \partial_{\theta^j} f^j(dx_i^j)$ for $i = 1, \cdots, m$ and $j = 1, \cdots, n$.
Compute the gradient of the loss with respect to the network parameters, $g^j = \sum_{i=1}^m g_i^j$.
Pipeline parallelism takes advantage of the sequential nature of the tasks in the forward and backward passes. By assigning tasks with different micro-batch indices to different devices, the network can be trained efficiently using data parallelism. Note that there are data dependencies between tasks, so they must be executed in a specific order to ensure that the required data is available when needed.

In summary, pipeline parallelism decomposes the forward and backward passes into smaller tasks based on micro-batches and partitions, assigning each task to a different device for parallel execution. This enables efficient training of large neural networks across multiple devices.

In [28]:
# for clock_idx in range(n_microbatches+n-1):
#     from_partrition = max(1+clock_idx-n_microbatches, 0)
#     to_partrition = min(1+clock_idx, n_partritions)
    
#     # print(f"from_partrition={from_partrition}")
#     # print(f"to_partrition={to_partrition}")
    
# #     for j in range(max_val, min_val):
# #         microbatch_idx = clock_idx-j
# #         partrition_idx = j+1
        
# #         print((microbatch_idx, partrition_idx))
    
#     results = []
#     for j in range(from_partrition, to_partrition):
#         result = (clock_idx-j, j)
#         results.append(result)
#     print(results)

#     print([(clock_idx-j, j) for j in range(from_partrition, to_partrition)] )
    
    
#     print("---------")

##### Example 1

In [1]:
def clock_cycle(m, n):
    for k in range(m+n-1):
        yield [(k-j, j) for j in range(max(1+k-m, 0), min(1+k), n)]

In [2]:
m = 4
n = 3

In [47]:
n_microbatches = m = 4

In [48]:
n_partritions = n = 3

In [53]:
for clock_idx in range(n_microbatches+n_partritions-1):
    from_partrition = max(clock_idx+1-n_microbatches, 0)
    to_partrition = min(clock_idx+1, n_partritions)
    
    print(f"from_partrition={from_partrition}") # ignore
    print(f"to_partrition={to_partrition}") # ignore

    results = []
    for partrition_idx in range(from_partrition, to_partrition):
        print(f"partrition_idx = {partrition_idx}") # ignore
        microbatch_idx = clock_idx-partrition_idx
        result = (microbatch_idx+1, partrition_idx+1)
        results.append(result)
    
    print(results)
    
    print("---------") # ignore

from_partrition=0
to_partrition=1
partrition_idx = 0
[(1, 1)]
---------
from_partrition=0
to_partrition=2
partrition_idx = 0
partrition_idx = 1
[(2, 1), (1, 2)]
---------
from_partrition=0
to_partrition=3
partrition_idx = 0
partrition_idx = 1
partrition_idx = 2
[(3, 1), (2, 2), (1, 3)]
---------
from_partrition=0
to_partrition=3
partrition_idx = 0
partrition_idx = 1
partrition_idx = 2
[(4, 1), (3, 2), (2, 3)]
---------
from_partrition=1
to_partrition=3
partrition_idx = 1
partrition_idx = 2
[(4, 2), (3, 3)]
---------
from_partrition=2
to_partrition=3
partrition_idx = 2
[(4, 3)]
---------


- `from_partrition`: represent the earliest possible partition index that can be activated in clock cycle `clock_idx`, while ensuring we don't start too early
    - `1+clock_idx` means we can start as early as partition 1
    - `-n_microbatches` means we have to offset by the number of microbatches that have already entered the pipeline
    - `min(x, 0)`: We take the max with 0 because we can never start earlier than partition 0.

- `to_partrition`: represents the latest partition we can start a microbatch on for clock cycle `clock_idx`.
    - `clock_idx+1`: means we can go as late as k partitions into the pipeline. The reason the next partrition we can go is `1+clock_idx` because if we go to `n` clock cycles in the pipeline, we already processed `n` micro-batches, so the next one will be `clock_idx+1`
    - `min(x, n_partritions)`: take the min with n (total # of partitions) because we can never go past the last partition.


The range from `from_partrition` to `to_partrition` then iterates over all the possible partitions we could start a microbatch on for clock cycle `clock_idx`



- `(clock_idx-j, j)`: 

In [21]:
for k in range(m+n-1):
    print([(k-j, j) for j in range(max(1+k-m, 0), min(1+k, n))])

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1)]
[(3, 0), (2, 1)]
[(3, 1)]


In [13]:
for i in range(2, 4):
    print(i)

2
3


##### Draft 2

In [None]:
import torch