In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [2]:
# Import necessary libraries
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
import torch.nn as nn
from torchvision import models
import torch.optim as optim

# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset
dataset_train = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


### This code performs two different preprocessing methods for splitting the CIFAR-100 dataset between two clients, simulating two sorting plants for plastic recycling:

### Preprocessing Method 1: Major/Minor Superclass Split
- Each client handles **different superclasses** as "major" and "minor."
- Client 1 gets superclasses 1-10 as major and 11-20 as minor.
- Client 2 gets superclasses 11-20 as major and 1-10 as minor.
- Ensures non-overlapping data between clients based on superclasses., ensuring no object (image) is shared between clients, reflecting real-world sorting conditions.

In [3]:
# # # Transformations
# # transform = transforms.Compose([
# #     transforms.RandomResizedCrop(224),
# #     transforms.RandomHorizontalFlip(),
# #     transforms.ToTensor(),
# #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# # ])

# # # Load CIFAR-100 dataset
# # dataset_train = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
# # dataset_test = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)


# # Define superclasses and subclasses
# superclasses = {
#     1: [4, 30, 55, 72, 95],  # aquatic mammals
#     2: [1, 32, 67, 73, 91],  # fish
#     3: [54, 62, 70, 82, 92],  # flowers
#     4: [9, 10, 16, 28, 61],  # food containers
#     5: [0, 51, 53, 57, 83],  # fruit and vegetables
#     6: [22, 39, 40, 86, 87],  # household electrical devices
#     7: [5, 20, 25, 84, 94],  # household furniture
#     8: [6, 7, 14, 18, 24],  # insects
#     9: [3, 42, 43, 88, 97],  # large carnivores
#     10: [12, 17, 37, 68, 76],  # large man-made outdoor things
#     11: [23, 33, 49, 60, 71],  # large natural outdoor scenes
#     12: [15, 19, 21, 31, 38],  # medium-sized mammals
#     13: [34, 63, 64, 66, 75],  # non-insect invertebrates
#     14: [26, 45, 77, 79, 99],  # people
#     15: [2, 11, 35, 46, 98],  # reptiles
#     16: [27, 29, 44, 78, 93],  # small mammals
#     17: [36, 50, 65, 74, 80],  # trees
#     18: [8, 13, 48, 58, 90],  # vehicles 1
#     19: [41, 66, 69, 81, 89],  # vehicles 2
#     20: [47, 50, 52, 56, 59],  # household furniture
# }

# # Function to map subclass to its corresponding superclass
# def get_superclass(subclass, superclasses):
#     for superclass, subclasses in superclasses.items():
#         if subclass in subclasses:
#             return superclass
#     return None

# # Function to filter dataset by superclasses
# def filter_dataset_by_superclass(dataset, superclasses, major_classes, minor_classes):
#     major_indices = []
#     minor_indices = []
#     for idx, (data, target) in enumerate(dataset):
#         subclass = target
#         superclass = get_superclass(subclass, superclasses)
#         if superclass in major_classes and major_classes[superclass] > 0:
#             major_indices.append(idx)
#             major_classes[superclass] -= 1
#         elif superclass in minor_classes and minor_classes[superclass] > 0:
#             minor_indices.append(idx)
#             minor_classes[superclass] -= 1

#     major_subset = Subset(dataset, major_indices)
#     minor_subset = Subset(dataset, minor_indices)
#     return major_subset, minor_subset



# ###### use for 80/20 split
# # # Define the number of images to sample from each superclass for both clients
# # major_count = 2000
# # minor_count = 500


# # ###### use for 50/50 split
# # # 50/50 Split for Major and Minor Classes
# # # Each superclass will contribute 50% to major and 50% to minor superclasses
# major_count = 1250  # 50% of 2500 (total images per superclass)
# minor_count = 1250  # 50% of 2500



# # Define which superclasses are major for each client
# client1_major_classes = {k: major_count for k in range(1, 11)}
# client1_minor_classes = {k: minor_count for k in range(11, 21)}

# client2_major_classes = {k: major_count for k in range(11, 21)}
# client2_minor_classes = {k: minor_count for k in range(1, 11)}

# # Filter dataset for each client
# client1_major, client1_minor = filter_dataset_by_superclass(dataset_train, superclasses, client1_major_classes, client1_minor_classes)
# client2_major, client2_minor = filter_dataset_by_superclass(dataset_train, superclasses, client2_major_classes, client2_minor_classes)

# # Combine major and minor subsets for each client
# client1_indices = list(client1_major.indices) + list(client1_minor.indices)
# client2_indices = list(client2_major.indices) + list(client2_minor.indices)

# client1_dataset = Subset(dataset_train, client1_indices)
# client2_dataset = Subset(dataset_train, client2_indices)

# # Split a portion of client1_dataset for validation
# val_split = 0.1  # 10% for validation


# # val_size = int(val_split * len(client1_dataset))
# # train_size = len(client1_dataset) - val_size

# # train_dataset, val_dataset = random_split(client1_dataset, [train_size, val_size])

# # # Update client1_dataset to exclude validation data
# # client1_dataset = train_dataset


# # For client1
# val_size_client1 = int(val_split * len(client1_dataset))
# train_size_client1 = len(client1_dataset) - val_size_client1
# train_dataset_client1, val_dataset_client1 = random_split(client1_dataset, [train_size_client1, val_size_client1])

# # Update client1_dataset to exclude validation data
# client1_dataset = train_dataset_client1

# # For client2
# val_size_client2 = int(val_split * len(client2_dataset))
# train_size_client2 = len(client2_dataset) - val_size_client2
# train_dataset_client2, val_dataset_client2 = random_split(client2_dataset, [train_size_client2, val_size_client2])

# # Update client2_dataset to exclude validation data
# client2_dataset = train_dataset_client2


# # Create DataLoaders for each client and the validation set
# batch_size = 64
# num_workers = 4

# client_loaders = [
#     DataLoader(client1_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers),
#     DataLoader(client2_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# ]

# val_loaders = [
#     DataLoader(val_dataset_client1, batch_size=batch_size, shuffle=False, num_workers=num_workers),
#     DataLoader(val_dataset_client2, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# ]

# test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# # Print statements to verify the data distribution
# print(f'Client 1 dataset size (after validation split): {len(client1_dataset)}')
# print(f'Client 2 dataset size (after validation split): {len(client2_dataset)}')
# print(f'Client 1 validation dataset size: {len(val_dataset_client1)}')
# print(f'Client 2 validation dataset size: {len(val_dataset_client2)}')
# print(f'Test dataset size: {len(dataset_test)}')

# # Check some samples from each client to ensure the data is correctly assigned
# for i, (images, labels) in enumerate(client_loaders[0]):
#     print(f'Client 1 batch {i+1}:')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels: {labels}')
#     break

# for i, (images, labels) in enumerate(client_loaders[1]):
#     print(f'Client 2 batch {i+1}:')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels: {labels}')
#     break

# print("Preprocessing Method 1: Major/Minor Superclass Split completed successfully.")

### Preprocessing Method 2: Random Subclass Image Assignment
- Both clients handle **all 20 superclasses**.
- Images within each superclass are **randomly split** between the clients.
- Ensures that both clients see all subclasses, but the specific images within each subclass are unique to each client.

In [3]:
# # Define superclasses and subclasses
# superclasses = {
#     1: [4, 30, 55, 72, 95],  # aquatic mammals
#     2: [1, 32, 67, 73, 91],  # fish
#     3: [54, 62, 70, 82, 92],  # flowers
#     4: [9, 10, 16, 28, 61],  # food containers
#     5: [0, 51, 53, 57, 83],  # fruit and vegetables
#     6: [22, 39, 40, 86, 87],  # household electrical devices
#     7: [5, 20, 25, 84, 94],  # household furniture
#     8: [6, 7, 14, 18, 24],  # insects
#     9: [3, 42, 43, 88, 97],  # large carnivores
#     10: [12, 17, 37, 68, 76],  # large man-made outdoor things
#     11: [23, 33, 49, 60, 71],  # large natural outdoor scenes
#     12: [15, 19, 21, 31, 38],  # medium-sized mammals
#     13: [34, 63, 64, 66, 75],  # non-insect invertebrates
#     14: [26, 45, 77, 79, 99],  # people
#     15: [2, 11, 35, 46, 98],  # reptiles
#     16: [27, 29, 44, 78, 93],  # small mammals
#     17: [36, 50, 65, 74, 80],  # trees
#     18: [8, 13, 48, 58, 90],  # vehicles 1
#     19: [41, 66, 69, 81, 89],  # vehicles 2
#     20: [47, 50, 52, 56, 59],  # household furniture
# }

# # Function to map subclass to its corresponding superclass
# def get_superclass(subclass, superclasses):
#     for superclass, subclasses in superclasses.items():
#         if subclass in subclasses:
#             return superclass
#     return None

# # Function to randomly split dataset by superclass
# def filter_dataset_randomly_by_superclass(dataset, superclasses):
#     client1_indices = []
#     client2_indices = []
    
#     # Shuffle and split images randomly for each superclass
#     for superclass, subclasses in superclasses.items():
#         superclass_indices = []
        
#         # Gather all images belonging to this superclass
#         for idx, (data, target) in enumerate(dataset):
#             if target in subclasses:
#                 superclass_indices.append(idx)
        
#         # If no images found for this superclass, continue to the next
#         if not superclass_indices:
#             continue
        
#         # Shuffle the indices for randomness
#         np.random.shuffle(superclass_indices)
        
#         # Split the indices between the two clients
#         half_split = len(superclass_indices) // 2
#         client1_indices.extend(superclass_indices[:half_split])
#         client2_indices.extend(superclass_indices[half_split:])
    
#     # Return two subsets: one for each client
#     return Subset(dataset, client1_indices), Subset(dataset, client2_indices)



# # Apply the function to split the dataset randomly for both clients
# client1_dataset, client2_dataset = filter_dataset_randomly_by_superclass(dataset_train, superclasses)

# # Split a portion of client1_dataset for validation
# val_split = 0.1  # 10% for validation

# # For client1
# val_size_client1 = int(val_split * len(client1_dataset))
# train_size_client1 = len(client1_dataset) - val_size_client1
# train_dataset_client1, val_dataset_client1 = random_split(client1_dataset, [train_size_client1, val_size_client1])

# # For client2
# val_size_client2 = int(val_split * len(client2_dataset))
# train_size_client2 = len(client2_dataset) - val_size_client2
# train_dataset_client2, val_dataset_client2 = random_split(client2_dataset, [train_size_client2, val_size_client2])

# # Create DataLoaders for each client and the validation set
# batch_size = 256
# num_workers = 4

# client_loaders = [
#     DataLoader(train_dataset_client1, batch_size=batch_size, shuffle=True, num_workers=num_workers),
#     DataLoader(train_dataset_client2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# ]

# val_loaders = [
#     DataLoader(val_dataset_client1, batch_size=batch_size, shuffle=False, num_workers=num_workers),
#     DataLoader(val_dataset_client2, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# ]

# test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# # Print statements to verify the data distribution
# print(f'Client 1 dataset size (after validation split): {len(train_dataset_client1)}')
# print(f'Client 2 dataset size (after validation split): {len(train_dataset_client2)}')
# print(f'Client 1 validation dataset size: {len(val_dataset_client1)}')
# print(f'Client 2 validation dataset size: {len(val_dataset_client2)}')
# print(f'Test dataset size: {len(dataset_test)}')

# # Check some samples from each client to ensure the data is correctly assigned
# for i, (images, labels) in enumerate(client_loaders[0]):
#     print(f'Client 1 batch {i+1}:')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels: {labels}')
#     break

# for i, (images, labels) in enumerate(client_loaders[1]):
#     print(f'Client 2 batch {i+1}:')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels: {labels}')
#     break

# print("Preprocessing Method 2: Random Subclass Image Assignment completed successfully.")

Client 1 dataset size (after validation split): 22500
Client 2 dataset size (after validation split): 22500
Client 1 validation dataset size: 2500
Client 2 validation dataset size: 2500
Test dataset size: 10000
Client 1 batch 1:
Images shape: torch.Size([256, 3, 224, 224])
Labels: tensor([ 7, 34, 28, 86,  8, 74, 87, 53, 66,  3, 80, 68, 13, 61, 92, 10,  0, 99,
        23, 83, 50, 14, 50, 60, 36, 76, 62, 83, 15, 60, 56, 61,  4, 30, 64, 44,
        51,  7, 88, 16, 77, 20, 52, 75, 23, 82, 44, 15, 92, 82, 93, 36, 49, 92,
        44, 34, 29,  3, 93, 94, 17, 71, 24, 32, 34, 99, 42,  0, 70, 61,  4, 61,
        88, 41, 76, 84, 74, 79,  4, 48, 93, 34, 54, 76, 27, 53, 29, 63, 52, 48,
        95, 90, 55,  9, 23,  9, 19, 11, 52, 90,  8, 59, 82, 73, 77, 66, 71, 13,
        63, 19, 94, 15, 97, 31, 12, 29,  7, 14, 32, 69, 86, 70,  4,  3, 68, 61,
         7, 76,  3,  2, 93, 18, 83, 75, 11, 54, 63, 64, 41, 31,  5, 50, 14, 18,
        37, 87, 51, 66, 50, 15, 90, 57, 28, 18, 18, 66, 97, 75, 27, 64,  8, 78

In [4]:
from torch.utils.data import Subset, DataLoader, random_split
import numpy as np

# Define superclasses and subclasses
superclasses = {
    1: [4, 30, 55, 72, 95],  # aquatic mammals
    2: [1, 32, 67, 73, 91],  # fish
    3: [54, 62, 70, 82, 92],  # flowers
    4: [9, 10, 16, 28, 61],  # food containers
    5: [0, 51, 53, 57, 83],  # fruit and vegetables
    6: [22, 39, 40, 86, 87],  # household electrical devices
    7: [5, 20, 25, 84, 94],  # household furniture
    8: [6, 7, 14, 18, 24],  # insects
    9: [3, 42, 43, 88, 97],  # large carnivores
    10: [12, 17, 37, 68, 76],  # large man-made outdoor things
    11: [23, 33, 49, 60, 71],  # large natural outdoor scenes
    12: [15, 19, 21, 31, 38],  # medium-sized mammals
    13: [34, 63, 64, 66, 75],  # non-insect invertebrates
    14: [26, 45, 77, 79, 99],  # people
    15: [2, 11, 35, 46, 98],  # reptiles
    16: [27, 29, 44, 78, 93],  # small mammals
    17: [36, 50, 65, 74, 80],  # trees
    18: [8, 13, 48, 58, 90],  # vehicles 1
    19: [41, 66, 69, 81, 89],  # vehicles 2
    20: [47, 50, 52, 56, 59],  # household furniture
}

# Function to randomly split dataset by superclass for 4 clients
def filter_dataset_randomly_by_superclass(dataset, superclasses):
    client_indices = [[] for _ in range(4)]  # Create a list to store indices for 4 clients
    
    # Shuffle and split images randomly for each superclass
    for superclass, subclasses in superclasses.items():
        superclass_indices = []
        
        # Gather all images belonging to this superclass
        for idx, (data, target) in enumerate(dataset):
            if target in subclasses:
                superclass_indices.append(idx)
        
        # If no images found for this superclass, continue to the next
        if not superclass_indices:
            continue
        
        # Shuffle the indices for randomness
        np.random.shuffle(superclass_indices)
        
        # Split the indices between the 4 clients
        split_size = len(superclass_indices) // 4
        for i in range(4):
            start = i * split_size
            end = (i + 1) * split_size if i < 3 else len(superclass_indices)  # Ensure the last client gets the remaining data
            client_indices[i].extend(superclass_indices[start:end])
    
    # Return four subsets: one for each client
    return [Subset(dataset, indices) for indices in client_indices]

# Apply the function to split the dataset randomly for 4 clients
client_datasets = filter_dataset_randomly_by_superclass(dataset_train, superclasses)

# Split a portion of each client dataset for validation
val_split = 0.1  # 10% for validation

train_val_splits = [random_split(client_dataset, [int(len(client_dataset) * (1 - val_split)), int(len(client_dataset) * val_split)]) for client_dataset in client_datasets]

# Separate train and validation datasets
train_datasets = [split[0] for split in train_val_splits]
val_datasets = [split[1] for split in train_val_splits]

# Create DataLoaders for each client and the validation set
batch_size = 256
num_workers = 4

client_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for train_dataset in train_datasets]
val_loaders = [DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for val_dataset in val_datasets]
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Print statements to verify the data distribution
for i in range(4):
    print(f'Client {i+1} dataset size (after validation split): {len(train_datasets[i])}')
    print(f'Client {i+1} validation dataset size: {len(val_datasets[i])}')

print(f'Test dataset size: {len(dataset_test)}')

# Check some samples from each client to ensure the data is correctly assigned
for i in range(4):
    for j, (images, labels) in enumerate(client_loaders[i]):
        print(f'Client {i+1} batch {j+1}:')
        print(f'Images shape: {images.shape}')
        print(f'Labels: {labels}')
        break  # Break to only display the first batch for verification

print("Preprocessing Method for 4 Clients completed successfully.")


Client 1 dataset size (after validation split): 11250
Client 1 validation dataset size: 1250
Client 2 dataset size (after validation split): 11250
Client 2 validation dataset size: 1250
Client 3 dataset size (after validation split): 11250
Client 3 validation dataset size: 1250
Client 4 dataset size (after validation split): 11250
Client 4 validation dataset size: 1250
Test dataset size: 10000
Client 1 batch 1:
Images shape: torch.Size([256, 3, 224, 224])
Labels: tensor([33, 57, 89, 91, 66, 92, 67, 55, 35, 42, 41, 69, 21, 67, 63, 18, 13, 62,
        41, 20, 54, 32, 46, 17, 48, 74, 94, 81, 59, 47,  5, 35, 72,  8, 27, 45,
        34,  7, 92, 54, 30,  1, 15, 50, 98,  5,  5, 19, 83, 66, 66, 52, 30, 31,
        14, 92, 42, 14, 14, 99, 76, 56, 39, 49, 84, 20, 84, 60, 68, 77, 83, 44,
        78, 11, 70, 77, 17, 82, 40, 45, 80, 43, 39, 73, 88, 47, 50, 42, 75, 13,
        84, 10, 66, 44, 19, 82, 86, 42, 23,  2, 47, 98, 93, 10, 22, 75,  0, 21,
        99, 21, 81, 73, 26, 53, 19, 29, 33, 14, 97, 

In [5]:
for images, labels in client_loaders[0]:
    print(f'Client 1 - Images shape: {images.shape}, Labels: {labels.shape}')
    break

for images, labels in val_loaders[0]:
    print(f'Client 1 Validation - Images shape: {images.shape}, Labels: {labels.shape}')
    break

Client 1 - Images shape: torch.Size([256, 3, 224, 224]), Labels: torch.Size([256])
Client 1 Validation - Images shape: torch.Size([256, 3, 224, 224]), Labels: torch.Size([256])


### 2. Model Preparation

In [12]:
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

### 3. Early Stopping

In [13]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

### 4. Federated Training Functions (Federated Averaging (FedAvg))

In [7]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)  
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)  
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation phase (handling multiple validation loaders)
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:
#                 for inputs, labels in val_loader:
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4. Federated Training Functions (Adaptive Federated Optimization)

In [19]:
import torch
import torch.optim as optim

def train_client(model, train_loader, criterion, optimizer, epochs=1):
    model.to(device)  
    model.train()
    for _ in range(epochs):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def federated_averaging(state_dicts, global_model, beta=0.9):
    avg_state_dict = global_model.state_dict()
    for key in avg_state_dict.keys():
        if avg_state_dict[key].dtype == torch.long:
            avg_state_dict[key] = torch.zeros_like(avg_state_dict[key], dtype=torch.float32)
            for state_dict in state_dicts:
                avg_state_dict[key] += state_dict[key].float() / len(state_dicts)
            avg_state_dict[key] = avg_state_dict[key].long()  # Convert back to long if necessary
        else:
            avg_state_dict[key] = torch.zeros_like(avg_state_dict[key])
            for state_dict in state_dicts:
                avg_state_dict[key] += state_dict[key] / len(state_dicts)
            avg_state_dict[key] = beta * avg_state_dict[key] + (1 - beta) * global_model.state_dict()[key]
    return avg_state_dict

def train_federated_model(client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
    model = prepare_model().to(device)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)

    for round in range(num_epochs):
        print(f"Starting federated learning round {round+1}/{num_epochs}...")
        state_dicts = []
        for i, client_loader in enumerate(client_loaders):
            print(f"Training model for client {i+1}...")
            client_model = prepare_model().to(device)
            client_model.load_state_dict(model.state_dict())
            optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
            client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
            state_dicts.append(client_state_dict)

        avg_state_dict = federated_averaging(state_dicts, model)  # Call to modified function
        model.load_state_dict(avg_state_dict)
        model.to(device)

        # Validation phase (handling multiple validation loaders)
        print("Validating model...")
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_loader in val_loaders:  # Iterate over multiple validation loaders
                for inputs, labels in val_loader:  # Iterate over batches within each loader
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
        val_accuracy = 100 * val_correct / val_total

        print(f'Federated Round {round+1}/{num_epochs}, '
            f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')


        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    # Test phase
    print("Testing model...")
    model.eval()
    test_running_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss = test_running_loss / len(test_loader)
    test_accuracy = 100 * test_correct / test_total

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    return model, val_loss, val_accuracy, test_loss, test_accuracy

### FedProx (Federated Proximal) 

In [None]:
# import torch
# import torch.optim as optim

# def train_client_prox(model, train_loader, criterion, optimizer, global_model, mu=0.01, epochs=1):
#     model.to(device)
#     model.train()
#     global_weights = global_model.state_dict()  # Get the global model's weights
    
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
            
#             # Add the proximal term (FedProx)
#             proximal_term = 0.0
#             for param, global_param in zip(model.parameters(), global_model.parameters()):
#                 proximal_term += (param - global_param).norm(2)
#             loss += (mu / 2) * proximal_term
            
#             loss.backward()
#             optimizer.step()
    
#     return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict


# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.01, patience=5, min_delta=0, mu=0.01):
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client_prox(client_model, client_loader, criterion, optimizer, model, mu)
#             state_dicts.append(client_state_dict)
        
#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation phase (same as before)
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:
#                 for inputs, labels in val_loader:
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase (same as before)
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 5. Define the Logging Function

In [20]:
import csv
import os


def log_experiment_result(filename, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy):
    file_exists = os.path.isfile(filename)
    
    with open(filename, 'a', newline='') as csvfile:
        fieldnames = ['num_clients', 'num_epochs', 'learning_rate', 'patience', 'min_delta', 'val_loss', 'val_accuracy', 'test_loss', 'test_accuracy']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()  # Write header only once
        
        writer.writerow({
            'num_clients': num_clients,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'patience': patience,
            'min_delta': min_delta,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        })

### 6. Federated Learning Experiment Configuration

In [21]:
num_clients = 4
num_epochs = 40
learning_rate = 0.001
patience = 5
min_delta = 0.01
batch_size = 256
num_workers = 4
log_file = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/experiment_results.csv'


# Initialize Data Loaders
#client_loaders, val_loader = get_data_loaders(batch_size, num_workers, num_clients)
#test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print("Starting federated training...")
model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate, patience, min_delta)
print("Federated training complete.")

# Log the results
log_experiment_result(log_file, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy)
print("Experiment results logged.")

Starting federated training...
Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 1/40, Val Loss: 2.9437, Val Accuracy: 30.94%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 2/40, Val Loss: 2.1303, Val Accuracy: 44.36%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 3/40, Val Loss: 1.9388, Val Accuracy: 48.56%
Starting federated learning round 4/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 4/40, Val Loss: 1.8062, Va