In [2]:
# Function to determine the data root directory
import os
def get_data_root():
    if 'COLAB_GPU' in os.environ:
        # Mount Google Drive if needed
        from google.colab import drive
        drive.mount('/content/drive')
        data_root = '/content/drive/MyDrive/PhD/XFED result/Result XFED log/colab output/'
    else:
        data_root = './data/'
    return data_root

# Get the appropriate data root directory
data_root = get_data_root()

Mounted at /content/drive


In [3]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"  # or ":16:8"

import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Function to attempt to import a module, and install it if not present
def try_import(module_name, package_name=None):
    try:
        module = __import__(module_name)
        return module
    except ImportError:
        if package_name is None:
            package_name = module_name
        print(f"Installing {package_name}...")
        install(package_name)
        module = __import__(module_name)
        return module

# Standard library imports (no need to install)
import logging
from datetime import datetime
from copy import deepcopy
import gc
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Iterable, Union, Optional

# Third-party imports
torch = try_import('torch')
torchvision = try_import('torchvision')
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
np = try_import('numpy')
# Import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset, Dataset, ConcatDataset
import torchvision.models as models
from torch.nn.functional import tanh, softmax

from torchvision.datasets import utils
from PIL import Image
import os.path
import shutil


# sklearn imports
sklearn = try_import('sklearn', 'scikit-learn')
from sklearn.model_selection import train_test_split
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import cosine_distances,euclidean_distances
from sklearn.metrics import pairwise_distances
import sklearn.metrics.pairwise as smp

import hdbscan

# Other third-party imports
plt = try_import('matplotlib.pyplot', 'matplotlib')
pd = try_import('pandas')



# torch.use_deterministic_algorithms(True, warn_only=True)
torch.manual_seed(0)

# Device configuration
# Get the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

# If GPUs are available, choose the desired device index (within the available range)
# Otherwise, default to CPU
if num_gpus > 0:
    desired_gpu_index = 3  # This is the index you originally wanted
    device_index = min(desired_gpu_index, num_gpus - 1)  # Clamp to available range
    device = torch.device(f"cuda:{device_index}")
    torch.cuda.set_device(device)  # Set the device
    print(f"Using GPU: {device}")
else:
    device = torch.device("cpu")
    print("No GPUs available, using CPU.")


#set devices to multiple GPUs
unwanted_device_indices = []
available_device_indices = list(range(num_gpus))
devices = [f'cuda:{i}' for i in available_device_indices if i not in unwanted_device_indices]
if not devices:
    devices = ['cpu']
    # raise RuntimeError("Desired GPUs are not available.")
print(f"Devices: {devices}")




import multiprocessing

# Get the number of available CPU cores
num_cores = multiprocessing.cpu_count()

# Set THREAD_NUMBER to the number of CPU cores
THREAD_NUMBER = min(num_cores, 2*(len(devices)))
# THREAD_NUMBER = 20 # num_cores

print(f"Number of CPU cores available: {num_cores}")
print(f"THREAD_NUMBER set to: {THREAD_NUMBER}")



# Check GPU information
def check_gpu():
    try:
        gpu_info = subprocess.check_output(['nvidia-smi']).decode('utf-8')
        print(gpu_info)
    except Exception as e:
        print('Not connected to a GPU or nvidia-smi not found.')

check_gpu()

# Check CPU information
def check_cpu():
    try:
        cpu_info = subprocess.check_output(['lscpu']).decode('utf-8')
        print(cpu_info)
    except Exception as e:
        print('Could not retrieve CPU information.')

check_cpu()

Number of GPUs available: 1
Using GPU: cuda:0
Devices: ['cuda:0']
Number of CPU cores available: 12
THREAD_NUMBER set to: 2
Tue Apr 15 06:55:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   36C    P0             46W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+---------------

# Function to determine the data root directory

# Model Definition for different datasets

In [4]:
class FashionMNISTAlexNet(nn.Module):
    def __init__(self):
        super(FashionMNISTAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FeatureNorm(nn.Module):
    def __init__(self, feature_shape):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1, feature_shape))

    def forward(self, x):
        x = torch.einsum('ni, j->ni', x, self.gamma)
        x = x + self.beta
        return  x

class purchase_fully_connected_IN(nn.Module):
    def __init__(self, num_classes):
        super(purchase_fully_connected_IN, self).__init__()
        self.fc1 = nn.Linear(600, 1024, bias=False)  # First layer: input size 600, output size 1024
        self.fc2 = nn.Linear(1024, 100, bias=False)  # Second layer: input size 1024, output size 100
        self.fc3 = nn.Linear(100, num_classes, bias=False)  # Output layer: input size 100, output size num_classes
        self.norm = FeatureNorm(600)

    def forward(self, x):
        x = self.norm(x)
        x = torch.tanh(self.fc1(x))  # Apply tanh activation after the first layer
        x = torch.tanh(self.fc2(x))  # Apply tanh activation after the second layer
        logits = self.fc3(x)         # Output layer, no activation
        return logits

class Purchase(torch.utils.data.Dataset):
    def __init__(self, root =data_root + 'dataset_purchase',train=True, download=True, transform = None):
        self.images = []
        self.root = root
        self.targets = []
        self.train = train
        self.download = download
        self.transform = transform

        x_train, x_test, y_train, y_test = self._train_test_split()

        if self.train:
            self._setup_dataset(x_train, y_train)
        else:
            self._setup_dataset(x_test, y_test)

    def _train_test_split(self):
        df = pd.read_csv(self.root)

        img_names = df.iloc[:, 1:].to_numpy(dtype='f')
        img_label = df.iloc[:, 0].to_numpy()-1
        x_train,x_test, y_train, y_test = train_test_split(img_names, img_label, train_size=0.8,
                                                            random_state=1)
        return x_train, x_test, y_train, y_test

    def _setup_dataset(self, x, y):
            self.images = x
            self.targets = y

    def __len__(self): # Added the __len__ method
        return len(self.images)

    def __getitem__(self, item):
        img = self.images[item]
        label = self.targets[item]
        return img, label

class ThreeLayerDNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ThreeLayerDNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

class FourLayerDNN(nn.Module):
    def __init__(self):
        super(FourLayerDNN, self).__init__()
        # Flatten the input image
        self.flatten = nn.Flatten()
        # Define the fully connected layers
        self.fc1 = nn.Linear(3 * 32 * 32, 1024)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(1024, 512)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(512, 256)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(256, 10)  # Output layer for 10 classes

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.relu3(self.fc3(x))
        x = self.fc4(x)
        return x

class InputNorm(nn.Module):
    def __init__(self, num_channel, num_feature):
        super().__init__()
        self.num_channel = num_channel
        self.gamma = nn.Parameter(torch.ones(num_channel))
        self.beta = nn.Parameter(torch.zeros(num_channel, num_feature, num_feature))
    def forward(self, x):
        if self.num_channel == 1:
            x = self.gamma*x
            x = x + self.beta
            return  x
        if self.num_channel == 3:
            return torch.einsum('...ijk, i->...ijk', x, self.gamma) + self.beta

class mnist_fully_connected_IN(nn.Module):
    def __init__(self,num_classes):
        super(mnist_fully_connected_IN, self).__init__()
        self.hidden1 = 600
        self.hidden2 = 100
        self.fc1 = nn.Linear(28 * 28, self.hidden1, bias=False)
        self.fc2 = nn.Linear(self.hidden1, self.hidden2, bias=False)
        self.fc3 = nn.Linear(self.hidden2, num_classes, bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.norm = InputNorm(1, 28)

    def forward(self,x):
        x = self.norm(x)
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

class CHMNISTDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('L')  # Convert to grayscale
        label = int(image_path.split('_')[-1].split('.')[0])  # Assuming label is in filename

        if self.transform:
            image = self.transform(image)

        return image, label


class HARLogisticRegression(nn.Module):
    """
    A single-layer logistic regression model for multi-class HAR classification.
    """
    def __init__(self, input_dim, num_classes=6):
        super(HARLogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)  # raw logits

    def forward(self, x):
        return self.linear(x)  # No softmax/sigmoid; use CrossEntropyLoss externally


class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        # Define the linear layer for logistic regression
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        # Apply the linear layer and then the sigmoid activation
        out = torch.sigmoid(self.linear(x))
        return out

class FEMNISTDataset(torchvision.datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(torchvision.datasets.MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
        self.download = download
        self.download_link = 'https://media.githubusercontent.com/media/GwenLegate/femnist-dataset-PyTorch/main/femnist.tar.gz'
        self.file_md5 = 'a8a28afae0e007f1acb87e37919a21db'
        self.train = train
        self.root = root
        self.training_file = f'{self.root}/FEMNIST/processed/femnist_train.pt'
        self.test_file = f'{self.root}/FEMNIST/processed/femnist_test.pt'
        self.user_list = f'{self.root}/FEMNIST/processed/femnist_user_keys.pt'

        if not os.path.exists(f'{self.root}/FEMNIST/processed/femnist_test.pt') \
                or not os.path.exists(f'{self.root}/FEMNIST/processed/femnist_train.pt'):
            if self.download:
                self.dataset_download()
            else:
                raise RuntimeError('Dataset not found, set parameter download=True to download')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        data_targets_users = torch.load(data_file)
        self.data, self.targets, self.users = torch.Tensor(data_targets_users[0]), torch.Tensor(data_targets_users[1]), data_targets_users[2]
        self.user_ids = torch.load(self.user_list)

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])

        # Reshape the flattened image to 28x28
        img = img.view(28, 28).numpy().astype(np.uint8)

        # Convert to PIL Image in grayscale mode
        img = Image.fromarray(img, mode='L')

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target  # Return only img and target

    def dataset_download(self):
        paths = [f'{self.root}/FEMNIST/raw/', f'{self.root}/FEMNIST/processed/']
        for path in paths:
            if not os.path.exists(path):
                os.makedirs(path)

        # download files
        filename = self.download_link.split('/')[-1]
        utils.download_and_extract_archive(self.download_link, download_root=f'{self.root}/FEMNIST/raw/', filename=filename, md5=self.file_md5)

        files = ['femnist_train.pt', 'femnist_test.pt', 'femnist_user_keys.pt']
        for file in files:
            # move to processed dir
            shutil.move(os.path.join(f'{self.root}/FEMNIST/raw/', file), f'{self.root}/FEMNIST/processed/')

Loading Model for different datasets (FashionMNIST, CIFAR-10, PURCHASE, MNIST, EMNIST, CIFAR-100)

In [5]:
from sklearn.preprocessing import LabelEncoder, StandardScaler
class SymbiPredictDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = torch.tensor(self.data[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return features, label

class TabularNet(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(TabularNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class HARLocalDataset(Dataset):
    """
    A simple PyTorch Dataset for a subset of HAR data (features + labels).
    """
    def __init__(self, features, labels):
        # Convert to PyTorch tensors
        self.features = torch.tensor(features, dtype=torch.float32)

        # Convert labels to numerical if they are not already
        if labels.dtype == np.object_:
            from sklearn.preprocessing import LabelEncoder
            encoder = LabelEncoder()
            labels = encoder.fit_transform(labels)

        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


def load_model(dataset: str):
    """Load and prepare the model and datasets based on the given dataset name."""
    if dataset == 'FashionMNIST':
        transform = transforms.Compose([
            transforms.Resize((227, 227)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_data = torchvision.datasets.FashionMNIST(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.FashionMNIST(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)
        classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
        model = FashionMNISTAlexNet().to(device)

    elif dataset == 'FashionMNIST_3DNN':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_data = torchvision.datasets.FashionMNIST(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.FashionMNIST(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)
        classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
        model = ThreeLayerDNN(input_size=784, hidden_size=512, output_size=10).to(device)

    elif dataset == 'CIFAR10':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        train_data = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)
        classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

        model = models.alexnet(pretrained=True)
        model.classifier[1] = nn.Linear(9216, 4096)
        model.classifier[4] = nn.Linear(4096, 1024)
        model.classifier[6] = nn.Linear(1024, 10)
        model = model.to(device)

    elif dataset == 'PURCHASE':
        train_data = Purchase(train=True, download=True)
        test_data = Purchase(train=False, download=True)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        model = purchase_fully_connected_IN(100).to(device)

    elif dataset == 'CHMNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # Assuming CHMNIST is similar to MNIST
        test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
        testloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)
        model=models.mobilenet_v2(pretrained=True).to(device)
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

    elif dataset == 'HAR':
        # Path to your CSV with 563 columns total: 561 features, then 'subject', then 'Activity'
        har_csv_path = os.path.join(data_root, 'har_all_in_one.csv')  # Adjust path as needed

        # Read the CSV. If your CSV has column names as in your screenshot, you can keep header=0
        # But if you have no row of column names, do header=None. Adjust as needed.
        df = pd.read_csv(har_csv_path, header=0)

        # Suppose:
        #   columns [0..560] => 561 features
        #   column 561 => subject in [1..30]
        #   column 562 => activity in [1..6] or string labels
        X = df.iloc[:, :561].values      # shape (N, 561)
        subjects = df.iloc[:, 561].values
        activity = df.iloc[:, 562].values

        # If activity is integer [1..6], but we want [0..5] for CrossEntropyLoss, do:
        # activity = activity - 1  # now [0..5]

        # Standardize all features
        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        # We'll build "per-subject" train/test sets
        train_datasets = []
        test_datasets = []

        # Identify unique subject IDs
        unique_subjects = np.unique(subjects)
        print("Subjects found:", unique_subjects)

        for subj_id in unique_subjects:
            # Gather rows for this subject
            subj_mask = (subjects == subj_id)
            X_sub = X[subj_mask]
            y_sub = activity[subj_mask]

            # 75/25 train/test for THIS subject
            X_train_sub, X_test_sub, y_train_sub, y_test_sub = train_test_split(
                X_sub, y_sub, test_size=0.25, random_state=42
            )

            # Wrap them in Datasets
            ds_train_sub = HARLocalDataset(X_train_sub, y_train_sub)
            ds_test_sub  = HARLocalDataset(X_test_sub,  y_test_sub)

            train_datasets.append(ds_train_sub)
            test_datasets.append(ds_test_sub)

        # Concat all per-subject train sets into one large train_data, likewise for test sets
        train_data = ConcatDataset(train_datasets)
        test_data  = ConcatDataset(test_datasets)

        # Build a testloader for the entire test set
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)

        # Build logistic regression model
        input_dim = 561
        # If your activity is 0..5, then num_classes=6
        # If you have 6 distinct string labels, also 6 total classes after label encoding
        num_classes = len(np.unique(activity))
        model = HARLogisticRegression(input_dim, num_classes).to(device)

    elif dataset == 'EMNIST':
        transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        train_data = torchvision.datasets.EMNIST(root='./data', split='byclass', train=True, download=True, transform=transform)
        test_data = torchvision.datasets.EMNIST(root='./data', split='byclass', train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        model = ThreeLayerDNN(input_size=28 * 28, hidden_size=512, output_size=62).to(device)

    elif dataset == 'MNIST':
        # Define transformation for MNIST
        transform = transforms.Compose([
            transforms.ToTensor(),        # Convert image to PyTorch tensor
            transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale values to [-1, 1]
        ])

        # Load the MNIST dataset ("ByClass" split as an example)
        train_data = torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

        # Load the pre-trained VGG16 model
        model = mnist_fully_connected_IN(10).to(device)

    elif dataset == 'CIFAR100':
        # Install and import CLIP
        try:
            import clip
        except ImportError:
            print("Installing CLIP...")
            install('git+https://github.com/openai/CLIP.git')
            import clip
        # Define the transformation for the dataset (matching CLIP preprocessing)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # CLIP expects 224x224 input
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        # Load the CIFAR100 dataset
        train_data = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform)
        test_data = torchvision.datasets.CIFAR100(root=data_root, train=False, download=True, transform=transform)

        # Create DataLoader for train and test sets
        trainloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=8)
        testloader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=8)

        # Define the class labels for CIFAR100
        classes = [str(i) for i in range(100)]  # CIFAR100 has 100 classes

        # Load the CLIP model from OpenAI
        model_clip, preprocess = clip.load("ViT-B/32", device=device)

        # Convert CLIP model to float32 to match other layers and data
        model_clip = model_clip.float()

        # Freeze the CLIP model's parameters (we're only training the classifier)
        for param in model_clip.parameters():
            param.requires_grad = False

        # Define a simple 1-layer DNN model on top of CLIP features
        class CLIP_DNN(nn.Module):
            def __init__(self, clip_model, num_classes=100):
                super(CLIP_DNN, self).__init__()
                self.clip_model = clip_model
                self.fc = nn.Linear(512, num_classes)  # CLIP ViT-B/32 gives 512-dimensional features

            def forward(self, images):
                with torch.no_grad():
                    # Extract image features using CLIP's image encoder (cast to float32)
                    image_features = self.clip_model.encode_image(images).float()
                return self.fc(image_features)

        # Initialize the model
        model = CLIP_DNN(model_clip, num_classes=100)

        # Move the model to the device (GPU or CPU)
        model = model.to(device)


    elif dataset == 'SYMBIPREDICT':
        # Load the CSV file
        csv_file = os.path.join(data_root, 'symbipredict_2022.csv')
        df = pd.read_csv(csv_file)

        # Encode target labels
        label_encoder = LabelEncoder()
        df['prognosis'] = label_encoder.fit_transform(df['prognosis'])

        # Separate features and labels
        X = df.drop(columns=['prognosis']).values
        y = df['prognosis'].values

        # Standardize features
        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        # Split into training and testing sets
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        # Create Dataset instances
        train_data = SymbiPredictDataset(X_train, y_train)
        test_data = SymbiPredictDataset(X_test, y_test)

        # DataLoader for test data only
        testloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)

        # Define model
        input_dim = X_train.shape[1]
        num_classes = len(label_encoder.classes_)
        model = TabularNet(input_dim=input_dim, num_classes=num_classes).to(device)

    else:
        raise ValueError("Dataset not supported")

    return model, train_data, testloader


Aggregation Rules (
FedAvg / Mean,
Median,
Trimmed Mean,
Multi-Krum,
Clipped Clustering,
SignGuard)

In [6]:
def tr_mean(all_updates: torch.Tensor) -> torch.Tensor:
    """Apply Trimmed Mean aggregation with 20% assumed attackers."""
    # tmp = all_updates
    # all_updates = all_updates.cpu()
    # del tmp
    # torch.cuda.empty_cache()
    sorted_updates = torch.sort(all_updates, dim=0)[0]
    num_clients = len(all_updates)
    n_attackers = round(0.2 * num_clients)
    if n_attackers != 0 and 2 * n_attackers < num_clients:
        ret = torch.mean(sorted_updates[n_attackers:-n_attackers], dim=0)
        # print("sorted_updates", sorted_updates)
        # print("num_clients", num_clients)
        # print("n_attackers", n_attackers)
        # print("ret", ret)
        return ret
    return torch.mean(sorted_updates, dim=0).to(device)

def multi_krum_optimized(local_updates: torch.Tensor):
    """
    Implements a memory-optimized version of the Multi-Krum aggregation rule with explicit deletion of local variables.
    Parameters:
    - local_updates: A tensor of shape (num_clients, num_params) containing the flattened model updates from each client.
    Returns:
    - The aggregated model update as a tensor of shape (num_params,).
    """
    num_clients = local_updates.size(0)
    byzantine_client_num = int(num_clients * 0.2)  # Assuming 20% are byzantine clients
    krum_limit = num_clients - byzantine_client_num - 2

    # Instead of computing a full pairwise distance matrix, compute distances incrementally
    scores = torch.zeros(num_clients)

    for i in range(num_clients):
        # Compute the squared L2 distances between client `i` and all other clients
        distances = torch.sum((local_updates - local_updates[i]) ** 2, dim=1)

        # Sort distances and ignore the first distance (which is 0, i.e., distance to itself)
        sorted_distances, _ = torch.sort(distances)

        # Sum the smallest `krum_limit` distances (ignore the first one)
        scores[i] = torch.sum(sorted_distances[1:krum_limit + 1])

        # Explicitly delete large tensors to free memory
        del distances, sorted_distances

    # Select the indices of the `krum_limit` clients with the lowest scores
    selected_indices = torch.topk(-scores, krum_limit, largest=True).indices

    # Average the updates of the selected clients
    aggregated_update = torch.mean(local_updates[selected_indices], dim=0)

    # Clean up memory before returning
    del scores, local_updates

    return aggregated_update, selected_indices

def clip_tensor_norm_(
    parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    max_norm = float(max_norm)
    norm_type = float(norm_type)

    if len(parameters) == 0:
        return torch.tensor(0.0)

    device = parameters[0].device

    if norm_type ==  float('inf'):
        norms = [p.detach().abs().max().to(device) for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(
            torch.cat(
                [
                    p.detach().view(-1).to(device)
                    for p in parameters
                    if p.dtype != torch.int64
                ]
            ),
            norm_type,
        )

    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. To disable "
            "this error and scale the gradients by the non-finite norm anyway, "
            "set `error_if_nonfinite=False`"
        )

    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

    for p in parameters:
        if p.dtype != torch.int64:
            p.mul_(clip_coef_clamped.to(p.device))

def Clippedclustering(updates: torch.Tensor):
    tau = 1e5
    l2norm_his = []

    # Calculate L2 norms in a single operation
    l2norms = [torch.norm(update).item() for update in updates]
    l2norm_his.extend(l2norms)

    threshold = np.median(l2norm_his)
    threshold = min(threshold, tau)

    # Clip tensor norms above the threshold
    for idx, l2 in enumerate(l2norms):
        if l2 > threshold:
            clip_tensor_norm_(updates[idx], threshold)

    num = len(updates)

    dis_max = 1 - torch.mm(
        updates, updates.t()
    ).cpu().numpy()  # Convert to numpy for AgglomerativeClustering

    # Handle boundary conditions for distance matrix
    dis_max = np.where(np.isinf(dis_max), 2.0, np.where(np.isnan(dis_max), 2.0, dis_max))

    # Hierarchical clustering
    clustering = AgglomerativeClustering(
        metric="precomputed", linkage="average", n_clusters=2
    )
    clustering.fit(dis_max)

    flag = 1 if np.sum(clustering.labels_) > num // 2 else 0
    S1_idxs = [idx for idx, label in enumerate(clustering.labels_) if label == flag]

    # Vectorized feature extraction
    num_para = len(updates[0])
    feature0 = (updates > 0).float().mean(dim=1)
    feature1 = (updates < 0).float().mean(dim=1)
    feature2 = (updates == 0).float().mean(dim=1)

    features = torch.stack([feature0, feature1, feature2], dim=1).cpu().numpy()

    # KMeans clustering
    kmeans = KMeans(n_clusters=2, random_state=0).fit(features)
    flag = 1 if np.sum(kmeans.labels_) > num // 2 else 0
    S2_idxs = [idx for idx, label in enumerate(kmeans.labels_) if label == flag]

    # Select intersection of both clustering methods
    selected_idxs = list(set(S1_idxs) & set(S2_idxs))

    # Return the mean of selected updates
    return torch.mean(updates[selected_idxs], dim=0)

def SignGuard(updates):
    # updates = updates.cpu()

    num = updates.shape[0]
    # Compute L2 norms across all dimensions except the first
    l2norms = torch.norm(updates, dim=tuple(range(1, updates.ndim)))

    # Compute the median using torch.median (stays on GPU)
    M = torch.median(l2norms)
    L = 0.1
    R = 3.0

    # Create a mask for S1 indices
    mask1 = (l2norms >= L * M) & (l2norms <= R * M)
    del l2norms, M  # Delete l2norms and M as they're no longer needed
    torch.cuda.empty_cache()

    # Flatten updates for feature computation
    updates_flat = updates.view(updates.shape[0], -1).cpu()
    num_para = updates_flat.size(1)

    # Compute features using vectorized operations
    positive_counts = (updates_flat > 0).sum(dim=1).float() / num_para
    negative_counts = (updates_flat < 0).sum(dim=1).float() / num_para
    zero_counts = (updates_flat == 0).sum(dim=1).float() / num_para

    features = torch.stack([positive_counts, negative_counts, zero_counts], dim=1).cpu().numpy()
    del updates_flat, positive_counts, negative_counts, zero_counts  # Clean up
    torch.cuda.empty_cache()

    # Perform KMeans clustering
    kmeans = KMeans(n_clusters=2, random_state=0).fit(features)
    labels = kmeans.labels_
    del kmeans, features  # Clean up CPU memory

    # Convert labels back to a CUDA tensor
    labels = torch.from_numpy(labels).to(device)

    # Determine the majority cluster
    flag = 1 if labels.sum() > num // 2 else 0

    # Create a mask for S2 indices
    mask2 = (labels == flag)
    del labels  # Delete labels as it's no longer needed
    torch.cuda.empty_cache()

    # Intersection of S1 and S2 indices
    inter_mask = mask1 & mask2
    del mask1, mask2  # Clean up masks
    torch.cuda.empty_cache()

    # Select the updates based on the intersection mask
    selected_updates = updates[inter_mask]
    del updates, inter_mask  # Delete updates and inter_mask
    torch.cuda.empty_cache()

    # Compute and return the mean of the selected updates
    result = torch.mean(selected_updates, dim=0)
    del selected_updates  # Clean up selected_updates
    torch.cuda.empty_cache()

    return result.to(device)



##########################
# FLTrust aggregator
##########################
def FLTrust(global_model_data: torch.Tensor,
            local_params_all: torch.Tensor,
            server_params: torch.Tensor) -> torch.Tensor:
    """
    Option B approach:
      - global_model_data: the old global param vector (1D).
      - local_params_all: shape (num_clients, num_params), each row is the
        final parameter vector from that client.
      - server_params: the final parameter vector from the server's root model training.
    Returns:
      aggregated_diff: The aggregated difference vector to add to global_model_data.
    """

    # 1) Convert each client's final param to difference from old global
    #    local_updates[i] = local_params[i] - old_global_model
    client_updates = local_params_all - global_model_data

    # 2) Server anchor update = server_params - old global model
    anchor_update = server_params - global_model_data

    # 3) Trust score = ReLU(cosine_similarity)
    TS = F.relu(F.cosine_similarity(client_updates, anchor_update.unsqueeze(0), dim=1))
    sum_ts = TS.sum()
    if sum_ts == 0:
        # Fallback: average the client updates
        return torch.mean(client_updates, dim=0)

    TS /= sum_ts

    # 4) Magnitude normalization
    client_norms = torch.norm(client_updates, dim=1, keepdim=True)
    anchor_norm = torch.norm(anchor_update)
    client_norms[client_norms == 0] = 1e-9

    normed_updates = client_updates / client_norms * anchor_norm

    # 5) Weighted sum
    agg_update = (normed_updates * TS.unsqueeze(1)).sum(dim=0)
    return agg_update


import numpy as np
import torch
import sklearn.metrics.pairwise as smp
import hdbscan

def flame(
    local_updates: torch.Tensor,
    global_model: torch.nn.Module,
    device: torch.device,
    num_clients: int,
    lamda: float = 0.001
) -> torch.Tensor:
    """
    FLAME aggregator function, closely matching the design in
    the FLAME paper (Nguyen et al., USENIX Security 2022).

    Steps:
      1) Convert local updates to float64 and compute pairwise cosine distances.
      2) HDBSCAN with min_cluster_size = (num_clients//2 + 1).
      3) If largest cluster < 0.5 * num_clients, fallback => treat all as benign.
      4) Norm-clipping using median of all update norms as clip_value.
      5) Average those clipped, "benign" updates => aggregated_update.
      6) Add random Gaussian noise: lamda * clip_value => final aggregator.

    Args:
      local_updates: (num_clients, param_dim) float32 PyTorch tensor
      global_model: current global model (unused here, but included for consistency)
      device: the GPU/CPU device
      num_clients: number of local updates
      lamda: scaling factor for the noise

    Returns:
      aggregated_update: shape (param_dim,) final aggregator vector
    """

    # 1) Convert local updates to double precision (float64) for HDBSCAN
    updates_np = local_updates.cpu().numpy().astype(np.float64)

    # 2) Compute pairwise cosine distances => shape (num_clients, num_clients)
    cd = smp.cosine_distances(updates_np)  # float64 now

    # HDBSCAN with min_cluster_size ~ majority
    min_cluster = max(2, (num_clients // 2) + 1)
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster,
        min_samples=1,
        allow_single_cluster=True,
        metric='precomputed'
    ).fit(cd)

    cluster_labels = clusterer.labels_  # array of size num_clients
    print("cluster_labels =", cluster_labels)
    print("Number of outliers =", np.sum(cluster_labels == -1))

    # 3) Identify largest cluster; if all outliers => treat all as benign
    if cluster_labels.max() < 0:
        # all outliers => fallback
        benign_ids = list(range(num_clients))
    else:
        # find the largest cluster
        max_cluster_index = None
        max_cluster_size  = 0
        for cl_idx in range(cluster_labels.max() + 1):
            size_cl = np.sum(cluster_labels == cl_idx)
            if size_cl > max_cluster_size:
                max_cluster_size = size_cl
                max_cluster_index = cl_idx

        # pick all that belong to that cluster
        benign_ids = [i for i in range(num_clients) if cluster_labels[i] == max_cluster_index]

        # fallback if the largest cluster is too small
        if max_cluster_size < 0.5 * num_clients:
            print("FLAME fallback: largest cluster < 50%, treat all as benign.")
            benign_ids = list(range(num_clients))

    # 4) Norm-clipping
    # compute the L2 norms of all client updates, get the median
    norms = torch.norm(local_updates, p=2, dim=1)  # shape (num_clients,)
    norms_np = norms.cpu().numpy()
    clip_value = np.median(norms_np)
    print(f"clip_value (median norm) = {clip_value:.4f}")

    accepted_updates = []
    for i in benign_ids:
        if norms_np[i] > clip_value:
            scale = clip_value / norms_np[i]
            clipped_vec = local_updates[i] * scale
        else:
            clipped_vec = local_updates[i]
        accepted_updates.append(clipped_vec.unsqueeze(0))

    # 5) Merge (average) clipped updates or fallback to average all
    if len(accepted_updates) == 0:
        # if we ended up with zero accepted => fallback
        print("FLAME fallback: no accepted updates => average all.")
        aggregated_update = torch.mean(local_updates, dim=0)
    else:
        merged_tensor = torch.cat(accepted_updates, dim=0)
        aggregated_update = merged_tensor.mean(dim=0)

    # 6) Add noise: lamda * clip_value
    # if clip_value > 0.0:
    #     noise_std = lamda * clip_value
    #     noise = torch.normal(
    #         mean=0.0, std=noise_std,
    #         size=aggregated_update.shape,
    #         device=device
    #     )
    #     aggregated_update = aggregated_update + noise
    # else:
    #     print("FLAME note: clip_value = 0 => no noise added.")

    return aggregated_update



from scipy.fftpack import dct
from sklearn.metrics.pairwise import cosine_similarity

def freqfed_aggregator(
    local_updates: torch.Tensor,
    min_cluster_size: int = None,
    filter_fraction: float = 0.5
) -> torch.Tensor:
    """
    FreqFed Aggregation Function

    Args:
      local_updates: shape (num_clients, param_dim) – each row is a flattened model update from a client.
      device: the GPU/CPU device.
      min_cluster_size: minimum cluster size used in HDBSCAN.
         By default = None => we set it to (num_clients // 2 + 1) as recommended in the paper.
      filter_fraction: fraction of DCT-coeffs we keep, from the low-frequency portion. 0.5 => keep half.

    Returns:
      aggregated_update: shape (param_dim,) the final aggregator vector from the largest cluster.
    """
    num_clients = local_updates.size(0)
    # 1) If min_cluster_size not specified, set to at least half the clients + 1
    if min_cluster_size is None:
        min_cluster_size = max(2, (num_clients // 2) + 1)

    # Convert local updates to CPU numpy for processing
    updates_np = local_updates.detach().cpu().numpy()  # shape: (num_clients, param_dim)

    # 2) For each local update, do 1D DCT, then keep low-frequency portion
    F_list = []
    for i in range(num_clients):
        # local_params is shape (param_dim,). DCT needs a 1D array
        param_1d = updates_np[i]   # shape (param_dim,)
        dct_1d   = dct(param_1d, norm='ortho')   # 1D DCT

        # filtering: keep only the first "filter_length" coefficients
        filter_length = int(len(dct_1d) * filter_fraction)
        if filter_length < 1:
            filter_length = 1
        filtered = dct_1d[:filter_length]

        F_list.append(filtered)

    # 3) Build a distance matrix = 1 - cos sim in the filtered space
    #   shape => (num_clients, num_clients)
    K = len(F_list)
    dist_matrix = np.zeros((K, K), dtype=np.float64)
    for i in range(K):
        for j in range(i + 1, K):
            # Cos sim in 1D
            csim = cosine_similarity(F_list[i].reshape(1, -1),
                                     F_list[j].reshape(1, -1))[0][0]
            dist = 1.0 - csim
            dist_matrix[i, j] = dist
            dist_matrix[j, i] = dist

    # 4) HDBSCAN clustering, metric=precomputed
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size,
        min_samples=1,
        metric='precomputed',
        allow_single_cluster=True
    )
    cluster_labels = clusterer.fit_predict(dist_matrix)

    # optional debugging
    print("[FreqFed] cluster_labels =", cluster_labels)
    outliers = np.sum(cluster_labels == -1)
    print("[FreqFed] Number of outliers =", outliers)

    # 5) Find the largest cluster. If all are outliers => fallback to all
    if cluster_labels.max() < 0:
        # means all are -1 => fallback
        benign_ids = list(range(num_clients))
    else:
        # find cluster with the most points
        unique_clusters, counts = np.unique(cluster_labels, return_counts=True)
        # cluster -1 is outliers, ignore them for largest
        best_label = None
        best_count = 0
        for lbl, cnt in zip(unique_clusters, counts):
            if lbl == -1:
                continue
            if cnt > best_count:
                best_label = lbl
                best_count = cnt

        # if everything except outliers is too small => fallback
        if best_count < 1:
            benign_ids = list(range(num_clients))
        else:
            # pick largest cluster
            benign_ids = [i for i in range(num_clients) if cluster_labels[i] == best_label]

    # 6) Average all local_updates in the largest cluster
    #   If the cluster is empty => fallback to all
    if len(benign_ids) == 0:
        print("[FreqFed] fallback: empty largest cluster => average all updates.")
        benign_ids = range(num_clients)

    # gather them
    accepted = local_updates[benign_ids, :]  # shape => (#accepted, param_dim)
    aggregated_update = accepted.mean(dim=0) # shape => (param_dim,)

    return aggregated_update.to(device)

# import torch
# import cupy as cp
# import numpy as np

# # We'll import the GPU-based HDBSCAN from cuML
# # (NOT the CPU-based hdbscan from the standard library.)
# from cuml.cluster import HDBSCAN as cuHDBSCAN

# def flame_cuml(
#     local_updates: torch.Tensor,
#     device: torch.device,
#     num_clients: int,
#     lamda: float = 0.001
# ) -> torch.Tensor:
#     """
#     FLAME aggregator with GPU usage, including:
#       - Compute pairwise distances in PyTorch on GPU
#       - Convert to cupy array
#       - Use cuML’s GPU-based HDBSCAN
#       - Then do fallback, clipping, and noise injection on GPU

#     Steps:
#       1) local_updates is on GPU
#       2) compute GPU-based cosine similarity => distance matrix
#       3) run cuML HDBSCAN with min_cluster_size=(num_clients//2 + 1)
#       4) fallback if largest cluster is < 50% or if all outliers
#       5) median-based norm clipping on GPU
#       6) average & add noise
#     """

#     # local_updates shape: (num_clients, param_dim), on device=GPU
#     # 1) Ensure local_updates is on GPU
#     local_updates = local_updates.to(device)

#     # 2) Compute pairwise cosine distance in PyTorch on GPU
#     #    a) normalize each row
#     normed = torch.nn.functional.normalize(local_updates, p=2, dim=1)
#     #    b) cos sim = normed @ normed.T
#     sim = normed @ normed.t()
#     #    c) dist = 1 - sim
#     dist_mat = 1.0 - sim
#     # dist_mat is shape (n, n), on GPU

#     # 2b) Convert dist_mat to a cuML-friendly CuPy array (float32 or float64)
#     # If you want float64, we can cast. Typically float32 might suffice.
#     dist_mat_cupy = cp.asarray(dist_mat.detach().cpu().numpy(), dtype=cp.float32)
#     # ^ Unfortunately, we do an intermediate .cpu().numpy() because
#     #   PyTorch -> CuPy direct conversion is not always trivial.
#     #   If you want a direct approach, you can do dist_mat_contig = dist_mat.contiguous()
#     #   Then memory pointer bridging. But the simplest is .cpu().numpy() => cp.asarray.

#     # 3) HDBSCAN on GPU
#     min_cluster = max(2, (num_clients // 2) + 1)
#     clusterer = cuHDBSCAN(
#         min_cluster_size=min_cluster,
#         min_samples=1,
#         metric='precomputed',
#         allow_single_cluster=True
#     )
#     # Fit
#     cluster_labels_cupy = clusterer.fit_predict(dist_mat_cupy)
#     # cluster_labels_cupy is a cupy array of shape (n,). Let's bring it back to CPU
#     cluster_labels = cluster_labels_cupy.get()  # shape (n,) in numpy

#     # Print debug
#     print("cluster_labels =", cluster_labels)
#     outliers_count = np.sum(cluster_labels == -1)
#     print("Number of outliers =", outliers_count)

#     # 4) Identify the largest cluster (unless all outliers => fallback)
#     if cluster_labels.max() < 0:
#         # all outliers => fallback => treat all as benign
#         benign_ids = list(range(num_clients))
#     else:
#         # find cluster with the most elements
#         max_cluster_index = None
#         max_cluster_size  = 0
#         for cl_idx in range(cluster_labels.max() + 1):
#             size_cl = np.sum(cluster_labels == cl_idx)
#             if size_cl > max_cluster_size:
#                 max_cluster_size = size_cl
#                 max_cluster_index = cl_idx

#         # pick all that belong to that cluster
#         benign_ids = [i for i in range(num_clients) if cluster_labels[i] == max_cluster_index]

#         # fallback if largest cluster is < 50%
#         if max_cluster_size < 0.5 * num_clients:
#             print("FLAME fallback: largest cluster < 50%, treat all as benign.")
#             benign_ids = list(range(num_clients))

#     # 5) Norm-clipping
#     # compute L2 norms in GPU
#     norms = torch.norm(local_updates, p=2, dim=1)  # shape (n,)
#     clip_value = norms.median()  # GPU median in newer PyTorch versions

#     # For older PyTorch, you might do norms.cpu().median().to(device)
#     # We'll assume you can do it fully on GPU if version >=1.7

#     print(f"clip_value (median norm) = {clip_value.item():.4f}")

#     accepted_updates = []
#     norms_cpu = norms.detach().cpu().numpy()
#     # We do scale on GPU:
#     for i in benign_ids:
#         # i-th row => local_updates[i]
#         # if norms[i] > clip_value => scale
#         if norms_cpu[i] > clip_value.item():
#             scale = clip_value.item() / norms_cpu[i]
#             clipped_vec = local_updates[i] * scale
#         else:
#             clipped_vec = local_updates[i]
#         accepted_updates.append(clipped_vec.unsqueeze(0))

#     if len(accepted_updates) == 0:
#         print("FLAME fallback: no accepted updates => average all.")
#         # just average everything
#         aggregated_update = torch.mean(local_updates, dim=0)
#     else:
#         merged_tensor = torch.cat(accepted_updates, dim=0)  # shape (k, param_dim)
#         aggregated_update = merged_tensor.mean(dim=0)

#     # 6) Add noise => lamda * clip_value
#     # If you want the paper's DP bound exactly, you can set
#     # lamda = (1/eps)*math.sqrt(2*math.log(1.25/delta)).
#     if clip_value.item() > 0.0:
#         noise_std = lamda * clip_value.item()
#         noise = torch.normal(mean=0.0, std=noise_std, size=aggregated_update.shape, device=device)
#         aggregated_update = aggregated_update + noise
#     else:
#         print("FLAME note: clip_value=0 => no noise added.")

#     return aggregated_update


def dnc_aggregator(local_updates: torch.Tensor,
                   num_clients: int,
                   num_adv: int,
                   subsample_frac: float = 0.2,
                   num_iters: int = 5,
                   fliter_frac: float = 1.0) -> torch.Tensor:
    """
    DNC aggregator logic:
      1) local_updates: shape (num_clients, num_params)
      2) Repeats num_iters times:
         a) randomly subsample 'subsample_frac' portion of parameters
         b) center them => (updates - mean)
         c) SVD => top singular vector => outlier scores
         d) keep the k smallest => intersection across iterations
      3) Average the final chosen updates
      4) returns (param_dim,) aggregated vector
    """
    device = local_updates.device
    # Convert to CPU numpy for easy SVD, or do it in Torch on GPU – up to you:
    # For simplicity, we'll do a .cpu().numpy() approach
    updates_np = local_updates.cpu().numpy()
    num_param = updates_np.shape[1]

    # Start with all clients => we refine by intersection
    benign_idx = set(range(num_clients))

    # number of clients to keep each iteration => #clients - (fliter_frac * #adv)
    # if fliter_frac=1.0 and num_adv=10 => we remove 10 => keep (num_clients - 10)
    k_keep = int(num_clients - fliter_frac * num_adv)
    # if that yields <=0, we keep all
    if k_keep <= 0:
        k_keep = num_clients

    for _ in range(num_iters):
        # 1) Subsample fraction of parameters
        param_count = int(subsample_frac * num_param)
        if param_count <= 0:
            # fallback: if param_count <=0 => skip
            break
        param_idx = np.random.choice(np.arange(num_param),
                                     param_count,
                                     replace=False)
        # shape: (num_clients, param_count)
        sampled = updates_np[:, param_idx]

        # 2) center
        mu = np.mean(sampled, axis=0)
        centered = sampled - mu  # shape (num_clients, param_count)

        # 3) SVD => top right singular vector
        #   np.linalg.svd => (U, S, V) with shape(U)=(C,C), shape(V)=(P,P) if full_matrices
        #   but we do full_matrices=False => shape(V)=(rank, param_count)
        U, S, V = np.linalg.svd(centered, full_matrices=False)
        top_vector = V[0]  # shape (param_count,)

        # 4) outlier scores => dot(centered, top_vector)**2
        scores = np.dot(centered, top_vector)**2  # shape (num_clients,)

        # now we pick the k_keep smallest
        if k_keep < len(scores):
            # partial sort
            chosen = np.argpartition(scores, k_keep)[:k_keep]
        else:
            chosen = np.arange(len(scores))

        # intersect with existing benign
        benign_idx = benign_idx.intersection(set(chosen))

    # after all iterations, we average only the final benign
    if len(benign_idx) == 0:
        # fallback => if none remain, just do naive average
        aggregator = torch.mean(local_updates, dim=0)
    else:
        # gather
        chosen_arr = local_updates[list(benign_idx)]
        aggregator = torch.mean(chosen_arr, dim=0)

    return aggregator.to(device)


def dnc_aggregator_torch(
    local_updates: torch.Tensor,
    num_clients: int,
    subsample_frac: float = 0.2,
    num_iters: int = 5
) -> torch.Tensor:
    """
    DNC aggregator, all in Torch on GPU if local_updates is on GPU.

    local_updates: shape (num_clients, param_dim), already on the same device
    num_clients: total number of clients
    num_adv: known or estimated number of attackers
    subsample_frac: fraction of parameters to subsample
    num_iters: how many repeated outlier computations
    fliter_frac: fraction of adversaries to remove each iteration

    returns: aggregated vector shape (param_dim,)
    """

    device = local_updates.device
    c, p = local_updates.shape

    # The maximum # of "bad" updates to remove is fliter_frac * num_adv
    # => we keep k = num_clients - that many
    # k_keep = int(num_clients - fliter_frac * num_adv)
    k_keep = int(num_clients * 0.81)
    if k_keep <= 0:
        k_keep = num_clients  # fallback: keep them all

    # We'll keep a boolean mask of shape (num_clients,)
    # True means "still in the candidate benign set"
    keep_mask = torch.ones(c, dtype=torch.bool, device=device)

    for _ in range(num_iters):
        # 1) Subsample fraction of parameters
        param_count = int(subsample_frac * p)
        if param_count <= 0:
            break
        # pick param_count distinct indices
        # NB: torch.randperm is typically on CPU, but we can do it on device if needed
        # However, for smaller param_count, CPU overhead might be negligible
        chosen_params = torch.randperm(p, device=device)[:param_count]

        # shape => (C, param_count)
        # but we only gather the ones that are still "keep_mask"?
        # The DnC paper's code basically does the outlier scoring over *all* clients each iteration,
        # then intersects. Alternatively, we can keep it the same to be faithful to the paper:
        # i.e., we always compute outlier score for all clients, not just the "kept" ones from last iteration.
        # Then we do intersection. We'll follow that approach:

        sub_updates = local_updates[:, chosen_params]

        # 2) center
        # mean over the *rows* => shape (param_count,)
        mu = sub_updates.mean(dim=0)
        centered = sub_updates - mu  # shape => (C, param_count)

        # 3) SVD => top right singular vector
        # shape => U: (C, rank), S: (rank,), V: (param_count, rank) if full_matrices=False
        # Actually torch.linalg.svd outputs (U, S, Vh) where Vh is of shape (p, p) for full_matrices=False
        # or (p, rank). So top_vector => Vh[0,:] is row 0 of V^T => that means top row => top eigenvector
        # in the dimension param_count
        U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
        # top singular vector is Vh[0], shape => (param_count,)
        top_vector = Vh[0, :]

        # 4) outlier scores => dot(centered[i], top_vector)^2
        # => (C,)
        # We can do (centered @ top_vector).pow(2)
        # shape => (C,)
        dot_vals = torch.matmul(centered, top_vector)
        scores = dot_vals.pow(2)

        # 5) find k_keep smallest scores among all clients
        # topk => "largest" by default. We want "smallest", so largest=False
        # topk returns (values, indices)
        if k_keep < c:
            # we want k_keep smallest => topk w/ largest=False
            val, idx = torch.topk(scores, k_keep, largest=False)
        else:
            # if k_keep >= c => keep all
            idx = torch.arange(c, device=device)

        # Now we have the set of "lowest outlier" clients for this iteration: idx
        # We do an intersection with keep_mask from previous iteration
        # We'll create a boolean mask new_mask => shape (C,) => True for those in idx
        new_mask = torch.zeros(c, dtype=torch.bool, device=device)
        new_mask[idx] = True

        # intersection: we keep only those that were in keep_mask & in new_mask
        keep_mask = keep_mask & new_mask

        # If keep_mask is all false => we can break early if we want
        # but let's just continue

    # after all iterations, "keep_mask" is the final benign set
    final_keep_count = keep_mask.sum().item()
    if final_keep_count == 0:
        # fallback => if none remain, do naive average
        aggregator = local_updates.mean(dim=0)
    else:
        aggregator = local_updates[keep_mask, :].mean(dim=0)

    return aggregator


Attacks (XFED)

In [7]:
def get_mu_pairwise_distance(global_models, MU_MULTIPLIER):
    """Compute pairwise distance based deviation (mu)."""
    num_models = len(global_models)
    # print("num_models", num_models)
    # print("global_models", global_models)
    if num_models > 1:
        if isinstance(global_models, list):
            global_models_tensor = torch.vstack(global_models)
        else:
            global_models_tensor = global_models
        # Step 1: Calculate the centroid (mean vector)
        centroid = torch.mean(global_models_tensor, dim=0)
        # Step 2: Compute the Euclidean distance of each vector from the centroid
        # Step 3: Calculate the standard deviation of the distances
        distances = torch.norm(global_models_tensor - centroid, dim=1)
        std_dev = torch.sqrt(torch.dot(distances, distances) / num_models)
        mu = MU_MULTIPLIER * std_dev
        return mu
    else:
        return torch.tensor(0.0)

def xfed_c(user_grads, n_attackers, dev_type, len_global, global_model_data, global_models, collab):
    if collab == 0:
        all_updates = user_grads[:n_attackers]
        start_idx = 0
    else:
        individual_attackers = n_attackers - collab
        all_updates = user_grads[individual_attackers:n_attackers]
        start_idx = individual_attackers

    model_re = torch.mean(all_updates, dim=0).to(device)

    if dev_type == 'C_XFED_unit_vec' or dev_type == 'Hybrid_XFED_unit_vec':
        deviation = model_re / torch.norm(model_re)
    elif dev_type == 'C_XFED_sign' or dev_type == 'Hybrid_XFED_sign':
        sgn_vec = torch.sign(model_re)
        deviation = sgn_vec / torch.norm(sgn_vec)
    elif dev_type == 'C_XFED_std':
        deviation = torch.std(all_updates, dim=0)

    if len_global > 1:
        # print(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(global_models, 3), get_mu_pairwise_distance(all_updates, 1))
        # mu = torch.max(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(global_models))
        # mu = torch.max(torch.norm(model_re - global_model_data), get_mu_pairwise_distance(all_updates))
        mu = get_mu_pairwise_distance(all_updates, 3)
    else:
        mu = torch.tensor(1.0)

     # Calculate delta and malicious updates
    deviation *= mu
    mal_update = (model_re - deviation)

    del model_re, deviation
    torch.cuda.empty_cache()


    for i in range(start_idx, n_attackers):
        tmp = user_grads[i]
        user_grads[i] = mal_update
        del tmp
        torch.cuda.empty_cache()

    return user_grads

    # Create the final stacked tensor of updates
    # Combine the malicious updates with the rest of the user_grads
    # mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
    # return torch.cat((mal_updates, user_grads[n_attackers:]), dim=0)




Attacks( VIRAT, FANG-TR-MEAN, FANG-KRUM, LIE)

In [8]:

def virat_min_max(user_grads, n_attackers, dev_type='VIRAT_unit_vec', epoch = 0, threshold=50):
    """Implement VIRAT Min-Max attack."""
    all_updates = user_grads[:n_attackers].to(device)
    model_re = torch.mean(all_updates, dim=0).to(device)

    if dev_type == 'VIRAT_unit_vec':
        deviation = model_re / torch.norm(model_re)
    elif dev_type == 'VIRAT_sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'VIRAT_std':
        deviation = torch.std(all_updates, dim=0)
    else:
        raise ValueError(f"Unknown deviation type: {dev_type}")

    lamda = torch.tensor([threshold], dtype=torch.float).to(device)
    threshold_diff = 1e-5
    lamda_fail = lamda.clone()
    lamda_succ = torch.tensor(0, dtype=torch.float).to(device)

    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = model_re - lamda * deviation
        distance = torch.norm(all_updates - mal_update, dim=1) ** 2
        max_d = torch.max(distance)

        if max_d <= max_distance:
            lamda_succ = lamda.clone()
            lamda += lamda_fail / 2
        else:
            lamda -= lamda_fail / 2

        lamda_fail /= 2
    print(f"For round {epoch}, max_distance", max_distance, "lamda_succ", lamda_succ)
    mal_update = (model_re - lamda_succ * deviation) #.to('cpu')
    mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
    return torch.cat((mal_updates, user_grads[n_attackers:]), dim=0)

def fang_attack_trmean_partial(user_grads, n_attackers):

    all_updates = user_grads[:n_attackers]
    model_re = torch.mean(all_updates, 0)
    model_std = torch.std(all_updates, 0)
    deviation = torch.sign(model_re)

    max_vector_low = model_re + 3 * model_std
    max_vector_hig = model_re + 4 * model_std
    min_vector_low = model_re - 4 * model_std
    min_vector_hig = model_re - 3 * model_std

    max_range = torch.cat((max_vector_low[:,None], max_vector_hig[:,None]), dim=1)
    min_range = torch.cat((min_vector_low[:,None], min_vector_hig[:,None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).to(device)

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    deviation_sign = (deviation > 0).type(torch.FloatTensor).to(device)

    mal_update = (torch.stack([deviation_sign] * max_rand.shape[1]).T * max_rand + torch.stack(
        [deviation_sign] * min_rand.shape[1]).T * min_rand).T

    return torch.cat((mal_update, user_grads[n_attackers:]), dim=0)

def compute_lambda_fang(all_updates, model_re, n_attackers):
    distances = []
    n_benign, d = all_updates.shape
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1)
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    distances[distances == 0] = 10000
    distances = torch.sort(distances, dim=1)[0]
    scores = torch.sum(distances[:, :n_benign - 2 - n_attackers], dim=1)
    min_score = torch.min(scores)
    term_1 = min_score / ((n_benign - n_attackers - 1) * torch.sqrt(torch.Tensor([d]))[0])
    max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0])

    return (term_1 + max_wre_dist)

def get_malicious_updates_fang(all_updates, model_re, deviation, n_attackers):

    lamda = 100 # compute_lambda_fang(all_updates, model_re, n_attackers)
    threshold = torch.tensor(1e-5)

    mal_updates = []
    while lamda > threshold:
        mal_update = (model_re - lamda * deviation)

        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        _, selected_indices = multi_krum_optimized(mal_updates)
        # print(f'len {len(mal_updates)}, lamda {lamda}, indices {selected_indices}')

        if torch.any(selected_indices < n_attackers):
            return mal_update

        lamda *= 0.5

    if not len(mal_updates):
        print(lamda, threshold)
        mal_update = (model_re - lamda * deviation)
    return mal_update

def lie_attack(user_grads, n_attackers, z):

    # Stack the gradients for the attackers
    all_updates = user_grads[:n_attackers]

    # Calculate mean and standard deviation of the attacker's updates
    avg = torch.mean(all_updates, dim=0)
    std = torch.std(all_updates, dim=0)

    # Generate malicious updates
    mal_update = avg + z * std

    mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
    return torch.cat((mal_updates, user_grads[n_attackers:]), dim=0)


Code for calculating Z value

In [9]:
z_values={(50,3):0.69847, (50,5):0.7054, (50,8):0.71904, (50,10):0.72575, (50,12):0.73891, (100,20):0.72907, (40, 8): 0.72575, (100,5):0.69497, (100,10):0.7054, (100,15):0.71566, (100,25):0.74215, (100, 30):0.75804}
# z value calculation code to execute lie attack
import math
# Update the value of m to 10
n=100
m = 30

# Recalculate s and z
s = math.floor(n / 2 + 1) - m
z = (n - m - s) / (n - m)
print(z)

0.7


Federated Learning Training

In [10]:
def train_server_model(global_model, root_data, batch_size, criterion, device, optim):
    """
    Train the server on the trusted root dataset (100 samples) each round.
    This is required for FLTrust to generate an anchor update.

    Arguments:
    - global_model: Current global model
    - root_data: Trusted root dataset
    - batch_size: Batch size for training
    - criterion: Loss function (e.g., CrossEntropyLoss)
    - device: GPU/CPU device
    - optim: Optimizer type ('SGD' or 'Adam')

    Returns:
    - server_params: Final parameter vector after training
    """

    # Copy the current global model (server starts from this)
    server_model = deepcopy(global_model).to(device)

    # Select an optimizer for training
    if optim == 'SGD':
        server_optimizer = torch.optim.SGD(server_model.parameters(), lr=0.5, momentum=0.9)
    else:
        server_optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)

    # Load root dataset into DataLoader
    root_loader = DataLoader(root_data, batch_size=batch_size, shuffle=True, num_workers=0)

    # Train server model on root dataset (single iteration or full epoch)
    for inputs, targets in root_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        server_optimizer.zero_grad()
        outputs = server_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(server_model.parameters(), max_norm=1.0)
        server_optimizer.step()
        break  # Only train for 1 pass (FLTrust paper suggests a single iteration)

    # Convert trained server model to a 1D parameter vector
    server_params = torch.cat([p.data.view(-1) for p in server_model.parameters()])

    # Cleanup to free memory
    del server_model, root_loader
    torch.cuda.empty_cache()

    return server_params


def train_local_model(client_id, client_indices, global_model, train_data, batch_size, criterion, device, optimizer):
    sampled_indices = random.sample(client_indices, min(batch_size, len(client_indices)))
    sampled_data = Subset(train_data, sampled_indices)
    # print(f"client_id: {client_id}, sampled_indices: {len(sampled_indices)}, sampled_data: {len(sampled_data)}")
    sampled_loader = DataLoader(sampled_data, batch_size=len(sampled_indices), shuffle=False, num_workers=0) # Set batch_size to the length of sampled_data

    # Move the model to the assigned GPU device
    local_model = deepcopy(global_model).to(device)
    if optimizer == 'SGD':
        local_optimizer = optim.SGD(local_model.parameters(), lr=0.5, momentum=0.9)
    else:
        local_optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)


    for inputs, targets in sampled_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        # print(len(inputs), len(targets))
        local_optimizer.zero_grad()
        outputs = local_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(local_model.parameters(), max_norm=1.0)
        local_optimizer.step()

    # Collect model parameters for aggregation
    # local_params = torch.cat([param.data.view(-1).cpu() for param in local_model.parameters()])
    local_params = torch.cat([param.data.view(-1) for param in local_model.parameters()])


    # Cleanup
    del local_model, sampled_data, sampled_loader
    torch.cuda.empty_cache()

    return client_id, local_params


def federated_learning(num_clients, aggregation, n_attackers, attack_type, dataset, n_round, batch_size, optim, cross_device, collab, consider_min_round):
    """Main federated learning loop."""
    global_model, train_data, testloader = load_model(dataset)
    criterion = nn.CrossEntropyLoss()

    # Initialize global_model_data right away:
    global_model_data = torch.cat([p.data.view(-1) for p in global_model.parameters()]).to(device)

    # pick 100 random examples as root dataset
    #root_indices = random.sample(range(len(train_data)), 50)
    #root_data = Subset(train_data, root_indices)

    root_size=100
    bias_prob = 0.1
    biased_class=1
    n_biased = int(root_size * bias_prob)  # e.g., 100 * 0.5 = 50
    n_other = root_size - n_biased         # e.g., 50

    # 1) Gather indices for each group
    biased_indices = []
    other_indices = []
    for i in range(len(train_data)):
        _, label = train_data[i]
        if label == biased_class:
            biased_indices.append(i)
        else:
            other_indices.append(i)

    # 2) Sample from each group
    selected_biased = random.sample(biased_indices, n_biased)
    selected_others = random.sample(other_indices, n_other)

    # 3) Combine them, optionally shuffle
    root_indices = selected_biased + selected_others
    random.shuffle(root_indices)

    # 4) Build a Subset for the root dataset
    root_data = Subset(train_data, root_indices)

    if dataset == 'HAR':
        # Step 1: Split the dataset among clients
        total_data_size = len(train_data)

        # 1) Identify unique subjects:
        subject_list = []
        current_subject_id = 1
        for ds in train_data.datasets:  # train_data.datasets is a list of sub-datasets
            length_ds = len(ds)
            for _ in range(length_ds):
                subject_list.append(current_subject_id)
            current_subject_id += 1

        unique_subjects = sorted(list(set(subject_list)))  # e.g., [1, 2, 3, ..., 30]
        num_clients = len(unique_subjects)
        print("Number of clients (subjects):", num_clients)

        # 2) For each subject, gather all example indices belonging to that subject
        clients_data_indices = []
        for s in unique_subjects:
            subject_indices = [i for i, subj in enumerate(subject_list) if subj == s]
            clients_data_indices.append(subject_indices)
            # Initialize global_model_data right away:

    else:

        # Step 1: Split the dataset among clients
        total_data_size = len(train_data)
        client_data_size = total_data_size // num_clients
        print("client_data_size", client_data_size)
        indices = list(range(total_data_size))
        random.shuffle(indices)
        clients_data_indices = [indices[i * client_data_size:(i + 1) * client_data_size] for i in range(num_clients)]

    global_models = []
    mn, sm, mx = 100, 0, 0

    with ThreadPoolExecutor(max_workers=THREAD_NUMBER) as executor:  # Adjust max_workers based on your system capabilities
        for epoch in range(n_round):
            # Ensure server trains on root_data each round
            server_params = train_server_model(global_model, root_data, batch_size, criterion, device, optim)

            global_model.train()
            local_models_data_diff = []

            # Delete the oldest item if size is greater then 25
            if len(global_models) > 4:
                del global_models[0]

            futures = [
                executor.submit(
                    train_local_model,
                    client_id,
                    client_indices,
                    global_model,
                    train_data,
                    batch_size,
                    criterion,
                    devices[client_id % len(devices)],  # Alternate between 'cuda:0' and 'cuda:1'
                    optim
                )
                for client_id, client_indices in enumerate(clients_data_indices)
            ]

            # Collect results
            for future in as_completed(futures):
                client_id, local_params = future.result()
                local_models_data_diff.append(local_params)

            for i in range(torch.cuda.device_count()):
                torch.cuda.set_device(i)
                torch.cuda.empty_cache()
            if num_gpus > 0:
                torch.cuda.set_device(device)
            print(f'For round {epoch}, training done')
            # time.sleep(30)
            local_models_data = torch.stack(local_models_data_diff).to(device)
            del local_models_data_diff
            gc.collect()

            if attack_type.startswith('XFED'):
                for local_machine in range(n_attackers):
                    if attack_type == 'XFED_unit_vec':
                        deviation = local_models_data[local_machine] / torch.norm(local_models_data[local_machine])
                        # deviation = global_model_data / global_model_data
                    elif attack_type == 'XFED_sign':
                        sgn_vec = torch.sign(local_models_data[local_machine])
                        deviation = sgn_vec / torch.norm(sgn_vec)
                    else:
                        raise ValueError("Invalid attack type")

                    if len(global_models) > 1:

                        # version 1
                        # print(f"\n\nFor round {epoch} and advNumber {local_machine}, mu", torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models))
                        # mu = torch.max(torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models, MU_MULTIPLIER))

                        # version 2
                        # mu = torch.norm(local_models_data[local_machine] - global_model_data)
                        # print(f"For round {epoch} and advNumber {local_machine}, mu: {mu}")

                        # version 3
                        # global_models.append(local_models_data[local_machine])
                        global_distance, pairwise_distance = torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models, MU_MULTIPLIER=MU_MULTIPLIER)
                        # global_models.pop()
                        # mu = torch.max(global_distance, pairwise_distance)
                        # mu = torch.min(global_distance, pairwise_distance)
                        # mu = global_distance * 0.5
                        # mu = (pairwise_distance + global_distance) / torch.tensor(2)
                        mu = pairwise_distance
                        # print(f"For round {epoch} and advNumber {local_machine}, mu: {mu}, global_distance: {global_distance}, pairwise_distance: {pairwise_distance}")

                    else:
                        mu = torch.tensor(1.0)

                    delta = mu * deviation

                    # print(f"\n\nFor round {epoch} and advNumber {local_machine}, mu: {mu}\ndeviation: {deviation}\ndelta: {delta}\nlocal model: {local_models_data[local_machine]}\nglobal model: {global_model_data}")
                    if epoch == 0:
                        print(f'local_models_data[local_machine] {local_models_data[local_machine].shape} delta {delta.shape}')
                        local_models_data[local_machine] -= delta
                    else:
                        # del local_models_data[local_machine]
                        local_models_data[local_machine] = global_model_data - delta
                    # print(f"after update model after attack: {local_models_data[local_machine]}\n")
                    del deviation, delta, mu


            elif attack_type.startswith('VIRAT') and n_attackers > 0:
                local_models_data = virat_min_max(local_models_data, n_attackers, attack_type, epoch=epoch)

            elif attack_type.startswith('LIE') and n_attackers > 0:
                local_models_data = lie_attack(local_models_data, n_attackers, z_values[(num_clients, n_attackers)])

            elif attack_type =='FANG_TR_MEAN' and n_attackers > 0:
                local_models_data = fang_attack_trmean_partial(local_models_data, n_attackers)

            elif attack_type =='FANG_KRUM' and n_attackers > 0:
                attacker_grads = local_models_data[:n_attackers]
                agg_grads = torch.mean(attacker_grads, 0)
                deviation = torch.sign(agg_grads)
                mal_update = get_malicious_updates_fang(attacker_grads, agg_grads, deviation, n_attackers)
                mal_updates = mal_update.unsqueeze(0).repeat(n_attackers, *[1 for _ in mal_update.shape])
                return torch.cat((mal_updates, local_models_data[n_attackers:]), dim=0)

            elif attack_type.startswith('C_XFED'):
                local_models_data = xfed_c(local_models_data, n_attackers, attack_type, len(global_models), global_model_data, global_models, 0)

            elif attack_type.startswith('Hybrid_XFED'):
                individual_attacker = n_attackers - collab
                for local_machine in range(individual_attacker):
                    if attack_type == 'Hybrid_XFED_unit_vec':
                        deviation = local_models_data[local_machine] / torch.norm(local_models_data[local_machine])
                    elif attack_type == 'Hybrid_XFED_sign':
                        sgn_vec = torch.sign(local_models_data[local_machine])
                        deviation = sgn_vec / torch.norm(sgn_vec)
                    else:
                        raise ValueError("Invalid attack type")

                    if len(global_models) > 1:
                        global_distance, pairwise_distance = torch.norm(local_models_data[local_machine] - global_model_data), get_mu_pairwise_distance(global_models, MU_MULTIPLIER=MU_MULTIPLIER)
                        mu = pairwise_distance
                        # print(f"For round {epoch} and advNumber {local_machine}, mu: {mu}, global_distance: {global_distance}, pairwise_distance: {pairwise_distance}")

                    else:
                        mu = torch.tensor(1.0)

                    delta = mu * deviation

                    # print(f"\n\nFor round {epoch} and advNumber {local_machine}, mu: {mu}\ndeviation: {deviation}\ndelta: {delta}\nlocal model: {local_models_data[local_machine]}\nglobal model: {global_model_data}")
                    if epoch == 0:
                        # print(f'local_models_data[local_machine] {local_models_data[local_machine].shape} delta {delta.shape}')
                        local_models_data[local_machine] -= delta
                    else:
                        # del local_models_data[local_machine]
                        local_models_data[local_machine] = global_model_data - delta
                    # print(f"after update model after attack: {local_models_data[local_machine]}\n")
                    del deviation, delta, mu

                local_models_data = xfed_c(local_models_data, n_attackers, attack_type, len(global_models), global_model_data, global_models, collab)

            elif n_attackers > 0:
                raise ValueError("Invalid attack type")

            print(f'For round {epoch}, attack done, Lenght of local_models_data:', len(local_models_data))
            print(f'Global model data device {global_model_data.device}')


            if cross_device == True:

                local_models_list = local_models_data.tolist()
                # Calculate the number of clients to select
                num_clients_to_select = max(1, int(num_clients * (20 / 100)))
                # Randomly select clients
                selected_clients_list = random.sample(local_models_list, num_clients_to_select)
                # Convert the selected list back to a tensor
                selected_clients_tensor = torch.tensor(selected_clients_list, device=local_models_data.device)
                local_models_data = selected_clients_tensor
                print(f'For round {epoch}, cross device done, Lenght of local_models_data:', len(local_models_data))
            else:
                pass

            # Aggregate model updates
            if aggregation == 'MEAN':
                global_model_data = torch.mean(local_models_data, dim=0)
            elif aggregation == 'MEDIAN':
                global_model_data = torch.median(local_models_data, dim=0)[0]
            elif aggregation == 'KRUM':
                # Check if local_models_data is already a tensor
                if isinstance(local_models_data, list):
                    global_model_data, _ = multi_krum_optimized(local_updates=local_models_data)
                else:
                    global_model_data, _ = multi_krum_optimized(local_updates=local_models_data)
            elif aggregation == 'TR-MEAN':
                global_model_data = tr_mean(local_models_data)
            elif aggregation == 'CC':
                global_model_data = Clippedclustering(local_models_data)
            elif aggregation == 'SignGuard':
                global_model_data = SignGuard(local_models_data)
            elif aggregation == 'FLAME':
                global_model_data = flame(local_models_data, global_model, device, num_clients)
            elif aggregation == 'DNC':
                # NEW aggregator block:
                # let's define default hyperparams:
                subfrac = 0.2
                iters = 5
                fltr = 1.0
                # call aggregator
                # global_model_data = dnc_aggregator(
                #     local_models_data,  # shape (num_clients, param_dim)
                #     num_clients=num_clients,
                #     num_adv=n_attackers,
                #     subsample_frac=subfrac,
                #     num_iters=iters,
                #     fliter_frac=fltr
                # )
                global_model_data = dnc_aggregator_torch(
                    local_models_data,  # shape (num_clients, param_dim)
                    num_clients=num_clients
                )
            elif aggregation == 'FLTrust':
                new_update = FLTrust(
                    global_model_data,      # old global param vector
                    local_models_data,      # final param vectors from clients
                    server_params           # final param vector from server root training
                )
            elif aggregation == 'FreqFred':
                global_model_data = freqfed_aggregator(local_models_data)
            else:
                raise ValueError("Invalid aggregation method")

            if torch.isnan(global_model_data).any():
                raise ValueError("NaN detected in model aggregation")

            # Update global model
            start_idx = 0
            with torch.no_grad():
                for param in global_model.parameters():
                    param_size = param.numel()

                    if aggregation == 'FLTrust':
                        # Because new_update is a difference
                        param.copy_(
                            (global_model_data[start_idx:start_idx + param_size]
                            + new_update[start_idx:start_idx + param_size]
                            ).view(param.shape)
                        )
                    else:
                        # Because for your older code, aggregator returns a final param vector
                        param.copy_(
                            param.copy_(global_model_data[start_idx:start_idx + param_size].view(param.shape))
                        )

                    start_idx += param_size

            # global_models.append(global_model_data.cpu())
            global_models.append(global_model_data)

            print(f'For round {epoch}, aggregation done')
            last_ten_percent = int(n_round * 0.89)
            if epoch >= last_ten_percent or epoch%20 == 0:
                # Evaluate global model
                global_model.eval()
                global_model = global_model.to(device)
                correct = 0
                total = 0
                with torch.no_grad():
                    for images, labels in testloader:
                        images, labels = images.to(device), labels.to(device)
                        outputs = global_model(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                accuracy = 100 * correct / total
                print(f'Time {datetime.now()}: Accuracy on round {epoch}, total {num_clients}, attackers {n_attackers}, attack_type {attack_type}, aggregation {aggregation} is: {accuracy:.2f} %')

                # File path to save the accuracy log
                file_path = os.path.join(data_root, f'accuracy_{dataset}_{aggregation}_{attack_type}_{n_attackers}_mu{MU_MULTIPLIER}_cd_{cross_device}_collab{collab}_log.txt')

                # Append accuracy to the file in the data_root location
                with open(file_path, 'a') as f:
                    f.write(f'Time {datetime.now()}: Accuracy on round {epoch}, dataset {dataset}, total {num_clients}, attackers {n_attackers}, attack_type {attack_type}, aggregation {aggregation} is: {accuracy:.2f} %\n')

                if consider_min_round <= epoch:
                    mn = min(mn, accuracy)
                    sm += accuracy
                    mx = max(mx, accuracy)

            global_model = global_model.to('cpu')
            del local_models_data
            torch.cuda.empty_cache()
            gc.collect()
        print(f'accuracy_{dataset}_{aggregation}_{attack_type}_{n_attackers}_mu{MU_MULTIPLIER}_cd_{cross_device}_collab{collab}: min {mn} max {mx} avg {sm/(n_round - consider_min_round)}')
        # File path to save the accuracy log
        file_path = os.path.join(data_root, f'accuracy_{dataset}_{aggregation}_{attack_type}_{n_attackers}_mu{MU_MULTIPLIER}_cd_{cross_device}_collab{collab}_log.txt')

        # Append accuracy to the file in the data_root location
        with open(file_path, 'a') as f:
            f.write(f'Time {datetime.now()}: accuracy_{dataset}_{aggregation}_{attack_type}_{n_attackers}_mu{MU_MULTIPLIER}_cd_{cross_device}_collab{collab}: avg {sm/(n_round - consider_min_round)}% min {mn}% max {mx}%\n')

    # Final cleanup after training
    del global_model, train_data, testloader, global_models, criterion



Example Execution

In [None]:
# Example execution
for MU_MULTIPLIER in [2.5, 3, 3.5, 4]:
    for cb in [0]:
        for attack_type in ['XFED_unit_vec']: # 'Hybrid_XFED_unit_vec','Hybrid_XFED_sign'  'XFED_unit_vec', 'XFED_sign', 'VIRAT_unit_vec', 'C_XFED_sign', 'LIE', 'FANG_TR_MEAN'
            for agg in ['MEDIAN']: # 'FLAME' 'MEAN', 'MEDIAN', 'KRUM', 'TR-MEAN', 'SignGuard', 'CC' 'MEAN', 'MEDIAN', 'KRUM', 'TR-MEAN', 'SignGuard'
                for attackers in [8]:
                    # torch.cuda.memory._record_memory_history()
                    # federated_learning(num_clients=200, n_attackers=attackers, aggregation=agg, n_round=300, dataset='EMNIST', attack_type=attack_type, batch_size=256, optim="SGD", cross_device=False, collab=cb, consider_min_round=240)
                    # federated_learning(num_clients=40, n_attackers=attackers, aggregation=agg, n_round=250, dataset='FashionMNIST', attack_type=attack_type, batch_size=256, optim="SGD", cross_device=False, collab=0, consider_min_round=225)
                    # federated_learning(num_clients=40, n_attackers=attackers, aggregation=agg, n_round=250, dataset='FashionMNIST_3DNN', attack_type=attack_type, batch_size=256, optim="SGD", cross_device=False, collab=0, consider_min_round=225)
                    # federated_learning(num_clients=50, n_attackers=attackers, aggregation=agg, n_round=255, dataset='CIFAR10', attack_type=attack_type, batch_size=250)
                    # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=1000, dataset='SVHN', attack_type=attack_type, batch_size=64, optim="SGD")
                    # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=275, dataset='MNIST', attack_type=attack_type, batch_size=256, optim="SGD", cross_device=False, collab=cb, consider_min_round=248)
                    # federated_learning(num_clients=100, n_attackers=attackers, aggregation=agg, n_round=500, dataset='PURCHASE', attack_type=attack_type, batch_size=128, optim="SGD", cross_device=False, collab=cb, consider_min_round=450)
                    # federated_learning(num_clients=40, n_attackers=attackers, aggregation=agg, n_round=300, dataset='CIFAR100', attack_type=attack_type, batch_size=250, optim="Adam")
                    # federated_learning(num_clients=200, n_attackers=attackers, aggregation=agg, n_round=50, dataset='SYMBIPREDICT', attack_type=attack_type, batch_size=10, optim="SGD", collab=0)
                    # federated_learning(num_clients=200, n_attackers=attackers, aggregation=agg, n_round=50, dataset='SYMBIPREDICT', attack_type=attack_type, batch_size=10, optim="Adam", cross_device=False, collab=0)
                    federated_learning(num_clients=30, n_attackers=attackers, aggregation=agg, n_round=1000, dataset='HAR', attack_type=attack_type, batch_size=32, optim="SGD", cross_device=False, collab=0, consider_min_round=900)
                    # torch.cuda.memory._dump_snapshot("cifar10.pickle")





Subjects found: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30]
Number of clients (subjects): 30
For round 0, training done
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
local_models_data[local_machine] torch.Size([3372]) delta torch.Size([3372])
For round 0, attack done, Lenght of local_models_data: 30
Global model data device cuda:0
For round 0, aggregation done
Time 2025-04-15 06:55:50.537235: Accuracy on round 0, total 30, attackers 8, attack_type XFED_u