In [1]:
import torch
from torchvision import datasets, transforms
import os

In [2]:
#!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
#!unzip hymenoptera_data.zip


In [3]:
!unzip /content/train_inf.zip

Archive:  /content/train_inf.zip
   creating: train_inf/train/
   creating: train_inf/train/ants_5/
  inflating: train_inf/train/ants_5/0013035.jpg  
  inflating: train_inf/train/ants_5/5650366_e22b7e1065.jpg  
  inflating: train_inf/train/ants_5/6240329_72c01e663e.jpg  
  inflating: train_inf/train/ants_5/6240338_93729615ec.jpg  
  inflating: train_inf/train/ants_5/6743948_2b8c096dda.jpg  
   creating: train_inf/train/bees_5/
  inflating: train_inf/train/bees_5/16838648_415acd9e3f.jpg  
  inflating: train_inf/train/bees_5/17209602_fe5a5a746f.jpg  
  inflating: train_inf/train/bees_5/21399619_3e61e5bb6f.jpg  
  inflating: train_inf/train/bees_5/29494643_e3410f0d37.jpg  
  inflating: train_inf/train/bees_5/36900412_92b81831ad.jpg  


In [4]:
from torch.utils.data import DataLoader
class InfiniteDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            # Dataset exhausted, use a new fresh iterator.
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch

In [5]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [6]:
data_dir = '/content/train_inf'

In [7]:
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train']}

In [8]:
dataloaders = {x: InfiniteDataLoader(image_datasets[x], batch_size=2,
                                             shuffle=True, num_workers=4)
              for x in ['train']}



In [9]:
dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}

In [10]:
dataset_sizes

{'train': 10}

In [11]:
class_names = image_datasets['train'].classes


In [12]:
class_names

['ants_5', 'bees_5']

In [13]:
for i in range(10):
  batch_images, batch_labels = next(iter(dataloaders['train']))
  print(f'Batch {i+1}: {batch_images.shape}, {batch_labels.shape}')  

Batch 1: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 2: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 3: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 4: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 5: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 6: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 7: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 8: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 9: torch.Size([2, 3, 224, 224]), torch.Size([2])
Batch 10: torch.Size([2, 3, 224, 224]), torch.Size([2])
