# Import necessary libraries

In [1]:
# General
import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
import time
import copy
import pandas as pd


# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms


# PySyft
import syft as sy
from syft.frameworks.torch.fl import utils
from syft.workers.websocket_client import WebsocketClientWorker

# Pre-processing the Data

In [2]:
# Set the image size Y where Y represents YxY 
IMG_SIZE = 64
BATCH_SIZE = 100
LR = 0.001

In [3]:
train = datasets.MNIST(r"/media/wilfredo/Willie931GB/EURECOM_SLU_Linux/II_SEMESTER/SLU/PAPER_KDD2022/EXPERIMENTS/PySyft/Datasets/MNIST_64", 
                      train = True, download = True, 
                      transform = transforms.Compose([transforms.Resize(IMG_SIZE),
                                                      transforms.ToTensor()]))

test = datasets.MNIST(r"/media/wilfredo/Willie931GB/EURECOM_SLU_Linux/II_SEMESTER/SLU/PAPER_KDD2022/EXPERIMENTS/PySyft/Datasets/MNIST_64", 
                      train = False, download = True, 
                      transform = transforms.Compose([transforms.Resize(IMG_SIZE),
                                                      transforms.ToTensor()]))

In [4]:
# Load the data from the file it was saved in. Take the ENTIRE dataset!
training_data = torch.utils.data.DataLoader(train, batch_size = int(len(train)/2), shuffle = True)
test_data = torch.utils.data.DataLoader(test, batch_size = int(len(test)/2), shuffle = True)

# Create the CNN (based on VGG11)
Source: Page 3/14, Table 1, Configuration A, https://arxiv.org/pdf/1409.1556.pdf

## Individual Client Models

In [5]:
class Net_client(nn.Module):
    def __init__(self):
        super().__init__()
        # Define your first convolutional layer: input = 1, output = 32 convolutional features, kernel size = 5
        # Remember that kernel = 5 means that the "window" used to scan for features will be 5x5
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)

    # Function defining only one part of the forward pass (the convolution layers only). This will also write
    # the output dimensions of the conv layers to self._to_linear ONCE, and this information will then be used 
    # as the input data flattened dimensions of the next fully connected layers 
    def convs(self, x):
        # Convolutional layer 1 + activation + max_pooling
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        return x
    
    # Function defining the rest of the forward pass
    def forward(self, x):
        # Run the convs layers first
        x = self.convs(x)
        return x

net_client = Net_client()

In [6]:
class Net_edge(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Start from the third convolutional layer
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.conv4 = nn.Conv2d(128, 256, 3)
        
        # Run the fully connected layers. We know the input of this fc1 layer is 512, because of our previous
        # results with FL, where self.__to__linear told us this result when you run the cell that contains the 
        # NN
        self._to_linear = 1024
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 2)

    # Function defining only one part of the forward pass (the convolution layers only). This will also write
    # the output dimensions of the conv layers to self._to_linear ONCE, and this information will then be used 
    # as the input data flattened dimensions of the next fully connected layers 
    def convs(self, x):
        # Convolutional layer 1 + activation + max_pooling
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        
        if self._to_linear is None:
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x
    
    # Function defining the rest of the forward pass
    def forward(self, x):
        # Run the convs layers first
        x = self.convs(x)
        # Reshape the output data from the convs to be flattened
        x = x.view(-1, self._to_linear)
        # Pass the data through the fully connected layers now
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # Pass it through the final layer
        x = self.fc3(x)
        # One final softmax function to make the output vector look nicer
        x = F.softmax(x, dim = 1)
        return x

net_edge = Net_edge()

In [7]:
# Take a look at our models
model_client = net_client
model_edge = net_edge

In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [9]:
model_client

Net_client(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
)

In [10]:
model_edge

Net_edge(
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2, bias=True)
)

In [11]:
total_parameters = (count_parameters(model_client) 
                    + count_parameters(model_edge))
print(total_parameters)

1044482


# Establish your loss function

In [12]:
# Set your loss function (MSE for images!)
loss_function = nn.MSELoss()

# Separate your data into data, labels, training, testing, and scale it

In [13]:
# Take the data loaded onto training_data. You NEED to iterate over it to take it, even if you
# want to take the entire thing. Make sure to convert the values to floats
X = next(iter(training_data))[0]
y_unformatted = next(iter(training_data))[1].type(torch.FloatTensor)
X_test = next(iter(test_data))[0]
y_test_unformatted = next(iter(test_data))[1].type(torch.FloatTensor)

# The two other cases in this paper use 2 dimensional labels (0, 1), not only (0)
# MNIST by default comes with labels in the format (9) instead of (9, 0). To change this:
# Create tensors with all zeros of the same size
y_unformatted_addition = torch.zeros(y_unformatted.size())
y_test_unformatted_addition = torch.zeros(y_test_unformatted.size())
# Then stack them together (0 for vertically, -1 for horizontally)
y = torch.stack((y_unformatted, y_unformatted_addition), -1)
y_test = torch.stack((y_test_unformatted, y_test_unformatted_addition), -1)


In [14]:
# Define your training data
# train_X = X[:-val_size]
# train_y = y[:-val_size]
train_X = X
train_y = y

# Define your testing (validation) data
# test_X = X[-val_size:]
# test_y = y[-val_size:]
test_X = X_test
test_y = y_test

# Federated Split Learning

## Establish the virtual workers, their data, their NNs, and their optimizers

In [15]:
# Start the hook
hook = sy.TorchHook(torch)

# Create your virtual workers and our server
client1 = sy.VirtualWorker(hook, id="client1")
client2 = sy.VirtualWorker(hook, id="client2")
client3 = sy.VirtualWorker(hook, id="client3")
edge1 = sy.VirtualWorker(hook, id="edge1")
edge2 = sy.VirtualWorker(hook, id="edge2")
edge3 = sy.VirtualWorker(hook, id="edge3")

# Put the WORKERS into a list for easier access later on
clients = [client1, client2, client3]
edges = [edge1, edge2, edge3]

In [16]:
# Split the training data for each worker
# General method is:
# train_X_workerN = train_X[int((N-1) * len(train_X)/len(compute_nodes)):int(N * len(train_X)/len(compute_nodes))].view(-1, 1, IMG_SIZE, IMG_SIZE)
train_X_client1 = train_X[:int(len(train_X)/len(clients))].view(-1, 1, IMG_SIZE, IMG_SIZE)
train_X_client2 = train_X[int(len(train_X)/len(clients)):int(2 * len(train_X)/len(clients))].view(-1, 1, IMG_SIZE, IMG_SIZE)
train_X_client3 = train_X[int(2 * len(train_X)/len(clients)):int(3 * len(train_X)/len(clients))].view(-1, 1, IMG_SIZE, IMG_SIZE)

train_y_edge1 = train_y[:int(len(train_X)/len(clients))]
train_y_edge2 = train_y[int(len(train_X)/len(clients)):int(2 * len(train_X)/len(clients))]
train_y_edge3 = train_y[int(2 * len(train_X)/len(clients)):int(3 * len(train_X)/len(clients))]

In [17]:
# Clear the workers of any objects, just in case you forgot some were still there from a previous run
client1.clear_objects()
client2.clear_objects()
client3.clear_objects()
edge1.clear_objects()
edge2.clear_objects()
edge3.clear_objects()

<VirtualWorker id:edge3 #objects:0>

In [18]:
# # Establish the NN model for each worker. This is model-centric FL, so it is the same model for all workers
client1_model = model_client.copy()
client2_model = model_client.copy()
client3_model = model_client.copy()

edge1_model = model_edge.copy()
edge2_model = model_edge.copy()
edge3_model = model_edge.copy()

# Establish the optimizer for each worker
client1_optimizer = optim.SGD(client1_model.parameters(), lr=LR)
client2_optimizer = optim.SGD(client2_model.parameters(), lr=LR)
client3_optimizer = optim.SGD(client3_model.parameters(), lr=LR)

edge1_optimizer = optim.SGD(edge1_model.parameters(), lr=LR)
edge2_optimizer = optim.SGD(edge2_model.parameters(), lr=LR)
edge3_optimizer = optim.SGD(edge3_model.parameters(), lr=LR)

In [19]:
# Organize the WORKER models and optimizers into lists. The server stuff must not be mixed with these
client_models = [client1_model, client2_model, client3_model]
client_optimizers = [client1_optimizer, client2_optimizer, client3_optimizer]
edge_models = [edge1_model, edge2_model, edge3_model]
edge_optimizers = [edge1_optimizer, edge2_optimizer, edge3_optimizer]

worker_collection = [[client1, client1_model, client1_optimizer, edge1, edge1_model, edge1_optimizer],
                     [client2, client2_model, client2_optimizer, edge2, edge2_model, edge2_optimizer],
                     [client3, client3_model, client3_optimizer, edge3, edge3_model, edge3_optimizer],
                     ]

## Federated Split Training

In [20]:
def train():
    batch_count = 0
    total_epoch_time = 0
    batch_times = []
    for i in tqdm(range(0, int(len(train_X)/len(clients)), BATCH_SIZE)):
        
        # Send the models to their appropriate locations
        client1_model.send(client1)
        client2_model.send(client2)
        client3_model.send(client3)
        edge1_model.send(edge1)
        edge2_model.send(edge2)
        edge3_model.send(edge3)
        # This is only done because at the END of each epoch I clear ALL objects from the clients and edges. 
        # In real life this sending and receiving must not be done, but I do it here to save memory because I 
        # am testing on a laptop
        
        # Send the data to the clients and the labels to the edges
        batch_X_1 = train_X_client1[i : i + BATCH_SIZE]
        batch_X_2 = train_X_client2[i : i + BATCH_SIZE]
        batch_X_3 = train_X_client3[i : i + BATCH_SIZE]
        
        batch_y_1 = train_y_edge1[i : i + BATCH_SIZE]
        batch_y_2 = train_y_edge2[i : i + BATCH_SIZE]
        batch_y_3 = train_y_edge3[i : i + BATCH_SIZE]
        
        batch_X_1 = batch_X_1.send(client1)
        batch_X_2 = batch_X_2.send(client2)
        batch_X_3 = batch_X_3.send(client3)
        
        batch_y_1 = batch_y_1.send(edge1)
        batch_y_2 = batch_y_2.send(edge2)
        batch_y_3 = batch_y_3.send(edge3)
        
        # Zero the sequence for all models on both workers and server!
        client1_optimizer.zero_grad()
        client2_optimizer.zero_grad()
        client3_optimizer.zero_grad()
        edge1_optimizer.zero_grad()
        edge2_optimizer.zero_grad()
        edge3_optimizer.zero_grad()
#         print("Zeroed the grads for all workers")
        
        
        # Start FP on clients
        FP_client_start_time = time.time()
        intermediate_1 = client1_model(batch_X_1)
        intermediate_2 = client2_model(batch_X_2)
        intermediate_3 = client3_model(batch_X_3)
        FP_client_end_time = time.time() - FP_client_start_time
#         print("Finished FP on ", worker1.id)
        
        # Send the results to edges. This is where you split the computational graph with .detach()!!!
        data_for_edge1 = intermediate_1.detach().move(edge1).requires_grad_()
        data_for_edge2 = intermediate_2.detach().move(edge2).requires_grad_()
        data_for_edge3 = intermediate_3.detach().move(edge3).requires_grad_()
#         print("Sent FP status to ", worker2.id)

        # Finish FP on edges
        FP_edge_start_time = time.time()
        pred1 = edge1_model(data_for_edge1)
        pred2 = edge2_model(data_for_edge2)
        pred3 = edge3_model(data_for_edge3)
        FP_edge_end_time = time.time() - FP_edge_start_time
        
        # Calculate losses
        BP_edge_start_time = time.time()
        loss1 = loss_function(pred1, batch_y_1)
        loss2 = loss_function(pred2, batch_y_2)
        loss3 = loss_function(pred3, batch_y_3)
#         print("Calculated the loss on ", worker3.id)

        # Do BP on edges
        loss1.backward()
        loss2.backward()
        loss3.backward()
        edge1_optimizer.step()
        edge2_optimizer.step()
        edge3_optimizer.step()
        BP_edge_end_time = time.time() - BP_edge_start_time
#         print("Finished the BP on ", worker3.id)
        
        # Send gradients back to workers and finish BP
        intermediate_1.move(client1)
        grad_intermediate_1 = data_for_edge1.grad.copy().move(client1)
        intermediate_2.move(client2)
        grad_intermediate_2 = data_for_edge2.grad.copy().move(client2)
        intermediate_3.move(client3)
        grad_intermediate_3 = data_for_edge3.grad.copy().move(client3)
        
        # Do BP on clients
        BP_client_start_time = time.time()
        intermediate_1.backward(grad_intermediate_1)
        client1_optimizer.step()
        
        intermediate_2.backward(grad_intermediate_2)
        client2_optimizer.step()
        
        intermediate_3.backward(grad_intermediate_3)
        client3_optimizer.step()
        BP_client_end_time = time.time() - BP_client_start_time
        
        # Total time
        total_batch_time = (FP_client_end_time + FP_edge_end_time 
                            + BP_edge_end_time + BP_client_end_time)/len(clients)
        batch_times.append(total_batch_time)
        total_epoch_time += total_batch_time
#         print("Total batch time = ", round(total_batch_time, 4), " s \n \n")
        
        # This following step must NOT be done in real life. It is only done here because I cannot remove only
        # ONE specific object from the workers!
        client1_model.get()
        client2_model.get()
        client3_model.get()
        edge1_model.get()
        edge2_model.get()
        edge3_model.get()
        
        # Remove batches from workers before next batch
        client1.clear_objects()
        client2.clear_objects()
        client3.clear_objects()
        edge1.clear_objects()
        edge2.clear_objects()
        edge3.clear_objects()
        
#         batch_count += 1
#         if batch_count == 25:
#             break
            
    # OUTSIDE THE FOR LOOP    
    # Average the models
    fed_client_start_time = time.time()
    fed_avg_client = utils.federated_avg({
        "client1": client1_model,
        "client2": client2_model,
        "client3": client3_model
    })
    fed_client_end_time = time.time() - fed_client_start_time
    total_epoch_time += fed_client_end_time
    
    fed_edge_start_time = time.time()
    fed_avg_edge = utils.federated_avg({
        "edge1": edge1_model,
        "edge2": edge2_model,
        "edge3": edge3_model
    })
    fed_edge_end_time = time.time() - fed_edge_start_time
    total_epoch_time += fed_edge_end_time
    print("TOTAL TIME FOR THIS EPOCH = ", round(total_epoch_time, 4), " s")
#     Return the new model on the server
    return fed_avg_client, fed_avg_edge, batch_times, total_epoch_time

## Function used for testing

In [21]:
def test(new_client_model, new_edge_model):
    
    # Calculate the accuracy
    correct = 0
    total = 0

    # Do not update your gradients while testing
    with torch.no_grad():
        print("Initiated model testing:")
        for i in tqdm(range(len(test_X))):
            
            # Put the model into evaluation mode so it does not update its gradients during this test
            new_client_model.eval()
            new_edge_model.eval()
            
            # Obtain the real class for the sample
            real_class = torch.argmax(test_y[i])

            # Obtain our prediction for said sample (not arg_maxed yet)
            output = new_client_model(test_X[i].view(-1, 1, IMG_SIZE, IMG_SIZE))
            output = new_edge_model(output)[0]
        
            # Obtain our arg_maxed prediction for said sample
            predicted_class = torch.argmax(output)

            # Update counters
            if predicted_class == real_class:
                correct += 1
            total += 1

    print("Accuracy of the new model = ", round(correct/total, 3), " \n \n")

In [22]:
def update_models(new_client_model, new_edge_model):
    # Clear the workers of any objects, just in case you forgot some were still there from a previous run
    client1.clear_objects()
    client2.clear_objects()
    client3.clear_objects()
    edge1.clear_objects()
    edge2.clear_objects()
    edge3.clear_objects()
    
    # Establish the NN model for each worker. This is model-centric FL, so it is the same model for all workers
    global client1_model
    client1_model = new_client_model.copy()
    global client2_model
    client2_model = new_client_model.copy()
    global client3_model
    client3_model = new_client_model.copy()
    
    global edge1_model
    edge1_model = new_edge_model.copy()
    global edge2_model
    edge2_model = new_edge_model.copy()
    global edge3_model
    edge3_model = new_edge_model.copy()
    
    # Establish the optimizer for each worker
    global client1_optimizer
    client1_optimizer = optim.SGD(client1_model.parameters(), lr=LR)
    global client2_optimizer
    client2_optimizer = optim.SGD(client2_model.parameters(), lr=LR)
    global client3_optimizer
    client3_optimizer = optim.SGD(client3_model.parameters(), lr=LR)
    
    global edge1_optimizer
    edge1_optimizer = optim.SGD(edge1_model.parameters(), lr=LR)
    global edge2_optimizer
    edge2_optimizer = optim.SGD(edge2_model.parameters(), lr=LR)
    global edge3_optimizer
    edge3_optimizer = optim.SGD(edge3_model.parameters(), lr=LR)
        
    # Organize the WORKER models and optimizers into lists. The server stuff must not be mixed with these
    global worker_collection
    worker_collection = [[client1, client1_model, client1_optimizer, edge1, edge1_model, edge1_optimizer],
                         [client2, client2_model, client2_optimizer, edge2, edge2_model, edge2_optimizer],
                         [client3, client3_model, client3_optimizer, edge3, edge3_model, edge3_optimizer]
                        ]

In [23]:
# # Get all objects as a dictionary, as keys, or remove a specific object
# worker1.object_store._objects.keys()
# worker1.object_store.rm_obj( obj_id = )

# RUN THE MODEL

In [24]:
# Define your number of epochs
epochs = 5
epoch_times = []

# Train all workers for the set number of epochs
for epoch in range(epochs):
    
    # Start counting the time for this epoch
#     start_time = time.time()
    print(f"Epoch Number {epoch + 1}")
        
    # Train the individual models, and then obtain the federated averaged model
#     train_start_time = time.time()
    new_client_model, new_edge_model, batch_times, epoch_time = train()
#     train_total_time = time.time() - train_start_time
#     print("Total TRAIN time for epoch ", epoch, " = ", 
#           round(train_total_time/(len(clients) * 60), 2), " min")
    
    # Save the epoch times
    epoch_times.append(epoch_time)
    
    # Stop counting the time
#     total_time = time.time() - start_time
#     print('Time for this epoch', round(total_time/60, 2), " min")
    
    # Test your new model to keep a log of how good we're doing per epoch 
    test(new_client_model, new_edge_model)

    # Re-organize everything before starting next epoch
    update_models(new_client_model, new_edge_model)
    
    # Save the batch times
    df_batch = pd.DataFrame(batch_times)
    df_batch.to_csv("./Batch_times/STD_MNIST_FSL_epoch_" + str(epoch) + ".csv")

# OUTSIDE THE FOR LOOP
# Save the epoch times
df_epoch = pd.DataFrame(epoch_times)
df_epoch.to_csv("./Epoch_times/STD_MNIST_FSL.csv")
    
# Clean the global namespace after run is done
%reset -f

Epoch Number 1


100%|██████████| 100/100 [16:12<00:00,  9.72s/it]


TOTAL TIME FOR THIS EPOCH =  41.7176  s
Initiated model testing:


100%|██████████| 5000/5000 [00:26<00:00, 186.33it/s]


Accuracy of the new model =  0.9  
 

Epoch Number 2


100%|██████████| 100/100 [15:53<00:00,  9.54s/it]


TOTAL TIME FOR THIS EPOCH =  40.711  s
Initiated model testing:


100%|██████████| 5000/5000 [00:26<00:00, 190.68it/s]


Accuracy of the new model =  0.9  
 

Epoch Number 3


100%|██████████| 100/100 [15:45<00:00,  9.45s/it]


TOTAL TIME FOR THIS EPOCH =  40.1119  s
Initiated model testing:


100%|██████████| 5000/5000 [00:24<00:00, 203.40it/s]


Accuracy of the new model =  0.9  
 

Epoch Number 4


100%|██████████| 100/100 [15:47<00:00,  9.48s/it]


TOTAL TIME FOR THIS EPOCH =  40.1855  s
Initiated model testing:


100%|██████████| 5000/5000 [00:25<00:00, 196.90it/s]


Accuracy of the new model =  0.9  
 

Epoch Number 5


100%|██████████| 100/100 [15:48<00:00,  9.49s/it]


TOTAL TIME FOR THIS EPOCH =  40.3413  s
Initiated model testing:


100%|██████████| 5000/5000 [00:25<00:00, 198.91it/s]


Accuracy of the new model =  0.9  
 

