In [None]:
# Dataset
class SpiralDataset(Dataset):
    def __init__(self, mode='train', n_samples=1000, noise=0.8, val_ratio=0.1, test_ratio=0.2, transform=None, random_state=42):
        assert mode in ['train', 'val', 'test'], "mode must be 'train', 'val' or 'test'"
        self.mode = mode
        self.transform = transform

        # Generate the Spiral dataset
        t = np.linspace(0, 4 * np.pi, n_samples)
        x = t * np.cos(t)
        y = t * np.sin(t)
        z = t
        c = np.zeros_like(t, dtype=int)
        c[(t >= np.pi) & (t < 2*np.pi)] = 1
        c[(t >= 2*np.pi) & (t < 3*np.pi)] = 2
        c[(t >= 3*np.pi)] = 3

        data = np.stack([x, y, z], axis=1)
        data += noise * np.random.randn(*data.shape)

        # Type conversion
        X = data.astype(np.float32)
        y = c.astype(np.int64)

        # Split into Train/Validation/Test
        X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=val_ratio + test_ratio, random_state=random_state, stratify=y)
        val_size = val_ratio / (val_ratio + test_ratio)
        X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1 - val_size, random_state=random_state, stratify=y_temp)

        if mode == 'train':
            self.X, self.y = X_train, y_train
        elif mode == 'val':
            self.X, self.y = X_val, y_val
        else:
            self.X, self.y = X_test, y_test

        # ---- Torch conversion ONCE ----
        self.X = torch.from_numpy(self.X).float()
        self.y = torch.from_numpy(self.y).long()

        # ---- Minâ€“Max normalization ----
        self.X = (self.X - self.X.min()) / (self.X.max() - self.X.min() + 1e-8)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]
        if self.transform:
            x_ = self.transform(x)
            return x, x_, y
        return x, y
