In [8]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image

In [12]:
class LegoDataset(Dataset):
    """Custom dataset class for Lego Minifigure dataset."""
    
    def __init__(self, img_dir, test=False, transform=None, target_transform=None):
        """init method.
        
        Keyword arguments:
        img_dir -- the path to the root image directory of test and train data
        test -- True to load test data (default: False)
        transform -- transform to apply to X
        target_transform -- transform to apply to y"""
        
        self.transform = transform
        self.target_transform = target_transform
        self.test = test
        self.img_dir = os.path.join(img_dir, "test/" if test else "train/")
        self.full_df = None
        
        # read path names and class names from csv files
        meta_df = pd.read_csv(os.path.join(self.img_dir, "metadata.csv"))
        if self.test:
            test_df = pd.read_csv(os.path.join(self.img_dir, "test.csv"))
            self.full_df = test_df.merge(meta_df, on="class_id")
        else:
            train_df = pd.read_csv(os.path.join(self.img_dir, "index.csv"))
            self.full_df = train_df.merge(meta_df, on="class_id")
    
    def __len__(self):
        return len(self.full_df)
    
    def __getitem__(self, idx):
        row = self.full_df.iloc[idx]
        image = read_image(os.path.join(self.img_dir, row["path"]))
        label = row["class_id"]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
    

In [13]:
# load datasets
train_dataset = LegoDataset("data/", test=False)
test_dataset = LegoDataset("data/", test=True)

In [18]:
img, label = train_dataset[0]
img.shape

torch.Size([3, 512, 512])