In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from typing import Tuple, Dict, List

class ImageFolderCustom(Dataset):
    def __init__(self, root_dir: str, partition:str='train', transform=None) -> None:
        self.targ_dir = os.path.join(root_dir, partition)
        self.paths = list(pathlib.Path(self.targ_dir).glob("*/*.jpeg"))
        self.readable_classes_dict = extract_readable_imagenet_labels(os.path.join(root_dir, 'LOC_synset_mapping.txt'))
        self.transform = transform
        self.classes, self.class_to_idx = find_classes(self.targ_dir, self.readable_classes_dict)

    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path) 
    
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        img = self.load_image(index)
        class_name = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
        readable_class_name = self.readable_classes_dict[class_name]
        class_idx = self.class_to_idx[readable_class_name]

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx
        
def extract_readable_imagenet_labels(file_path: os.path) -> dict:
    """
    Helper function for storing imagenet human read-able 
    class mappings. Mapping downloaded from 
    https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a 
    """
    class_dict = {}

    with open(file_path, 'r') as file:
        for line in file:
            words = line.strip().split()
            class_dict[words[0]] = words[1].rstrip(',') # Incase there are several readable labels which are comma separated. 

    return class_dict


def find_classes(directory: str, readable_classes_dict: dict) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.
    
    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    readable_classes = [readable_classes_dict.get(key) for key in classes]
    
    if not readable_classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        
    class_to_idx = {cls_name: i for i, cls_name in enumerate(readable_classes)}
    return readable_classes, class_to_idx

In [27]:
train_data = ImageFolderCustom(r'C:\data\imagenet')
import torchvision.transforms as transforms
val_data = ImageFolderCustom(r'C:\data\imagenet', 'train', transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ]))
val_data.class_to_idx = train_data.class_to_idx

In [23]:
print(val_data.targ_dir)
val_data.paths

C:\data\imagenet\val


[WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00002138.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00003014.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00006697.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00007197.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00009111.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00009191.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00009346.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00009379.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00009396.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00010306.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00011233.JPEG'),
 WindowsPath('C:/data/imagenet/val/n01440764/ILSVRC2012_val_00011993.JPEG'),

In [29]:
img, label = val_data.__getitem__(5)
img.shape

torch.Size([3, 224, 224])