In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [4]:
def mnistIID(dataset, num_users):
    """
    Function that prepares IID data corresponding to each of the users.
    """
    num_images = int(len(dataset) / num_users)
    users_dict, indices = {}, list(range(len(dataset)))
    for i in range(num_users):
        np.random.seed(i)
        users_dict[i] = set(numpy.random.choice(indices, num_images, replace=False))
        indices = list(set(indices) - users_dict[i]) # prevents using same images by two different users
    return users_dict  
    

In [6]:
def load_datasets(num_users):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform, download=True)
    train_group = mnistIID(train_dataset, num_users)
    test_group = mnistIID(test_dataset, num_users)
    return train_dataset, test_dataset, train_group, test_group

In [7]:
class FedDataset(Dataset):
    def __init__(self, dataset, index):
        self.dataset = dataset
        self.index = [int(i) for i in index]
        
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, item):
        image, label = self.dataset[self.index[item]]
        return torch.Tensor(images), torch.Tensor(label)

In [9]:
def get_actual_images(dataset, index, batch_size):
    return DataLoader(FedDataset(dataset, index), batch_size=batch_size, shuffle=True)