In [1]:
%matplotlib inline

from importlib import reload

from functools import partial
from pprint import pprint
import random
from collections import deque
import numpy as np
from matplotlib import pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim
import wsds
reload(wsds)

import time


def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth

In [2]:
# parameters
epochs = 3
max_steps = 100000
batch_size = 32
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet/"
num_workers = 4
cache_dir = "./_cache"

In [3]:
# The standard TorchVision transformations.

transform_train = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

transform_val = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


In [4]:
 # All the decoding and data augmentation is done in the make_sample function.

def make_sample(sample, val=False):
    # decode the sample in place
    wsds.decode_basic(sample)
    wsds.decode_images_to_pil(sample)

    # extract the image and label
    image = sample[".jpg"]
    label = sample[".cls"]

    # apply the transformations
    if val:
        return transform_val(image), label
    else:
        return transform_train(image), label

In [5]:
spec = f"""---
train:
    sequential:
        shards: {bucket}imagenet-train.json
        shuffle_size: 10000
        cache_dir: {cache_dir}
        keep_downloaded: true
val:
    sequential:
        shards: {bucket}imagenet-val.json
        cache_dir: {cache_dir}
        keep_downloaded: true
"""

trainset = wsds.SequentialDataset(spec, transformations=make_sample)
next(iter(trainset))

(tensor([[[ 0.3994,  0.3309,  0.4337,  ...,  0.1597,  0.2624,  0.3481],
          [ 0.4851,  0.4679,  0.5022,  ...,  0.4508,  0.5707,  0.6221],
          [ 0.4166,  0.4679,  0.4679,  ...,  0.3994,  0.3994,  0.5707],
          ...,
          [-0.6794, -0.4568, -0.5596,  ..., -0.6281, -0.6452, -0.6623],
          [-0.6965, -0.4911, -0.6281,  ..., -0.4911, -0.6281, -0.5767],
          [-0.7993, -1.0562, -1.1760,  ..., -0.6452, -0.5424, -0.6109]],
 
         [[ 0.0476,  0.0126,  0.1001,  ..., -0.0924, -0.0224,  0.0651],
          [ 0.1702,  0.1527,  0.1527,  ...,  0.2227,  0.2927,  0.3627],
          [ 0.0826,  0.1352,  0.1352,  ...,  0.1702,  0.1001,  0.3102],
          ...,
          [-0.8277, -0.6001, -0.7402,  ..., -0.8452, -0.8452, -0.8627],
          [-0.8803, -0.6702, -0.7927,  ..., -0.6877, -0.8277, -0.7752],
          [-0.9678, -1.2304, -1.3354,  ..., -0.8452, -0.7577, -0.8102]],
 
         [[ 0.0431,  0.0082,  0.0779,  ...,  0.0431,  0.1128,  0.1999],
          [ 0.1651,  0.1302,

In [6]:
valset = wsds.SequentialDataset(spec, which="val", transformations=partial(make_sample, val=True))
next(iter(valset))

(tensor([[[-1.1932, -1.1932, -1.1932,  ...,  0.3138,  0.3481,  0.3823],
          [-1.1932, -1.1589, -1.1760,  ...,  0.2796,  0.2796,  0.3309],
          [-1.1589, -1.1418, -1.1418,  ...,  0.2282,  0.2453,  0.3138],
          ...,
          [-0.9534, -1.0904, -1.2274,  ...,  0.8961,  0.8789,  0.7591],
          [-0.7650, -0.8849, -0.9705,  ...,  0.7419,  0.7248,  0.7762],
          [-0.7993, -0.8335, -0.8507,  ...,  0.7762,  0.7591,  0.7248]],
 
         [[-0.8978, -0.9328, -0.8978,  ...,  1.4307,  1.4657,  1.5007],
          [-0.9153, -0.8978, -0.8978,  ...,  1.3957,  1.4307,  1.4832],
          [-0.8978, -0.8803, -0.8803,  ...,  1.3606,  1.3957,  1.4657],
          ...,
          [-0.3025, -0.4426, -0.5826,  ...,  1.6933,  1.6758,  1.6057],
          [ 0.1176, -0.0574, -0.1625,  ...,  1.5707,  1.5532,  1.6057],
          [ 0.1001,  0.0476, -0.0224,  ...,  1.5707,  1.5882,  1.5882]],
 
         [[-1.3687, -1.3513, -1.3687,  ..., -1.2990, -1.2990, -1.2641],
          [-1.3687, -1.3339,

In [6]:
!ls ./_cache

imagenet-train-000496.tar  imagenet-train-000569.tar


In [None]:
# We also need a sampler for the training set. There are three
# special samplers in the `wids` package that work particularly
# well with sharded datasets:
# - `wids.ShardedSampler` shuffles shards and then samples in shards;
#   it guarantees that only one shard is used at a time
# - `wids.ChunkedSampler` samples by fixed sized chunks, shuffles
#   the chunks, and the the samples within each chunk
# - `wids.DistributedChunkedSampler` is like `ChunkedSampler` but
#   works with distributed training (it first divides the entire
#   dataset into per-node chunks, then the per-node chunks into
#   smaller chunks, then shuffles the smaller chunks)

# trainsampler = wids.ShardedSampler(trainset)
# trainsampler = wids.ChunkedSampler(trainset, chunksize=1000, shuffle=True)
trainsampler = wids.DistributedChunkedSampler(trainset, chunksize=1000, shuffle=True)

plt.plot(list(trainsampler)[:2500])

# Note that the sampler shuffles within each shard before moving on to
# the next shard. Furthermore, on the first epoch, the sampler
# uses the shards in order, but on subsequent epochs, it shuffles
# them. This makes testing and debugging easier. If you don't like
# this behavior, you can use shufflefirst=True

trainsampler.set_epoch(0)

In [None]:
# Create data loaders for the training and validation datasets

trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=4, sampler=trainsampler)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)

images, classes = next(iter(trainloader))
print(images.shape, classes.shape)

In [None]:
# The usual PyTorch model definition. We use an uninitialized ResNet50 model.

model = resnet50(pretrained=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
losses, accuracies = deque(maxlen=100), deque(maxlen=100)

steps = 0

# Train the model
for epoch in range(epochs):
    for i, data, verbose in enumerate_report(trainloader, 5):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)
        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()
        accuracy = correct / float(len(labels))

        losses.append(loss.item())
        accuracies.append(accuracy)
        steps += len(labels)

        if verbose and len(losses) > 5:
            print(
                "[%d, %5d] loss: %.5f correct: %.5f"
                % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies))
            )
            running_loss = 0.0

        if steps > max_steps:
            break
    if steps > max_steps:
        break

print("Finished Training")