In [6]:
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import ray
import wids
import dataclasses
import time
from collections import deque

def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth


In [7]:
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
trainset_url = bucket+"/imagenet-train-{000000..001281}.tar"
valset_url = bucket+"/imagenet-val-{000000..000049}.tar"

if 'google.colab' in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")

def make_dataset_train():
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    def make_sample(sample):
        image = sample[".jpg"]
        label = sample[".cls"]
        return transform_train(image), label

    trainset = wids.ShardListDataset("gs://webdataset/fake-imagenet/imagenet-train.json", cache_dir="./_cache", keep=True)
    trainset = trainset.add_transform(make_sample)

    return trainset


def make_dataloader_train():
    dataset = make_dataset_train()
    sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
    return dataloader


def make_dataloader(split="train"):
    """Make a dataloader for training or validation."""
    if split == "train":
        return make_dataloader_train()
    elif split == "val":
        return make_dataloader_val()
    else:
        raise ValueError(f"unknown split {split}")

# Try it out.
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)

not running in colab, caching data locally in ./_cache


gs://webdataset/fake-imagenet/imagenet-train.json base: gs://webdataset/fake-imagenet name: imagenet-train nfiles: 1282 nbytes: 31242280960 samples: 128200 cache: ./_cache


torch.Size([32, 3, 224, 224]) torch.Size([32])


In [8]:
@dataclasses.dataclass
class Args:
    epochs: int = 1
    maxsteps: int = int(1e18)
    lr: float = 0.001
    momentum: float = 0.9
    rank: int = 0
    world_size: int = 2
    backend: str = "nccl"
    master_addr: str = "localhost"
    master_port: str = "12355"
    report_s: float = 15.0
    report_growth: float = 1.1

In [9]:
def train(rank, args):
    # Set up distributed PyTorch.
    if rank is not None:
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.master_port
        print(f"rank {rank} initializing process group", file=sys.stderr)
        dist.init_process_group(backend=args.backend, rank=rank, world_size=args.world_size)
        print(f"rank {rank} done initializing process group", file=sys.stderr)

    # Define the model, loss function, and optimizer
    model = resnet50(pretrained=False).cuda()
    if rank is not None:
        model = DistributedDataParallel(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Data loading code
    trainloader = make_dataloader(split='train')

    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0

    # Training loop
    for epoch in range(args.epochs):
        for i, data, verbose in enumerate_report(trainloader, args.report_s):
            inputs, labels = data[0].cuda(), data[1].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            # update statistics
            loss = loss_fn(outputs, labels)
            accuracy = (outputs.argmax(1) == labels).float().mean()  # calculate accuracy
            losses.append(loss.item())
            accuracies.append(accuracy.item())

            if verbose and len(losses) > 0:
                avgloss = sum(losses)/len(losses)
                avgaccuracy = sum(accuracies)/len(accuracies)
                print(f"rank {rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}", file=sys.stderr)
            loss.backward()
            optimizer.step()
            steps += len(labels)
            if steps > args.maxsteps:
                print("finished training (maxsteps)", steps, args.maxsteps, file=sys.stderr)
                return

    print("finished Training", steps)

In [10]:
%%script true
args = Args()
args.epochs = 1
args.maxsteps = 100000
train(None, args)

In [11]:
if not ray.is_initialized():
    ray.init()
ray.available_resources()['GPU']

2023-12-12 02:32:13,935	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


2.0

In [12]:
@ray.remote(num_gpus=1)
def train_remote(rank, args):
    # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
    train(rank, args)

def distributed_training(world_size=2):
    args = Args()
    num_gpus = ray.available_resources()['GPU']
    args.world_size = min(world_size, num_gpus)
    results = ray.get([train_remote.remote(i, args) for i in range(args.world_size)])
    print(results)

distributed_training(2)

[36m(train_remote pid=2094790)[0m rank 0 initializing process group
[36m(train_remote pid=2094789)[0m rank 1 done initializing process group
[36m(train_remote pid=2094790)[0m gs://webdataset/fake-imagenet/imagenet-train.json base: gs://webdataset/fake-imagenet name: imagenet-train nfiles: 1282 nbytes: 31242280960 samples: 128200 cache: ./_cache
[36m(train_remote pid=2094789)[0m rank 1 initializing process group
[36m(train_remote pid=2094790)[0m rank 0 done initializing process group
[36m(train_remote pid=2094790)[0m rank 0 epoch     0/        0 loss    6.811 acc    0.000         0
[36m(train_remote pid=2094789)[0m gs://webdataset/fake-imagenet/imagenet-train.json base: gs://webdataset/fake-imagenet name: imagenet-train nfiles: 1282 nbytes: 31242280960 samples: 128200 cache: ./_cache
[36m(train_remote pid=2094789)[0m rank 1 epoch     0/        0 loss    6.931 acc    0.000         0
[36m(train_remote pid=2094790)[0m rank 0 epoch     0/        2 loss    6.704 acc    0.02

In [None]:
distributed_training()