In [None]:
# Dataset
class SwissRollDataset(Dataset):
    def __init__(self, mode='train', n_samples=1000, noise=0.1, val_ratio=0.1, test_ratio=0.3, transform=None, random_state=42):
        assert mode in ['train', 'val', 'test'], "mode must be 'train', 'val' or 'test'"
        self.mode = mode
        self.transform = transform

        # Swiss roll verisini oluştur
        data, t = make_swiss_roll(n_samples=n_samples, noise=noise, random_state=random_state)
        data = data.astype(np.float32)

        # Sınıflandırma için etiketleme (isteğe bağlı)
        labels = np.zeros_like(t, dtype=int)
        labels[(t >= 6) & (t < 9)] = 1
        labels[(t >= 9) & (t < 12)] = 2
        labels[(t >= 12)] = 3
        labels = labels.astype(np.int64)

        # Veriyi böl
        X_train, X_temp, y_train, y_temp = train_test_split(data, labels, test_size=val_ratio + test_ratio, random_state=random_state, stratify=labels)
        val_size = val_ratio / (val_ratio + test_ratio)
        X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1 - val_size, random_state=random_state, stratify=y_temp)

        if mode == 'train':
            self.X, self.y = X_train, y_train
        elif mode == 'val':
            self.X, self.y = X_val, y_val
        else:
            self.X, self.y = X_test, y_test

        # Normalization
        self.X = self.rescale(torch.from_numpy(self.X).clone().detach(), 0, 1).numpy()

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]
        if self.transform:
            x_ = self.transform(x)
            return torch.tensor(x, dtype=torch.float32), torch.tensor(x_, dtype=torch.float32), torch.tensor(y)
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y)

    def rescale(self, x, lo, hi):
        """Rescale a tensor to [lo,hi]."""
        assert(lo < hi), "[rescale] lo={0} must be smaller than hi={1}".format(lo,hi)
        old_width = torch.max(x)-torch.min(x)
        old_center = torch.min(x) + (old_width / 2.)
        new_width = float(hi-lo)
        new_center = lo + (new_width / 2.)
        # shift everything back to zero:
        x = x - old_center
        # rescale to correct width:
        x = x * (new_width / old_width)
        # shift everything to the new center:
        x = x + new_center
        # return:
        return x