In [10]:
import os
import random
import numpy as np
import torch
from torchvision import datasets, transforms

# ---- CONFIG ----
ROOT = "/Users/ishabhansali/Downloads/resnet_project/cleaned"  # Adjust if your notebook is not in RESNET_PROJECT root
TRAIN_DIR = os.path.join(ROOT, "Training")
TEST_DIR = os.path.join(ROOT, "Testing")
NUM_CLIENTS = 4
LOCAL_TEST_FRAC = 0.2
SEED = 77

# ---- REPRODUCIBILITY ----
def set_seed(s=SEED):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed()

# ---- PATH CHECKS ----
if not os.path.exists(ROOT):
    raise FileNotFoundError(f"ROOT folder not found: {ROOT}")
if not os.path.exists(TRAIN_DIR):
    raise FileNotFoundError(f"Training directory not found: {TRAIN_DIR}")
if not os.path.exists(TEST_DIR):
    raise FileNotFoundError(f"Testing directory not found: {TEST_DIR}")

print("Training classes:", os.listdir(TRAIN_DIR))
print("Testing classes:", os.listdir(TEST_DIR))

# ---- TRANSFORM AND LOAD ----
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform)
print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# ---- IID PARTITION ----
num_items = len(train_dataset)
client_lengths = [num_items // NUM_CLIENTS] * NUM_CLIENTS
client_lengths[0] += num_items - sum(client_lengths)  # Add remainder to first client

from torch.utils.data import random_split

client_subsets = random_split(train_dataset, client_lengths, generator=torch.Generator().manual_seed(SEED))
print("Client samples:", [len(subset) for subset in client_subsets])

# ---- TEST SPLIT FOR EACH CLIENT ----
client_data = []
for c in client_subsets:
    n_test = int(len(c) * LOCAL_TEST_FRAC)
    n_train = len(c) - n_test
    tr, te = random_split(c, [n_train, n_test], generator=torch.Generator().manual_seed(SEED))
    client_data.append((tr, te))
    print(f"Client train: {n_train}, test: {n_test}")

# ---- SAMPLE ACCESS ----
image, label = client_data[0][0][0]  # First sample from client 0 train set
print("First client, first train label:", label)


Training classes: ['pituitary', 'notumor', 'glioma', 'meningioma']
Testing classes: ['pituitary', 'notumor', 'glioma', 'meningioma']
Train samples: 5712, Test samples: 1311
Client samples: [1428, 1428, 1428, 1428]
Client train: 1143, test: 285
Client train: 1143, test: 285
Client train: 1143, test: 285
Client train: 1143, test: 285
First client, first train label: 0


In [11]:
# Get class names from the underlying ImageFolder
class_names = train_dataset.classes

# Count samples per class for each client
for idx, (client_train, _) in enumerate(client_data):
    # Extract all labels for this client's training subset
    labels = [client_train[i][1] for i in range(len(client_train))]
    # Count occurrences of each label
    counts = {class_name: labels.count(i) for i, class_name in enumerate(class_names)}
    print(f"Client {idx}:")
    for class_name in class_names:
        print(f"  {class_name}: {counts[class_name]} samples")
    print("-" * 30)


Client 0:
  glioma: 243 samples
  meningioma: 294 samples
  notumor: 310 samples
  pituitary: 296 samples
------------------------------
Client 1:
  glioma: 303 samples
  meningioma: 242 samples
  notumor: 309 samples
  pituitary: 289 samples
------------------------------
Client 2:
  glioma: 259 samples
  meningioma: 265 samples
  notumor: 310 samples
  pituitary: 309 samples
------------------------------
Client 3:
  glioma: 253 samples
  meningioma: 260 samples
  notumor: 344 samples
  pituitary: 286 samples
------------------------------
