In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [2]:
# Import necessary libraries
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
import torch.nn as nn
from torchvision import models
import torch.optim as optim

# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset
dataset_train = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Define superclasses and subclasses
superclasses = {
    0: [4, 30, 55, 72, 95],  # aquatic mammals
    1: [1, 32, 67, 73, 91],  # fish
    2: [54, 62, 70, 82, 92],  # flowers
    3: [9, 10, 16, 28, 61],  # food containers
    4: [0, 51, 53, 57, 83],  # fruit and vegetables
    5: [22, 39, 40, 86, 87],  # household electrical devices
    6: [5, 20, 25, 84, 94],  # household furniture
    7: [6, 7, 14, 18, 24],  # insects
    8: [3, 42, 43, 88, 97],  # large carnivores
    9: [12, 17, 37, 68, 76],  # large man-made outdoor things
    10: [23, 33, 49, 60, 71],  # large natural outdoor scenes
    11: [15, 19, 21, 31, 38],  # medium-sized mammals
    12: [34, 63, 64, 66, 75],  # non-insect invertebrates
    13: [26, 45, 77, 79, 99],  # people
    14: [2, 11, 35, 46, 98],  # reptiles
    15: [27, 29, 44, 78, 93],  # small mammals
    16: [36, 50, 65, 74, 80],  # trees
    17: [8, 13, 48, 58, 90],  # vehicles 1
    18: [41, 66, 69, 81, 89],  # vehicles 2
    19: [47, 50, 52, 56, 59],  # household furniture
}

# Add missing subclasses
superclasses[16].append(96)  # Add "willow tree" to the trees superclass
superclasses[3].append(85)   # Add "plate" to the food containers superclass

# Function to map subclass to its corresponding superclass
def get_superclass(subclass, superclasses):
    for superclass in superclasses.keys():
        subclasses=superclasses[superclass]
        if subclass in subclasses:
            return superclass
    return None

### preprocessing for 4 clients

In [4]:
# from torch.utils.data import Subset, DataLoader, random_split
# import numpy as np

# # Function to randomly split dataset by superclass for 4 clients
# def filter_dataset_randomly_by_superclass_4_clients(dataset, superclasses):
#     client_indices = [[] for _ in range(4)]  # Create an empty list for each client (4 clients)

#     # Shuffle and split images randomly for each superclass
#     for subclasses in superclasses.values():
#         superclass_indices = []
        
#         # Gather all images belonging to this superclass
#         for idx, (data, target) in enumerate(dataset):
#             if target in subclasses:
#                 superclass_indices.append(idx)
        
#         # If no images found for this superclass, continue to the next
#         if not superclass_indices:
#             continue
        
#         # Shuffle the indices for randomness
#         np.random.shuffle(superclass_indices)
        
#         # Split the indices between the 4 clients
#         split_size = len(superclass_indices) // 4
#         for i in range(4):
#             start = i * split_size
#             end = (i + 1) * split_size if i < 3 else len(superclass_indices)  # Ensure last client gets the remaining data
#             client_indices[i].extend(superclass_indices[start:end])
    
#     # Return 4 subsets: one for each client
#     return [Subset(dataset, indices) for indices in client_indices]

# def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
#     # Step 1: Split the dataset for 4 clients
#     client_datasets = filter_dataset_randomly_by_superclass_4_clients(dataset_train, superclasses)

#     # Step 2: Create validation splits for each client
#     val_split = 0.1  # 10% for validation

#     # Create training and validation splits for each client
#     train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

#     # Separate train and validation datasets
#     train_datasets = [split[0] for split in train_val_splits]
#     val_datasets = [split[1] for split in train_val_splits]

#     # Step 3: Create DataLoaders for each client
#     client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
#     val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
#     test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#     return client_loaders, val_loaders, test_loader

# # Parameters
# batch_size = 256
# num_workers = 4

# # Call the function to initialize DataLoaders for 4 clients
# client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)

### preprocessing for 2 clients

In [None]:
# from torch.utils.data import Subset, DataLoader, random_split
# import numpy as np

# # Function to randomly split dataset by superclass for 2 clients
# def filter_dataset_randomly_by_superclass_2_clients(dataset, superclasses):
#     client_indices = [[] for _ in range(2)]
    
#     for subclasses in superclasses.values():
#         superclass_indices = []
        
#         # Collect indices for samples in the current superclass
#         for idx, (data, target) in enumerate(dataset):
#             if target in subclasses:
#                 # Map to superclass and assign label
#                 superclass_label = get_superclass(target, superclasses)
#                 if superclass_label is not None:
#                     dataset.targets[idx] = superclass_label  # Remap the label
#                     superclass_indices.append(idx)
        
#         # Split indices between the 2 clients
#         np.random.shuffle(superclass_indices)
#         split_size = len(superclass_indices) // 2
#         for i in range(2):
#             client_indices[i].extend(superclass_indices[i * split_size:(i + 1) * split_size])
    
#     return [Subset(dataset, indices) for indices in client_indices]


# def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
#     # Step 1: Split the dataset for 2 clients
#     client_datasets = filter_dataset_randomly_by_superclass_2_clients(dataset_train, superclasses)

#     # Step 2: Create validation splits for each client
#     val_split = 0.1  # 10% for validation

#     # Create training and validation splits for each client
#     train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

#     # Separate train and validation datasets
#     train_datasets = [split[0] for split in train_val_splits]
#     val_datasets = [split[1] for split in train_val_splits]

#     # Step 3: Create DataLoaders for each client
#     client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
#     val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
#     test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#     return client_loaders, val_loaders, test_loader

# # Parameters
# batch_size = 256
# num_workers = 4

# # Call the function to initialize DataLoaders for 2 clients
# client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)

In [4]:
from torch.utils.data import Subset, DataLoader, random_split
import numpy as np

# Updated function to map CIFAR-100 subclasses to superclasses consistently across all data splits
def filter_dataset_by_superclass(dataset, superclasses):
    # Remap labels to superclasses
    for idx, (_, target) in enumerate(dataset):
        # Map CIFAR-100 label to its superclass
        superclass_label = get_superclass(target, superclasses)
        if superclass_label is not None:
            dataset.targets[idx] = superclass_label  # Update target label to superclass

    # Randomly split dataset indices for two clients
    client_indices = [[] for _ in range(2)]
    superclass_indices = [[] for _ in range(len(superclasses))]  # Separate by superclass for balanced splits

    for superclass, subclasses in superclasses.items():
        for idx, target in enumerate(dataset.targets):
            if target == superclass:
                superclass_indices[superclass - 1].append(idx)  # Group indices by superclass

    # Split each superclass evenly between the two clients
    for indices in superclass_indices:
        np.random.shuffle(indices)
        half = len(indices) // 2
        client_indices[0].extend(indices[:half])
        client_indices[1].extend(indices[half:])

    return [Subset(dataset, indices) for indices in client_indices]

# Updated data loader function to apply consistent label mapping across train, validation, and test sets
def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
    # Step 1: Map dataset_train labels to superclasses and split for 2 clients
    client_datasets = filter_dataset_by_superclass(dataset_train, superclasses)

    # Step 2: Create validation splits for each client
    val_split = 0.1  # 10% for validation

    train_val_splits = [
        random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)])
        for client_dataset in client_datasets
    ]

    train_datasets = [split[0] for split in train_val_splits]
    val_datasets = [split[1] for split in train_val_splits]

    # Step 3: Remap test set labels to superclasses
    filter_dataset_by_superclass(dataset_test, superclasses)

    # Step 4: Create DataLoaders
    client_loaders = [
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        for train_dataset in train_datasets
    ]
    val_loaders = [
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        for val_dataset in val_datasets
    ]
    test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return client_loaders, val_loaders, test_loader

In [5]:
# Parameters for data loaders
batch_size = 256
num_workers = 4

# Run get_data_loaders to create client loaders, validation loaders, and test loader
client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)

In [6]:
# Check label mapping and distribution for each client
from collections import Counter

# Verify client datasets
for i, client_loader in enumerate(client_loaders):
    print(f"\nClient {i+1} label distribution in train dataset:")
    all_labels = []
    for inputs, labels in client_loader:
        all_labels.extend(labels.tolist())
    label_counts = Counter(all_labels)
    print(sorted(label_counts.items()))  # This should show counts for labels [0, 19]

# Check label mapping and distribution for validation datasets
for i, val_loader in enumerate(val_loaders):
    print(f"\nClient {i+1} label distribution in validation dataset:")
    all_val_labels = []
    for inputs, labels in val_loader:
        all_val_labels.extend(labels.tolist())
    val_label_counts = Counter(all_val_labels)
    print(sorted(val_label_counts.items()))  # This should show counts for labels [0, 19]

# Check label mapping and distribution in test dataset
print("\nTest dataset label distribution:")
all_test_labels = []
for inputs, labels in test_loader:
    all_test_labels.extend(labels.tolist())
test_label_counts = Counter(all_test_labels)
print(sorted(test_label_counts.items()))  # This should show counts for labels [0, 19]


Client 1 label distribution in train dataset:
[(0, 1143), (1, 1125), (2, 1116), (3, 1350), (4, 1126), (5, 1119), (6, 1124), (7, 1125), (8, 1127), (9, 1113), (10, 1136), (11, 1149), (12, 1120), (13, 1128), (14, 1108), (15, 1139), (16, 1342), (17, 1126), (18, 899), (19, 885)]

Client 2 label distribution in train dataset:
[(0, 1135), (1, 1132), (2, 1120), (3, 1351), (4, 1117), (5, 1115), (6, 1131), (7, 1133), (8, 1115), (9, 1107), (10, 1127), (11, 1141), (12, 1132), (13, 1130), (14, 1128), (15, 1131), (16, 1340), (17, 1125), (18, 893), (19, 897)]

Client 1 label distribution in validation dataset:
[(0, 107), (1, 125), (2, 134), (3, 150), (4, 124), (5, 131), (6, 126), (7, 125), (8, 123), (9, 137), (10, 114), (11, 101), (12, 130), (13, 122), (14, 142), (15, 111), (16, 158), (17, 124), (18, 101), (19, 115)]

Client 2 label distribution in validation dataset:
[(0, 115), (1, 118), (2, 130), (3, 149), (4, 133), (5, 135), (6, 119), (7, 117), (8, 135), (9, 143), (10, 123), (11, 109), (12, 118),

###  preprocessing for 8 clients

In [4]:
# from torch.utils.data import Subset, DataLoader, random_split
# import numpy as np

# # Function to randomly split dataset by superclass for 8 clients
# def filter_dataset_randomly_by_superclass_8_clients(dataset, superclasses):
#     client_indices = [[] for _ in range(8)]  # Create an empty list for each client (8 clients)

#     # Shuffle and split images randomly for each superclass
#     for subclasses in superclasses.values():
#         superclass_indices = []
        
#         # Gather all images belonging to this superclass
#         for idx, (data, target) in enumerate(dataset):
#             if target in subclasses:
#                 superclass_indices.append(idx)
        
#         # If no images found for this superclass, continue to the next
#         if not superclass_indices:
#             continue
        
#         # Shuffle the indices for randomness
#         np.random.shuffle(superclass_indices)
        
#         # Split the indices between the 8 clients
#         split_size = len(superclass_indices) // 8
#         for i in range(8):
#             start = i * split_size
#             end = (i + 1) * split_size if i < 7 else len(superclass_indices)  # Ensure last client gets the remaining data
#             client_indices[i].extend(superclass_indices[start:end])
    
#     # Return 8 subsets: one for each client
#     return [Subset(dataset, indices) for indices in client_indices]

# def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
#     # Step 1: Split the dataset for 8 clients
#     client_datasets = filter_dataset_randomly_by_superclass_8_clients(dataset_train, superclasses)

#     # Step 2: Create validation splits for each client
#     val_split = 0.1  # 10% for validation

#     # Create training and validation splits for each client
#     train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

#     # Separate train and validation datasets
#     train_datasets = [split[0] for split in train_val_splits]
#     val_datasets = [split[1] for split in train_val_splits]

#     # Step 3: Create DataLoaders for each client
#     client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
#     val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
#     test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#     return client_loaders, val_loaders, test_loader

# # Parameters
# batch_size = 256
# num_workers = 4

# # Call the function to initialize DataLoaders for 8 clients
# client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)


###  preprocessing for 16 clients

In [4]:
# from torch.utils.data import Subset, DataLoader, random_split
# import numpy as np

# # Function to randomly split dataset by superclass for 16 clients
# def filter_dataset_randomly_by_superclass_16_clients(dataset, superclasses):
#     client_indices = [[] for _ in range(16)]  # Create an empty list for each client (16 clients)

#     # Shuffle and split images randomly for each superclass
#     for subclasses in superclasses.values():
#         superclass_indices = []
        
#         # Gather all images belonging to this superclass
#         for idx, (data, target) in enumerate(dataset):
#             if target in subclasses:
#                 superclass_indices.append(idx)
        
#         # If no images found for this superclass, continue to the next
#         if not superclass_indices:
#             continue
        
#         # Shuffle the indices for randomness
#         np.random.shuffle(superclass_indices)
        
#         # Split the indices between the 16 clients
#         split_size = len(superclass_indices) // 16
#         for i in range(16):
#             start = i * split_size
#             end = (i + 1) * split_size if i < 15 else len(superclass_indices)  # Ensure last client gets the remaining data
#             client_indices[i].extend(superclass_indices[start:end])
    
#     # Return 16 subsets: one for each client
#     return [Subset(dataset, indices) for indices in client_indices]

# def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
#     # Step 1: Split the dataset for 16 clients
#     client_datasets = filter_dataset_randomly_by_superclass_16_clients(dataset_train, superclasses)

#     # Step 2: Create validation splits for each client
#     val_split = 0.1  # 10% for validation

#     # Create training and validation splits for each client
#     train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

#     # Separate train and validation datasets
#     train_datasets = [split[0] for split in train_val_splits]
#     val_datasets = [split[1] for split in train_val_splits]

#     # Step 3: Create DataLoaders for each client
#     client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
#     val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
#     test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#     return client_loaders, val_loaders, test_loader

# # Parameters
# batch_size = 256
# num_workers = 4

# # Call the function to initialize DataLoaders for 16 clients
# client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)



# # from torch.utils.data import Subset, DataLoader, random_split
# # import numpy as np

# # # Function to randomly split dataset by superclass for 16 clients
# # def filter_dataset_randomly_by_superclass_16_clients(dataset, superclasses):
# #     client_indices = [[] for _ in range(16)]  # Create an empty list for each client (16 clients)

# #     # Shuffle and split images randomly for each superclass
# #     for subclasses in superclasses.values():
# #         superclass_indices = []
        
# #         # Gather all images belonging to this superclass
# #         for idx, (data, target) in enumerate(dataset):
# #             if target in subclasses:
# #                 superclass_indices.append(idx)
        
# #         # If no images found for this superclass, continue to the next
# #         if not superclass_indices:
# #             continue
        
# #         # Shuffle the indices for randomness
# #         np.random.shuffle(superclass_indices)
        
# #         # Split the indices between the 16 clients
# #         split_size = len(superclass_indices) // 16
# #         for i in range(16):
# #             start = i * split_size
# #             end = (i + 1) * split_size if i < 15 else len(superclass_indices)  # Ensure last client gets the remaining data
# #             client_indices[i].extend(superclass_indices[start:end])
    
# #     # Return 16 subsets: one for each client
# #     return [Subset(dataset, indices) for indices in client_indices]

# # def get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses):
# #     # Step 1: Split the dataset for 16 clients
# #     client_datasets = filter_dataset_randomly_by_superclass_16_clients(dataset_train, superclasses)

# #     # Step 2: Create validation splits for each client
# #     val_split = 0.1  # 10% for validation

# #     # Create training and validation splits for each client
# #     train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

# #     # Separate train and validation datasets
# #     train_datasets = [split[0] for split in train_val_splits]
# #     val_datasets = [split[1] for split in train_val_splits]

# #     # Step 3: Create DataLoaders for each client
# #     client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
# #     val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
# #     test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# #     return client_loaders, val_loaders, test_loader

# # # Parameters
# # batch_size = 256
# # num_workers = 4

# # # Call the function to initialize DataLoaders for 16 clients
# # client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)

In [12]:
# Check the number of samples in each client's train and validation loaders
client_loader_sizes = [len(loader.dataset) for loader in client_loaders]
val_loader_sizes = [len(loader.dataset) for loader in val_loaders]

print("Client Train Loader Sizes:", client_loader_sizes)
print("Client Validation Loader Sizes:", val_loader_sizes)

# Check the number of samples in the test loader
print("Test Loader Size:", len(test_loader.dataset))

Client Train Loader Sizes: [22500, 22500]
Client Validation Loader Sizes: [2500, 2500]
Test Loader Size: 10000


In [13]:
from collections import Counter

def check_superclass_coverage(client_loaders, superclasses):
    for i, loader in enumerate(client_loaders):
        superclass_counts = Counter()

        for _, labels in loader:
            for label in labels:
                superclass_counts[label.item()] += 1

        print(f"Client {i + 1} superclass distribution:")
        for superclass in range(len(superclasses)):
            count = superclass_counts.get(superclass, 0)
            print(f"Superclass {superclass}: {count} samples")
            if count == 0:
                print(f"Warning: Client {i + 1} is missing samples for superclass {superclass}")
        print("\n")

# Checking superclass distribution for each client across train, validation, and test sets
print("Checking superclass coverage for each client...\n")

def print_superclass_distribution(client_datasets, stage):
    for i, loader in enumerate(client_datasets):
        superclass_counts = Counter()
        for _, targets in loader:
            superclass_counts.update(targets.tolist())

        print(f"\nClient {i + 1} {stage} superclass distribution:")
        for superclass, count in sorted(superclass_counts.items()):
            print(f"Superclass {superclass}: {count} samples")
        print("\n")

# Check distribution in train datasets
print_superclass_distribution(client_loaders, "train")

# Check distribution in validation datasets
print_superclass_distribution(val_loaders, "validation")

# Check distribution in test dataset (since test_loader is shared, it only needs one output)
print("Overall test superclass distribution:")
test_superclass_counts = Counter()
for _, targets in test_loader:
    test_superclass_counts.update(targets.tolist())

for superclass, count in sorted(test_superclass_counts.items()):
    print(f"Superclass {superclass}: {count} samples")

Checking superclass coverage for each client...


Client 1 train superclass distribution:
Superclass 0: 1143 samples
Superclass 1: 1125 samples
Superclass 2: 1116 samples
Superclass 3: 1350 samples
Superclass 4: 1126 samples
Superclass 5: 1119 samples
Superclass 6: 1124 samples
Superclass 7: 1125 samples
Superclass 8: 1127 samples
Superclass 9: 1113 samples
Superclass 10: 1136 samples
Superclass 11: 1149 samples
Superclass 12: 1120 samples
Superclass 13: 1128 samples
Superclass 14: 1108 samples
Superclass 15: 1139 samples
Superclass 16: 1342 samples
Superclass 17: 1126 samples
Superclass 18: 899 samples
Superclass 19: 885 samples



Client 2 train superclass distribution:
Superclass 0: 1135 samples
Superclass 1: 1132 samples
Superclass 2: 1120 samples
Superclass 3: 1351 samples
Superclass 4: 1117 samples
Superclass 5: 1115 samples
Superclass 6: 1131 samples
Superclass 7: 1133 samples
Superclass 8: 1115 samples
Superclass 9: 1107 samples
Superclass 10: 1127 samples
Superclass 11: 1141 s

In [14]:
import pickle

# Define the number of clients
num_clients = 2  # Change this to 4, 8, or 16 for other configurations

# Define the save path with the number of clients in the filename
save_path = f'/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/client_datasets_{num_clients}_clients.pkl'

# Save split datasets
with open(save_path, 'wb') as f:
    pickle.dump([client_loaders, val_loaders, test_loader], f)

print(f"Datasets for {num_clients} clients saved successfully to {save_path}")

Datasets for 2 clients saved successfully to /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/client_datasets_2_clients.pkl


In [15]:
# Define the number of clients you want to load
num_clients = 2  # Adjust to the desired configuration (2, 4, 8, or 16 clients)

# Load the saved dataset splits
load_path = f'/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/client_datasets_2_clients.pkl' # When loading, match the num_clients variable
with open(load_path, 'rb') as f:
    client_loaders, val_loaders, test_loader = pickle.load(f)

print(f"Datasets for {num_clients} clients loaded successfully from {load_path}")

Datasets for 2 clients loaded successfully from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/client_datasets_2_clients.pkl


###  Model Preparation

In [10]:
def prepare_model(num_classes=20, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

### Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

### 4.(a) Federated Training Functions (Federated Averaging (FedAvg))

In [None]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)  
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)  
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0):
#     model = prepare_model(num_classes=20).to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation phase (handling multiple validation loaders)
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:
#                 for inputs, labels in val_loader:
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    
#     # Optionally save the final model
#     num_clients = 2  # Adjust this for 2, 4, 8, or 16 clients as needed
#     model_save_path = f'/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/final_federated_model_{num_clients}_clients.pth'
#     torch.save(model.state_dict(), model_save_path)
#     print(f"Model saved to {model_save_path}")
    
#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4.(b) Federated Training Functions (Adaptive Federated Optimization(AdaptFedOpt))

In [None]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)  
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)  
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts, global_model, beta=0.9):
#     avg_state_dict = global_model.state_dict()
#     for key in avg_state_dict.keys():
#         if avg_state_dict[key].dtype == torch.long:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key], dtype=torch.float32)
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key].float() / len(state_dicts)
#             avg_state_dict[key] = avg_state_dict[key].long()  # Convert back to long if necessary
#         else:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key])
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key] / len(state_dicts)
#             avg_state_dict[key] = beta * avg_state_dict[key] + (1 - beta) * global_model.state_dict()[key]
#     return avg_state_dict

# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0):
#     model = prepare_model(num_classes=20).to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     # Initialize server optimizer for global model
#     # Use the same learning rate as the clients for the server-side optimizer
#     server_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts, model, beta=0.9)
#         model.load_state_dict(avg_state_dict)

#         # Server-side optimization step
#         server_optimizer.zero_grad()
#         server_optimizer.step()  # Apply server optimizer to improve global model
        

#         # Validation phase (handling multiple validation loaders)
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:  # Iterate over multiple validation loaders
#                 for inputs, labels in val_loader:  # Iterate over batches within each loader
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#             f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')


#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4.(c) Federated Training Functions (FedProx (Federated Proximal))

In [None]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, global_model, mu=0.01, epochs=1):
#     model.to(device)
#     model.train()
#     global_weights = global_model.state_dict()  # Get the global model's weights
    
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
            
#             # Add the proximal term (FedProx)
#             proximal_term = 0.0
#             for param, global_param in zip(model.parameters(), global_model.parameters()):
#                 proximal_term += (param - global_param).norm(2)
#             loss += (mu / 2) * proximal_term
            
#             loss.backward()
#             optimizer.step()
    
#     return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict


# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0, mu=0.01):
#     model = prepare_model(num_classes=20).to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer, model, mu)
#             state_dicts.append(client_state_dict)
        
#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation phase (same as before)
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:
#                 for inputs, labels in val_loader:
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase (same as before)
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4.(a) Federated Training Functions (Weighted Federated Averaging (weightedFedAvg))

In [16]:
import torch
import torch.optim as optim

def train_client(model, train_loader, criterion, optimizer, epochs=1):
    model.to(device)  
    model.train()
    for _ in range(epochs):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# Ensure that client_loader_sizes is already defined before this
client_data_sizes = client_loader_sizes

def weighted_federated_averaging(state_dicts, client_data_sizes):
    """Performs WeightedFedAvg aggregation of the model state_dicts based on client data sizes."""
    
    if len(state_dicts) != len(client_data_sizes):
        raise ValueError("Number of state dicts and client data sizes must match.")
    
    total_data = sum(client_data_sizes)
    avg_state_dict = {}
    
    # Initialize the avg_state_dict with zero values, and ensure they are float type
    for key in state_dicts[0].keys():
        avg_state_dict[key] = torch.zeros_like(state_dicts[0][key], dtype=torch.float32)  # Ensure float type
    
    # Weighted averaging of the state_dicts
    for i, state_dict in enumerate(state_dicts):
        weight = client_data_sizes[i] / total_data
        for key in state_dict.keys():
            avg_state_dict[key] += state_dict[key].float() * weight  # Convert state_dict to float before multiplication

    return avg_state_dict



def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0):
    model = prepare_model(num_classes=20).to(device)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for round in range(num_epochs):
        print(f"Starting federated learning round {round+1}/{num_epochs}...")
        state_dicts = []
        for i, client_loader in enumerate(client_loaders):
            print(f"Training model for client {i+1}...")
            client_model = prepare_model().to(device)
            client_model.load_state_dict(model.state_dict())
            optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
            client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
            state_dicts.append(client_state_dict)

        avg_state_dict = weighted_federated_averaging(state_dicts, client_data_sizes)

        model.load_state_dict(avg_state_dict)
        model.to(device)
        
        # Validation phase (handling multiple validation loaders)
        print("Validating model...")
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_loader in val_loaders:
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
        val_accuracy = 100 * val_correct / val_total

        print(f'Federated Round {round+1}/{num_epochs}, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    # Test phase
    print("Testing model...")
    model.eval()
    test_running_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss = test_running_loss / len(test_loader)
    test_accuracy = 100 * test_correct / test_total

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    return model, val_loss, val_accuracy, test_loss, test_accuracy

### Define the Logging Function

In [17]:
import csv
import os

def log_experiment_result(filename, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy):
    file_exists = os.path.isfile(filename)
    
    with open(filename, 'a', newline='') as csvfile:
        fieldnames = ['num_clients', 'num_epochs', 'learning_rate', 'patience', 'min_delta', 'val_loss', 'val_accuracy', 'test_loss', 'test_accuracy']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()  # Write header only once
        
        writer.writerow({
            'num_clients': num_clients,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'patience': patience,
            'min_delta': min_delta,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        })

### Federated Learning Experiment Configuration

In [18]:
import pickle
import os

# Parameters
num_clients = 2  # Update this to match the number of clients
num_epochs = 40
learning_rate = 0.0001
patience = 5
min_delta = 0.01
batch_size = 256
num_workers = 4
log_file = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/experiment_results.csv'

# Path for saving/loading pre-split datasets
dataset_save_path = f'/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/client_datasets_2_clients.pkl'

# Try to load pre-split datasets
try:
    with open(dataset_save_path, 'rb') as f:
        client_loaders, val_loaders, test_loader = pickle.load(f)
    print("Pre-split datasets loaded successfully!")
except FileNotFoundError:
    # If not found, split the dataset and save it
    print("Pre-split datasets not found. Splitting dataset...")
    client_loaders, val_loaders, test_loader = get_data_loaders(batch_size, num_workers, dataset_train, dataset_test, superclasses)
    
    # Save the dataset splits after the first split
    os.makedirs(os.path.dirname(dataset_save_path), exist_ok=True)
    with open(dataset_save_path, 'wb') as f:
        pickle.dump([client_loaders, val_loaders, test_loader], f)
    print("Datasets saved successfully!")

# Start federated training
print("Starting federated training...")
model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model(
    client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate, patience, min_delta)
print("Federated training complete.")

# Log the results
log_experiment_result(log_file, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy)
print("Experiment results logged.")

Pre-split datasets loaded successfully!
Starting federated training...




Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Validating model...
Federated Round 1/40, Val Loss: 1.4788, Val Accuracy: 54.42%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Validating model...
Federated Round 2/40, Val Loss: 1.2134, Val Accuracy: 61.76%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Validating model...
Federated Round 3/40, Val Loss: 1.1079, Val Accuracy: 65.28%
Starting federated learning round 4/40...
Training model for client 1...
Training model for client 2...
Validating model...
Federated Round 4/40, Val Loss: 1.0374, Val Accuracy: 67.82%
Starting federated learning round 5/40...
Training model for client 1...
Training model for client 2...
Validating model...
Federated Round 5/40, Val Loss: 0.9982, Val Accuracy: 68.60%
Starting federated learning round 6/40...
Training model for client 1...
Tr