# Importing libraries

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch

import torchsummary
from torchsummary import summary


from create_MNIST_datasets import get_MNIST, plot_samples
from alibi.confidence import TrustScore


In [None]:
import syft as sy

# Test set preparation

In [None]:
testing_set = torch.load('server_t.pth')
train_data = torch.load('server.pth')

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  
            nn.Conv2d(16, 8, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 
            
            nn.Conv2d(8, 4, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 
            
        )
        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(4, 8, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 16, 2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 32, 2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1) 
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = torch.sigmoid(x)  
        return x
model = autoencoder()

In [None]:
model = torch.load('auto_encoder.pth')

# Connection establishment with client0

In [None]:
duet_0 = sy.join_duet(loopback=True)

# Connection establishment with client1

In [None]:
duet_1 = sy.join_duet(loopback=True)

# Connection establishment with client2

In [None]:
duet_2 = sy.join_duet(loopback=True)

# Model creation

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.fc = nn.Linear(784, 10)
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = F.softmax(x)
        return x


In [None]:
model_server = torch.load('initial_model.pth')
model_server_0 = deepcopy(model_server)
model_server_1 = deepcopy(model_server)
model_server_2 = deepcopy(model_server)

# Initial model parameters

In [None]:
parameters = torch.nn.utils.parameters_to_vector(model_server.parameters())


# The initial model parameters sent to clients

In [None]:
parameters.send(duet_0)

In [None]:
parameters.send(duet_1)

In [None]:
parameters.send(duet_2)

# Model update from client0

In [None]:
duet_0.store.pandas

In [None]:
try:
    duet_0.store[0].get()
except Exception as e:
    print(e)

In [None]:
duet_0.store[0].request(reason="Please approve,updated parameters are needeed for aggregating")

In [None]:
y0 = duet_0.store[0].get()

In [None]:
torch.nn.utils.vector_to_parameters(y0,model_server_0.parameters())

# Model update from client1

In [None]:
duet_1.store.pandas

In [None]:
try:
    duet_1.store[0].get()
except Exception as e:
    print(e)

In [None]:
duet_1.store[0].request(reason="Please approve,updated parameters are needeed for aggregating")

In [None]:
y1 = duet_1.store[0].get()

# Model update for client 2

In [None]:
duet_2.store.pandas

In [None]:
try:
    duet_2.store[0].get()
except Exception as e:
    print(e)

In [None]:
duet_2.store[0].request(reason="Please approve,updated parameters are needeed for aggregating")

In [None]:
y2 = duet_2.store[0].get()

In [None]:
def trust_score(test_samples, model_update, model):
    ts = TrustScore()
    encoded_tests=[]
    model_pred_test=[]
    x_tests=[]
    y_tests=[]
    for idx, (features,labels) in enumerate(test_samples):
        x_tests.append(features)
        predictions = model.encoder(features)
        encoded_tests.append(predictions)
        pred_test = model_update(features)
        model_pred_test.append(pred_test)
        y_tests.append(labels)
        ts.fit(encoded_tests[0].detach().numpy(), y_tests[0], classes=10) 
        score, closest_class = ts.score(encoded_tests[0].detach().numpy(),model_pred_test[0].detach().numpy(), k=5)
        print(score)
        print(np.average(score))

# Trust score calculation local model 0

In [None]:
torch.nn.utils.vector_to_parameters(y0,model_server_0.parameters())

In [None]:
trust_score_local_model_0 = trust_score(testing_set, model_server_0, model)

# Trust score calculation for local model 1

In [None]:
torch.nn.utils.vector_to_parameters(y1,model_server_1.parameters())

In [None]:
trust_score_local_model_1 = trust_score(testing_set, model_server_1, model)

#  Trust score calculation for local model 2

In [None]:
torch.nn.utils.vector_to_parameters(y2,model_server_2.parameters())

In [None]:
trust_score_local_model_2 = trust_score(testing_set, model_server_2, model)

# Trust score calculation for global model

In [None]:
y = (y1+y2+y0)/3
torch.nn.utils.vector_to_parameters(y,model_server.parameters())

In [None]:
trust_score_global_model = trust_score(testing_set, model_server, model)