In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from torch.utils.data import DataLoader, Subset
from enum import Enum
from dataclasses import dataclass
from typing import Union


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.3 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/pabramov/opt/anaconda3/envs/fc/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/pabramov/opt/anaconda3/envs/fc/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/pabramov/opt/anaconda3/envs/fc/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/pabramov/opt/anaconda3/envs/fc/lib/python3.10/site-packages/traitlets/config/application.

In [3]:
# Base class for type hinting
class BaseLoader:
    pass

@dataclass
class IIDLoader(BaseLoader):
    pass

@dataclass
class NonIIDLoader(BaseLoader):
    # number of classes per client
    per_client: int
        
@dataclass
class SortedByBrightnessLoader(BaseLoader):
    pass
        
LoaderType = Union[IIDLoader, NonIIDLoader, SortedByBrightnessLoader]
        
def get_cifar10_loaders(t: LoaderType, num_clients=10, batch_size=32, transform=None):
    if transform is None:
        transform = transforms.Compose([transforms.ToTensor()])
    
    trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
    
    train_data, train_labels = trainset.data, np.array(trainset.targets)
    test_data, test_labels = testset.data, np.array(testset.targets)
    
    indices = np.arange(len(train_labels))
    np.random.shuffle(indices)
    split_sizes = [len(indices) // num_clients] * num_clients
    split_sizes[-1] += len(indices) % num_clients  # Handle remainder
    client_splits = np.split(indices, np.cumsum(split_sizes)[:-1])
    for i, split in enumerate(client_splits):
        client_data_indices[i] = split.tolist()
            
    client_data_indices = {i: [] for i in range(num_clients)}
    
    match t:
        # Normal IID case
        case IIDLoader():
            indices = np.arange(len(train_labels))
            np.random.shuffle(indices)
            split_sizes = [len(indices) // num_clients] * num_clients
            split_sizes[-1] += len(indices) % num_clients  # Handle remainder
            client_splits = np.split(indices, np.cumsum(split_sizes)[:-1])
            for i, split in enumerate(client_splits):
                client_data_indices[i] = split.tolist()                
        
        # Class-based Non-IID: Assign 2-3 classes per client
        case NonIIDLoader(per_client):
            class_indices = {i: np.where(train_labels == i)[0] for i in range(10)}
            class_per_client = [random.sample(range(10), k=per_client) for _ in range(num_clients)]
            
            for i, classes in enumerate(class_per_client):
                for cls in classes:
                    client_data_indices[i].extend(np.random.choice(class_indices[cls], size=len(class_indices[cls])))
                                                                   
        # Feature-based (domain shift) Non-IID: Sort by brightness level
        case SortedByBrightnessLoader():
            brightness_levels = np.mean(train_data, axis=(1, 2, 3))  # Compute mean brightness per image
            sorted_indices = np.argsort(brightness_levels)
            split_sizes = [len(sorted_indices) // num_clients] * num_clients
            split_sizes[-1] += len(sorted_indices) % num_clients  # Handle remainder
            client_splits = np.split(sorted_indices, np.cumsum(split_sizes)[:-1])
            for i, split in enumerate(client_splits):
                client_data_indices[i] = split.tolist()
                                                                   
    client_loaders = {}
    for i in range(num_clients):
        subset = Subset(trainset, client_data_indices[i])
        client_loaders[i] = DataLoader(subset, batch_size=batch_size, shuffle=True)
    
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return client_loaders, test_loader