In [None]:
%load_ext autoreload
%autoreload 2
%env PYTHONBREAKPOINT=ipdb.set_trace
import torch
import torch._inductor.metrics as metrics

torch.set_default_device("cuda")

# import torch
from controller import (
    DeviceMesh,
    active_mesh,
    active_stream,
    Stream,
    fetch_shard,
    Future,
    Stream,
    get_active_stream,
)
from torch.fx.experimental.proxy_tensor import make_fx
from controller._testing import simulator_mesh, example_mesh
from controller.simulator import set_meta

In [None]:
from monarch import explicit_autograd
from torch.utils._pytree import tree_map

def f(x):
    return x.cos().cos()

x = torch.zeros(5, requires_grad=True)
out, bw_callback = explicit_autograd(f)(x)
print(out)
print(bw_callback.values)
# the saved activations can be freely manipulated
bw_callback.values = tree_map(lambda x: x.to('cpu'), bw_callback.values)
bw_callback.values = tree_map(lambda x: x.to('cuda'), bw_callback.values)

gradOut = torch.ones(5)
gradX, = bw_callback(gradOut)
f(x).sum().backward()
assert torch.allclose(x.grad, gradX)

In [None]:
NUM_GPUS=4
# device_mesh = example_mesh(hosts=1, gpus=NUM_GPUS)

In [None]:
def call_layer(w, x):
    return torch.mm(x, w)

def initialize(num_gpus, stages_per_gpu, hidden_dim=4, local_batch_size=4, use_real=False):
    if use_real:
        device_mesh = example_mesh(hosts=1, gpus=num_gpus)
    else:
        device_mesh = simulator_mesh(hosts=1, gpus=num_gpus)
    meshes = [device_mesh(gpu=i) for i in range(num_gpus)]
    layers = [[] for _ in range(num_gpus)]
    for idx, mesh in enumerate(meshes):
        with active_mesh(mesh):
            for _ in range(stages_per_gpu):
                layers[idx].append(torch.randn(hidden_dim, hidden_dim, requires_grad=True))

    def dataloader():
        with active_mesh(meshes[0]):
            # a great dataloader
            return torch.randn(local_batch_size, hidden_dim, requires_grad=True)

    return device_mesh, meshes, layers, dataloader
    
def fill_drain(num_microbatches, meshes, layers, dataloader):
    pp_stages = len(meshes)
    bw_callbacks = {}
    bw_microbatches = []
    for mb_num in range(num_microbatches):
        mb = dataloader() # loads the microbatch
        with set_meta(str(mb_num)):
            for stage in range(0, pp_stages):
                with active_mesh(meshes[stage]):
                    mb = mb.to_mesh(meshes[stage]) # moves the microbatch to the current stage
                    for layer_idx, layer in enumerate(layers[stage]):
                        mb, bw_callback = explicit_autograd(call_layer)(layer, mb) # calls the current stage's layer
                        bw_callbacks[(mb_num, stage, layer_idx)] = bw_callback # saves bw callback for later
        bw_microbatches.append(mb)
        
    for mb_num, mb in enumerate(bw_microbatches):
        with set_meta(str(mb_num)):
            # Iterate from the back to the front
            for stage in range(pp_stages - 1, -1, -1):
                # Get the corresponding bw_callback
                with active_mesh(meshes[stage]):
                    for layer_idx in range(len(layers[stage]) -1, -1, -1):
                        gradWeight, mb = bw_callbacks[(mb_num, stage, layer_idx)](mb)
                        del bw_callbacks[(mb_num, stage, layer_idx)]
                    if stage != 0:
                        mb = mb.to_mesh(meshes[stage - 1])

def fill_drain_interleaved(num_microbatches, meshes, layers, dataloader):
    pp_stages = len(meshes)
    bw_callbacks = {}
    bw_microbatches = []
    stages_per_layer = len(layers[0])
    microbatches = []
    for loop in range(stages_per_layer):
        for mb_num in range(num_microbatches):
            for stage in range(0, pp_stages):
                with set_meta(str(mb_num)):
                    if loop == 0:
                        microbatches.append(dataloader())
                    mb = microbatches[mb_num]
                    with active_mesh(meshes[stage]):
                        mb = mb.to_mesh(meshes[stage])
                        layer = layers[stage][loop]
                        mb, bw_callback = explicit_autograd(call_layer)(layer, mb)
                        bw_callbacks[(mb_num, stage, loop)] = bw_callback # saves bw callback for later
                        # if stage == pp_stages - 1:
                        #     mb = mb.to_mesh(meshes[0])
                        # elif stage == pp_stages -1 and loop == stages_per_layer - 1:
                        #     pass
                        # else:
                        #     mb = mb.to_mesh(meshes[stage + 1])
                        microbatches[mb_num] = mb
    return                   
    for loop in range(stages_per_layer - 1, -1, -1):
        for mb_num, mb in enumerate(bw_microbatches):
            with set_meta(str(mb_num)):
                # Iterate from the back to the front
                for stage in range(pp_stages - 1, -1, -1):
                    # Get the corresponding bw_callback
                    with active_mesh(meshes[stage]):
                        mb = mb.to_mesh(meshes[stage])
                        gradWeight, mb = bw_callbacks[(mb_num, stage, loop)](mb)
                        del bw_callbacks[(mb_num, stage, loop)]
                        

In [None]:
from controller.simulator import simulate_commands, chrome_events, visualize_events, analyze_events
device_mesh, meshes, layers, dataloader = initialize(num_gpus=4, stages_per_gpu=2, local_batch_size=32)
fill_drain(num_microbatches=4, meshes=meshes, layers=layers, dataloader=dataloader)
events = simulate_commands(device_mesh.client.backend.worker_commands)
visualize_events(events)
analyze_events(events)

In [None]:
from ipywidgets import *

# Create a dropdown widget with the possible values
make_range = lambda name, values: widgets.SelectionSlider(
    options=values,
    description=f'{name}:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
def update(num_gpus, batch_size_per_gpu, use_real=False):
    num_layers = 8
    stages_per_gpu = 8 // num_gpus
    num_microbatches = num_gpus * batch_size_per_gpu
    print(f"Stages per GPU: {stages_per_gpu}")
    print(f"batch size for pipeline: {batch_size_per_gpu * num_gpus}")
    device_mesh, meshes, layers, dataloader = initialize(num_gpus=num_gpus, stages_per_gpu=stages_per_gpu, local_batch_size=1, use_real=use_real)
    fill_drain(num_microbatches=num_microbatches, meshes=meshes, layers=layers, dataloader=dataloader)
    if not use_real:
        events = simulate_commands(device_mesh.client.backend.worker_commands)
        visualize_events(events)
        analyze_events(events)

interact(update, num_gpus = make_range('num gpus', [2,4,8]), batch_size_per_gpu = IntSlider(value=1, min=1, max=10, step=1))

In [None]:
update(num_gpus=8, batch_size_per_gpu=4, use_real=True)