# Resnet 50 Training on (Fake)Imagenet with WebDataset

This notebook illustrates how to use WebDataset with PyTorch training.

In [1]:
# imports

%matplotlib inline

from functools import partial
from pprint import pprint
import random
from collections import deque
import numpy as np
from matplotlib import pyplot as plt
import sys
import os

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim

# helpers

import time

def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth


In [2]:
# We usually abbreviate webdataset as wds
import webdataset as wds

# Data Loader Construction

In [3]:
# The source of the dataset.

bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
training_urls = bucket + "/imagenet-train-{000000..001281}.tar"
batch_size = 32


In [4]:

# WebDataset is designed to work without any local storage. Use caching
# only if you are on a desktop with slow networking.

if 'google.colab' in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    !mkdir -p ./_cache
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")

not running in colab, caching data locally in ./_cache


In [5]:
# The standard TorchVision transformations.

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def make_sample(sample, val=False):
    """Take a decoded sample dictionary, augment it, and return an (image, label) tuple."""
    assert not val, "only implemented training dataset for this notebook"
    image = sample["jpg"]
    label = sample["cls"]
    return transform_train(image), label


In [6]:

# Create the datasets with shard and sample shuffling and decoding.
trainset = wds.WebDataset(training_urls, resampled=True, cache_dir=cache_dir, shardshuffle=True)
trainset = trainset.shuffle(1000).decode("pil").map(make_sample)

# Since this is an IterableDataset, PyTorch requires that we batch in the dataset.
# WebLoader is PyTorch DataLoader with some convenience methods.
trainset = trainset.batched(64)
trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)

# Unbatch, shuffle between workers, then rebatch.
trainloader = trainloader.unbatched().shuffle(1000).batched(64)

# Since we are using resampling, the dataset is infinite; set an artificial epoch size.
trainloader = trainloader.with_epoch(1282 * 100 // 64)


In [7]:
# Smoke test it.

os.environ["GOPEN_VERBOSE"] = "1"
images, classes = next(iter(trainloader))
print(images.shape, classes.shape)
os.environ["GOPEN_VERBOSE"] = "0"

GOPENGOPEN  https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-000860.tarhttps://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-000233.tar  {}{}



pipe exit [0 2416815:2416868] (("curl -f -s -L 'https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-000233.tar'",), {'shell': True, 'bufsize': 8192}) {}
pipe exit [0 2416803:2416867] (("curl -f -s -L 'https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-000860.tar'",), {'shell': True, 'bufsize': 8192}) {}


torch.Size([64, 3, 224, 224]) torch.Size([64])


# PyTorch Training

This is a typical PyTorch training pipeline.

In [8]:
# The usual PyTorch model definition. We use an uninitialized ResNet50 model.

model = resnet50(pretrained=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [9]:
num_epochs = 5

losses, accuracies = deque(maxlen=100), deque(maxlen=100)

# Train the model
for epoch in range(num_epochs):
    for i, data, verbose in enumerate_report(trainloader, 5):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)
        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()
        accuracy = correct / float(len(labels))

        losses.append(loss.item())
        accuracies.append(accuracy)

        if verbose and len(losses) > 5:
            print('[%d, %5d] loss: %.5f correct: %.5f' % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies)))
            running_loss = 0.0

print('Finished Training')

[1,    10] loss: 4.76379 correct: 0.07656
[1,    15] loss: 4.60549 correct: 0.08333
[1,    29] loss: 4.22173 correct: 0.09052
[1,    43] loss: 3.97581 correct: 0.09339
[1,    56] loss: 3.83008 correct: 0.10100
[1,    69] loss: 3.63373 correct: 0.11164
[1,    82] loss: 3.47505 correct: 0.12957
[1,    95] loss: 3.35933 correct: 0.14062
[1,   108] loss: 3.11937 correct: 0.16156
[1,   121] loss: 2.83944 correct: 0.19672
[1,   134] loss: 2.66036 correct: 0.22766
[1,   148] loss: 2.44853 correct: 0.27234
[1,   162] loss: 2.28026 correct: 0.31734
[1,   172] loss: 2.19525 correct: 0.34359
[1,   186] loss: 2.08001 correct: 0.37703
[1,   199] loss: 1.98180 correct: 0.40766
[1,   213] loss: 1.87447 correct: 0.43688
[1,   226] loss: 1.80813 correct: 0.45969
[1,   239] loss: 1.73008 correct: 0.48750
[1,   252] loss: 1.65877 correct: 0.49906
[1,   266] loss: 1.59890 correct: 0.51203
[1,   280] loss: 1.53183 correct: 0.53031
[1,   293] loss: 1.45920 correct: 0.54719
[1,   300] loss: 1.43194 correct: 