In [None]:
# Dataset
class DigitsDataset(Dataset):
    def __init__(self, mode='train', val_ratio=0.1, test_ratio=0.2, transform=None, random_state=42):
        assert mode in ['train', 'val', 'test'], "mode must be 'train', 'val' or 'test'"
        self.mode = mode
        self.transform = transform

        # Veriyi yükle
        digits = load_digits()
        X = digits.data.astype(np.float32)  # 64 özellik (8x8 image flattened)
        y = digits.target.astype(np.int64)

        # Train/Val/Test bölme
        X_train, X_temp, y_train, y_temp = train_test_split(
            X, y, test_size=val_ratio + test_ratio, stratify=y, random_state=random_state)

        val_size = val_ratio / (val_ratio + test_ratio)
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp, test_size=1 - val_size, stratify=y_temp, random_state=random_state)

        if mode == 'train':
            self.X, self.y = X_train, y_train
        elif mode == 'val':
            self.X, self.y = X_val, y_val
        else:
            self.X, self.y = X_test, y_test

        self.X = self.X.astype(np.float32) / 16.0

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]

        if self.transform:
            x_ = self.transform(x)
            return torch.tensor(x, dtype=torch.float32), torch.tensor(x_, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)