In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.optimizer import Optimizer
import random
from torch.utils.data.dataset import Subset
from more_itertools import chunked
import numpy as np
import math



def distribute_data_dirichlet(
    targets, non_iid_alpha, n_workers, seed=0, num_auxiliary_workers=10
):
    # we refer https://github.com/epfml/relaysgd/tree/89719198ba227ebbff9a6bf5b61cb9baada167fd
    """Code adapted from Tao Lin (partition_data.py)"""
    random_state = np.random.RandomState(seed=seed)

    num_indices = len(targets)
    num_classes = len(np.unique(targets))

    indices2targets = np.array(list(enumerate(targets)))
    random_state.shuffle(indices2targets)

    # partition indices.
    from_index = 0
    splitted_targets = []
    num_splits = math.ceil(n_workers / num_auxiliary_workers)
    split_n_workers = [
        num_auxiliary_workers
        if idx < num_splits - 1
        else n_workers - num_auxiliary_workers * (num_splits - 1)
        for idx in range(num_splits)
    ]
    split_ratios = [_n_workers / n_workers for _n_workers in split_n_workers]
    for idx, ratio in enumerate(split_ratios):
        to_index = from_index + int(num_auxiliary_workers / n_workers * num_indices)
        splitted_targets.append(
            indices2targets[
                from_index : (num_indices if idx == num_splits - 1 else to_index)
            ]
        )
        from_index = to_index

    idx_batch = []
    for _targets in splitted_targets:
        # rebuild _targets.
        _targets = np.array(_targets)
        _targets_size = len(_targets)

        # use auxi_workers for this subset targets.
        _n_workers = min(num_auxiliary_workers, n_workers)
        n_workers = n_workers - num_auxiliary_workers

        # get the corresponding idx_batch.
        min_size = 0
        while min_size < int(0.50 * _targets_size / _n_workers):
            _idx_batch = [[] for _ in range(_n_workers)]
            for _class in range(num_classes):
                # get the corresponding indices in the original 'targets' list.
                idx_class = np.where(_targets[:, 1] == _class)[0]
                idx_class = _targets[idx_class, 0]

                # sampling.
                try:
                    proportions = random_state.dirichlet(
                        np.repeat(non_iid_alpha, _n_workers)
                    )
                    # balance
                    proportions = np.array(
                        [
                            p * (len(idx_j) < _targets_size / _n_workers)
                            for p, idx_j in zip(proportions, _idx_batch)
                        ]
                    )
                    proportions = proportions / proportions.sum()
                    proportions = (np.cumsum(proportions) * len(idx_class)).astype(int)[
                        :-1
                    ]
                    _idx_batch = [
                        idx_j + idx.tolist()
                        for idx_j, idx in zip(
                            _idx_batch, np.split(idx_class, proportions)
                        )
                    ]
                    sizes = [len(idx_j) for idx_j in _idx_batch]
                    min_size = min([_size for _size in sizes])
                except ZeroDivisionError:
                    pass
        idx_batch += _idx_batch
    return idx_batch


## https://github.com/epfml/relaysgd/tree/89719198ba227ebbff9a6bf5b61cb9baada167fd
def dirichlet_split(
        dataset,
        num_workers: int,
        alpha: float = 1,
        seed: int = 0,
        distribute_evenly: bool = True,
    ):
        indices_per_worker = distribute_data_dirichlet(
            dataset.targets, alpha, num_workers, num_auxiliary_workers=10, seed=seed
        )

        if distribute_evenly:
            indices_per_worker = np.array_split(
                np.concatenate(indices_per_worker), num_workers
            )

        return [
            Subset(dataset, indices)
            for indices in indices_per_worker
        ]

    
def load_CIFAR10(n_node, alpha=1.0, batch=128, val_rate=0.2, seed=0):
    """
    node_label : 
        the list of labes that each node has. (example. [[0,1],[1,2],[0,2]] (n_node=3, n_class=2))
    """

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.RandomErasing(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_val_dataset = datasets.CIFAR10('../data',
                       train=True,
                       download=True,
                       transform=transform_train)
    
    test_dataset = datasets.CIFAR10('../data',
                       train=False,
                       transform=transform_test)


    # split datasets into n_node datasets by Dirichlet distribution. 
    train_val_subset_list = dirichlet_split(train_val_dataset, n_node, alpha, seed=seed)
        
    # the number of train datasets per node.
    n_data = min([len(train_val_subset_list[node_id]) for node_id in range(n_node)])
    n_train = int((1.0 - val_rate) * n_data)
    
    # choose validation datasets.
    val_dataset = None
    train_subset_list = []
    for node_id in range(n_node):
        n_val = len(train_val_subset_list[node_id]) - n_train
        a, b = torch.utils.data.random_split(train_val_subset_list[node_id], [n_train, n_val])
        train_subset_list.append(a)
        
        if val_dataset is None:
            val_dataset = b
        else:
            val_dataset += b
                  
    return {'train': train_subset_list, 'val': val_dataset, 'test': test_dataset}


def load_CIFAR100(n_node, alpha=1.0, batch=128, val_rate=0.2, seed=0):
    """
    node_label : 
        the list of labes that each node has. (example. [[0,1],[1,2],[0,2]] (n_node=3, n_class=2))
    """

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.RandomErasing(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_val_dataset = datasets.CIFAR100('../data',
                       train=True,
                       download=True,
                       transform=transform_train)
    
    test_dataset = datasets.CIFAR100('../data',
                       train=False,
                       transform=transform_test)

    # split datasets into n_node datasets by Dirichlet distribution. 
    train_val_subset_list = dirichlet_split(train_val_dataset, n_node, alpha, seed=seed)
        
    # the number of train datasets per node.
    n_data = min([len(train_val_subset_list[node_id]) for node_id in range(n_node)])
    n_train = int((1.0 - val_rate) * n_data)
    
    # choose validation datasets.
    val_dataset = None
    train_subset_list = []
    for node_id in range(n_node):
        n_val = len(train_val_subset_list[node_id]) - n_train
        a, b = torch.utils.data.random_split(train_val_subset_list[node_id], [n_train, n_val])
        train_subset_list.append(a)
        
        if val_dataset is None:
            val_dataset = b
        else:
            val_dataset += b
                  
    return {'train': train_subset_list, 'val': val_dataset, 'test': test_dataset}


def load_FMNIST(n_node, alpha=1.0, val_rate=0.2, seed=0):
    """
    node_label : 
        the list of labes that each node has. (example. [[0,1],[1,2],[0,2]] (n_node=3, n_class=2))
    """

    train_val_dataset = datasets.FashionMNIST('../data',
                       train=True,
                       download=True,
                       transform=transforms.Compose([
                           transforms.RandomCrop(28, padding=4),
                           transforms.ToTensor()
                       ]))

    test_dataset = datasets.FashionMNIST('../data',
                       train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ]))
    
    # split datasets into n_node datasets by Dirichlet distribution. 
    train_val_subset_list = dirichlet_split(train_val_dataset, n_node, alpha, seed=seed)
        
    # the number of train datasets per node.
    n_data = min([len(train_val_subset_list[node_id]) for node_id in range(n_node)])
    n_train = int((1.0 - val_rate) * n_data)
    
    # choose validation datasets.
    val_dataset = None
    train_subset_list = []
    for node_id in range(n_node):
        n_val = len(train_val_subset_list[node_id]) - n_train
        a, b = torch.utils.data.random_split(train_val_subset_list[node_id], [n_train, n_val])
        train_subset_list.append(a)
        
        if val_dataset is None:
            val_dataset = b
        else:
            val_dataset += b
                  
    return {'train': train_subset_list, 'val': val_dataset, 'test': test_dataset}
                                

def datasets_to_loaders(datasets, batch_size=128):
    """
    datasets dict:
    """
    train_loader = torch.utils.data.DataLoader(
        datasets["train"],
        batch_size=batch_size,
        shuffle=True, num_workers=2, pin_memory=False)

    all_train_loader = torch.utils.data.DataLoader(
        datasets["all_train"],
        batch_size=batch_size,
        shuffle=True, num_workers=2, pin_memory=False)

    
    val_loader = torch.utils.data.DataLoader(
        datasets["val"],
        batch_size=batch_size,
        shuffle=False, num_workers=2, pin_memory=False)

    test_loader = torch.utils.data.DataLoader(
        datasets["test"],
        batch_size=batch_size,
        shuffle=False, num_workers=2, pin_memory=False)

    return {"train": train_loader, "val": val_loader, "all_train": all_train_loader, "test": test_loader}


In [10]:
n_nodes = 30
my_datasets = load_FMNIST(n_nodes, alpha=0.1, val_rate=0.1, seed=0)

count = 0

for key in my_datasets:
    if key == "train":
        for node_id in range(n_nodes):
            count += len(my_datasets[key][node_id])
    if key == "val":
        count += len(my_datasets[key])
print(count)

for node_id in range(n_nodes):
    counter = [0 for _ in range(10)]
    
    for data, label in my_datasets["train"][node_id]:
        counter[label] += 1
    print(counter)
    
print("validation")
counter = [0 for _ in range(10)]
    
for data, label in my_datasets["val"]:
    counter[label] += 1
print(counter)

60000
[180, 643, 14, 0, 963, 0, 0, 0, 0, 0]
[56, 131, 78, 0, 744, 791, 0, 0, 0, 0]
[355, 559, 103, 0, 0, 783, 0, 0, 0, 0]
[0, 1, 44, 1, 2, 4, 1098, 200, 450, 0]
[0, 0, 21, 558, 0, 0, 0, 0, 1221, 0]
[0, 171, 14, 603, 14, 0, 668, 327, 2, 1]
[0, 18, 936, 0, 0, 0, 0, 846, 0, 0]
[1, 3, 552, 417, 0, 225, 0, 80, 160, 362]
[169, 0, 35, 48, 82, 0, 1, 1, 0, 1464]
[1016, 284, 32, 165, 1, 4, 1, 296, 1, 0]
[1174, 265, 224, 4, 0, 0, 0, 50, 0, 83]
[151, 19, 64, 586, 101, 0, 0, 407, 472, 0]
[0, 0, 0, 290, 0, 290, 1220, 0, 0, 0]
[42, 39, 2, 917, 3, 0, 33, 0, 763, 1]
[0, 318, 0, 0, 39, 2, 501, 940, 0, 0]
[50, 648, 66, 6, 15, 0, 0, 312, 0, 703]
[0, 296, 0, 35, 1469, 0, 0, 0, 0, 0]
[0, 0, 23, 1, 178, 1434, 3, 0, 0, 161]
[247, 0, 0, 0, 0, 49, 1, 136, 487, 880]
[150, 224, 1426, 0, 0, 0, 0, 0, 0, 0]
[141, 5, 26, 0, 0, 1503, 0, 122, 3, 0]
[144, 13, 0, 0, 0, 0, 0, 13, 1630, 0]
[105, 350, 2, 168, 0, 0, 0, 849, 149, 177]
[285, 0, 0, 1515, 0, 0, 0, 0, 0, 0]
[653, 806, 1, 0, 0, 166, 3, 157, 14, 0]
[0, 405, 1184, 1

In [8]:
n_nodes = 3
my_datasets = load_CIFAR10(n_nodes, alpha=0.1, val_rate=0.1, seed=0)

count = 0

for key in my_datasets:
    if key == "train":
        for node_id in range(n_nodes):
            count += len(my_datasets[key][node_id])
    if key == "val":
        count += len(my_datasets[key])
print(count)

for node_id in range(n_nodes):
    counter = [0 for _ in range(10)]
    
    for data, label in my_datasets["train"][node_id]:
        counter[label] += 1
    print(counter)
    
print("validation")
counter = [0 for _ in range(10)]
    
for data, label in my_datasets["val"]:
    counter[label] += 1
print(counter)

Files already downloaded and verified
50000
[4477, 2507, 2807, 2, 486, 68, 136, 0, 0, 4516]
[1, 1516, 1068, 2344, 1, 4439, 1169, 4461, 0, 0]
[0, 475, 606, 2138, 4011, 1, 3183, 73, 4512, 0]
validation
[522, 502, 519, 516, 502, 492, 512, 466, 488, 484]


In [7]:
n_nodes = 3
my_datasets = load_CIFAR100(n_nodes, alpha=0.1, val_rate=0.1, seed=0)

count = 0

for key in my_datasets:
    if key == "train":
        for node_id in range(n_nodes):
            count += len(my_datasets[key][node_id])
    if key == "val":
        count += len(my_datasets[key])
print(count)

for node_id in range(n_nodes):
    counter = [0 for _ in range(100)]
    
    for data, label in my_datasets["train"][node_id]:
        counter[label] += 1
    print(counter)

Files already downloaded and verified
50000
[445, 0, 274, 0, 49, 6, 12, 0, 0, 448, 447, 395, 288, 76, 0, 259, 24, 0, 122, 38, 0, 0, 1, 450, 1, 0, 0, 425, 0, 0, 428, 456, 0, 293, 37, 46, 13, 0, 384, 15, 0, 0, 5, 438, 173, 0, 6, 196, 455, 0, 0, 196, 445, 448, 0, 437, 0, 0, 416, 3, 0, 0, 451, 342, 0, 353, 220, 0, 0, 0, 187, 0, 0, 360, 92, 0, 429, 0, 369, 208, 397, 0, 5, 83, 206, 3, 37, 0, 402, 20, 0, 0, 328, 0, 283, 378, 452, 456, 0, 288]
[1, 290, 111, 229, 0, 456, 122, 443, 461, 0, 5, 1, 148, 138, 105, 107, 0, 13, 171, 0, 0, 428, 445, 6, 451, 1, 0, 25, 0, 306, 15, 0, 454, 0, 412, 410, 432, 0, 61, 424, 434, 456, 0, 0, 283, 457, 202, 0, 0, 392, 448, 0, 0, 0, 0, 0, 0, 448, 0, 0, 220, 0, 0, 96, 0, 68, 0, 0, 110, 400, 266, 8, 0, 93, 358, 47, 0, 411, 63, 2, 50, 262, 348, 13, 230, 1, 416, 2, 53, 93, 176, 437, 110, 0, 176, 80, 1, 1, 447, 170]
[0, 155, 65, 215, 393, 1, 317, 6, 1, 1, 0, 48, 1, 236, 343, 71, 424, 430, 157, 413, 467, 21, 2, 1, 0, 450, 443, 1, 442, 142, 1, 1, 1, 157, 1, 0, 1, 450, 1,