# Feature Engineering with SHAP values Experiment 1

SHAP Images right after poisoning attack

rounds [1,2,10,75,200]


## Google Colab

In [None]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks')
sys.path.append('/content/drive/My Drive/Colab Notebooks/federated_learning')

!pip install shap==0.40.0

## Experiment Setup

In [None]:
from federated_learning.utils import SHAPUtil, experiment_util, Visualizer
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime

In [None]:
GOOGLE_COLAB_MODEL_PATH = "/content/drive/My Drive/Colab Notebooks/temp/models/MNISTtrained.model"

## MNIST
(1) 5 → 4,
(2) 1 → 7,
(3) 3 → 8,

In [None]:
from federated_learning.nets import MNISTCNN
from federated_learning.dataset import MNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = MNISTDataset
config.MODELNAME = config.MNIST_NAME
config.NETWORK = MNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none"

In [None]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
visualizer = Visualizer(shap_util)

### Experimental Setup

In [None]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "./temp/models/ex5/MNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)

## alpha(5,4)

In [None]:
config.FROM_LABEL = 5
config.TO_LABEL = 4

In [None]:
import torch
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for i in range(5):
    for j in [2,5,10,75,100,200]:
        model_file = file = "./temp/models/ex5/MNIST_round_{}.model".format(j)
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        
        server.test()
        recall, precision, accuracy = server.analize_test()
        print("Original", recall, precision, accuracy)
        server_shap = server.get_shap_values(shap_images)
        server_file_name = "./results/ex5/MNIST/5_4/{}_round_{}_MNIST_server_lf_5_4_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(server_shap, server_file_name, indices=shap_images)
        
        config.POISONED_CLIENTS = 1
        experiment_util.update_configs(client_plane, server, config, observer_config)
        client_plane.poison_clients()
        clean_clients = experiment_util.select_random_clean(client_plane, config, 1)
        poisoned_clients = experiment_util.select_poisoned(client_plane, 1)
        
        client_plane.update_clients(server.get_nn_parameters())
        print("Client Clean {}".format(j+1))
        client_plane.clients[clean_clients[0]].train(j+1)
        clean_client_shap = client_plane.clients[clean_clients[0]].get_shap_values(shap_images)
        clean_client_file_name = "./results/ex5/MNIST/5_4/{}_round_{}_MNIST_clean_client_lf_5_4_ex5.pdf".format(i, j)
        clean_client_compare_file_name = "./results/ex5/MNIST/5_4/{}_round_{}_MNIST_clean_client_compare_lf_5_4_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(clean_client_shap, clean_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(clean_client_shap, server_shap, clean_client_compare_file_name,indices=shap_images)
        
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.update_clients(server.get_nn_parameters())    
        print("Client Poisoned {}".format(j+1))
        client_plane.clients[poisoned_clients[0]].train(j+1)
        poisoned_client_shap = client_plane.clients[poisoned_clients[0]].get_shap_values(shap_images)
        poisoned_client_file_name = "./results/ex5/MNIST/5_4/{}_round_{}_MNIST_poisoned_client_lf_5_4_ex5.pdf".format(i, j)
        poisoned_client_compare_file_name = "./results/ex5/MNIST/5_4/{}_round_{}__MNIST_poisoned_client_compare_lf_5_4_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(poisoned_client_shap, poisoned_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(poisoned_client_shap, server_shap, poisoned_client_compare_file_name,indices=shap_images)
        print("Round {} finished".format(i+1))

## alpha(1,7)

In [None]:
config.FROM_LABEL = 1
config.TO_LABEL = 7

In [None]:
import torch
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for i in range(4,5):
    for j in [2,5,10,75,100,200]:
        model_file = file = "./temp/models/ex5/MNIST_round_{}.model".format(j)
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        
        server.test()
        recall, precision, accuracy = server.analize_test()
        print("Original", recall, precision, accuracy)
        server_shap = server.get_shap_values(shap_images)
        server_file_name = "./results/ex5/MNIST/1_7/{}_round_{}_MNIST_server_lf_1_7_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(server_shap, server_file_name, indices=shap_images)
        
        config.POISONED_CLIENTS = 1
        experiment_util.update_configs(client_plane, server, config, observer_config)
        client_plane.poison_clients()
        clean_clients = experiment_util.select_random_clean(client_plane, config, 1)
        poisoned_clients = experiment_util.select_poisoned(client_plane, 1)
        
        client_plane.update_clients(server.get_nn_parameters())
        print("Client Clean {}".format(j+1))
        client_plane.clients[clean_clients[0]].train(j+1)
        clean_client_shap = client_plane.clients[clean_clients[0]].get_shap_values(shap_images)
        clean_client_file_name = "./results/ex5/MNIST/1_7/{}_round_{}_MNIST_clean_client_lf_1_7_ex5.pdf".format(i, j)
        clean_client_compare_file_name = "./results/ex5/MNIST/1_7/{}_round_{}_MNIST_clean_client_compare_lf_1_7_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(clean_client_shap, clean_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(clean_client_shap, server_shap, clean_client_compare_file_name,indices=shap_images)
        
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.update_clients(server.get_nn_parameters())    
        print("Client Poisoned {}".format(j+1))
        client_plane.clients[poisoned_clients[0]].train(j+1)
        poisoned_client_shap = client_plane.clients[poisoned_clients[0]].get_shap_values(shap_images)
        poisoned_client_file_name = "./results/ex5/MNIST/1_7/{}_round_{}_MNIST_poisoned_client_lf_1_7_ex5.pdf".format(i, j)
        poisoned_client_compare_file_name = "./results/ex5/MNIST/1_7/{}_round_{}__MNIST_poisoned_client_compare_lf_1_7_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(poisoned_client_shap, poisoned_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(poisoned_client_shap, server_shap, poisoned_client_compare_file_name,indices=shap_images)
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        print("Round {} finished".format(i+1))

In [None]:
config.FROM_LABEL = 3
config.TO_LABEL = 8

In [None]:
import torch
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for i in range(5):
    for j in [2,5,10,75,100,200]:
        model_file = file = "./temp/models/ex5/MNIST_round_{}.model".format(j)
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        
        server.test()
        recall, precision, accuracy = server.analize_test()
        print("Original", recall, precision, accuracy)
        server_shap = server.get_shap_values(shap_images)
        server_file_name = "./results/ex5/MNIST/3_8/{}_round_{}_MNIST_server_lf_3_8_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(server_shap, server_file_name, indices=shap_images)
        
        config.POISONED_CLIENTS = 1
        experiment_util.update_configs(client_plane, server, config, observer_config)
        client_plane.poison_clients()
        clean_clients = experiment_util.select_random_clean(client_plane, config, 1)
        poisoned_clients = experiment_util.select_poisoned(client_plane, 1)
        
        client_plane.update_clients(server.get_nn_parameters())
        print("Client Clean {}".format(j+1))
        client_plane.clients[clean_clients[0]].train(j+1)
        clean_client_shap = client_plane.clients[clean_clients[0]].get_shap_values(shap_images)
        clean_client_file_name = "./results/ex5/MNIST/3_8/{}_round_{}_MNIST_clean_client_lf_3_8_ex5.pdf".format(i, j)
        clean_client_compare_file_name = "./results/ex5/MNIST/3_8/{}_round_{}_MNIST_clean_client_compare_lf_3_8_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(clean_client_shap, clean_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(clean_client_shap, server_shap, clean_client_compare_file_name,indices=shap_images)
        
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.update_clients(server.get_nn_parameters())    
        print("Client Poisoned {}".format(j+1))
        client_plane.clients[poisoned_clients[0]].train(j+1)
        poisoned_client_shap = client_plane.clients[poisoned_clients[0]].get_shap_values(shap_images)
        poisoned_client_file_name = "./results/ex5/MNIST/3_8/{}_round_{}_MNIST_poisoned_client_lf_3_8_ex5.pdf".format(i, j)
        poisoned_client_compare_file_name = "./results/ex5/MNIST/3_8/{}_round_{}__MNIST_poisoned_client_compare_lf_3_8_ex5.pdf".format(i, j)
        visualizer.plot_shap_values(poisoned_client_shap, poisoned_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(poisoned_client_shap, server_shap, poisoned_client_compare_file_name,indices=shap_images)
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        print("Round {} finished".format(i+1))

## FashionMNIST
For Fashion-MNIST we experiment with 
(1) 5: sandal → 4: coat,
(2) 1: trouser → 3: dress, and 
(3) 8: Bag → 9: Ankle Boot.
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']


In [None]:
from federated_learning.nets import FMNISTCNN
from federated_learning.dataset import FMNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = FMNISTDataset
config.MODELNAME = config.FMNIST_NAME
config.NETWORK = FMNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none"

In [None]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
visualizer = Visualizer(shap_util)

In [None]:
import numpy as np
import copy
for i in range(199):
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)
print("Run 199 finished")
old_params = copy.deepcopy(server.get_nn_parameters())

In [None]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "/content/drive/My Drive/Colab Notebooks/temp/models/ex5/MNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)

In [None]:
config.FROM_LABEL = 5
config.TO_LABEL = 4

In [None]:
import torch
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for i in range(5):
    for j in [2,5,10,75,100,200]:
        model_file = file = "/content/drive/My Drive/Colab Notebooks/temp/models/ex5/MNIST_round_{}.model".format(j)
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        
        server.test()
        recall, precision, accuracy = server.analize_test()
        print("Original", recall, precision, accuracy)
        server_shap = server.get_shap_values(shap_images)
        server_file_name =  "/content/drive/My Drive/Colab Notebooks/results/ex5/MNIST/{}_{}/{}_round_{}_MNIST_server_lf_{}_{}_ex5.pdf".format(config.FROM_LABEL, config.TO_LABEL, i, j, config.FROM_LABEL, config.TO_LABEL)
        visualizer.plot_shap_values(server_shap, server_file_name, indices=shap_images)
        
        config.POISONED_CLIENTS = 1
        experiment_util.update_configs(client_plane, server, config, observer_config)
        client_plane.poison_clients()
        clean_clients = experiment_util.select_random_clean(client_plane, config, 1)
        poisoned_clients = experiment_util.select_poisoned(client_plane, 1)
        
        client_plane.update_clients(server.get_nn_parameters())
        print("Client Clean {}".format(j+1))
        client_plane.clients[clean_clients[0]].train(j+1)
        clean_client_shap = client_plane.clients[clean_clients[0]].get_shap_values(shap_images)
        clean_client_file_name = "/content/drive/My Drive/Colab Notebooks/results/ex5/MNIST/{}_{}/{}_round_{}_MNIST_clean_client_lf_{}_{}_ex5.pdf".format(config.FROM_LABEL, config.TO_LABEL, i, j, config.FROM_LABEL, config.TO_LABEL)
        clean_client_compare_file_name = "/content/drive/My Drive/Colab Notebooks/results/ex5/MNIST/{}_{}/{}_round_{}_MNIST_clean_client_compare_lf_{}_{}_ex5.pdf".format(config.FROM_LABEL, config.TO_LABEL, i, j, config.FROM_LABEL, config.TO_LABEL)
        visualizer.plot_shap_values(clean_client_shap, clean_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(clean_client_shap, server_shap, clean_client_compare_file_name,indices=shap_images)
        
        server.net =  MNISTCNN()
        server.net.load_state_dict(torch.load(model_file))
        client_plane.update_clients(server.get_nn_parameters())    
        print("Client Poisoned {}".format(j+1))
        client_plane.clients[poisoned_clients[0]].train(j+1)
        poisoned_client_shap = client_plane.clients[poisoned_clients[0]].get_shap_values(shap_images)
        poisoned_client_file_name = "/content/drive/My Drive/Colab Notebooks/results/ex5/MNIST/{}_{}/{}_round_{}_MNIST_poisoned_client_lf_{}_{}_ex5.pdf".format(config.FROM_LABEL, config.TO_LABEL, i, j, config.FROM_LABEL, config.TO_LABEL)
        poisoned_client_compare_file_name = "/content/drive/My Drive/Colab Notebooks/results/ex5/MNIST/{}_{}/{}_round_{}__MNIST_poisoned_client_compare_lf_{}_{}_ex5.pdf".format(config.FROM_LABEL, config.TO_LABEL, i, j, config.FROM_LABEL, config.TO_LABEL)
        visualizer.plot_shap_values(poisoned_client_shap, poisoned_client_file_name,indices=shap_images)
        visualizer.compare_shap_values(poisoned_client_shap, server_shap, poisoned_client_compare_file_name,indices=shap_images)
        client_plane.reset_default_client_nets()
        client_plane.reset_poisoning_attack()
        print("Round {} finished".format(i+1))

## alpha(3,8)