In [5]:
# imports

%matplotlib inline

from functools import partial
from pprint import pprint
import random
from collections import deque
import numpy as np
from matplotlib import pyplot as plt
import sys

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim
import webdataset as wds


In [6]:
# helpers

import time

def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth

In [7]:
# The standard TorchVision transformations.

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# The dataset returns dictionaries. This is a small function we transform it
# with to get the augmented image and the label.

def make_sample(sample, val=False):
    image = sample["jpg"]
    label = sample["cls"]
    if val:
        return transform_val(image), label
    else:
        return transform_train(image), label


In [8]:
# These are standard PyTorch datasets. Download is incremental into the cache.

!mkdir -p ./_cache

bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"

# WebDataset is designed to work without any local storage. If you have a fast
# data server (AIStore, object store, cloud storage for cloud jobs), you don't need
# to cache the data locally. If you are using a local filesystem, you can use
# the cache_dir argument to cache the data locally. By default, shards will just
# be downloaded into the cache directory with their original names. You can share
# the cache directory between multiple jobs and between the webdataset and wids
# libraries. You can't share the cache directory between different datasets
# if there are name conflicts between the files of those datasets.

if 'google.colab' in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")

# First, we create the datasets. We use resampled=True to make training completely
# independent of the number of workers.

trainset = wds.WebDataset(bucket+"/imagenet-train-{000000..001281}.tar", resampled=True, cache_dir=cache_dir)
valset = wds.WebDataset(bucket+"/imagenet-val-{000000..000049}.tar", cache_dir=cache_dir)

# We shuffle the training data and decode the images using the PIL decoder.
# We then decode the images and apply make_sample.

trainset = trainset.shuffle(1000).decode("pil").map(make_sample)
valset = valset.decode("pil").map(partial(make_sample, val=True))

# For IterableDataset, the batching needs to be carried out in the dataset itself.

trainset = trainset.batched(64)
valset = valset.batched(64)

# Make sure it works so far

print(repr(next(iter(trainset)))[:100])

# Create the dataloaders. WebLoader is just a convenient wrapper around
# DataLoader with some useful methods.

trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)
valloader = wds.WebLoader(valset, batch_size=None, num_workers=4)

# For training, we unbatch the dataset, set the number of samples per epoch, then
# rebatch. This reproduces approximately the original epoch size and it shuffles
# between shards.

trainloader = trainloader.unbatched().batched(64).with_epoch(1282 * 100 // 64)

# Again, make sure it works so far.

images, classes = next(iter(trainloader))
print(images.shape, classes.shape)



not running in colab, caching data locally in ./_cache


[tensor([[[[-0.4568, -0.4226, -0.4054,  ..., -0.4397, -0.3883, -0.4226],
          [-0.4226, -0.4226
torch.Size([64, 3, 224, 224]) torch.Size([64])


In [9]:
# The usual PyTorch model definition. We use an uninitialized ResNet50 model.

model = resnet50(pretrained=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [10]:
num_epochs = 5

losses, accuracies = deque(maxlen=100), deque(maxlen=100)

# Train the model
for epoch in range(num_epochs):
    for i, data, verbose in enumerate_report(trainloader, 5):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)
        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()
        accuracy = correct / float(len(labels))

        losses.append(loss.item())
        accuracies.append(accuracy)

        if verbose and len(losses) > 5:
            print('[%d, %5d] loss: %.5f correct: %.5f' % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies)))
            running_loss = 0.0

print('Finished Training')

[1,    13] loss: 4.68286 correct: 0.05168
[1,    40] loss: 4.24923 correct: 0.07500
[1,    67] loss: 3.94172 correct: 0.08675
[1,    94] loss: 3.61655 correct: 0.11503
[1,   121] loss: 3.05803 correct: 0.17438
[1,   149] loss: 2.56617 correct: 0.25484
[1,   175] loss: 2.17077 correct: 0.33797
[1,   202] loss: 1.91672 correct: 0.40672
[1,   229] loss: 1.73701 correct: 0.46203
[1,   255] loss: 1.61650 correct: 0.49906
[1,   282] loss: 1.47733 correct: 0.53547
[1,   309] loss: 1.36030 correct: 0.56609
[1,   336] loss: 1.27227 correct: 0.59391
[1,   363] loss: 1.23422 correct: 0.60672
[1,   389] loss: 1.20162 correct: 0.62094
[1,   416] loss: 1.16268 correct: 0.63297
[1,   443] loss: 1.09108 correct: 0.65938
[1,   469] loss: 1.03599 correct: 0.67359
[1,   495] loss: 1.00575 correct: 0.68234
[1,   522] loss: 0.99541 correct: 0.68297
[1,   548] loss: 0.99640 correct: 0.68188
[1,   575] loss: 0.98761 correct: 0.68234
[1,   602] loss: 0.97867 correct: 0.69172
[1,   628] loss: 0.96813 correct: 

KeyboardInterrupt: 