# 4. `tf.data` input pipelines for PyTorch: CNN on ImageNet

With ResNet50 and `batch_size=64` this gives a throughput of 222.69 images/sec. This is the same value can be obtained with [the synthetic benchmark](https://github.com/eth-cscs/pytorch-training/blob/master/cnn_synthetic_benchmark/cnn_distr.py).

For `batch_size=64` the GPU memory is \~14355MiB, which is the same memory used by [the synthetic benchmark](https://github.com/eth-cscs/pytorch-training/blob/master/cnn_synthetic_benchmark/cnn_distr.py) with the same `batch_size` (~14205MiB).

With ResNet101 and `batch_size=64` this gives a throughput of \~132.05 images/sec. [The synthetic benchmark](https://github.com/eth-cscs/pytorch-training/blob/master/cnn_synthetic_benchmark/cnn_distr.py) gives \~113.46 images/sec. Both have similar GPU memory usage (~10635MiB).

The idea used in this notebook comes from the Kaggle notebook [TF-DS for PyTorch](https://www.kaggle.com/hirotaka0122/tf-ds-for-pytorch).

In [1]:
import glob
import time
import tensorflow as tf
import tensorflow_datasets as tfds
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torch.utils.data import DataLoader, Dataset
from torchvision import models

In [2]:
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
list_of_files = sorted(glob.glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/train/*'))

In [4]:
def decode(serialized_example):
    """Decode and resize"""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    image = tf.image.resize_with_crop_or_pad(image, *(224, 224))
    image = tf.transpose(image, (2, 0, 1)) # rgb channels to the front
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    label = tf.cast(features['image/class/label'], tf.int64)
    return image, label - 1

In [8]:
batch_size = 64

dataset = tf.data.TFRecordDataset(list_of_files)
dataset = dataset.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)

In [9]:
dataset = dataset.take(50)
dataset_np = tfds.as_numpy(dataset)

In [10]:
device = 0
model = models.resnet101()
model.to(device);

In [11]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [12]:
def benchmark_step(model, imgs, labels):
    imgs = torch.tensor(imgs)
    labels = torch.tensor(labels)
    optimizer.zero_grad()
    output = model(imgs.to(device))
    loss = F.cross_entropy(output, labels.to(device))
    loss.backward()
    optimizer.step()

In [14]:
t0 = time.time()
for step, (imgs, labels) in enumerate(dataset_np):
    benchmark_step(model, imgs, labels)

dt = time.time() - t0
imgs_sec = batch_size * (step + 1) / dt
print(f' * throughput: {imgs_sec:.2f} images/sec')

 * throughput: 132.29 images/sec
