In [5]:
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import os

In [6]:
batch_size = 256
image_dir = "/home/jajal/datasets/imagenet/train"

In [7]:
@pipeline_def
def simple_pipeline():
    jpegs, labels = fn.readers.file(file_root=image_dir, random_shuffle=True)
    images = fn.decoders.image(jpegs, device="cpu")
    return images, labels

In [8]:
pipe = simple_pipeline(batch_size=batch_size, num_threads=2, device_id=0)
pipe.build()

In [10]:
pipe_out = pipe.run()
print(pipe_out)

(TensorListCPU(
    [[[[  3   1   4]
      [  6   0   2]
      ...
      [  5   0   4]
      [  3   0   7]]

     [[  0   4   0]
      [ 90 122  46]
      ...
      [ 45  74  18]
      [  1   1   1]]

     ...

     [[  2   0   1]
      [123 125 124]
      ...
      [126 126 126]
      [  3   3   3]]

     [[  3   1   2]
      [  1   1   1]
      ...
      [  0   0   0]
      [  2   2   2]]]


    [[[10 15 11]
      [10 15 11]
      ...
      [ 8 13  9]
      [11 16 12]]

     [[10 15 11]
      [10 15 11]
      ...
      [ 8 13  9]
      [10 15 11]]

     ...

     [[ 9 14 10]
      [10 15 11]
      ...
      [ 9 14 10]
      [ 9 14 10]]

     [[ 8 13  9]
      [10 15 11]
      ...
      [ 9 14 10]
      [10 15 11]]]


    ...


    [[[149 143 143]
      [139 136 157]
      ...
      [138 120  34]
      [126 136  75]]

     [[243 241 246]
      [152 155 160]
      ...
      [141 127  56]
      [ 85  85  47]]

     ...

     [[ 67  79  79]
      [ 68  79  75]
      ...
      [217 199 16

In [1]:
from typing import List

import torch as ch
import torchvision

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datasets = {
    'train': torchvision.datasets.CIFAR10('./tmp', train=True, download=True),
    'test': torchvision.datasets.CIFAR10('./tmp', train=False, download=True)
}

for (name, ds) in datasets.items():
    writer = DatasetWriter(f'./tmp/cifar_{name}.beton', {
        'image': RGBImageField(),
        'label': IntField()
    })
    writer.from_indexed_dataset(ds)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 50000/50000 [00:00<00:00, 248258.00it/s]
100%|██████████| 10000/10000 [00:00<00:00, 49670.59it/s]


In [6]:
# Note that statistics are wrt to uin8 range, [0,255].
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

BATCH_SIZE = 512

loaders = {}
for name in ['train', 'test']:
    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    # Add image transforms and normalization
    if name == 'train':
        image_pipeline.extend([
            RandomHorizontalFlip(),
            RandomTranslate(padding=2),
            Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization.
        ])
    image_pipeline.extend([
        ToTensor(),
        ToTorchImage(),
        Convert(ch.float16),
        torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    
    # Create loaders
    loaders[name] = Loader(f'/home/jajal/research/compvit/tmp/cifar_{name}.beton',
                            batch_size=BATCH_SIZE,
                            num_workers=8,
                            order=OrderOption.RANDOM,
                            drop_last=(name == 'train'),
                            pipelines={'image': image_pipeline,
                                       'label': label_pipeline})

In [7]:
loaders

{'train': <ffcv.loader.loader.Loader at 0x7f9ae8138f10>,
 'test': <ffcv.loader.loader.Loader at 0x7f9bd01f5eb0>}