## Torchmeta

In [1]:
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

dataset = omniglot("data", ways=5, shots=1, test_shots=1, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=32)

tm_dataset1, tm_dataloader1 = dataset, dataloader

In [2]:
from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader


dataset = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=5,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=1, num_test_per_class=1)
dataloader = BatchMetaDataLoader(dataset, batch_size=32)

tm_dataset2, tm_dataloader2 = dataset, dataloader

## Existing

In [4]:
from types import SimpleNamespace
from utils import init_dataset

options = SimpleNamespace(
    dataset="omniglot",
    num_cls=5,
    num_samples=1,
    iterations=10000,
    batch_size=32   
)

tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader = init_dataset(options)
old_dataloader = tr_dataloader

== Dataset: Found 82240 items 
== Dataset: Found 4112 classes
== Dataset: Found 13760 items 
== Dataset: Found 688 classes
== Dataset: Found 96000 items 
== Dataset: Found 4800 classes
== Dataset: Found 33840 items 
== Dataset: Found 1692 classes


In [8]:
import torch
from datasets import OmniglotDataset

train_dataset = OmniglotDataset(mode='train')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

== Dataset: Found 82240 items 
== Dataset: Found 4112 classes


## Comparison

In [10]:
from tqdm import tqdm

In [11]:
%%time
for batch_idx, batch in tqdm(enumerate(tr_dataloader), total=1000):
    if batch_idx >= 1000:
        break
    pass

100%|██████████| 1000/1000 [00:07<00:00, 125.38it/s]

CPU times: user 8 s, sys: 21.1 ms, total: 8.02 s
Wall time: 7.98 s





In [12]:
%%time
dataloader = BatchMetaDataLoader(tm_dataset1, batch_size=32, num_workers=8)
for batch_idx, batch in tqdm(enumerate(dataloader), total=1000):
    if batch_idx >= 1000:
        break
    pass

100%|██████████| 1000/1000 [00:43<00:00, 23.07it/s]


CPU times: user 1.95 s, sys: 890 ms, total: 2.84 s
Wall time: 44.1 s


In [13]:
%%time
dataloader = BatchMetaDataLoader(tm_dataset2, batch_size=32, num_workers=8)
for batch_idx, batch in tqdm(enumerate(dataloader), total=1000):
    if batch_idx >= 1000:
        break
    pass

100%|██████████| 1000/1000 [00:42<00:00, 23.38it/s]


CPU times: user 1.95 s, sys: 931 ms, total: 2.88 s
Wall time: 43.5 s
