# Monarch DDP Example with SkyPilot

This notebook demonstrates running PyTorch DDP (DistributedDataParallel) training on cloud infrastructure provisioned by SkyPilot.

Adapted from the SLURM DDP example (`slurm_ddp.ipynb`).

## Prerequisites

```bash
pip install torchmonarch-nightly
pip install skypilot[kubernetes]  # or skypilot[aws], skypilot[gcp], etc.
sky check  # Verify SkyPilot configuration
```


## Imports and Setup


In [None]:
import os

# Set timeouts before importing monarch
os.environ["HYPERACTOR_HOST_SPAWN_READY_TIMEOUT"] = "300s"
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT"] = "300s"
os.environ["HYPERACTOR_MESH_PROC_SPAWN_MAX_IDLE"] = "300s"

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

import sky
from monarch.actor import Actor, current_rank, endpoint
from monarch.utils import setup_env_for_distributed
from torch.nn.parallel import DistributedDataParallel as DDP

# Import SkyPilotJob from local package
from monarch_skypilot import SkyPilotJob


## Define the Model and DDP Actor


In [None]:
class ToyModel(nn.Module):
    """A simple toy model for demonstration purposes."""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


class DDPActor(Actor):
    """This Actor wraps the basic functionality from Torch's DDP example.

    Adapted from: https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#basic-use-case
    """

    def __init__(self):
        self.rank = current_rank().rank

    @endpoint
    async def setup(self) -> str:
        """Initialize the PyTorch distributed process group."""
        WORLD_SIZE = int(os.environ["WORLD_SIZE"])
        dist.init_process_group("gloo", rank=self.rank, world_size=WORLD_SIZE)
        return f"Rank {self.rank}: Initialized distributed (world_size={WORLD_SIZE})"

    @endpoint
    async def cleanup(self) -> str:
        """Clean up the PyTorch distributed process group."""
        dist.destroy_process_group()
        return f"Rank {self.rank}: Cleaned up distributed"

    @endpoint
    async def demo_basic(self) -> str:
        """Run a basic DDP training example."""
        local_rank = int(os.environ["LOCAL_RANK"])
        model = ToyModel().to(local_rank)
        ddp_model = DDP(model, device_ids=[local_rank])

        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        outputs = ddp_model(torch.randn(20, 10))
        labels = torch.randn(20, 5).to(local_rank)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        return f"Rank {self.rank}: Training step complete (loss={loss.item():.4f})"


## Configuration

Configure your cloud provider, cluster size, and GPU type below:


In [None]:
# Configuration - modify these values as needed
CLOUD = "kubernetes"  # Options: kubernetes, aws, gcp, azure
NUM_HOSTS = 2
GPUS_PER_HOST = 1
CLUSTER_NAME = "monarch-ddp"
ACCELERATOR = "H200:1"  # e.g., H100:1, A100:1, V100:1

def get_cloud(cloud_name: str):
    """Get SkyPilot cloud object from name."""
    clouds = {
        "kubernetes": sky.Kubernetes,
        "aws": sky.AWS,
        "gcp": sky.GCP,
        "azure": sky.Azure,
    }
    if cloud_name.lower() not in clouds:
        raise ValueError(f"Unknown cloud: {cloud_name}. Available: {list(clouds.keys())}")
    return clouds[cloud_name.lower()]()

print(f"Configuration:")
print(f"  Cloud: {CLOUD}")
print(f"  Hosts: {NUM_HOSTS}")
print(f"  GPUs per host: {GPUS_PER_HOST}")
print(f"  Accelerator: {ACCELERATOR}")
print(f"  Cluster name: {CLUSTER_NAME}")


## Create SkyPilot Job

Create a SkyPilot job to provision cloud instances:


In [None]:
job = SkyPilotJob(
    meshes={"mesh0": NUM_HOSTS},
    resources=sky.Resources(
        cloud=get_cloud(CLOUD),
        accelerators=ACCELERATOR,
    ),
    cluster_name=CLUSTER_NAME,
    idle_minutes_to_autostop=10,
    down_on_autostop=True,
)

print(f"SkyPilot job created for cluster '{CLUSTER_NAME}'")


## Launch Cluster and Create Process Mesh


In [None]:
# Launch the cluster and get the job state
print("Launching SkyPilot cluster...")
job_state = job.state()

# Create process mesh with GPUs
print("Creating process mesh...")
proc_mesh = job_state.mesh0.spawn_procs({"gpus": GPUS_PER_HOST})
print(f"Process mesh extent: {proc_mesh.extent}")


## Spawn DDP Actors and Run Training


In [None]:
# Spawn DDP actors on the process mesh
print("Spawning DDP actors...")
ddp_actor = proc_mesh.spawn("ddp_actor", DDPActor)

# Set up the distributed environment
print("Setting up distributed environment...")
await setup_env_for_distributed(proc_mesh)


In [None]:
# Run the DDP example
print("Running DDP training...\n")

# Initialize distributed process group
print("[1] Initializing distributed process group...")
results = await ddp_actor.setup.call()
for coord, msg in results:
    print(f"    {msg}")

# Run the basic DDP training example
print("\n[2] Running DDP training step...")
results = await ddp_actor.demo_basic.call()
for coord, msg in results:
    print(f"    {msg}")

# Clean up distributed process group
print("\n[3] Cleaning up distributed process group...")
results = await ddp_actor.cleanup.call()
for coord, msg in results:
    print(f"    {msg}")

print("\n" + "=" * 60)
print("DDP example completed successfully!")
print("=" * 60)


## Cleanup

Tear down the SkyPilot cluster when done:


In [None]:
# Tear down the SkyPilot cluster
print("Cleaning up SkyPilot cluster...")
job.kill()
print("Done!")
