In [1]:
def update_learning_rate(i, splitNN):
    lr_a = 0.3
    if i > 200:
        lr_a = 0.05
    if i > 400:
        lr_a = 0.01
    if i > 500:
        lr_a = 0.001
    optimizers = [
        (optim.SGD(models[location.id].parameters(), lr=lr_a,), location)
        for location in model_locations
    ]
    
    splitNN.set_lr(optimizers)

    return

PREPROCESSING_EPOCHS = 5
LEARNING_EPOCHS = 5
SUBSET_UPDATE_PROB = 1
PADDING_METHOD = "zeros"
LEARNING_RATE = 0.3
GROUP_TESTING_ROUNDS = 5
TO_BE_SELECTED = 2
TESTS = 100
CHANGE_PROBABILITY = 1
res = []

import sys
sys.path.append('../')

import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import syft as sy
import random
from time import process_time

from src.psi.util import Client, Server
from src.discrete_splitnn import DiscreteSplitNN
from src.utils import add_ids
from src.discrete_distribute_data import DiscreteDistributeMNIST

hook = sy.TorchHook(torch)

# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64
                                          , shuffle=True)

# create some workers
client_1 = sy.VirtualWorker(hook, id="client_1")
client_2 = sy.VirtualWorker(hook, id="client_2")
client_3 = sy.VirtualWorker(hook, id="client_3")
client_4 = sy.VirtualWorker(hook, id="client_4")
server = sy.VirtualWorker(hook, id= "server") 

data_owners = (client_1, client_2, client_3, client_4)
model_locations = [client_1, client_2, client_3, client_4, server]

#Split each image and send one part to client_1, and other to client_2
distributed_trainloader = DiscreteDistributeMNIST(data_owners=data_owners, data_loader=trainloader)

torch.manual_seed(0)

# Define our model segments

input_size= [28*7, 28*7, 28*7, 28*7]
hidden_sizes= {"client_1": [128, 64], "client_2":[128, 64], "client_3": [128, 64], "client_4":[128, 64], "server":[256, 128]}
output_size = 10

counters = {own: 0 for own in hidden_sizes}    
scores_lists = {own: [] for own in hidden_sizes}

for _ in range(TESTS):

    models = {
        "client_1": nn.Sequential(
                    nn.Linear(input_size[0], hidden_sizes["client_1"][0]),
                    nn.ReLU(),
                    nn.Linear(hidden_sizes["client_1"][0], hidden_sizes["client_1"][1]),
                    nn.ReLU(),
        ),
        "client_2":  nn.Sequential(
                    nn.Linear(input_size[1], hidden_sizes["client_2"][0]),
                    nn.ReLU(),
                    nn.Linear(hidden_sizes["client_2"][0], hidden_sizes["client_2"][1]),
                    nn.ReLU(),
        ),
        "client_3":  nn.Sequential(
                    nn.Linear(input_size[2], hidden_sizes["client_3"][0]),
                    nn.ReLU(),
                    nn.Linear(hidden_sizes["client_3"][0], hidden_sizes["client_3"][1]),
                    nn.ReLU(),
        ),
        "client_4":  nn.Sequential(
                    nn.Linear(input_size[3], hidden_sizes["client_4"][0]),
                    nn.ReLU(),
                    nn.Linear(hidden_sizes["client_4"][0], hidden_sizes["client_4"][1]),
                    nn.ReLU(),
        ),
        "server": nn.Sequential(
                    nn.Linear(hidden_sizes["server"][0], hidden_sizes["server"][1]),
                    nn.ReLU(),
                    nn.Linear(hidden_sizes["server"][1], 10),
                    nn.LogSoftmax(dim=1)
        )
    }

    # Create optimisers for each segment and link to them
    optimizers = [
        (optim.SGD(models[location.id].parameters(), lr=LEARNING_RATE,), location)
        for location in model_locations
    ]

    for location in model_locations:
        models[location.id].send(location)

    #Instantiate a SpliNN class with our distributed segments and their respective optimizers
    splitNN = DiscreteSplitNN(models, server, data_owners, optimizers, distributed_trainloader, k=10, n_selected=2, padding_method=PADDING_METHOD)
    distributed_trainloader.generate_subdata()

    test_perf = []
    performance = []

    splitNN.group_testing(GROUP_TESTING_ROUNDS)

    for i in range(PREPROCESSING_EPOCHS):
        running_loss = 0
        test_loss = 0
        if (random.random() < CHANGE_PROBABILITY):
            splitNN.group_testing(GROUP_TESTING_ROUNDS)

        #iterate over each datapoint 
        for _, data_ptr, label in distributed_trainloader.distributed_subdata:

            #send labels to server's location for training
            label = label.send(server)

            loss = splitNN.train(data_ptr, label)
            running_loss += loss


        performance.append((running_loss/len(distributed_trainloader.distributed_subdata)).item())
        
    losses = {}
    for own in splitNN.selected:
        losses[own] = []

    for i in range(LEARNING_EPOCHS):
        running_loss = 0
        test_loss = 0
        if (random.random() < CHANGE_PROBABILITY):
            splitNN.group_testing(GROUP_TESTING_ROUNDS)

        #iterate over each datapoint 
        for _, data_ptr, label in distributed_trainloader.distributed_subdata:

            #send labels to server's location for training
            label = label.send(server)

            loss = splitNN.train(data_ptr, label)
            running_loss += loss

        performance.append((running_loss/len(distributed_trainloader.distributed_subdata)).item())

        for own in splitNN.selected:
            if splitNN.selected[own]:
                losses[own].append( (running_loss/len(distributed_trainloader.distributed_subdata)).item() )


    scores = {own: sum(losses[own]) / len(losses[own]) if len(losses) > 0 else 0 for own in losses}
    for own in splitNN.selected:
        splitNN.selected[own] = False
        scores_lists[own].append(scores[own])

    for k in range(TO_BE_SELECTED):
        min_so_far = 0
        lo = float('inf')
        for i in scores:
            if scores[i] < lo:
                lo = scores[i]
                min_so_far = i

        splitNN.selected[min_so_far] = True
        scores.pop(min_so_far)
        counters[min_so_far] += 1    
    
    print(counters)

print('===FINAL RESULTS===')
for own in hidden_sizes:
    print(own)
    print('counter: ' + str(counters[own]))
    mean = sum(scores_lists[own]) / len(scores_lists[own])
    print('mean: ' + str(mean))
    sd = sum( [(i - mean) * (i - mean) for i in scores_lists[own]] ) / len(scores_lists[own])
    print('sd: ' + str(sd))
    print(scores_lists[own])
    print('min: ' + str(min(scores_lists[own])))
    print('max: ' + str(max(scores_lists[own])))
    




{'client_1': 0, 'client_2': 1, 'client_3': 1, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 2, 'client_3': 2, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 3, 'client_3': 3, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 4, 'client_3': 4, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 5, 'client_3': 5, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 6, 'client_3': 6, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 7, 'client_3': 7, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 8, 'client_3': 8, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 9, 'client_3': 9, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 10, 'client_3': 10, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 11, 'client_3': 11, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 12, 'client_3': 12, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 13, 'client_3': 13, 'client_4': 0, 'server': 0}
{'client_1': 0, 'client_2': 14