In [None]:
import numpy as np
import matplotlib.pyplot as plt

# used in printing to a file
import sys

# GMM
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split

# used to deep copy the global models to local models
import copy

# tracks progress of each loops
from tqdm import tqdm

# pytorch
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset, ConcatDataset
from torchvision import datasets, transforms

from collections import Counter

In [None]:
seed = 1368297158
rng = np.random.default_rng(seed)
seeds = rng.integers(low=0, high=2**32 - 1, size=1000)

In [None]:
def sample(train_dataset, test_dataset, num_clients, alpha):
    def helper(dataset, proportions):
        data = dataset.data
        targets = dataset.targets
        
        num_classes = 10
        indices_per_class = [[] for _ in range(num_classes)]
        for idx in range(len(data)):
            label = targets[idx].item()
            indices_per_class[label].append(idx)
        
        client_indices = [[] for _ in range(num_clients)]
        for c in range(num_classes):
            rng = np.random.default_rng(seed)
            rng.shuffle(indices_per_class[c])
            class_splits = (np.cumsum(proportions[c]) * len(indices_per_class[c])).astype(int)[:-1]
            client_splits = np.split(np.array(indices_per_class[c]), class_splits)
            for client_id in range(num_clients):
                client_indices[client_id].extend(client_splits[client_id])
        
        client_subsets = [Subset(dataset, indices) for indices in client_indices]
        return client_subsets
    
    # Generate proportions for each class and client
    rng = np.random.default_rng(seed)
    proportions = rng.dirichlet(np.repeat(alpha, num_clients), size=10)
    
    train_client_subsets = helper(train_dataset, proportions)
    test_client_subsets = helper(test_dataset, proportions)
    
    client_data = [(train_client_subsets[i], test_client_subsets[i]) for i in range(num_clients)]
    
    return client_data

In [None]:
def getDataset(dataset):
    if(dataset == "MNIST"):
        transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST('./data', train=True, download=False, transform=transform)
    
    elif(dataset == "FashionMNIST"):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3527,))
        ])
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
        
    return train_dataset, test_dataset

In [None]:
# convolutional neural network
class MNIST(nn.Module):
    def __init__(self):
        super(MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    

In [None]:
class FashionMNIST(nn.Module):
    def __init__(self):
        super(FashionMNIST, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7*7*32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
def Model(dataset):
    if(dataset == "MNIST"):
        return MNIST()
    elif(dataset == "FashionMNIST"):
        return FashionMNIST()

In [None]:
def trainModel(model, dataset, malicious, attack, epochs = 10, device = "cuda"):
    
    def randomWeights(model):
        min_weight = float('inf')
        max_weight = float('-inf')
           
        for param in model.parameters():
            min_weight = min(min_weight, param.data.min().item())
            max_weight = max(max_weight, param.data.max().item())
    
        for param in model.parameters():
            param.data.uniform_(min_weight, max_weight)

    
    if malicious and attack == "random weights":
        randomWeights(model)
        return model.state_dict()

    
    data_loader = DataLoader(dataset, batch_size = 8, shuffle = True)
    criterion = nn.CrossEntropyLoss().to(device)
    model.to(device)
    model.train()
    epoch_loss = []

    optimizer = torch.optim.SGD(model.parameters(), lr = 0.0001 ,momentum=0.5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.995)
    
    for epoch in range(epochs):
        batch_loss = []
        rng = np.random.default_rng(seed)
        for batch_idx, (images, labels) in enumerate(data_loader):
            # poisoning
            if(malicious):
                if(attack == "random label flipping"):
                    for i in range(len(labels)):
                        labels[i] = (labels[i] + rng.integers(1,10)) % 10
                elif(attack == "cyclic label flipping"):
                    for i in range(len(labels)):
                        labels[i] = (labels[i] + 5)% 10
                elif(attack == "taregeted label flipping"):
                    for i in range(len(labels)):
                        if (labels[i] == 1):
                            labels[i] = 7
                elif (attack == "out of distribution"):
                    images = torch.randn_like(images)
            
            images, labels = images.to(device), labels.to(device)

            model.zero_grad()
            log_probs = model(images)
            loss = criterion(log_probs, labels)
            loss.backward()
            
            optimizer.step()
            scheduler.step()
            
            batch_loss.append(loss.item())
        epoch_loss.append(sum(batch_loss)/len(batch_loss))

    return model.state_dict() #, epoch_grads #, (sum(epoch_loss) / len(epoch_loss))

In [None]:
def testModel(model, dataset, malicious = False, attack = "random label flipping", device = "cuda"):
    criterion = nn.NLLLoss().to(device)
    
    data_loader = DataLoader(dataset, batch_size = 8, shuffle=False)

    model.to(device)
    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    
    for batch_idx, (images, labels) in enumerate(data_loader):
        
        if(malicious):
            if(attack == "cyclic label flipping"):
                for i in range(len(labels)):
                    labels[i] = (labels[i] + 5)% 10
            elif(attack == "taregeted label flipping"):
                    for i in range(len(labels)):
                        if (labels[i] == 1):
                            labels[i] = 7
                        
        images, labels = images.to(device), labels.to(device)


        outputs = model(images)
        batch_loss = F.cross_entropy(outputs, labels, reduction='sum')
        loss += batch_loss.item()


        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    return accuracy, loss

In [None]:
def aggregate(weights, bias = 1):
        aggregated_weights = copy.deepcopy(weights[0])
        for key in aggregated_weights.keys():
            for i in range(1,bias):
                aggregated_weights[key] += weights[0][key]
            for i in range(1, len(weights)):
                aggregated_weights[key] += weights[i][key]
            aggregated_weights[key] = torch.div(aggregated_weights[key], len(weights) + bias - 1)
        return aggregated_weights

In [None]:
def SpyShield(
    dataset,
    clients_count, 
    clients_per_model,
    clients_local_weights,
    bias,
    clients_datasets, 
    malicious_clients,  
    attack, 
    criterion = "accuracy"):
    
    models_count = (clients_count - 1) // (clients_per_model - 1)
    
    def group(clients_count, clients_per_model):
        # grouping
        models_count = (clients_count - 1) // (clients_per_model - 1)
        groups  = np.empty([clients_count, models_count], dtype = object)
        groupings  = np.empty([clients_count, models_count], dtype = object)
        rng = np.random.default_rng()
        for i in range(clients_count):
            remaining = np.delete(np.arange(clients_count), i)
            for j in range(models_count):
                    group = np.concatenate(([i],rng.choice(remaining, size= clients_per_model - 1, replace=False)))
                    groups[i][j] = group
                    remaining = np.delete(remaining, np.where(np.isin(remaining, group[1:])))

        # mapping
        while(True):
            frequency = np.zeros(clients_count, dtype=int)
            for i in range(clients_count):
                for j in range(models_count):
                    pairs = [(i, frequency[i]) for i in range(len(frequency))]
                    filtered = []
                    for pair in pairs:
                        if pair[0] not in groups[i][j]:
                            filtered.append(pair)
                    rng.shuffle(filtered)
                    assignee = min(filtered, key=lambda x: x[1])[0]
                    groupings[i][j] = (groups[i][j], assignee)
                    # update the frequcny 
                    frequency[assignee] += 1;
            if(np.all(frequency == models_count)):
                break
        return groupings

    def test(
        dataset,
        clients_count, 
        models_count, 
        groupings, 
        clients_local_weights,
        bias,
        clients_datasets, 
        malicious_clients,
        attack, 
        criterion):
        # intialize the results 2D list
        results = np.empty([clients_count, models_count])
        for i in tqdm(range(clients_count)):
            for j in range(models_count):
            	# aggregate weights of clients in the jth group of the ith client        
                aggregated_weights = aggregate(weights = [clients_local_weights[index] for index in groupings[i][j][0] ], bias = bias)
            	# intialize the model using the aggregated weights
                model = Model(dataset)
                model.load_state_dict(aggregated_weights)
            	# test the model at the respective tester client
                results[i][j] = testModel(
                    model,
                    clients_datasets[groupings[i][j][1]][1],
                    groupings[i][j][1] in malicious_clients,
                    attack)[0 if criterion == "accuracy" else 1]  
        return results;
    
    def evaluateTesters(clients_count, models_count, groupings, results, criterion):
        # prepare the results to test for reliability
        # intialize results_per_tester
        results_per_tester = [[] for i in range(clients_count)]
        # populate results_per_tester
        for i in range(clients_count):
            for j in range(models_count):
                results_per_tester[groupings[i][j][1]].append(results[i][j])
        # intialize mean_result_per_tester
        mean_result_per_tester = [np.mean(results) for results in results_per_tester]
        # index mean_result_per_tester
        indexed_mean_results_per_tester = [[index, mean_result_per_tester[index]] for index in range(clients_count)]
        # sort indexed_mean_results_per_tester
        indexed_mean_results_per_tester = sorted(indexed_mean_results_per_tester, key = lambda x: x[1])
        # testers in the second half of mean_result are considered reliable if the criterion is accuracy.
        # testers in the first half of mean_result are considered reliable if the criterion is loss.
        reliable = [index for index,_ in (indexed_mean_results_per_tester[int(clients_count/2) + 2:] if criterion == "accuracy" else indexed_mean_results_per_tester[: int(clients_count/2) - 2])]        
        
        return reliable

    def evaluateTestees(clients_count, models_count, groupings, reliable, results, criterion):
        # intialize results_per_testee
        results_per_testee = [[] for i in range(clients_count)]
        for i in range(clients_count):
            for j in range(models_count):
                if(groupings[i][j][1] in reliable):
                    results_per_testee[i].append(results[i][j])
        # intialize mean_result_per_testee
        mean_result_per_testee = [(np.mean(results) if len(results) > 0 else 0 ) for results in results_per_testee]
        # index mean_result_per_testee
        indexed_mean_results_per_testee = [[index, mean_result_per_testee[index]] for index in range(clients_count)]
        # sort indexed_mean_results_per_testee
        indexed_mean_results_per_testee = sorted(indexed_mean_results_per_testee, key = lambda x: x[1])
        # testees in the second half of mean_result are considered honest if the criterion is accuracy.
        # testees in the first half of mean_result are considered honest if the criterion is loss.
        honest = [index for index,_ in (indexed_mean_results_per_testee[int(clients_count/2) + 2: ] if criterion == "accuracy" else indexed_mean_results_per_testee[: int(clients_count/2) - 2])]


        # analysis
        mean_result_per_poisoned_testee = [mean_result_per_testee[i] for i in range(clients_count) if i in malicious_clients]
        
        plot(mean_result_per_testee, mean_result_per_poisoned_testee, criterion)

        return honest

    
    groupings = group(clients_count, clients_per_model)
    
    results = test(
        dataset,
        clients_count, 
        models_count, 
        groupings, 
        clients_local_weights,
        bias,
        clients_datasets, 
        malicious_clients, 
        attack, 
        criterion)

    reliable = evaluateTesters(clients_count, models_count, groupings, results, criterion)
    
    indices = evaluateTestees(clients_count, models_count, groupings, reliable, results, criterion)


    return indices, reliable

In [None]:
def MultiKrum(weights, d=1):

    n = len(weights)

    # Calculate squared Euclidean distances
    distances = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            model1 = global_model = Model("FashionMNIST")
            model1.load_state_dict(weights[i])
            model2 = global_model = Model("FashionMNIST")
            model2.load_state_dict(weights[j])
            model1_params = list(model1.parameters())
            model2_params = list(model2.parameters())
            
            model1_flat = torch.cat([param.flatten() for param in model1_params])
            model2_flat = torch.cat([param.flatten() for param in model2_params])
            
            # Calculate the difference
            diff = model1_flat - model2_flat
            distances[i][j] = torch.linalg.vector_norm(diff)

    indexed_distances = []
    for i in range(n):
        sorted_distances = np.sort(distances[i])
        indexed_distances.append((i, np.sum(sorted_distances[:n - int(n/2)])))


    sorted_indexed_distances = sorted(indexed_distances, key=lambda x: x[1])
    sorted_indices = [indexed_distance[0] for indexed_distance in sorted_indexed_distances]
    # index of  the client with the minimum sum of distances
    return sorted_indices[:d]

In [None]:
def Krum(weights):

    return MultiKrum(weights)[0]

In [None]:
def Median(weights):
    n = len(weights)
    global_model = Model("FashionMNIST")
    
    # Get the keys of the state dict
    keys = weights[0].keys()
    
    # Initialize a state dict to store the median parameters
    median_weights = {key: torch.zeros_like(weights[0][key]) for key in keys}
    
    # Compute the median for each parameter
    for key in keys:
        stacked_params = torch.stack([weights[i][key].flatten() for i in range(n)])
        median_params = torch.median(stacked_params, dim=0).values
        median_weights[key] = median_params.view_as(weights[0][key])
    
    return median_weights

In [None]:
def TrimmedMean(weights, trim_ratio=0.25):
    n = len(weights)
    global_model = Model("FashionMNIST")
    
    # Get the keys of the state dict
    keys = weights[0].keys()
    
    # Initialize a state dict to store the trimmed mean parameters
    trimmed_mean_weights = {key: torch.zeros_like(weights[0][key]) for key in keys}
    
    # Compute the trimmed mean for each parameter
    for key in keys:
        stacked_params = torch.stack([weights[i][key].flatten() for i in range(n)])
        
        # Sort the parameters for each element
        sorted_params, _ = torch.sort(stacked_params, dim=0)
        
        # Calculate the indices to trim
        trim_n = int(trim_ratio * n)
        trimmed_params = sorted_params[ trim_n + 1 : n - trim_n - 1, :].float()
        
        # Compute the mean of the remaining parameters
        trimmed_mean_params = trimmed_params.mean(dim=0)
        trimmed_mean_weights[key] = trimmed_mean_params.view_as(weights[0][key])
    
    return trimmed_mean_weights

In [None]:
def plot(data1, data2, criterion):
    # Create the figure and axis
    fig, ax = plt.subplots()
    
    # Plot histograms for both datasets, using different colors and alpha for transparency
    
    if(criterion == 'loss'):
        ax.hist(data1, label='honest', bins=(20) ,color='b', hatch='O', align='right')
        ax.hist(data2, alpha=0.8, label='poisoned', bins=20, color='r', hatch='.' , align='right')
    else:
        ax.hist(data1, label='honest', range =(0,0.9), bins=(20) , color='b', hatch='O', align='right')
        ax.hist(data2, alpha=0.8, label='poisoned',  range =(0,0.9), bins=20, color='r', hatch='.', align='right')
    
    # Add legend
    ax.legend()

    plt.ylabel('clients count')
    plt.xlabel(criterion)
    
    
    # Show the plot
    plt.show()

In [None]:
def main(dataset, clients_count, malicious_clients_count,  attack, algorithm, algorithm_parameters, device = "cuda"):    
    
    train_dataset, test_dataset = getDataset(dataset)
    
    
    clients_datasets = sample(train_dataset, test_dataset, 150, 0.5)
    
    rng = np.random.default_rng(seed)
    rng.shuffle(clients_datasets)
    
    clients_datasets = clients_datasets[0: clients_count]
            
    # intialize model
    global_model = Model(dataset)
    global_model.to(device)
    epoch = 0
    
    metrics = []
    while (epoch < 250):
        print("epoch: ", epoch + 1)
        # intialize the dictionary of malicious clients
        rng = np.random.default_rng(seeds[epoch])
        malicious_clients = rng.choice(np.arange(clients_count), size = malicious_clients_count, replace = False)
        # print(malicious_clients)
        # intialize local wieghts list 
        clients_local_weights = []
        for i in tqdm(range(clients_count)):
            # train the ith client 
            weights = trainModel(
                copy.deepcopy(global_model), 
                clients_datasets[i][0], 
                i in malicious_clients, 
                attack)
            # append the ith client's local weights to local_wieghts list
            clients_local_weights.append(weights)

        stats ={}
        if(algorithm == "FedAvg"):
            global_model.load_state_dict(aggregate(clients_local_weights))
            
        elif(algorithm == "SpyShield"):
            indices, relibale = SpyShield(
                dataset,
                clients_count, 
                algorithm_parameters["clients_per_model"],
                clients_local_weights,
                algorithm_parameters["bias"],
                clients_datasets, 
                malicious_clients,  
                attack, 
                algorithm_parameters["criterion"])
    
            global_model.load_state_dict(aggregate([clients_local_weights[i] for i in indices]))

            # log
            # intialize true_positives_unreliable
            true_positives_unreliable = 0
            # count true_positives_unreliable
            if(attack == "cyclic label flipping" or attack == "targeted label flipping"):
                for i in range(clients_count):
                    if i not in relibale and i in malicious_clients:
                        true_positives_unreliable += 1
                stats["true_positives_unreliable"] = true_positives_unreliable
                
            # intialize true_positives_poisoned
            true_positives_poisoned = 0
            # count true_positives_malicious
            for i in range(clients_count):
                if i not in indices and i in malicious_clients:
                    true_positives_poisoned += 1
            stats["true_positives_poisoned"] = true_positives_poisoned
                
        elif (algorithm == "Krum"):
            index = Krum(clients_local_weights)
            global_model.load_state_dict(clients_local_weights[index])

            # log
            stats["is_selected_poisoned"] = index in malicious_clients
            
            
        elif (algorithm == "MultiKrum"):
            indices = MultiKrum(clients_local_weights, algorithm_parameters["d"])
            global_model.load_state_dict(aggregate([clients_local_weights[i] for i in indices]))

            # log
            false_negatives_poisoned = 0
            for index in indices:
                if index in malicious_clients:
                    false_negatives_poisoned += 1
        
            stats["false_negative_poisoned"] =  false_negatives_poisoned
            
        elif (algorithm == "Median"):
            median_weights = Median(clients_local_weights)
            global_model.load_state_dict(median_weights)

        elif (algorithm == "TrimmedMean"):
            trimmed_mean_weights = TrimmedMean(clients_local_weights)
            global_model.load_state_dict(trimmed_mean_weights)

        
        accuracies, losses = [], []
        for i in range(clients_count):
            accuracy, loss = testModel(global_model, clients_datasets[i][1], malicious = False)
            accuracies.append(accuracy)
            losses.append(loss)

        mean_accuracy = np.mean(accuracies)
        mean_loss = np.mean(loss)
        
        metrics.append([mean_accuracy, mean_loss, stats])
        print(metrics)
        
        epoch += 1
    return global_model, metrics

In [None]:
# algorithm_parameters = {}

In [None]:
# algorithm_parameters = {"d":5}

In [None]:
algorithm_parameters = {
    "clients_per_model": 4,
    "bias": 9,
    "criterion":"loss",
}

In [None]:
config = {
    "dataset": "FashionMNIST",
    "clients_count": 50,
    "malicious_clients_count": 10,
    "attack":"cyclic label flipping",
    "algorithm": "SpyShield",
    "algorithm_parameters": algorithm_parameters
}

In [None]:
print(config)

In [None]:
folder = "results/" 

In [None]:
print(folder)

In [None]:
folder += str(config["clients_count"]) + "/"

In [None]:
print(folder)

In [None]:
folder += config["attack"] + "/" if config["malicious_clients_count"] > 0 else "no-attack/"

In [None]:
print(folder)

In [None]:
file = config["algorithm"] +  ".txt"

In [None]:
print(folder + file)

In [None]:
global_model, metrics = main(
    config["dataset"], 
    config["clients_count"],
    config["malicious_clients_count"],
    config["attack"],
    config["algorithm"],
    config["algorithm_parameters"])

In [None]:
with open(folder + file, 'a') as data:
    data.write("config =" + str(config) + "\n")
    data.write("metrics =" + str(metrics) + "\n")
    data.write("\n")