In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(self, 
                 mode='train', 
                 val_ratio=0.1, 
                 transform=None, 
                 root='./data', 
                 random_state=42):
        
        assert mode in ['train', 'val', 'test'], "mode must be 'train', 'val' or 'test'"
        self.transform = transform

        # Temel dönüşümler
        base_transform = transforms.Compose([
            transforms.ToTensor()
        ])

        # Verisetlerini indir
        full_train = datasets.CIFAR10(root=root, train=True, download=True, transform=base_transform)
        test_dataset = datasets.CIFAR10(root=root, train=False, download=True, transform=base_transform)

        # Stratified Train/Val ayırımı
        train_dataset, val_dataset = None, None
        targets = np.array(full_train.targets)
        indices = np.arange(len(full_train))

        train_idx, val_idx = train_test_split(
            indices,
            test_size=val_ratio,
            random_state=random_state,
            stratify=targets
        )

        train_dataset = Subset(full_train, train_idx)
        val_dataset = Subset(full_train, val_idx)

        # Mode'a göre uygun dataset seç
        if mode == 'train':
            self.dataset = train_dataset
        elif mode == 'val':
            self.dataset = val_dataset
        else:  # mode == 'test'
            self.dataset = test_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]

        if self.transform:
            x_ = self.transform(x).to(torch.float32)
            return x, x_, y

        return x, y