In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
import ray
import webdataset as wds
import dataclasses
import time
from collections import deque

def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth


  from .autonotebook import tqdm as notebook_tqdm
2023-12-12 00:06:33,290	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [None]:
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
trainset_url = bucket+"/imagenet-train-{000000..001281}.tar"
valset_url = bucket+"/imagenet-val-{000000..000049}.tar"

if 'google.colab' in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")

def make_dataloader_train():
    """Create a DataLoader for training on the ImageNet dataset using WebDataset.

    Returns:
        trainloader: a DataLoader for training on ImageNet.
    """
    # Standard Imagenet transformations.
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    # Define a function to apply the transformations to a sample.
    def make_sample(sample):
        return transform(sample["jpg"]), sample["cls"]

    # Create a WebDataset from the URL specified by trainset_url.
    # We're using the resampled version of the dataset, which works correctly with
    # distributed processing.
    trainset = wds.WebDataset(trainset_url, resampled=True, cache_dir=cache_dir)

    # Shuffle the samples inline, decode the images to PIL, and apply the transformations.
    trainset = trainset.shuffle(1000).decode("pil").map(make_sample)

    # For IterableDataset, we need to perform the batching in the dataset.
    trainset = trainset.batched(64)

    # For IterableDataset, the batch_size needs to be None in the DataLoader.
    # (WebLoader is just a wrapper around DataLoader with some additional methods.)
    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)

    # Unbatch the DataLoader, re-batch it into groups of 64 samples, and set the number of epochs
    # The number of epochs is calculated as 1282 * 100 // 64
    trainloader = trainloader.unbatched().batched(64).with_epoch(1282 * 100 // 64)

    # Return the DataLoader
    return trainloader

def make_dataloader(split="train"):
    """Make a dataloader for training or validation."""
    if split == "train":
        return make_dataloader_train()
    elif split == "val":
        return make_dataloader_val()
    else:
        raise ValueError(f"unknown split {split}")

# Try it out.
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)

In [None]:
@dataclasses.dataclass
class Args:
    epochs: int = 1
    maxsteps: int = int(1e18)
    lr: float = 0.001
    momentum: float = 0.9
    rank: int = 0
    world_size: int = 2
    backend: str = "nccl"
    master_addr: str = "localhost"
    master_port: str = "12355"
    report_s: float = 15.0
    report_growth: float = 1.1

In [None]:
def train(rank, args):
    # Set up distributed PyTorch.
    if rank is not None:
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.master_port
        print(f"rank {rank} initializing process group", file=sys.stderr)
        dist.init_process_group(backend=args.backend, rank=rank, world_size=args.world_size)
        print(f"rank {rank} done initializing process group", file=sys.stderr)

    # Define the model, loss function, and optimizer
    model = resnet50(pretrained=False).cuda()
    if rank is not None:
        model = DistributedDataParallel(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Data loading code
    trainloader = make_dataloader(split='train')

    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0

    # Training loop
    for epoch in range(args.epochs):
        for i, data, verbose in enumerate_report(trainloader, args.report_s):
            inputs, labels = data[0].cuda(), data[1].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            # update statistics
            loss = loss_fn(outputs, labels)
            accuracy = (outputs.argmax(1) == labels).float().mean()  # calculate accuracy
            losses.append(loss.item())
            accuracies.append(accuracy.item())

            if verbose and len(losses) > 0:
                avgloss = sum(losses)/len(losses)
                avgaccuracy = sum(accuracies)/len(accuracies)
                print(f"rank {rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}", file=sys.stderr)
            loss.backward()
            optimizer.step()
            steps += len(labels)
            if steps > args.maxsteps:
                print("finished training (maxsteps)", steps, args.maxsteps, file=sys.stderr)
                return

    print("finished Training", steps)

In [None]:
%%script true
args = Args()
args.epochs = 1
args.maxsteps = 100000
train(None, args)

In [None]:
if not ray.is_initialized():
    ray.init()
ray.available_resources()['GPU']

In [None]:
@ray.remote(num_gpus=1)
def train_remote(rank, args):
    # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
    train(rank, args)

def distributed_training(world_size=2):
    args = Args()
    num_gpus = ray.available_resources()['GPU']
    args.world_size = min(world_size, num_gpus)
    results = ray.get([train_remote.remote(i, args) for i in range(args.world_size)])
    print(results)

distributed_training(2)

In [None]:
distributed_training()