In [9]:
import sys
sys.path.append('../')

%matplotlib inline
from data.loader_dirichlet import *

In [10]:


def distribute_data_dirichlet(
    targets, non_iid_alpha, n_workers, seed=0, num_auxiliary_workers=10
):
    # we refer https://github.com/epfml/relaysgd/tree/89719198ba227ebbff9a6bf5b61cb9baada167fd
    """Code adapted from Tao Lin (partition_data.py)"""
    random_state = np.random.RandomState(seed=seed)

    num_indices = len(targets)
    num_classes = len(np.unique(targets))

    indices2targets = np.array(list(enumerate(targets)))
    random_state.shuffle(indices2targets)

    # partition indices.
    from_index = 0
    splitted_targets = []
    num_splits = math.ceil(n_workers / num_auxiliary_workers)
    split_n_workers = [
        num_auxiliary_workers
        if idx < num_splits - 1
        else n_workers - num_auxiliary_workers * (num_splits - 1)
        for idx in range(num_splits)
    ]
    split_ratios = [_n_workers / n_workers for _n_workers in split_n_workers]
    for idx, ratio in enumerate(split_ratios):
        to_index = from_index + int(num_auxiliary_workers / n_workers * num_indices)
        splitted_targets.append(
            indices2targets[
                from_index : (num_indices if idx == num_splits - 1 else to_index)
            ]
        )
        from_index = to_index

    idx_batch = []
    for _targets in splitted_targets:
        # rebuild _targets.
        _targets = np.array(_targets)
        _targets_size = len(_targets)

        # use auxi_workers for this subset targets.
        _n_workers = min(num_auxiliary_workers, n_workers)
        n_workers = n_workers - num_auxiliary_workers

        # get the corresponding idx_batch.
        min_size = 0
        while min_size < int(0.50 * _targets_size / _n_workers):
            _idx_batch = [[] for _ in range(_n_workers)]
            for _class in range(num_classes):
                # get the corresponding indices in the original 'targets' list.
                idx_class = np.where(_targets[:, 1] == _class)[0]
                idx_class = _targets[idx_class, 0]

                # sampling.
                try:
                    proportions = random_state.dirichlet(
                        np.repeat(non_iid_alpha, _n_workers)
                    )
                    # balance
                    proportions = np.array(
                        [
                            p * (len(idx_j) < _targets_size / _n_workers)
                            for p, idx_j in zip(proportions, _idx_batch)
                        ]
                    )
                    proportions = proportions / proportions.sum()
                    proportions = (np.cumsum(proportions) * len(idx_class)).astype(int)[
                        :-1
                    ]
                    _idx_batch = [
                        idx_j + idx.tolist()
                        for idx_j, idx in zip(
                            _idx_batch, np.split(idx_class, proportions)
                        )
                    ]
                    sizes = [len(idx_j) for idx_j in _idx_batch]
                    min_size = min([_size for _size in sizes])
                except ZeroDivisionError:
                    pass
        idx_batch += _idx_batch
    return idx_batch


## https://github.com/epfml/relaysgd/tree/89719198ba227ebbff9a6bf5b61cb9baada167fd
def dirichlet_split(
        dataset,
        num_workers: int,
        alpha: float = 1,
        seed: int = 0,
        distribute_evenly: bool = True,
    ):
        indices_per_worker = distribute_data_dirichlet(
            dataset.targets, alpha, num_workers, num_auxiliary_workers=10, seed=seed
        )

        if distribute_evenly:
            indices_per_worker = np.array_split(
                np.concatenate(indices_per_worker), num_workers
            )

        return [
            Subset(dataset, indices)
            for indices in indices_per_worker
        ]

def datasets_to_loaders(datasets, batch_size=128):
    """
    datasets dict:
    """
    train_loader = torch.utils.data.DataLoader(
        datasets["train"],
        batch_size=batch_size,
        shuffle=True, num_workers=2, pin_memory=False, persistent_workers=True)

    all_train_loader = torch.utils.data.DataLoader(
        datasets["all_train"],
        batch_size=batch_size,
        shuffle=True, num_workers=2, pin_memory=False)

    
    val_loader = torch.utils.data.DataLoader(
        datasets["val"],
        batch_size=batch_size,
        shuffle=False, num_workers=2, pin_memory=False)

    test_loader = torch.utils.data.DataLoader(
        datasets["test"],
        batch_size=batch_size,
        shuffle=False, num_workers=2, pin_memory=False)

    return {"train": train_loader, "val": val_loader, "all_train": all_train_loader, "test": test_loader}


In [93]:
def load_FMNIST(n_node, alpha=1.0, val_rate=0.2, seed=0):
    """
    node_label : 
        the list of labes that each node has. (example. [[0,1],[1,2],[0,2]] (n_node=3, n_class=2))
    """

    train_val_dataset = datasets.FashionMNIST('../data',
                       train=True,
                       download=True,
                       transform=transforms.Compose([
                           transforms.RandomCrop(28, padding=4),
                           transforms.ToTensor()
                       ]))

    test_dataset = datasets.FashionMNIST('../data',
                       train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ]))
    
    print(f"train_val: {len(train_val_dataset)}, test: {len(test_dataset)}")
    
    # split datasets into n_node datasets by Dirichlet distribution. 
    train_val_subset_list = dirichlet_split(train_val_dataset, n_node, alpha, seed=seed)
        
    # the number of train datasets per node.
    n_data = min([len(train_val_subset_list[node_id]) for node_id in range(n_node)])
    n_train = int((1.0 - val_rate) * n_data)
    
    # choose validation datasets.
    val_dataset = None
    train_subset_list = []
    for node_id in range(n_node):
        n_val = len(train_val_subset_list[node_id]) - n_train
        a, b = torch.utils.data.random_split(train_val_subset_list[node_id], [n_train, n_val])
        train_subset_list.append(a)
        
        if val_dataset is None:
            val_dataset = b
        else:
            val_dataset += b
                  
    return {'train': train_subset_list, 'val': val_dataset, 'test': test_dataset}


In [98]:
n_nodes = 3
my_datasets = load_FMNIST(n_nodes, alpha=0.1, val_rate=0.1, seed=0)

count = 0

for key in my_datasets:
    if key == "train":
        for node_id in range(n_nodes):
            count += len(my_datasets[key][node_id])
    if key == "val":
        count += len(my_datasets[key])
print(count)

train_val: 60000, test: 10000
60000


In [99]:
for node_id in range(n_nodes):
    counter = [0 for _ in range(10)]
    
    for data, label in my_datasets["train"][node_id]:
        counter[label] += 1
    print(counter)

[5161, 0, 5394, 0, 5375, 776, 8, 164, 1122, 0]
[250, 5365, 7, 5413, 1, 4637, 887, 8, 864, 568]
[0, 0, 0, 0, 0, 0, 4523, 5223, 3430, 4824]


In [101]:
counter = [0 for _ in range(10)]
    
for data, label in my_datasets["val"]:
    counter[label] += 1
print(counter)

[589, 635, 599, 587, 624, 587, 582, 605, 584, 608]


In [20]:
n_train = int( 0.9 * len(my_datasets["train"][0]))
n_val = len(my_datasets["train"][0]) - int(0.9 * len(my_datasets["train"][0])) 


a, b = torch.utils.data.random_split(my_datasets["train"][0], [n_train, n_val])

In [26]:
loader = torch.utils.data.DataLoader(
        a+b,
        batch_size=1,
        shuffle=True, num_workers=2, pin_memory=False)


In [32]:
counter = [0 for i in range(10)]
for data, label in a:
    counter[label] += 1
print(counter)

[3804, 3, 3393, 69, 445, 0, 0, 0, 0, 0]


In [33]:
counter = [0 for i in range(10)]
for data, label in loader:
    counter[label.item()] += 1
print(counter)

[4245, 3, 3760, 74, 490, 0, 0, 0, 0, 0]


In [35]:
print(data.shape)

torch.Size([1, 1, 28, 28])


In [5]:

for i in range(7):
    label_count = [0 for _ in range(10)]

    for data, label in my_datasets["train"][i]:
        label_count[label] += 1

    print(label_count)

[4245, 3, 3760, 74, 490, 0, 0, 0, 0, 0]
[543, 50, 383, 17, 423, 6, 172, 1433, 10, 5535]
[0, 5, 1351, 1815, 0, 3840, 0, 1561, 0, 0]
[807, 0, 0, 2155, 938, 2042, 212, 2417, 0, 0]
[404, 3353, 0, 0, 3815, 0, 533, 1, 0, 465]
[1, 2589, 506, 393, 0, 0, 5082, 0, 0, 0]
[0, 0, 0, 1546, 334, 112, 1, 588, 5990, 0]


In [6]:
train_val_dataset = datasets.FashionMNIST('../data',
                       train=True,
                       download=True,
                       transform=transforms.Compose([
                           transforms.RandomCrop(28, padding=4),
                           transforms.ToTensor()
                       ]))

In [8]:
train_val_dataset.data

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [6]:
my_datasets = load_FMNIST(7, batch=100, alpha=0.1, val_rate=0.1, seed=0)

In [8]:

for i in range(7):
    label_count = [0 for _ in range(10)]

    for data, label in my_datasets["train"][i]:
        label_count[label] += 1

    print(label_count)

[4245, 3, 3760, 74, 490, 0, 0, 0, 0, 0]
[543, 50, 383, 17, 423, 6, 172, 1433, 10, 5535]
[0, 5, 1351, 1815, 0, 3840, 0, 1561, 0, 0]
[807, 0, 0, 2155, 938, 2042, 212, 2417, 0, 0]
[404, 3353, 0, 0, 3815, 0, 533, 1, 0, 465]
[1, 2589, 506, 393, 0, 0, 5082, 0, 0, 0]
[0, 0, 0, 1546, 334, 112, 1, 588, 5990, 0]
