In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import ConcatDataset, Dataset, DataLoader
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class CustomImageFolderDataset(Dataset):
    def __init__(self, *paths):
        self.datasets = [datasets.ImageFolder(root=path) for path in paths]
        
        all_classes = sorted(set(cls for ds in self.datasets for cls in ds.classes))
        
        self.class_to_idx = {cls: idx for idx, cls in enumerate(all_classes)}
        self.classes = all_classes
        
        self.imgs = []
        for ds in self.datasets:
            for img_path, target in ds.imgs:
                new_target = self.class_to_idx[ds.classes[target]]
                self.imgs.append((img_path, new_target))
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = Image.open(path).convert('RGB')  
        return img, target

    def __len__(self):
        return len(self.imgs)

In [3]:
data_dir = '/kaggle/input/fruit-recognition'
custom_dataset = CustomImageFolderDataset(data_dir)

In [4]:
custom_dataset.classes

['test', 'train']