##### Draft 1

The passage handles backward dependency by introducing virtual dependencies and using Fork and Join primitive functions. These dependencies ensure that the autograd engine is aware of the correct execution order during the backward pass. The Fork function maps a tensor $x$ to a pair $(x, \varnothing)$, where $\varnothing$ is an empty tensor, while the Join function maps a pair $(x, \varnothing)$ back to the tensor $x$. By utilizing these functions, the dependency of $F_{i+1, j}$ upon $F_{i, j}$ can be expressed, which translates to the dependency of $B_{i, j}$ upon $B_{i+1, j}$ in the backward computation graph. This approach maintains the correct timeline of the backward pass and ensures proper parallelization during training.






Here's a simple concrete example:
Suppose we have a neural network with 2 partitions ($f^1$ and $f^2$) and 2 micro-batches ($x_1$ and $x_2$). The forward pass tasks are represented as $F_{1,1}, F_{1,2}, F_{2,1}$, and $F_{2,2}$, while the backward pass tasks are represented as $B_{1,1}, B_{1,2}, B_{2,1}$, and $B_{2,2}$.

During the forward pass, after completing $F_{1,1}$, we apply the Fork function, creating a checkpoint with the data and an empty tensor. When starting the next task, $F_{2,1}$, we use the Join function, which waits for both the data and the "sign-off" from $F_{1,1}$. Once it receives both, the Join function combines them and proceeds with $F_{2,1}$.

This process ensures that, during the backward pass, the autograd engine knows that $B_{2,1}$ must be executed before $B_{1,1}$, maintaining the correct order and timeline of the backward pass.

### Fork and Join

In [None]:
import torch

In [None]:
_phonies = {}

In [None]:
def get_phony(device: torch.device, requires_grad: bool) -> torch.Tensor:
    key = (device, requires_grad)
    
    try:
        phony = _phonies[key]
    except KeyError:
        # stream = torch.cuda.default_stream()
        # with torch.cuda.stream(stream):
        phony = torch.empty(0, device=device, requires_grad=requires_grad)
        
        _phonies[key] = phony
    
    return phony

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Fork(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        phony = get_phony(input.device, requires_grad=False)
        return input.detach(), phony.detach()
    
    @staticmethod
    def backward(ctx, grad_input, grad_grad):
        return grad_input

In [None]:
class Join(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, phony):
        return input.detach()
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input, None

In [None]:
def fork(input):
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)
        
    return input, phony

In [None]:
def join(input, phony):
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)
    return input

In [None]:
def depend(fork_from: "Batch", join_to: "Batch"):
    fork_from, phony = fork(fork_from)
    join_to = join(join_to, phony)

In [None]:
depend(x1, x2)

In [None]:
_phonies

{(device(type='cpu'), False): tensor([])}

In [None]:
x1

tensor(1)

$f^j$: The $j$th partition of the neural network.
$\theta^j$: The parameters associated with the $j$th partition.
$x_i$: The $i$th micro-batch of training data.
$F_{i,j}$: The forward pass task for the $i$th micro-batch on the $j$th partition.
$B_{i,j}$: The backward pass task for the $i$th micro-batch on the $j$th partition.
$x_i^j$: The intermediate output of the $i$th micro-batch after passing through the $j$th partition.
$d x_i^j$: The gradient of the loss with respect to $x_i^j$.
$g_i^j$: The gradient of the loss with respect to the $j$th partition's parameters ($\theta^j$).
$F_{i, j}^{\prime}$: The recomputation of the forward pass task $F_{i, j}$ to save memory during the backward pass.

Now, let's consider a simple example with 2 partitions ($f^1$ and $f^2$) and 2 micro-batches ($x_1$ and $x_2$). The forward pass tasks are represented as $F_{1,1}, F_{1,2}, F_{2,1}$, and $F_{2,2}$, while the backward pass tasks are represented as $B_{1,1}, B_{1,2}, B_{2,1}$, and $B_{2,2}$.

The dependency graph in this example would have two floors (partitions) and two rooms on each floor (micro-batches). The elevators (data dependencies) connect the rooms as follows:

$F_{1,1}$ must be completed before $F_{1,2}$ and $F_{2,1}$.
$F_{1,2}$ must be completed before $F_{2,2}$.
$F_{2,1}$ must be completed before $B_{2,1}$.
$F_{2,2}$ must be completed before $B_{2,2}$.
$B_{2,1}$ must be completed before $B_{1,1}$.
$B_{2,2}$ must be completed before $B_{1,2}$.

In [None]:
x1, x2 = torch.tensor(1), torch.tensor(2)

In [None]:
phony1, phony2 = torch.randn(1), torch.randn(2)

### Build Dependency

In [None]:
with spawn_workers(devices) as (in_queues, out_queues):
    for schedule in clock_cycles(m, n):
        self.fence(schedule, skip_trackers)
        self.compute(schedule, skip_trackers, in_queues, out_queues)

In [None]:
def clock_cycles(n_microbatches: int, n_patritions: int):
    """Generates schedules for each clock cycle."""
    # n_microbatches: number of micro-batches
    # n_patritions: number of partitions
    # i: index of micro-batch
    # j: index of partition
    # k: clock number
    #
    # k (i,j) (i,j) (i,j)
    # - ----- ----- -----
    # 0 (0,0)
    # 1 (1,0) (0,1)
    # 2 (2,0) (1,1) (0,2)
    # 3       (2,1) (1,2)
    # 4             (2,2)
    for k in range(m+n-1):
        yield [(k-j, j) for j in range(max(1+k-m, 0), min(1+k, n))]

Because at 

In [None]:
list(range(n_microbatches+n_patritions-1))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

In [None]:
n_microbatches = 10
n_patritions = 4

In [None]:
schedules = list(clock_cycles(n_microbatches, n_patritions))

In [None]:
for i, schedule in enumerate(schedules):
    print(f"Clock cycle {i}: {schedule}")

Clock cycle 0: [(0, 0)]
Clock cycle 1: [(1, 0), (0, 1)]
Clock cycle 2: [(2, 0), (1, 1), (0, 2)]
Clock cycle 3: [(3, 0), (2, 1), (1, 2), (0, 3)]
Clock cycle 4: [(4, 0), (3, 1), (2, 2), (1, 3)]
Clock cycle 5: [(5, 0), (4, 1), (3, 2), (2, 3)]
Clock cycle 6: [(6, 0), (5, 1), (4, 2), (3, 3)]
Clock cycle 7: [(7, 0), (6, 1), (5, 2), (4, 3)]
Clock cycle 8: [(8, 0), (7, 1), (6, 2), (5, 3)]
Clock cycle 9: [(9, 0), (8, 1), (7, 2), (6, 3)]
Clock cycle 10: [(9, 1), (8, 2), (7, 3)]
Clock cycle 11: [(9, 2), (8, 3)]
Clock cycle 12: [(9, 3)]


In [None]:
for k in range(n_microbatches+n_patritions-1):
    print(k)

0
1
2
3
4
5
6
7
8
9
10
11
12
