In [1]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader, Subset
import torch
import numpy as np


class Mnist:
    """ Mnist dataset class.
    """
    def __init__(self, train_fraction):
        """ Constructor method.

            Parameters:
            train_fraction (float): Fraction of training data to use.
        """

        transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
        self.test_data = MNIST(
            root='data', 
            train=False, 
            transform=transform, 
            download=True)

        self.train_data = self._sample_train_data(train_fraction, transform)
        self.n_samples = len(self.train_data)

    def _sample_train_data(self, train_fraction, transform):
        """ Sample a chosen fraction of the training dataset to use.

            Parameters:
            train_fraction  (float): Fraction of training data to use.
            transform       (torchvision.transforms.Compose): Collection of transforms to apply on data.

            Returns:
            torch.utils.data.Subset: Subset of training data.
        """
        train_data = MNIST(
            root='data', 
            train=True, 
            transform=transform, 
            download=True)
        n_samples = len(train_data.targets)
        index_limit = int(train_fraction * n_samples)
        chosen_indices = np.random.choice(torch.arange(n_samples), size=index_limit, replace=False)
        print(f"\nUsing {index_limit} training samples\n", flush=True)

        return Subset(train_data, chosen_indices)

    def get_train_data_loaders(self, n_clients, distribution, alpha, batch_size):
        """ Get list of client training data loaders sampled from dirichlet distribution.

            Parameters:
            n_clients       (int): Number of clients.
            distribution    (str): iid/non-iid distributed data.
            alpha           (float): Concentration parameter for dirichlet distribution.
            batch_size      (int): Batch size for loading training data.

            Returns List[torch.utils.data.DataLoader]
        """
        labels = np.array([y for (_, y) in self.train_data])
        n_classes = len(np.unique(labels))
        partition_matrix = np.ones((n_classes, n_clients))

        # iid TODO: Make random slices.
        if distribution == "iid":
            partition_matrix /= n_clients
            size = int(np.floor(len(labels)/n_clients))
            client_data_loaders = []
            client_indices = np.random.choice(torch.arange(self.n_samples), size=(n_clients, size), replace=False)
            for i in range(n_clients):
                client_data_loaders.append(DataLoader(Subset(self.train_data, client_indices[i]), batch_size))
        # non-iid
        else:
            class_indices = []
            for i in range(10):
                class_indices.append(np.array(range(len(labels)))[labels == i])
            valid_pm = False
            while not valid_pm:
                partition_matrix = np.random.dirichlet((alpha, )*n_clients, n_classes)
                valid_pm = all(np.sum(partition_matrix, axis=0) > 0.01)
            #print("class_indices", class_indices)
            local_sets_indices = [[] for _ in range(n_clients)]
            for each_class in range(n_classes):
                sample_size = len(class_indices[each_class])
                for client in range(n_clients):
                    np.random.shuffle(class_indices[each_class])
                    local_size = int(np.floor(partition_matrix[each_class, client] * sample_size))
                    local_sets_indices[client] += list(class_indices[each_class][:local_size])
                    class_indices[each_class] = class_indices[each_class][local_size:]

            client_data_loaders = []
            print("local sets", local_sets_indices)
            for client_indices in local_sets_indices:
                np.random.shuffle(client_indices)
                client_data_loaders.append(DataLoader(Subset(self.train_data, client_indices), batch_size))

        return partition_matrix, client_data_loaders
    
    def get_test_data_loader(self, batch_size):
        """ Get test data loader.

            Parameters:
            batch_size      (int): Batch size for loading test data.
        """
        return DataLoader(self.test_data, batch_size)

In [49]:
dataset = Mnist(0.1)


Using 6000 training samples



In [50]:
partition_matrix, client_data_loaders = dataset.get_train_data_loaders(5, "niid", 0.5, 64)

local sets [[4752, 2148, 2938, 5146, 2070, 1354, 4351, 5308, 2897, 5657, 4055, 2418, 4635, 4070, 2587, 5235, 1762, 2484, 4975, 1200, 2329, 3265, 809, 793, 3722, 2300, 4586, 1751, 4412, 1867, 3000, 5020, 2008, 2725, 836, 2204, 3137, 1171, 5041, 3421, 4665, 891, 2541, 15, 1921, 3056, 2921, 654, 937, 1538, 942, 620, 494, 87, 3200, 396, 3633, 1789, 1717, 2284, 559, 5396, 3670, 5766, 5420, 663, 3024, 2961, 3591, 501, 5963, 5155, 1298, 4348, 2426, 5916, 2771, 1121, 232, 5097, 4047, 2508, 4804, 4265, 1501, 5123, 5435, 5874, 5521, 1959, 5879, 4567, 619, 2858, 4619, 4003, 4693, 3059, 5100, 962, 3823, 3476, 3686, 5389, 2200, 1685, 3337, 5507, 2206, 7, 4516, 2843, 4218, 4147, 4757, 2563, 303, 3289, 1817, 2975, 5239, 4381, 5047, 1192, 5151, 3339, 1257, 3595, 3597, 5156, 5145, 973, 1583, 4692, 5749, 1145, 668, 4043, 3055, 2912, 3522, 3713, 2354, 5433, 1532, 3040, 3160, 5150, 5405, 4451, 4849, 877, 2349, 1966, 454, 506, 4395, 2399, 5562, 2051, 340, 4732, 3422, 2599, 4303, 3348, 2207, 4133, 5589, 584

In [51]:
dataiter = iter(client_data_loaders[0])
images, labels = dataiter.next()
labels

tensor([9, 8, 7, 7, 2, 7, 8, 4, 5, 4, 2, 6, 9, 6, 2, 3, 1, 8, 8, 2, 4, 7, 5, 7,
        6, 2, 1, 1, 7, 6, 3, 4, 8, 9, 5, 6, 3, 8, 8, 6, 3, 2, 3, 8, 9, 3, 7, 6,
        2, 6, 6, 6, 8, 6, 4, 6, 2, 8, 8, 6, 8, 6, 3, 2])

In [46]:
labels = [1,5,7,0,1,3,5,7,1,3,5,6,2,4,8,9,5,3]
class_indices = []
class_indices.append(np.arange(len(labels))[labels==1])

In [47]:
labels == 1

False

In [2]:
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
class Cifar10_Cnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
model = Cifar10_Cnn()
weights = model.state_dict()
w0 = OrderedDict()
w1 = OrderedDict()
for name in weights.keys():
    w0[name] = torch.zeros(weights[name].shape)
    w1[name] = w0[name] + weights[name]*0.5
print(weights['conv1.weight'])
print(w1['conv1.weight'])

tensor([[[[ 0.0769, -0.0933,  0.0995, -0.0623,  0.0479],
          [ 0.1013,  0.0325, -0.0209,  0.0397, -0.0666],
          [-0.0200,  0.0374,  0.0962, -0.0965, -0.0690],
          [ 0.1149,  0.0072,  0.0460, -0.0353,  0.0835],
          [ 0.0973,  0.0898,  0.0231,  0.0134, -0.0685]],

         [[-0.0728, -0.0114,  0.0644,  0.1108, -0.0391],
          [ 0.0169,  0.0186,  0.0681,  0.0599,  0.0752],
          [-0.0526,  0.0656, -0.1116,  0.0346,  0.0353],
          [-0.0030, -0.0433,  0.0897, -0.0486,  0.0883],
          [-0.0404, -0.0723, -0.0764,  0.0584, -0.0129]],

         [[ 0.1076, -0.0255,  0.0807, -0.0447, -0.0900],
          [ 0.0081,  0.0747,  0.1066,  0.0553,  0.0371],
          [-0.0990, -0.0687, -0.1150, -0.0316,  0.0558],
          [-0.0724, -0.0372, -0.1031,  0.0054, -0.0729],
          [ 0.1033,  0.0786, -0.0414, -0.0488, -0.0164]]],


        [[[-0.0390,  0.0378,  0.0002, -0.0556, -0.0778],
          [-0.0125, -0.0175, -0.0474, -0.0375,  0.1058],
          [-0.1083, -0.