# Data creator


In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from tensorflow.keras.datasets import mnist
import os
from collections import Counter

# --- Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive')

# --- Set save directory inside Google Drive ---
SAVE_DIR = '/content/drive/MyDrive/models_cc/subsets'
os.makedirs(SAVE_DIR, exist_ok=True)

# --- Load MNIST from Keras ---
(X, y), (X_val, y_val) = mnist.load_data()
X = X.reshape(-1, 28 * 28)
y = y.astype(int)

# --- Step 1: Global Test Set ---
sss = StratifiedShuffleSplit(n_splits=1, test_size=10000, random_state=42)
train_idx, global_test_idx = next(sss.split(X, y))

X_remaining, y_remaining = X[train_idx], y[train_idx]
X_global_test, y_global_test = X[global_test_idx], y[global_test_idx]

df_global = pd.DataFrame(X_global_test)
df_global['label'] = y_global_test
df_global.to_parquet(os.path.join(SAVE_DIR, 'mnist_global_test.parquet'), index=False)
print("✅ Saved global test set")

# --- Step 2: Stratified Split into 5 clients ---
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for client_id, (client_idx, _) in enumerate(skf.split(X_remaining, y_remaining), 1):
    X_client, y_client = X_remaining[client_idx], y_remaining[client_idx]

    # --- Step 3: Local train/test split for each client (stratified) ---
    sss_local = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=client_id)
    train_idx, test_idx = next(sss_local.split(X_client, y_client))

    X_train, y_train = X_client[train_idx], y_client[train_idx]
    X_test, y_test = X_client[test_idx], y_client[test_idx]

    print(f"Client {client_id} test class counts:", Counter(y_test))

    # Save to Parquet inside Drive
    df_train = pd.DataFrame(X_train)
    df_train['label'] = y_train
    df_train.to_parquet(os.path.join(SAVE_DIR, f'mnist_client{client_id}_train.parquet'), index=False)

    df_test = pd.DataFrame(X_test)
    df_test['label'] = y_test
    df_test.to_parquet(os.path.join(SAVE_DIR, f'mnist_client{client_id}_test.parquet'), index=False)

    print(f"✅ Saved client {client_id} train/test sets")

print("\n🚀 All stratified splits saved to Google Drive!")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Saved global test set
Client 1 test class counts: Counter({np.int64(1): 899, np.int64(7): 835, np.int64(3): 818, np.int64(2): 794, np.int64(9): 793, np.int64(0): 790, np.int64(6): 789, np.int64(8): 780, np.int64(4): 779, np.int64(5): 723})
✅ Saved client 1 train/test sets
Client 2 test class counts: Counter({np.int64(1): 899, np.int64(7): 835, np.int64(3): 818, np.int64(2): 794, np.int64(9): 793, np.int64(0): 790, np.int64(6): 789, np.int64(8): 780, np.int64(4): 779, np.int64(5): 723})
✅ Saved client 2 train/test sets
Client 3 test class counts: Counter({np.int64(1): 899, np.int64(7): 835, np.int64(3): 818, np.int64(2): 794, np.int64(9): 793, np.int64(0): 790, np.int64(6): 789, np.int64(8): 780, np.int64(4): 779, np.int64(5): 723})
✅ Saved client 3 train/test sets
Client 4 test class counts: Counter({np.int64(1): 899, np.int64(7): 835, np.int64(3): 818, np.