In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
import ray
import webdataset as wds
import dataclasses
import time
from collections import deque

def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth


  from .autonotebook import tqdm as notebook_tqdm
2023-12-11 19:02:25,300	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
trainset_url = bucket+"/imagenet-train-{000000..001281}.tar"
valset_url = bucket+"/imagenet-val-{000000..000049}.tar"

if 'google.colab' in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")

def make_dataloader_train():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    def make_sample(sample):
        return transform(sample["jpg"]), sample["cls"]
    trainset = wds.WebDataset(trainset_url, resampled=True, cache_dir=cache_dir)
    trainset = trainset.shuffle(1000).decode("pil").map(make_sample)
    trainset = trainset.batched(64)
    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)
    trainloader = trainloader.unbatched().batched(64).with_epoch(1282 * 100 // 64)
    return trainloader

def make_dataloader(split="train"):
    return getattr(sys.modules[__name__], f"make_dataloader_{split}")()

sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)

not running in colab, caching data locally in ./_cache
torch.Size([64, 3, 224, 224]) torch.Size([64])


In [3]:
@dataclasses.dataclass
class Args:
    epochs: int = 1
    maxsteps: int = int(1e18)
    lr: float = 0.001
    momentum: float = 0.9
    rank: int = 0
    world_size: int = 2
    backend: str = "nccl"
    master_addr: str = "localhost"
    master_port: str = "12355"
    report_s: float = 15.0
    report_growth: float = 1.1

In [4]:
def train(rank, args):
    # Set up distributed PyTorch.
    if rank is not None:
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.master_port
        dist.init_process_group(backend=args.backend, rank=rank, world_size=args.world_size)

    # Define the model, loss function, and optimizer
    model = resnet50(pretrained=False).cuda()
    if rank is not None:
        model = DistributedDataParallel(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Data loading code
    trainloader = make_dataloader(split='train')

    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0

    # Training loop
    for epoch in range(args.epochs):
        for i, data, verbose in enumerate_report(trainloader, args.report_s):
            inputs, labels = data[0].cuda(), data[1].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            # update statistics
            loss = loss_fn(outputs, labels)
            accuracy = (outputs.argmax(1) == labels).float().mean()  # calculate accuracy
            losses.append(loss.item())
            accuracies.append(accuracy.item())

            if verbose and len(losses) > 0:
                avgloss = sum(losses)/len(losses)
                avgaccuracy = sum(accuracies)/len(accuracies)
                print(f"rank {rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}", file=sys.stderr)
            loss.backward()
            optimizer.step()
            steps += len(labels)
            if steps > args.maxsteps:
                print("finished training (maxsteps)", steps, args.maxsteps, file=sys.stderr)
                return

    print("finished Training", steps)

In [5]:
args = Args()
args.epochs = 1
args.maxsteps = 100000
train(None, args)

rank None epoch     0/        0 loss    6.939 acc    0.000         0
rank None epoch     0/       94 loss    4.281 acc    0.070      6016
rank None epoch     0/      189 loss    3.021 acc    0.104     12096
rank None epoch     0/      284 loss    2.898 acc    0.103     18176
rank None epoch     0/      379 loss    2.862 acc    0.096     24256
rank None epoch     0/      475 loss    2.803 acc    0.110     30400
rank None epoch     0/      570 loss    2.758 acc    0.133     36480
rank None epoch     0/      665 loss    2.727 acc    0.132     42560
rank None epoch     0/      759 loss    2.693 acc    0.156     48576
rank None epoch     0/      854 loss    2.694 acc    0.154     54656
rank None epoch     0/      950 loss    2.702 acc    0.159     60800
rank None epoch     0/     1046 loss    2.664 acc    0.179     66944
rank None epoch     0/     1141 loss    2.627 acc    0.194     73024
rank None epoch     0/     1236 loss    2.597 acc    0.199     79104
rank None epoch     0/     1330 lo

In [6]:

@ray.remote(num_gpus=1)
def train_remote(args):
    # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
    train(0, args)

def distributed_training(world_size=2):
    ray.init()
    args = Args()
    num_gpus = ray.available_resources()['GPU']
    args.world_size = min(world_size, num_gpus)
    results = ray.get([train_remote.remote(args) for _ in range(args.world_size)])
    print(results)