In [None]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"  # or ":16:8"

import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Function to attempt to import a module, and install it if not present
def try_import(module_name, package_name=None):
    try:
        module = __import__(module_name)
        return module
    except ImportError:
        if package_name is None:
            package_name = module_name
        print(f"Installing {package_name}...")
        install(package_name)
        module = __import__(module_name)
        return module

# Standard library imports (no need to install)
import logging
from datetime import datetime
from copy import deepcopy
import gc
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Iterable, Union, Optional

# Third-party imports
torch = try_import('torch')
torchvision = try_import('torchvision')
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
np = try_import('numpy')
# Import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset, Dataset
import torchvision.models as models
from torch.nn.functional import tanh, softmax

from torchvision.datasets import utils
from PIL import Image
import os.path
import shutil


# sklearn imports
sklearn = try_import('sklearn', 'scikit-learn')
from sklearn.model_selection import train_test_split
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import cosine_distances,euclidean_distances
from sklearn.metrics import pairwise_distances

# Other third-party imports
plt = try_import('matplotlib.pyplot', 'matplotlib')
pd = try_import('pandas')

# Install and import CLIP
try:
    import clip
except ImportError:
    print("Installing CLIP...")
    install('git+https://github.com/openai/CLIP.git')
    import clip



# torch.use_deterministic_algorithms(True, warn_only=True)
torch.manual_seed(0)

# Device configuration
# Get the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

# If GPUs are available, choose the desired device index (within the available range)
# Otherwise, default to CPU
if num_gpus > 0:
    desired_gpu_index = 3  # This is the index you originally wanted
    device_index = min(desired_gpu_index, num_gpus - 1)  # Clamp to available range
    device = torch.device(f"cuda:{device_index}")
    torch.cuda.set_device(device)  # Set the device
    print(f"Using GPU: {device}")
else:
    device = torch.device("cpu")
    print("No GPUs available, using CPU.")


#set devices to multiple GPUs
unwanted_device_indices = []
available_device_indices = list(range(num_gpus))
devices = [f'cuda:{i}' for i in available_device_indices if i not in unwanted_device_indices]
if not devices:
    devices = ['cpu']
    # raise RuntimeError("Desired GPUs are not available.")
print(f"Devices: {devices}")




import multiprocessing

# Get the number of available CPU cores
num_cores = multiprocessing.cpu_count()

# Set THREAD_NUMBER to the number of CPU cores
# THREAD_NUMBER = min(num_cores, 2*(len(devices)))
THREAD_NUMBER = len(devices) # num_cores

print(f"Number of CPU cores available: {num_cores}")
print(f"THREAD_NUMBER set to: {THREAD_NUMBER}")



# Check GPU information
def check_gpu():
    try:
        gpu_info = subprocess.check_output(['nvidia-smi']).decode('utf-8')
        print(gpu_info)
    except Exception as e:
        print('Not connected to a GPU or nvidia-smi not found.')

check_gpu()

# Check CPU information
def check_cpu():
    try:
        cpu_info = subprocess.check_output(['lscpu']).decode('utf-8')
        print(cpu_info)
    except Exception as e:
        print('Could not retrieve CPU information.')

check_cpu()

Number of GPUs available: 1
Using GPU: cuda:0
Devices: ['cuda:0']
Number of CPU cores available: 12
THREAD_NUMBER set to: 1
Sat Nov  9 19:46:12 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0              49W / 400W |    717MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+---------------------------------

# Function to determine the data root directory

In [None]:
# Function to determine the data root directory
def get_data_root():
    if 'COLAB_GPU' in os.environ:
        # Mount Google Drive if needed
        from google.colab import drive
        drive.mount('/content/drive')
        data_root = '/content/drive/MyDrive/PhD/XFED result/Result XFED log/colab output/'
    else:
        data_root = './data/'
    return data_root

# Get the appropriate data root directory
data_root = get_data_root()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Model Definition for different datasets

In [None]:
class FashionMNISTAlexNet(nn.Module):
    def __init__(self):
        super(FashionMNISTAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FeatureNorm(nn.Module):
    def __init__(self, feature_shape):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1, feature_shape))

    def forward(self, x):
        x = torch.einsum('ni, j->ni', x, self.gamma)
        x = x + self.beta
        return  x

class purchase_fully_connected_IN(nn.Module):
    def __init__(self, num_classes):
        super(purchase_fully_connected_IN, self).__init__()
        self.fc1 = nn.Linear(600, 1024, bias=False)  # First layer: input size 600, output size 1024
        self.fc2 = nn.Linear(1024, 100, bias=False)  # Second layer: input size 1024, output size 100
        self.fc3 = nn.Linear(100, num_classes, bias=False)  # Output layer: input size 100, output size num_classes
        self.norm = FeatureNorm(600)

    def forward(self, x):
        x = self.norm(x)
        x = torch.tanh(self.fc1(x))  # Apply tanh activation after the first layer
        x = torch.tanh(self.fc2(x))  # Apply tanh activation after the second layer
        logits = self.fc3(x)         # Output layer, no activation
        return logits

class Purchase(torch.utils.data.Dataset):
    def __init__(self, root =data_root + 'dataset_purchase',train=True, download=True, transform = None):
        self.images = []
        self.root = root
        self.targets = []
        self.train = train
        self.download = download
        self.transform = transform

        x_train, x_test, y_train, y_test = self._train_test_split()

        if self.train:
            self._setup_dataset(x_train, y_train)
        else:
            self._setup_dataset(x_test, y_test)

    def _train_test_split(self):
        df = pd.read_csv(self.root)

        img_names = df.iloc[:, 1:].to_numpy(dtype='f')
        img_label = df.iloc[:, 0].to_numpy()-1
        x_train,x_test, y_train, y_test = train_test_split(img_names, img_label, train_size=0.8,
                                                            random_state=1)
        return x_train, x_test, y_train, y_test

    def _setup_dataset(self, x, y):
            self.images = x
            self.targets = y

    def __len__(self): # Added the __len__ method
        return len(self.images)

    def __getitem__(self, item):
        img = self.images[item]
        label = self.targets[item]
        return img, label

class ThreeLayerDNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ThreeLayerDNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

class FourLayerDNN(nn.Module):
    def __init__(self):
        super(FourLayerDNN, self).__init__()
        # Flatten the input image
        self.flatten = nn.Flatten()
        # Define the fully connected layers
        self.fc1 = nn.Linear(3 * 32 * 32, 1024)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(1024, 512)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(512, 256)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(256, 10)  # Output layer for 10 classes

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.relu3(self.fc3(x))
        x = self.fc4(x)
        return x

class InputNorm(nn.Module):
    def __init__(self, num_channel, num_feature):
        super().__init__()
        self.num_channel = num_channel
        self.gamma = nn.Parameter(torch.ones(num_channel))
        self.beta = nn.Parameter(torch.zeros(num_channel, num_feature, num_feature))
    def forward(self, x):
        if self.num_channel == 1:
            x = self.gamma*x
            x = x + self.beta
            return  x
        if self.num_channel == 3:
            return torch.einsum('...ijk, i->...ijk', x, self.gamma) + self.beta

class mnist_fully_connected_IN(nn.Module):
    def __init__(self,num_classes):
        super(mnist_fully_connected_IN, self).__init__()
        self.hidden1 = 600
        self.hidden2 = 100
        self.fc1 = nn.Linear(28 * 28, self.hidden1, bias=False)
        self.fc2 = nn.Linear(self.hidden1, self.hidden2, bias=False)
        self.fc3 = nn.Linear(self.hidden2, num_classes, bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.norm = InputNorm(1, 28)

    def forward(self,x):
        x = self.norm(x)
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

class CHMNISTDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('L')  # Convert to grayscale
        label = int(image_path.split('_')[-1].split('.')[0])  # Assuming label is in filename

        if self.transform:
            image = self.transform(image)

        return image, label


class FEMNISTDataset(torchvision.datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(torchvision.datasets.MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
        self.download = download
        self.download_link = 'https://media.githubusercontent.com/media/GwenLegate/femnist-dataset-PyTorch/main/femnist.tar.gz'
        self.file_md5 = 'a8a28afae0e007f1acb87e37919a21db'
        self.train = train
        self.root = root
        self.training_file = f'{self.root}/FEMNIST/processed/femnist_train.pt'
        self.test_file = f'{self.root}/FEMNIST/processed/femnist_test.pt'
        self.user_list = f'{self.root}/FEMNIST/processed/femnist_user_keys.pt'

        if not os.path.exists(f'{self.root}/FEMNIST/processed/femnist_test.pt') \
                or not os.path.exists(f'{self.root}/FEMNIST/processed/femnist_train.pt'):
            if self.download:
                self.dataset_download()
            else:
                raise RuntimeError('Dataset not found, set parameter download=True to download')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        data_targets_users = torch.load(data_file)
        self.data, self.targets, self.users = torch.Tensor(data_targets_users[0]), torch.Tensor(data_targets_users[1]), data_targets_users[2]
        self.user_ids = torch.load(self.user_list)

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])

        # Reshape the flattened image to 28x28
        img = img.view(28, 28).numpy().astype(np.uint8)

        # Convert to PIL Image in grayscale mode
        img = Image.fromarray(img, mode='L')

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target  # Return only img and target

    def dataset_download(self):
        paths = [f'{self.root}/FEMNIST/raw/', f'{self.root}/FEMNIST/processed/']
        for path in paths:
            if not os.path.exists(path):
                os.makedirs(path)

        # download files
        filename = self.download_link.split('/')[-1]
        utils.download_and_extract_archive(self.download_link, download_root=f'{self.root}/FEMNIST/raw/', filename=filename, md5=self.file_md5)

        files = ['femnist_train.pt', 'femnist_test.pt', 'femnist_user_keys.pt']
        for file in files:
            # move to processed dir
            shutil.move(os.path.join(f'{self.root}/FEMNIST/raw/', file), f'{self.root}/FEMNIST/processed/')

Loading Model for different datasets (FashionMNIST, CIFAR-10, PURCHASE, MNIST, EMNIST, CIFAR-100)

In [None]:
from sklearn.preprocessing import LabelEncoder, StandardScaler
class SymbiPredictDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = torch.tensor(self.data[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return features, label

class TabularNet(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(TabularNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def load_model(dataset: str):
    """Load and prepare the model and datasets based on the given dataset name."""
    if dataset == 'FashionMNIST':
        transform = transforms.Compose([
            transforms.Resize((227, 227)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_data = torchvision.datasets.FashionMNIST(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.FashionMNIST(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)
        classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
        model = FashionMNISTAlexNet().to(device)

    elif dataset == 'CIFAR10':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        train_data = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)
        classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

        model = models.alexnet(pretrained=True)
        model.classifier[1] = nn.Linear(9216, 4096)
        model.classifier[4] = nn.Linear(4096, 1024)
        model.classifier[6] = nn.Linear(1024, 10)
        model = model.to(device)

    elif dataset == 'PURCHASE':
        train_data = Purchase(train=True, download=True)
        test_data = Purchase(train=False, download=True)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        model = purchase_fully_connected_IN(100).to(device)

    elif dataset == 'CHMNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # Assuming CHMNIST is similar to MNIST
        test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
        testloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)
        model=models.mobilenet_v2(pretrained=True).to(device)
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)



    elif dataset == 'EMNIST':
        transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        train_data = torchvision.datasets.EMNIST(root='./data', split='byclass', train=True, download=True, transform=transform)
        test_data = torchvision.datasets.EMNIST(root='./data', split='byclass', train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        model = ThreeLayerDNN(input_size=28 * 28, hidden_size=512, output_size=62).to(device)

    elif dataset == 'MNIST':
        # Define transformation for MNIST
        transform = transforms.Compose([
            transforms.ToTensor(),        # Convert image to PyTorch tensor
            transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale values to [-1, 1]
        ])

        # Load the MNIST dataset ("ByClass" split as an example)
        train_data = torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        # Load the pre-trained VGG16 model
        model = mnist_fully_connected_IN(10).to(device)

    elif dataset == 'CIFAR100':
        # Define the transformation for the dataset (matching CLIP preprocessing)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # CLIP expects 224x224 input
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        # Load the CIFAR100 dataset
        train_data = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.CIFAR100(root=data_root, train=False, download=True, transform=transform)

        # Create DataLoader for train and test sets
        trainloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=8)
        testloader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=8)

        # Define the class labels for CIFAR100
        classes = [str(i) for i in range(100)]  # CIFAR100 has 100 classes

        # Load the CLIP model from OpenAI
        model_clip, preprocess = clip.load("ViT-B/32", device=device)

        # Convert CLIP model to float32 to match other layers and data
        model_clip = model_clip.float()

        # Freeze the CLIP model's parameters (we're only training the classifier)
        for param in model_clip.parameters():
            param.requires_grad = False

        # Define a simple 1-layer DNN model on top of CLIP features
        class CLIP_DNN(nn.Module):
            def __init__(self, clip_model, num_classes=100):
                super(CLIP_DNN, self).__init__()
                self.clip_model = clip_model
                self.fc = nn.Linear(512, num_classes)  # CLIP ViT-B/32 gives 512-dimensional features

            def forward(self, images):
                with torch.no_grad():
                    # Extract image features using CLIP's image encoder (cast to float32)
                    image_features = self.clip_model.encode_image(images).float()
                return self.fc(image_features)

        # Initialize the model
        model = CLIP_DNN(model_clip, num_classes=100)

        # Move the model to the device (GPU or CPU)
        model = model.to(device)

    elif dataset == 'FEMNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_data = FEMNISTDataset(root=data_root, train=True, download=True, transform=transform)
        test_data = FEMNISTDataset(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        model = ThreeLayerDNN(input_size=28 * 28, hidden_size=256, output_size=62).to(device)

    elif dataset == 'SYMBIPREDICT':
        # Load the CSV file
        csv_file = os.path.join(data_root, 'symbipredict_2022.csv')
        df = pd.read_csv(csv_file)

        # Encode target labels
        label_encoder = LabelEncoder()
        df['prognosis'] = label_encoder.fit_transform(df['prognosis'])

        # Separate features and labels
        X = df.drop(columns=['prognosis']).values
        y = df['prognosis'].values

        # Standardize features
        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        # Split into training and testing sets
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        # Create Dataset instances
        train_data = SymbiPredictDataset(X_train, y_train)
        test_data = SymbiPredictDataset(X_test, y_test)

        # DataLoader for test data only
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)

        # Define model
        input_dim = X_train.shape[1]
        num_classes = len(label_encoder.classes_)
        model = TabularNet(input_dim=input_dim, num_classes=num_classes).to(device)

    else:
        raise ValueError("Dataset not supported")

    return model, train_data, testloader


import random
from collections import defaultdict
from torch.utils.data import Subset


def create_noniid_dataset_symbipredict(symbipredict_data, num_clients, degree_noniid_p=0.5, num_classes=100):
    """
    Create a non-IID dataset for federated learning from the Purchase dataset.

    Args:
    - purchase_data: An instance of the Purchase dataset.
    - num_clients: Total number of clients to split the data among.
    - degree_noniid_p: Probability controlling the non-IID level.
    - num_classes: Total number of classes in the dataset.

    Returns:
    - clients_data: A dictionary with client IDs as keys and lists of indices (data points) as values.
    """
    # Initialize a dictionary to hold the clients' data
    clients_data = defaultdict(list)

    # Calculate dataset statistics by counting occurrences of each unique label
    labels = symbipredict_data.labels  # Replace 'labels' with the actual attribute name
    unique_labels = set(int(lbl) for lbl in labels)
    dataset_statistics = {int(label): 0 for label in unique_labels}

    for lbl in labels:
        dataset_statistics[int(lbl)] += 1

    # Group data by label
    label_indices = {int(label): [] for label in unique_labels}
    for idx, lbl in enumerate(labels):
        label_indices[int(lbl)].append(idx)  # Convert lbl to int to avoid KeyError

    # Shuffle indices for randomness
    for label in label_indices:
        random.shuffle(label_indices[label])

    # Split clients into groups by the number of unique labels
    clients_per_group = num_clients // len(unique_labels)

    # Allocate data to clients in a non-IID manner based on probability p
    for label, total_label_data in dataset_statistics.items():
        label_data = label_indices[label]
        split_point = int(degree_noniid_p * len(label_data))

        # Assign data with label `label` to clients in the current group with probability `p`
        for i in range(clients_per_group):
            client_id = label * clients_per_group + i
            clients_data[client_id].extend(label_data[:split_point // clients_per_group])

        # Distribute remaining data randomly among clients
        remaining_data = label_data[split_point:]
        for idx in remaining_data:
            random_client = random.randint(0, num_clients - 1)
            clients_data[random_client].append(idx)

    # Ensure each client has approximately the same number of data points
    total_data_points = len(labels)
    data_per_client = total_data_points // num_clients
    for client_id in range(num_clients):
        while len(clients_data[client_id]) < data_per_client:
            random_label = random.choice(list(unique_labels))
            if label_indices[random_label]:  # Only pop if there's remaining data
                clients_data[client_id].append(label_indices[random_label].pop())

    # Print the label distribution for each client
    for client_id, indices in clients_data.items():
        label_count = defaultdict(int)
        for idx in indices:
            label_count[int(labels[idx])] += 1  # Convert labels[idx] to int for consistency
        print(f"Client {client_id}: {dict(label_count)}")

    return clients_data

# Example usage:
# clients_data = create_noniid_dataset_purchase(purchase_data_instance, num_clients=100, degree_noniid_p=0.5, num_classes=100)
# This will print the label distribution for each client and return the clients' data indices.


def create_noniid_dataset_purchase(purchase_data, num_clients, degree_noniid_p=0.5, num_classes=100):
    """
    Create a non-IID dataset for federated learning from the Purchase dataset.

    Args:
    - purchase_data: An instance of the Purchase dataset.
    - num_clients: Total number of clients to split the data among.
    - degree_noniid_p: Probability controlling the non-IID level.
    - num_classes: Total number of classes in the dataset.

    Returns:
    - clients_data: A dictionary with client IDs as keys and lists of indices (data points) as values.
    """
    # Initialize a dictionary to hold the clients' data
    clients_data = defaultdict(list)

    # Calculate dataset statistics by counting occurrences of each unique label
    labels = purchase_data.targets
    unique_labels = set(int(lbl) for lbl in labels)
    dataset_statistics = {int(label): 0 for label in unique_labels}

    for lbl in labels:
        dataset_statistics[int(lbl)] += 1

    # Group data by label
    label_indices = {int(label): [] for label in unique_labels}
    for idx, lbl in enumerate(labels):
        label_indices[int(lbl)].append(idx)  # Convert lbl to int to avoid KeyError

    # Shuffle indices for randomness
    for label in label_indices:
        random.shuffle(label_indices[label])

    # Split clients into groups by the number of unique labels
    clients_per_group = num_clients // len(unique_labels)

    # Allocate data to clients in a non-IID manner based on probability p
    for label, total_label_data in dataset_statistics.items():
        label_data = label_indices[label]
        split_point = int(degree_noniid_p * len(label_data))

        # Assign data with label `label` to clients in the current group with probability `p`
        for i in range(clients_per_group):
            client_id = label * clients_per_group + i
            clients_data[client_id].extend(label_data[:split_point // clients_per_group])

        # Distribute remaining data randomly among clients
        remaining_data = label_data[split_point:]
        for idx in remaining_data:
            random_client = random.randint(0, num_clients - 1)
            clients_data[random_client].append(idx)

    # Ensure each client has approximately the same number of data points
    total_data_points = len(labels)
    data_per_client = total_data_points // num_clients
    for client_id in range(num_clients):
        while len(clients_data[client_id]) < data_per_client:
            random_label = random.choice(list(unique_labels))
            if label_indices[random_label]:  # Only pop if there's remaining data
                clients_data[client_id].append(label_indices[random_label].pop())

    # Print the label distribution for each client
    for client_id, indices in clients_data.items():
        label_count = defaultdict(int)
        for idx in indices:
            label_count[int(labels[idx])] += 1  # Convert labels[idx] to int for consistency
        print(f"Client {client_id}: {dict(label_count)}")

    return clients_data

# Example usage:
# clients_data = create_noniid_dataset_purchase(purchase_data_instance, num_clients=100, degree_noniid_p=0.5, num_classes=100)
# This will print the label distribution for each client and return the clients' data indices.

def create_noniid_dataset(train_data, num_clients, degree_noniid_p=0.5, num_classes=10):
    # Initialize a dictionary to hold the clients' data
    clients_data = defaultdict(list)

    # Split data and labels
    data = train_data.data
    labels = train_data.targets

    # Determine total number of data points and data per client
    total_data_points = len(data)
    data_per_client = total_data_points // num_clients  # Ensure equal data per client

    # Group data by label
    label_indices = {label: [] for label in range(num_classes)}
    for idx, lbl in enumerate(labels):
        label_indices[int(lbl)].append(idx)  # Convert lbl to int to avoid KeyError

    # Shuffle indices for randomness
    for label in label_indices:
        random.shuffle(label_indices[label])

    # Split clients into L groups, where L is the number of classes
    clients_per_group = num_clients // num_classes

    # Allocate data to clients in a non-IID manner based on probability p
    for label in range(num_classes):
        label_data = label_indices[label]
        split_point = int(degree_noniid_p * len(label_data))

        # Assign data with label `l` to clients in group `l` with probability `p`
        for i in range(clients_per_group):
            client_id = label * clients_per_group + i
            clients_data[client_id].extend(label_data[:split_point // clients_per_group])

        # Distribute remaining data randomly among clients
        remaining_data = label_data[split_point:]
        for idx in remaining_data:
            random_client = random.randint(0, num_clients - 1)
            clients_data[random_client].append(idx)

    # Ensure each client has the same number of data points
    for client_id in range(num_clients):
        while len(clients_data[client_id]) < data_per_client:
            random_label = random.randint(0, num_classes - 1)
            if label_indices[random_label]:
                clients_data[client_id].append(label_indices[random_label].pop())

    # Print the label distribution for each client
    for client_id, indices in clients_data.items():
        label_count = defaultdict(int)
        for idx in indices:
            label_count[int(labels[idx])] += 1  # Convert labels[idx] to int for consistency
        print(f"Client {client_id}: {dict(label_count)}")

    return clients_data

# Example usage (assuming you have a dataset like MNIST loaded as `train_data`):
# clients_data = create_noniid_dataset(train_data, num_clients=10, degree_noniid_p=0.5, num_classes=10)
# This will print the label distribution for each client and return the clients' data indices.


def get_data_loaders(clients_data, dataset, batch_size=64, specific_client_id=None):
    """
    Create DataLoader for each client from the non-IID split data or for a specific client.

    Args:
        clients_data: Dictionary containing client data indices.
        dataset: Original PyTorch dataset object.
        batch_size: Batch size for DataLoader.
        specific_client_id: If provided, return DataLoader for this specific client.
    Returns:
        A dictionary containing a single DataLoader if specific_client_id is given.
    """
    if specific_client_id is not None:
        # Fetch the indices for the specific client
        indices = clients_data[specific_client_id]
        indices = random.sample(indices, min(batch_size, len(indices)))
        client_subset = Subset(dataset, indices)
        client_loader = DataLoader(client_subset, batch_size=batch_size, shuffle=False, num_workers= 0)
        return client_loader

Aggregation Rules (
FedAvg / Mean,
Median,
Trimmed Mean,
Multi-Krum,
Clipped Clustering,
SignGuard)

In [None]:
def tr_mean(all_updates: torch.Tensor) -> torch.Tensor:
    """Apply Trimmed Mean aggregation with 20% assumed attackers."""
    sorted_updates = torch.sort(all_updates, dim=0)[0]
    num_clients = len(all_updates)
    n_attackers = round(0.2 * num_clients)
    if n_attackers != 0 and 2 * n_attackers < num_clients:
        ret = torch.mean(sorted_updates[n_attackers:-n_attackers], dim=0)
        # print("sorted_updates", sorted_updates)
        # print("num_clients", num_clients)
        # print("n_attackers", n_attackers)
        # print("ret", ret)
        return ret
    return torch.mean(sorted_updates, dim=0)

def multi_krum_optimized(local_updates: torch.Tensor):
    """
    Implements a memory-optimized version of the Multi-Krum aggregation rule with explicit deletion of local variables.
    Parameters:
    - local_updates: A tensor of shape (num_clients, num_params) containing the flattened model updates from each client.
    Returns:
    - The aggregated model update as a tensor of shape (num_params,).
    """
    num_clients = local_updates.size(0)
    byzantine_client_num = int(num_clients * 0.2)  # Assuming 20% are byzantine clients
    krum_limit = num_clients - byzantine_client_num - 2

    # Instead of computing a full pairwise distance matrix, compute distances incrementally
    scores = torch.zeros(num_clients)

    for i in range(num_clients):
        # Compute the squared L2 distances between client `i` and all other clients
        distances = torch.sum((local_updates - local_updates[i]) ** 2, dim=1)

        # Sort distances and ignore the first distance (which is 0, i.e., distance to itself)
        sorted_distances, _ = torch.sort(distances)

        # Sum the smallest `krum_limit` distances (ignore the first one)
        scores[i] = torch.sum(sorted_distances[1:krum_limit + 1])

        # Explicitly delete large tensors to free memory
        del distances, sorted_distances

    # Select the indices of the `krum_limit` clients with the lowest scores
    selected_indices = torch.topk(-scores, krum_limit, largest=True).indices

    # Average the updates of the selected clients
    aggregated_update = torch.mean(local_updates[selected_indices], dim=0)

    # Clean up memory before returning
    del scores, local_updates

    return aggregated_update, selected_indices

def clip_tensor_norm_(
    parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    max_norm = float(max_norm)
    norm_type = float(norm_type)

    if len(parameters) == 0:
        return torch.tensor(0.0)

    device = parameters[0].device

    if math.isinf(norm_type):
        norms = [p.detach().abs().max().to(device) for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(
            torch.cat(
                [
                    p.detach().view(-1).to(device)
                    for p in parameters
                    if p.dtype != torch.int64
                ]
            ),
            norm_type,
        )

    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. To disable "
            "this error and scale the gradients by the non-finite norm anyway, "
            "set `error_if_nonfinite=False`"
        )

    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

    for p in parameters:
        if p.dtype != torch.int64:
            p.mul_(clip_coef_clamped.to(p.device))

def Clippedclustering(updates: torch.Tensor):
    tau = 1e5
    l2norm_his = []

    # Calculate L2 norms in a single operation
    l2norms = [torch.norm(update).item() for update in updates]
    l2norm_his.extend(l2norms)

    threshold = np.median(l2norm_his)
    threshold = min(threshold, tau)

    # Clip tensor norms above the threshold
    for idx, l2 in enumerate(l2norms):
        if l2 > threshold:
            clip_tensor_norm_(updates[idx], threshold)

    num = len(updates)

    dis_max = 1 - torch.mm(
        updates, updates.t()
    ).cpu().numpy()  # Convert to numpy for AgglomerativeClustering

    # Handle boundary conditions for distance matrix
    dis_max = np.where(np.isinf(dis_max), 2.0, np.where(np.isnan(dis_max), 2.0, dis_max))

    # Hierarchical clustering
    clustering = AgglomerativeClustering(
        metric="precomputed", linkage="average", n_clusters=2
    )
    clustering.fit(dis_max)

    flag = 1 if np.sum(clustering.labels_) > num // 2 else 0
    S1_idxs = [idx for idx, label in enumerate(clustering.labels_) if label == flag]

    # Vectorized feature extraction
    num_para = len(updates[0])
    feature0 = (updates > 0).float().mean(dim=1)
    feature1 = (updates < 0).float().mean(dim=1)
    feature2 = (updates == 0).float().mean(dim=1)

    features = torch.stack([feature0, feature1, feature2], dim=1).cpu().numpy()

    # KMeans clustering
    kmeans = KMeans(n_clusters=2, random_state=0).fit(features)
    flag = 1 if np.sum(kmeans.labels_) > num // 2 else 0
    S2_idxs = [idx for idx, label in enumerate(kmeans.labels_) if label == flag]

    # Select intersection of both clustering methods
    selected_idxs = list(set(S1_idxs) & set(S2_idxs))

    # Return the mean of selected updates
    return torch.mean(updates[selected_idxs], dim=0)

def SignGuard(updates):
    # updates = updates.cpu()

    num = updates.shape[0]
    # Compute L2 norms across all dimensions except the first
    l2norms = torch.norm(updates, dim=tuple(range(1, updates.ndim)))

    # Compute the median using torch.median (stays on GPU)
    M = torch.median(l2norms)
    L = 0.1
    R = 3.0

    # Create a mask for S1 indices
    mask1 = (l2norms >= L * M) & (l2norms <= R * M)
    del l2norms, M  # Delete l2norms and M as they're no longer needed
    torch.cuda.empty_cache()

    # Flatten updates for feature computation
    updates_flat = updates.view(updates.shape[0], -1).cpu()
    num_para = updates_flat.size(1)

    # Compute features using vectorized operations
    positive_counts = (updates_flat > 0).sum(dim=1).float() / num_para
    negative_counts = (updates_flat < 0).sum(dim=1).float() / num_para
    zero_counts = (updates_flat == 0).sum(dim=1).float() / num_para

    features = torch.stack([positive_counts, negative_counts, zero_counts], dim=1).cpu().numpy()
    del updates_flat, positive_counts, negative_counts, zero_counts  # Clean up
    torch.cuda.empty_cache()

    # Perform KMeans clustering
    kmeans = KMeans(n_clusters=2, random_state=0).fit(features)
    labels = kmeans.labels_
    del kmeans, features  # Clean up CPU memory

    # Convert labels back to a CUDA tensor
    labels = torch.from_numpy(labels).to(device)

    # Determine the majority cluster
    flag = 1 if labels.sum() > num // 2 else 0

    # Create a mask for S2 indices
    mask2 = (labels == flag)
    del labels  # Delete labels as it's no longer needed
    torch.cuda.empty_cache()

    # Intersection of S1 and S2 indices
    inter_mask = mask1 & mask2
    del mask1, mask2  # Clean up masks
    torch.cuda.empty_cache()

    # Select the updates based on the intersection mask
    selected_updates = updates[inter_mask]
    del updates, inter_mask  # Delete updates and inter_mask
    torch.cuda.empty_cache()

    # Compute and return the mean of the selected updates
    result = torch.mean(selected_updates, dim=0)
    del selected_updates  # Clean up selected_updates
    torch.cuda.empty_cache()

    return result.to(device)


Attacks (XFED)

In [None]:
def get_mu_pairwise_distance(global_models, MU_MULTIPLIER):
    """Compute pairwise distance based deviation (mu)."""
    num_models = len(global_models)
    # print("num_models", num_models)
    # print("global_models", global_models)
    if num_models > 1:
        if isinstance(global_models, list):
            global_models_tensor = torch.vstack(global_models)
        else:
            global_models_tensor = global_models
        # Step 1: Calculate the centroid (mean vector)
        centroid = torch.mean(global_models_tensor, dim=0)
        # Step 2: Compute the Euclidean distance of each vector from the centroid
        # Step 3: Calculate the standard deviation of the distances
        distances = torch.norm(global_models_tensor - centroid, dim=1)
        std_dev = torch.sqrt(torch.dot(distances, distances) / num_models)
        mu = MU_MULTIPLIER * std_dev
        return mu
    else:
        return torch.tensor(0.0)

def xfed_c(user_grads, n_attackers, dev_type, len_global, global_model_data, global_models):
    all_updates = user_grads[:n_attackers].to(device)
    model_re = torch.mean(all_updates, dim=0).to(device)

    if dev_type == 'C_XFED_unit_vec':
        deviation = model_re / torch.norm(model_re)
    elif dev_type == 'C_XFED_sign':
        sgn_vec = torch.sign(model_re)
        deviation = sgn_vec / torch.norm(sgn_vec)
    elif dev_type == 'C_XFED_std':
        deviation = torch.std(all_updates, dim=0)

    if len_global > 1:
        # print(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(global_models, 3), get_mu_pairwise_distance(all_updates, 1))
        # mu = torch.max(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(global_models))
        # mu = torch.max(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(all_updates))
        mu = get_mu_pairwise_distance(all_updates, 1)
    else:
        mu = torch.tensor(1.0)

     # Calculate delta and malicious updates
    delta = mu * deviation
    mal_updates = (model_re - delta)

    # Create the final stacked tensor of updates
    # Combine the malicious updates with the rest of the user_grads
    mal_updates_list = [mal_updates] * n_attackers

    if isinstance(user_grads, list):
        user_grads_rest = user_grads[n_attackers:]
    else:
        user_grads_rest = [user_grads[i] for i in range(n_attackers, len(user_grads))]

    # Stack all updates into a single tensor to return
    final_updates = torch.stack(mal_updates_list + user_grads_rest)

    return final_updates




Attacks( VIRAT, FANG-TR-MEAN, FANG-KRUM, LIE)

In [None]:

def virat_min_max(user_grads, n_attackers, dev_type='VIRAT_unit_vec', epoch = 0, threshold=50):
    """Implement VIRAT Min-Max attack."""
    all_updates = user_grads[:n_attackers].to(device)
    model_re = torch.mean(all_updates, dim=0).to(device)

    if dev_type == 'VIRAT_unit_vec':
        deviation = model_re / torch.norm(model_re)
    elif dev_type == 'VIRAT_sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'VIRAT_std':
        deviation = torch.std(all_updates, dim=0)
    else:
        raise ValueError(f"Unknown deviation type: {dev_type}")

    lamda = torch.tensor([threshold], dtype=torch.float).to(device)
    threshold_diff = 1e-5
    lamda_fail = lamda.clone()
    lamda_succ = torch.tensor(0, dtype=torch.float).to(device)

    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = model_re - lamda * deviation
        distance = torch.norm(all_updates - mal_update, dim=1) ** 2
        max_d = torch.max(distance)

        if max_d <= max_distance:
            lamda_succ = lamda.clone()
            lamda += lamda_fail / 2
        else:
            lamda -= lamda_fail / 2

        lamda_fail /= 2
    print(f"For round {epoch}, max_distance", max_distance, "lamda_succ", lamda_succ)
    mal_update = (model_re - lamda_succ * deviation) #.to('cpu')
    mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
    return torch.cat((mal_updates, user_grads[n_attackers:]), dim=0)

def fang_attack_trmean_partial(user_grads, n_attackers):

    all_updates = user_grads[:n_attackers]
    model_re = torch.mean(all_updates, 0)
    model_std = torch.std(all_updates, 0)
    deviation = torch.sign(model_re)

    max_vector_low = model_re + 3 * model_std
    max_vector_hig = model_re + 4 * model_std
    min_vector_low = model_re - 4 * model_std
    min_vector_hig = model_re - 3 * model_std

    max_range = torch.cat((max_vector_low[:,None], max_vector_hig[:,None]), dim=1)
    min_range = torch.cat((min_vector_low[:,None], min_vector_hig[:,None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).to(device)

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    deviation_sign = (deviation > 0).type(torch.FloatTensor).to(device)

    mal_update = (torch.stack([deviation_sign] * max_rand.shape[1]).T * max_rand + torch.stack(
        [deviation_sign] * min_rand.shape[1]).T * min_rand).T

    return torch.cat((mal_update, user_grads[n_attackers:]), dim=0)

def compute_lambda_fang(all_updates, model_re, n_attackers):
    distances = []
    n_benign, d = all_updates.shape
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1)
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    distances[distances == 0] = 10000
    distances = torch.sort(distances, dim=1)[0]
    scores = torch.sum(distances[:, :n_benign - 2 - n_attackers], dim=1)
    min_score = torch.min(scores)
    term_1 = min_score / ((n_benign - n_attackers - 1) * torch.sqrt(torch.Tensor([d]))[0])
    max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0])

    return (term_1 + max_wre_dist)

def get_malicious_updates_fang(all_updates, model_re, deviation, n_attackers):

    lamda = 100 # compute_lambda_fang(all_updates, model_re, n_attackers)
    threshold = torch.tensor(1e-5)

    mal_updates = []
    while lamda > threshold:
        mal_update = (model_re - lamda * deviation)

        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        _, selected_indices = multi_krum_optimized(mal_updates)
        # print(f'len {len(mal_updates)}, lamda {lamda}, indices {selected_indices}')

        if torch.any(selected_indices < n_attackers):
            return mal_update

        lamda *= 0.5

    if not len(mal_updates):
        print(lamda, threshold)
        mal_update = (model_re - lamda * deviation)
    return mal_update

def lie_attack(user_grads, n_attackers, z):

    # Stack the gradients for the attackers
    all_updates = user_grads[:n_attackers]

    # Calculate mean and standard deviation of the attacker's updates
    avg = torch.mean(all_updates, dim=0)
    std = torch.std(all_updates, dim=0)

    # Generate malicious updates
    mal_update = avg + z * std

    mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
    return torch.cat((mal_updates, user_grads[n_attackers:]), dim=0)


Code for calculating Z value

In [None]:
z_values={(50,3):0.69847, (50,5):0.7054, (50,8):0.71904,(200,40):0.7291, (50,10):0.72575, (50,12):0.73891, (100,20):0.72907, (40, 8): 0.72575, (100,5):0.69497, (100,10):0.7054, (100,15):0.71566, (100,25):0.74215, (100, 30):0.75804}
# z value calculation code to execute lie attack
import math
# Update the value of m to 10
n=100
m = 30

# Recalculate s and z
s = math.floor(n / 2 + 1) - m
z = (n - m - s) / (n - m)
print(z)

0.7


Federated Learning Training

In [None]:

MU_MULTIPLIER = 3

def train_local_model(client_id, dataset, clients_data, global_model, batch_size, criterion, device, optimizer):

    sampled_loader= get_data_loaders(clients_data, dataset, batch_size, client_id)

    # Move the model to the assigned GPU device
    local_model = deepcopy(global_model).to(device)
    if optimizer == 'SGD':
        local_optimizer = optim.SGD(local_model.parameters(), lr=0.5, momentum=0.9)
    else:
        local_optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)

    for inputs, targets in sampled_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        # print(len(inputs), len(targets))
        local_optimizer.zero_grad()
        outputs = local_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(local_model.parameters(), max_norm=1.0)
        local_optimizer.step()

    # Collect model parameters for aggregation
    # local_params = torch.cat([param.data.view(-1).cpu() for param in local_model.parameters()])
    local_params = torch.cat([param.data.view(-1) for param in local_model.parameters()])


    # Cleanup
    del local_model,sampled_loader
    torch.cuda.empty_cache()

    return client_id, local_params


def federated_learning(num_clients=10, aggregation='MEAN', n_attackers=3, attack_type='XFED_unit_vec', dataset='CIFAR10', n_round=1000, batch_size=64, optim='SGD', iid_degree=0.1):
    """Main federated learning loop."""
    global_model, train_data, testloader = load_model(dataset)
    criterion = nn.CrossEntropyLoss()

    # Step 1: Split the dataset among clients
    if dataset == 'MNIST':
        clients_data = create_noniid_dataset(train_data, num_clients, degree_noniid_p=iid_degree, num_classes=10)
    elif dataset == 'SYMBIPREDICT':  # Add this condition
        clients_data = create_noniid_dataset_symbipredict(train_data, num_clients, degree_noniid_p=iid_degree, num_classes=10)  # Call the new function
    else:
        clients_data = create_noniid_dataset_purchase(train_data, num_clients, degree_noniid_p=iid_degree, num_classes=10)


    global_models, global_model_data = [], []

    with ThreadPoolExecutor(max_workers=THREAD_NUMBER) as executor:  # Adjust max_workers based on your system capabilities
        for epoch in range(n_round):
            global_model.train()
            local_models_data_diff = []

            # Delete the oldest item if size is greater then 25
            if len(global_models) > 4:
                del global_models[0]

            futures = [
                executor.submit(
                    train_local_model,
                    client_id,
                    train_data,
                    clients_data,
                    global_model,
                    batch_size,
                    criterion,
                    devices[client_id % len(devices)],  # Alternate between 'cuda:0' and 'cuda:1'
                    optim
                )
                for client_id in range(num_clients)
            ]

            # Collect results
            for future in as_completed(futures):
                client_id, local_params = future.result()
                local_models_data_diff.append(local_params)

            for i in range(torch.cuda.device_count()):
                torch.cuda.set_device(i)
                torch.cuda.empty_cache()
            if num_gpus > 0:
                torch.cuda.set_device(device)
            print(f'For round {epoch}, training done')
            # time.sleep(30)
            local_models_data = torch.stack(local_models_data_diff).to(device)
            del local_models_data_diff
            gc.collect()

            if attack_type.startswith('XFED'):
                for local_machine in range(n_attackers):
                    if attack_type == 'XFED_unit_vec':
                        deviation = local_models_data[local_machine] / torch.norm(local_models_data[local_machine])
                    elif attack_type == 'XFED_sign':
                        sgn_vec = torch.sign(local_models_data[local_machine])
                        deviation = sgn_vec / torch.norm(sgn_vec)
                    else:
                        raise ValueError("Invalid attack type")

                    if len(global_models) > 1:

                        # version 1
                        # print(f"\n\nFor round {epoch} and advNumber {local_machine}, mu", torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models))
                        # mu = torch.max(torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models, MU_MULTIPLIER))

                        # version 2
                        # mu = torch.norm(local_models_data[local_machine] - global_model_data)
                        # print(f"For round {epoch} and advNumber {local_machine}, mu: {mu}")

                        # version 3
                        global_distance, pairwise_distance = torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models, MU_MULTIPLIER=MU_MULTIPLIER)
                        # mu = torch.max(global_distance, pairwise_distance)
                        mu = pairwise_distance
                        # mu = (pairwise_distance + global_distance) / torch.tensor(2)
                        print(f"For round {epoch} and advNumber {local_machine}, mu: {mu}, global_distance: {global_distance}, pairwise_distance: {pairwise_distance}")

                    else:
                        mu = torch.tensor(1.0)

                    delta = mu * deviation

                    # print(f"\n\nFor round {epoch} and advNumber {local_machine}, mu: {mu}\ndeviation: {deviation}\ndelta: {delta}\nlocal model: {local_models_data[local_machine]}\nglobal model: {global_model_data}")
                    if epoch == 0:
                        local_models_data[local_machine] -= delta
                    else:
                        # del local_models_data[local_machine]
                        local_models_data[local_machine] = global_model_data - delta
                    # print(f"after update model after attack: {local_models_data[local_machine]}\n")
                    del deviation, delta, mu

            elif attack_type.startswith('VIRAT') and n_attackers > 0:
                local_models_data = virat_min_max(local_models_data, n_attackers, attack_type, epoch=epoch)

            elif attack_type.startswith('LIE') and n_attackers > 0:
                local_models_data = lie_attack(local_models_data, n_attackers, z_values[(num_clients, n_attackers)])

            elif attack_type =='FANG_TR_MEAN' and n_attackers > 0:
                local_models_data = fang_attack_trmean_partial(local_models_data, n_attackers)

            elif attack_type =='FANG_KRUM' and n_attackers > 0:
                attacker_grads = torch.stack(local_models_data[:n_attackers])
                agg_grads = torch.mean(attacker_grads, 0)
                deviation = torch.sign(agg_grads)
                mal_update = get_malicious_updates_fang(attacker_grads, agg_grads, deviation, n_attackers)
                local_models_data = [mal_update] * n_attackers + local_models_data[n_attackers:]

            elif attack_type.startswith('C_XFED'):
                local_models_data = xfed_c(local_models_data, n_attackers, attack_type, len(global_models), global_model_data, global_models )

            else:
                raise ValueError("Invalid attack type")

            print(f'For round {epoch}, attack done, Lenght of local_models_data:', len(local_models_data))

            # Aggregate model updates
            if aggregation == 'MEAN':
                global_model_data = torch.mean(local_models_data, dim=0)
            elif aggregation == 'MEDIAN':
                global_model_data = torch.median(local_models_data, dim=0)[0]
            elif aggregation == 'KRUM':
                # Check if local_models_data is already a tensor
                if isinstance(local_models_data, list):
                    global_model_data, _ = multi_krum_optimized(local_updates=local_models_data)
                else:
                    global_model_data, _ = multi_krum_optimized(local_updates=local_models_data)
            elif aggregation == 'TR-MEAN':
                global_model_data = tr_mean(local_models_data)
            elif aggregation == 'CC':
                global_model_data = Clippedclustering(local_models_data)
            elif aggregation == 'SignGuard':
                global_model_data = SignGuard(local_models_data)
            else:
                raise ValueError("Invalid aggregation method")

            if torch.isnan(global_model_data).any():
                raise ValueError("NaN detected in model aggregation")

            # Update global model
            start_idx = 0
            with torch.no_grad():
                for param in global_model.parameters():
                    param_size = param.numel()
                    param.copy_(global_model_data[start_idx:start_idx + param_size].view(param.shape))
                    start_idx += param_size

            global_models.append(global_model_data)

            print(f'For round {epoch}, aggregation done')
            last_ten_percent = int(n_round * 0.80)
            if epoch >= last_ten_percent or epoch%20 == 0:
                # Evaluate global model
                global_model.eval()
                global_model = global_model.to(device)
                correct = 0
                total = 0
                with torch.no_grad():
                    for images, labels in testloader:
                        images, labels = images.to(device), labels.to(device)
                        outputs = global_model(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                accuracy = 100 * correct / total
                print(f'Time {datetime.now()}: Accuracy on round {epoch}, total {num_clients}, attackers {n_attackers}, attack_type {attack_type}, aggregation {aggregation} is: {accuracy:.2f} %')

                # File path to save the accuracy log
                file_path = os.path.join(data_root, f'accuracy_{dataset}_{aggregation}_{attack_type}_{n_attackers}_mu{MU_MULTIPLIER}_{iid_degree}_log.txt')

                # Append accuracy to the file in the data_root location
                with open(file_path, 'a') as f:
                    f.write(f'Time {datetime.now()}: Accuracy on round {epoch}, dataset {dataset}, total {num_clients}, attackers {n_attackers}, attack_type {attack_type}, aggregation {aggregation} is: {accuracy:.2f} %\n')

            # global_model = global_model.to('cpu')
            del local_models_data
            torch.cuda.empty_cache()
            gc.collect()

    # Final cleanup after training
    del global_model, train_data, testloader, global_models, criterion



Example Execution

In [None]:

# Example execution

# Example execution
for attack_type in ['VIRAT_unit_vec', 'LIE']: # 'C_XFED_unit_vec', 'XFED_unit_vec', 'XFED_sign', 'VIRAT_unit_vec', 'LIE', 'FANG_TR_MEAN'
    for agg in ['MEDIAN', 'KRUM', 'CC', 'SignGuard']: # 'MEAN', 'MEDIAN', 'KRUM', 'TR-MEAN', 'CC', 'SignGuard'
        for attackers in [20]:
            for iid_degree in [0.1, 0.3, 0.5, 0.7, 0.9]:
                # torch.cuda.memory._record_memory_history()
                # federated_learning(num_clients=200, n_attackers=attackers, aggregation=agg, n_round=400, dataset='EMNIST', attack_type=attack_type, batch_size=256, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=62, n_attackers=attackers, aggregation=agg, n_round=500, dataset='FEMNIST', attack_type=attack_type, batch_size=256, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=40, n_attackers=attackers, aggregation=agg, n_round=250, dataset='FashionMNIST', attack_type=attack_type, batch_size=256, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=50, n_attackers=attackers, aggregation=agg, n_round=255, dataset='CIFAR10', attack_type=attack_type, batch_size=250, iid_degree=iid_degree)
                # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=1000, dataset='SVHN', attack_type=attack_type, batch_size=64, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=275, dataset='MNIST', attack_type=attack_type, batch_size=256, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=500, dataset='PURCHASE', attack_type=attack_type, batch_size=128, optim="SGD", iid_degree=iid_degree)
                # federated_learning(num_clients=40, n_attackers=attackers, aggregation=agg, n_round=300, dataset='CIFAR100', attack_type=attack_type, batch_size=250, optim="Adam", iid_degree=iid_degree)
                federated_learning(num_clients=200, n_attackers=attackers, aggregation=agg, n_round=50, dataset='SYMBIPREDICT', attack_type=attack_type, batch_size=256, optim="SGD")
                # torch.cuda.memory._dump_snapshot("cifar10.pickle")





Client 0: {0: 2, 1: 1, 3: 2, 5: 1, 6: 1, 7: 1, 9: 1, 13: 1, 14: 1, 17: 1, 20: 1, 24: 1, 25: 2, 29: 1, 30: 1, 31: 1, 40: 1}
Client 1: {0: 2, 4: 3, 11: 1, 12: 1, 17: 1, 19: 2, 21: 1, 23: 1, 24: 1, 26: 1, 28: 1, 31: 1, 34: 1, 36: 1, 38: 1, 39: 1}
Client 2: {0: 2, 2: 2, 4: 1, 5: 1, 7: 2, 8: 1, 11: 1, 18: 1, 19: 1, 22: 1, 26: 1, 27: 1, 29: 1, 30: 2, 31: 1, 33: 2, 34: 1, 35: 1, 36: 1, 39: 1, 40: 1}
Client 3: {0: 3, 2: 1, 7: 1, 9: 1, 19: 2, 24: 1, 29: 1, 30: 3, 31: 2, 39: 2, 32: 1, 15: 1}
Client 19: {0: 1, 4: 2, 6: 4, 8: 1, 10: 1, 13: 1, 14: 1, 20: 1, 21: 5, 28: 1, 31: 2, 33: 1, 38: 1, 40: 1}
Client 112: {0: 2, 4: 1, 6: 1, 8: 2, 12: 1, 21: 1, 22: 1, 27: 1, 28: 3, 29: 1, 39: 2, 17: 1, 33: 1, 2: 1}
Client 36: {0: 1, 1: 2, 3: 2, 8: 1, 9: 2, 11: 1, 13: 1, 14: 1, 28: 1, 30: 1, 37: 2, 39: 1, 4: 1, 26: 1, 32: 1}
Client 50: {0: 1, 3: 2, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 12: 2, 16: 1, 18: 1, 19: 1, 23: 1, 25: 1, 26: 2, 27: 1, 31: 1, 32: 1, 34: 2}
Client 6: {0: 1, 1: 4, 7: 2, 10: 1, 16: 1, 20: 1, 21: 1, 2

In [None]:
# from collections import Counter

# def print_dataset_statistics(dataset):
#     """Prints dataset statistics including sample count, classes, and sample count per class.

#     Args:
#         dataset: The PyTorch dataset object.
#     """
#     labels = [label for _, label in dataset]  # Extract all labels
#     label_counts = Counter(labels)            # Count label occurrences

#     num_samples = len(dataset)                # Total number of samples
#     num_classes = len(label_counts)            # Number of unique classes

#     print(f"Dataset Statistics:")
#     print(f"  Number of samples: {num_samples}")
#     print(f"  Number of classes: {num_classes}")
#     print(f"  Sample count per class:")
#     for label, count in label_counts.items():
#         print(f"    Label {label}: {count} samples")

# # Example usage:
# global_model, train_data, testloader = load_model('FEMNIST')
# print_dataset_statistics(train_data)  # Assuming 'train_data' is your dataset