In [1]:
import numpy as np

def iid(images, labels, num_users):

    num_items = int(len(images)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(images))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


In [2]:
import torch
import cv2
import os
import skimage.measure

def get_dataset(folder, fileLists, labelLists, num, num_users):
    
    images = np.empty([1,128, 128, 1], dtype = int) 
    
    file_names = fileLists[0:num]
    for i in range(len(file_names)):
        read_image = cv2.imread(os.path.join(folder,file_names[i]), 0)
        read_image = skimage.measure.block_reduce(read_image, (8,8), np.mean)
        read_image = np.expand_dims(read_image,axis = (0,3))
        images = np.append(images, read_image, axis = 0)
        
    images = np.uint8(images)
    images = np.delete(images, (0), axis = 0)
    images = np.swapaxes(images, 2, 3)
    images = np.swapaxes(images, 1, 2)

    labels = labelLists[0:num]
    for i in range(len(labels)):
        if sum(labels[i][1::]) >= 1.0:
            labels[i][1] = 1.0
            labels[i][0] = 0.0
    labels = np.delete(labels, np.s_[2::], axis = 1)

    user_groups = iid(images, labels, num_users)

    return images, labels, user_groups



In [3]:
from torch import nn
import torch.nn.functional as F

class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 8, 8)
        self.fc1 = nn.Linear(2*8*8, 8)
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

In [4]:
def average_weights(w):

    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg



In [5]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class DatasetSplit(Dataset):
    def __init__(self, images, labels, idxs):
        self.imageset = images
        self.labelset = labels
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image = self.imageset[self.idxs[item]]
        label = self.labelset[self.idxs[item]]
        return torch.FloatTensor(image), torch.FloatTensor(label)

    
    
class LocalUpdate(object):
    def __init__(self, images, labels, idxs):
        self.trainloader, self.testloader = self.train_test(images, labels, list(idxs))
        self.device = 'cpu'
        self.criterion = nn.CrossEntropyLoss().to(self.device)
    
    def train_test(self, images, labels, idxs):
        idxs_train = idxs[:int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(images, labels, idxs_train),
                                 batch_size=int(len(idxs_train)/2), shuffle=True)
        testloader = DataLoader(DatasetSplit(images, labels, idxs_test),
                                batch_size=len(idxs_test), shuffle=False)
        return trainloader, testloader

    def update_weights(self, model, local_ep, lr, global_round):
        model.train()
        epoch_loss = []
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.5)

        for iter in range(local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, torch.max(labels, 1)[1])
                loss.backward()
                optimizer.step()
                print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0
        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            outputs = model(images)
            batch_loss = self.criterion(outputs, torch.max(labels, 1)[1])
            loss += batch_loss.item()
        return loss



In [6]:
images, labels, user_groups = get_dataset(folder = r'G:/images_001/', 
                                  fileLists = np.load('fileLists.npy').tolist(), 
                                  labelLists = np.load('labels_hotEnocded.npy'), 
                                  num = 300, 
                                  num_users = 5)


In [7]:
import copy
from tqdm import tqdm


device = 'cpu'
epochs = 5
frac = 0.4
num_users = 5
num_classes = 2
print_every = 2

global_model = CNN()
global_model.to(device)
global_model.train()
print(global_model)
global_weights = global_model.state_dict()


train_loss = []


for epoch in tqdm(range(epochs)):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n')

    global_model.train()
    m = max(int(frac * num_users), 1)
    idxs_users = np.random.choice(range(num_users), m, replace=False)

    for idx in idxs_users:
        local_model = LocalUpdate(images, labels, idxs=user_groups[idx])
        w, loss = local_model.update_weights(model=copy.deepcopy(global_model),
                                             local_ep=5, lr=0.01, global_round=epoch)
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_loss = []
        global_model.eval()
        for c in range(num_users):
            local_model = LocalUpdate(images, labels, idxs=user_groups[idx])
            loss = local_model.inference(model=global_model)
            list_loss.append(loss)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')



  0%|                                                                                            | 0/5 [00:00<?, ?it/s]

CNN(
  (conv1): Conv2d(1, 2, kernel_size=(8, 8), stride=(8, 8))
  (fc1): Linear(in_features=128, out_features=8, bias=True)
  (fc2): Linear(in_features=8, out_features=2, bias=True)
)

 | Global Training Round : 1 |


 20%|████████████████▊                                                                   | 1/5 [00:00<00:01,  3.57it/s]



 | Global Training Round : 2 |


 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00,  3.71it/s]


 
Avg Training Stats after 2 global rounds:
Training Loss : 0.4360292742649714
 
Avg Training Stats after 2 global rounds:
Training Loss : 0.41228181347250936

 | Global Training Round : 3 |


 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:00<00:00,  3.66it/s]



 | Global Training Round : 4 |

 
Avg Training Stats after 4 global rounds:

 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:01<00:00,  3.73it/s]


Training Loss : 0.3738128481166703
 
Avg Training Stats after 4 global rounds:
Training Loss : 0.3685587618499994

 | Global Training Round : 5 |


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.69it/s]





