In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

### Step 1: Setting up the Environment for FL with 8 Clients
##### Load CIFAR-100 Dataset: We apply standard transformations and load the CIFAR-100 dataset for training and testing.
##### Validation Split: We split the dataset into 90% training and 10% validation.
##### Output: After running the code, you should see the size of the training set (45,000 images) and validation set (5,000 images).

In [2]:
# Step 1: Setting up FL with 8 Clients

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Transformations for CIFAR-100
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset (train and test)
dataset_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Total number of images after reserving 10% for validation (90% for training)
num_train_images = int(0.9 * len(dataset_train))  # 45,000 images for training
num_val_images = len(dataset_train) - num_train_images  # 5,000 images for validation

# Randomly split the dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset_train, [num_train_images, num_val_images])

# Create DataLoaders for training, validation, and test sets
batch_size = 256
num_workers = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to check the size of the training, validation, and test sets
print(f"Training Dataset Size: {len(train_loader.dataset)} images")
print(f"Validation Dataset Size: {len(val_loader.dataset)} images")
print(f"Test Dataset Size: {len(test_loader.dataset)} images")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Training Dataset Size: 45000 images
Validation Dataset Size: 5000 images
Test Dataset Size: 10000 images


### Step 2: Split Data for 8 Clients
#### Splitting for 8 Clients: We use random_split to split the training dataset into 8 subsets, with each subset containing 5625 images.
#### Output: After running the code, you should see the size of each client’s dataset (all should have 5625 images).

In [4]:
# # Step 2: Split the training dataset into 8 clients

# # Each client will get 5625 images
# client_splits = torch.utils.data.random_split(train_dataset, [5625] * 8)

# # Verify that each client has 5625 images
# for i, client_dataset in enumerate(client_splits):
#     print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

Client 1 Dataset Size: 5625 images
Client 2 Dataset Size: 5625 images
Client 3 Dataset Size: 5625 images
Client 4 Dataset Size: 5625 images
Client 5 Dataset Size: 5625 images
Client 6 Dataset Size: 5625 images
Client 7 Dataset Size: 5625 images
Client 8 Dataset Size: 5625 images


### Step 2: Split Data for 4 Clients

In [3]:
# Step 2: Split the training dataset into 4 clients

# Each client will get 11250 images
client_splits = torch.utils.data.random_split(train_dataset, [11250] * 4)

# Verify that each client has 11250 images
for i, client_dataset in enumerate(client_splits):
    print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")


Client 1 Dataset Size: 11250 images
Client 2 Dataset Size: 11250 images
Client 3 Dataset Size: 11250 images
Client 4 Dataset Size: 11250 images


### Step 3: Prepare DataLoaders for Each Client
#### Creating DataLoaders: We create a DataLoader for each client to handle training, and another DataLoader for validation.
#### Batch Verification: After running the code, you should see the shape of a batch of images and labels from Client 1’s dataset.

In [4]:
# Step 3: Create DataLoaders for each client

batch_size = 256
num_workers = 4

# Create a DataLoader for each client
client_loaders = [DataLoader(client_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for client_dataset in client_splits]

# Also create a DataLoader for the validation set
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to verify one batch from Client 1
images, labels = next(iter(client_loaders[0]))
print(f"Client 1 First Batch - Images Shape: {images.shape}, Labels Shape: {labels.shape}")

Client 1 First Batch - Images Shape: torch.Size([256, 3, 224, 224]), Labels Shape: torch.Size([256])


### Step 4: Applying Incremental Splitting with Replay
#### First Batch: We start by splitting 2625 images from each client’s dataset as the first batch.
#### 6 Incremental Batches: For each of the next 6 batches, we use 500 new images plus 500 replay images (from previously seen data).
#### Replay Strategy: Each incremental batch includes 500 images from previously seen batches to help the model retain earlier knowledge.

In [6]:
# import numpy as np
# from torch.utils.data import random_split, ConcatDataset

# # Split each client's 5625 images into 1 initial batch and 6 incremental batches
# initial_split_size = 2625  # First batch size
# num_incremental_batches = 6  # Number of incremental batches
# new_data_per_batch = 500  # New images per incremental batch
# replay_size = 500  # Replay size (previously seen data to include)

# # Iterate over all clients to create incremental splits
# client_splits_cl = []

# for client_idx, client_dataset in enumerate(client_splits):
#     # Initial split: 2625 images for the first batch, remaining for incremental batches
#     remaining_split_size = len(client_dataset) - initial_split_size  # Should be 3000
#     first_split, remaining_dataset = random_split(client_dataset, [initial_split_size, remaining_split_size])

#     # Split the remaining dataset into 6 incremental batches of 500 new images each
#     incremental_splits = random_split(remaining_dataset, [new_data_per_batch] * num_incremental_batches)
    
#     # Store the initial split as the first batch
#     client_splits_for_cl = [first_split]
    
#     # Create incremental batches with replay technique
#     for i in range(num_incremental_batches):
#         # Get the previous seen data from all earlier batches
#         previous_data = ConcatDataset(client_splits_for_cl)
        
#         # Select 500 random previously seen images (for replay)
#         replay_data, _ = random_split(previous_data, [replay_size, len(previous_data) - replay_size])
        
#         # Create the new batch with 500 new images and 500 replay images
#         new_batch = ConcatDataset([incremental_splits[i], replay_data])
        
#         # Add this new batch to the client's splits
#         client_splits_for_cl.append(new_batch)
    
#     client_splits_cl.append(client_splits_for_cl)

# # Verify the incremental splits for Client 1
# for i, split in enumerate(client_splits_cl[0]):
#     print(f"Client 1 - Batch {i+1}: {len(split)} images")


Client 1 - Batch 1: 2625 images
Client 1 - Batch 2: 1000 images
Client 1 - Batch 3: 1000 images
Client 1 - Batch 4: 1000 images
Client 1 - Batch 5: 1000 images
Client 1 - Batch 6: 1000 images
Client 1 - Batch 7: 1000 images


In [5]:
import numpy as np
from torch.utils.data import random_split, ConcatDataset

# Split each client's 11,250 images into 1 initial batch and 6 incremental batches
initial_split_size = 5250  # First batch size 
num_incremental_batches = 6  # Number of incremental batches
new_data_per_batch = 1000  # New images per incremental batch (adjusted to fit the total size)
replay_size = 1000  # Replay size (previously seen data to include, adjusted to be 1000)

# Iterate over all clients to create incremental splits
client_splits_cl = []

for client_idx, client_dataset in enumerate(client_splits):
    # Initial split: 5125 images for the first batch, remaining for incremental batches
    remaining_split_size = len(client_dataset) - initial_split_size  # Should be 6125
    first_split, remaining_dataset = random_split(client_dataset, [initial_split_size, remaining_split_size])

    # Split the remaining dataset into 6 incremental batches of 1000 new images each
    incremental_splits = random_split(remaining_dataset, [new_data_per_batch] * num_incremental_batches)
    
    # Store the initial split as the first batch
    client_splits_for_cl = [first_split]
    
    # Create incremental batches with replay technique
    for i in range(num_incremental_batches):
        # Get the previous seen data from all earlier batches
        previous_data = ConcatDataset(client_splits_for_cl)
        
        # Select 1000 random previously seen images (for replay)
        replay_data, _ = random_split(previous_data, [replay_size, len(previous_data) - replay_size])
        
        # Create the new batch with 1000 new images and 1000 replay images
        new_batch = ConcatDataset([incremental_splits[i], replay_data])
        
        # Add this new batch to the client's splits
        client_splits_for_cl.append(new_batch)
    
    client_splits_cl.append(client_splits_for_cl)

# Verify the incremental splits for Client 1
for i, split in enumerate(client_splits_cl[0]):
    print(f"Client 1 - Batch {i+1}: {len(split)} images")

Client 1 - Batch 1: 5250 images
Client 1 - Batch 2: 2000 images
Client 1 - Batch 3: 2000 images
Client 1 - Batch 4: 2000 images
Client 1 - Batch 5: 2000 images
Client 1 - Batch 6: 2000 images
Client 1 - Batch 7: 2000 images


### Step 5: Train the Model with Continual Learning (CL)
#### Model Preparation: For each client, we initialize a ResNet18 model with 100 classes (as in CIFAR-100).
#### Training Function: The model is trained on each incremental batch in the Continual Learning (CL) setup. After training on one batch, the next batch is introduced.
#### Validation: After each training phase, we evaluate the model using the validation set.
#### Early Stopping: The training process includes early stopping to prevent overfitting if the validation loss doesn’t improve after a certain number of epochs.
#### Output: After running the code, we should see the training and validation progress for each client and each batch.

In [7]:
import os

# Step 5: Train each client with Continual Learning (CL), test after training, and save models

from torchvision import models
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

# Model Preparation Function (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Training function for Continual Learning
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, patience=5, min_delta=0):
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    return model

# Testing function to evaluate the model on the test set
def test_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_accuracy = 100 * test_correct / test_total
    print(f'Test Accuracy: {test_accuracy:.2f}%')

# Save function
def save_model(model, client_idx, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    torch.save(model.state_dict(), model_path)
    print(f'Model for Client {client_idx} saved at {model_path}')

# Prepare the training and validation loaders for each client
batch_size = 256
num_workers = 4
num_epochs = 5  # Set the number of epochs for each CL round

# Define the directory to save the models
save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/'

# Iterate over all clients and perform incremental training (CL)
for client_idx, client_batches in enumerate(client_splits_cl):
    print(f"\nStarting Continual Learning for Client {client_idx + 1}...")

    # Initialize the model once for each client, and then update it for each batch
    model = prepare_model().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Train on each batch incrementally, continuing from the last updated model
    for batch_idx, batch in enumerate(client_batches):
        print(f"Training on Batch {batch_idx + 1}/{len(client_batches)}")

        # Create DataLoader for the current batch
        train_loader = DataLoader(batch, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        # Use the same validation set for all batches
        model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

    # After training on all batches, test the model on the test set
    print(f"Testing model for Client {client_idx + 1} after Continual Learning...")
    test_model(model, test_loader)

    # Save the model for each client
    save_model(model, client_idx + 1, save_dir)

print("Continual Learning, testing, and saving models for all clients completed.")


Starting Continual Learning for Client 1...




Training on Batch 1/7
Epoch 1/5, Train Loss: 4.4896, Train Accuracy: 4.61%, Val Loss: 4.1334, Val Accuracy: 9.90%
Epoch 2/5, Train Loss: 3.7782, Train Accuracy: 17.90%, Val Loss: 3.5704, Val Accuracy: 20.00%
Epoch 3/5, Train Loss: 3.2899, Train Accuracy: 27.98%, Val Loss: 3.2131, Val Accuracy: 26.96%
Epoch 4/5, Train Loss: 2.9340, Train Accuracy: 35.83%, Val Loss: 2.9628, Val Accuracy: 32.02%
Epoch 5/5, Train Loss: 2.6534, Train Accuracy: 41.41%, Val Loss: 2.7683, Val Accuracy: 34.18%
Training on Batch 2/7
Epoch 1/5, Train Loss: 2.6293, Train Accuracy: 39.90%, Val Loss: 2.6840, Val Accuracy: 36.14%
Epoch 2/5, Train Loss: 2.4704, Train Accuracy: 44.50%, Val Loss: 2.6461, Val Accuracy: 37.56%
Epoch 3/5, Train Loss: 2.3059, Train Accuracy: 49.10%, Val Loss: 2.6259, Val Accuracy: 36.84%
Epoch 4/5, Train Loss: 2.2154, Train Accuracy: 51.15%, Val Loss: 2.5799, Val Accuracy: 37.80%
Epoch 5/5, Train Loss: 2.0843, Train Accuracy: 54.00%, Val Loss: 2.5581, Val Accuracy: 38.76%
Training on Batch 

In [20]:
import torch
import os
import torch.nn as nn
from torchvision import models
from collections import OrderedDict

# Path to saved client models after CL
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/'

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare model and dynamically adjust the final layer based on saved state_dict dimensions
def prepare_model(state_dict, num_clients=4):
    """Initialize ResNet18 model and modify the output layer dynamically based on state_dict."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    
    # Dynamically set final layer based on saved model's state_dict
    model.fc = nn.Linear(num_ftrs, state_dict['fc.weight'].shape[0])
    model.load_state_dict(state_dict)  # Load adjusted state_dict
    return model.to(device)

# Load CL models for 4 clients
def load_cl_models(num_clients=4):
    cl_models = []
    for client_idx in range(1, num_clients + 1):
        model_path = os.path.join(saved_models_dir, f'client_{client_idx}_model.pth')
        state_dict = torch.load(model_path)
        model = prepare_model(state_dict)
        cl_models.append(model)
        print(f'Loaded and adjusted model for Client {client_idx} from {model_path}')
    return cl_models

# Federated Averaging function
def federated_averaging(state_dicts):
    avg_state_dict = OrderedDict()
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Fine-tune each client's model using FedAvg
def apply_federated_learning(cl_models, num_epochs=10):
    global_model = cl_models[0]  # Start with Client 1's model structure
    for round in range(num_epochs):
        print(f'\n--- Federated Learning Round {round + 1} ---')
        state_dicts = [model.state_dict() for model in cl_models]

        # Perform FedAvg
        avg_state_dict = federated_averaging(state_dicts)
        global_model.load_state_dict(avg_state_dict)

        # Update each client with the global model
        for client_idx, client_model in enumerate(cl_models):
            client_model.load_state_dict(global_model.state_dict())
            print(f'Updated Client {client_idx + 1} model with global model state.')

    return global_model



# Load models saved after CL
cl_models = load_cl_models(num_clients=4)

# Apply FL with FedAvg on the saved CL models
global_model = apply_federated_learning(cl_models, num_epochs=10)

# Test the global model (assuming `test_loader` is defined)
test_accuracy = test_global_model(global_model, test_loader)
print(f'Test Accuracy of the global model: {test_accuracy:.2f}%')

Loaded and adjusted model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_1_model.pth
Loaded and adjusted model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_2_model.pth
Loaded and adjusted model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_3_model.pth
Loaded and adjusted model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_4_model.pth

--- Federated Learning Round 1 ---
Updated Client 1 model with global model state.
Updated Client 2 model with global model state.
Updated Client 3 model with global model state.
Updated Client 4 model with global model state.

--- Federated Learning Round 2 ---
Updated Client 1 model with global model state.
Updated Client 2 model with global model state.
Updated Client 3 model with glob

In [None]:
# from torch.utils.data import DataLoader, random_split

# # Function to create DataLoaders for each client
# def create_client_loaders(dataset, num_clients=4, batch_size=256, val_split=0.1, num_workers=4):
#     """Split dataset into `num_clients` parts and create train/validation loaders."""
#     client_loaders = []
#     val_loaders = []
    
#     # Split the dataset randomly into `num_clients` equal parts
#     client_datasets = random_split(dataset, [len(dataset) // num_clients] * num_clients)
    
#     for client_dataset in client_datasets:
#         # Further split each client's dataset into train and validation sets
#         train_size = int((1 - val_split) * len(client_dataset))
#         val_size = len(client_dataset) - train_size
#         train_dataset, val_dataset = random_split(client_dataset, [train_size, val_size])
        
#         # Create DataLoaders for the client with pin_memory and optimizations
#         train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
#                                   num_workers=num_workers, pin_memory=True)
#         val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
#                                 num_workers=num_workers, pin_memory=True)
        
#         client_loaders.append(train_loader)
#         val_loaders.append(val_loader)
    
#     return client_loaders, val_loaders

# # Assuming `dataset_train` is your training dataset
# train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=4, batch_size=256, num_workers=4)

### Step 6a: BASELINE(FL After Full Continual Learning" (FL-FCL))- Apply Federated Learning (FL) After Continual Learning (CL)
#### Federated Averaging: We collect the models from each client after CL, average their weights, and load the averaged weights into the global model.
#### Test the Global Model: After aggregation, the global model is tested on the test set, and the test accuracy is reported.

In [None]:
# import os
# import torch.optim as optim
# import torch.nn as nn
# from torchvision import models
# from torch.cuda.amp import autocast, GradScaler  # For mixed-precision training

# # Path where client models are saved
# saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/'

# # Model Preparation (ResNet18 for CIFAR-100)
# def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
#     """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
#     model = models.resnet18(pretrained=True)
#     num_ftrs = model.fc.in_features
#     if use_dropout:
#         model.fc = nn.Sequential(
#             nn.Dropout(p=dropout_prob),
#             nn.Linear(num_ftrs, num_classes)
#         )
#     else:
#         model.fc = nn.Linear(num_ftrs, num_classes)
#     return model

# # Function to load a model from a file
# def load_client_model(client_idx, save_dir):
#     model = prepare_model().to(device)
#     model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
#     model.load_state_dict(torch.load(model_path))
#     print(f'Loaded model for Client {client_idx} from {model_path}')
#     return model

# # Function to average client models (FL aggregation)
# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# # Training function for each client after receiving the global model
# def fine_tune_client(global_model, train_loader, val_loader, num_epochs=510, use_mixed_precision=True):
#     model = global_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=0.001)
    
#     # Optional: Mixed precision training
#     scaler = GradScaler() if use_mixed_precision else None

#     for epoch in range(num_epochs):
#         model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)

#             optimizer.zero_grad()

#             if use_mixed_precision:
#                 with autocast():
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#             else:
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
            
#             if use_mixed_precision:
#                 scaler.scale(loss).backward()
#                 scaler.step(optimizer)
#                 scaler.update()
#             else:
#                 loss.backward()
#                 optimizer.step()

#             running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#         train_accuracy = 100 * correct / total
#         print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

#     return model.state_dict()

# # Step 6: Federated Learning (FL) after Continual Learning (CL)
# def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=4, num_epochs=10):
#     global_model = prepare_model().to(device)

#     # Collect state dicts from all clients after CL
#     state_dicts = [model.state_dict() for model in cl_models]

#     # Perform federated averaging
#     avg_state_dict = federated_averaging(state_dicts)
#     global_model.load_state_dict(avg_state_dict)

#     for round in range(num_epochs):
#         print(f'\n--- Federated Learning Round {round + 1} ---')
#         client_state_dicts = []

#         for client_idx in range(num_clients):
#             print(f'\nTraining client {client_idx + 1} with the global model')
#             client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=10)
#             client_state_dicts.append(client_state_dict)

#         # Federated averaging after each round
#         avg_state_dict = federated_averaging(client_state_dicts)
#         global_model.load_state_dict(avg_state_dict)

#         test_accuracy = test_global_model(global_model, test_loader)
#         print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

#     return global_model

# # Function to test the global model
# def test_global_model(model, test_loader):
#     model.eval()
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_accuracy = 100 * test_correct / test_total
#     return test_accuracy

# # Load models saved after CL
# cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 5)]

# # Prepare DataLoader splits for each client with pin_memory=True and fewer workers
# train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=4, batch_size=256)

# # Apply FL and test the global model
# global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader)

Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_4_model.pth

--- Federated Learning Round 1 ---

Training client 1 with the global model
Epoch 1/10, Train Accuracy: 32.78%
Epoch 2/10, Train Accuracy: 43.17%
Epoch 3/10, Train Accuracy: 48.58%
Epoch 4/10, Train Accuracy: 51.16%
Epoch 5/10, Train Accuracy: 55.93%
Epoch 6/10, Train Accuracy: 56.71%
Epoch 7/10, Train Accuracy: 59.83%
Epoch 8/10, Train Accuracy: 61.53%
Epoch 9/10, Train Accuracy: 63.68%
Epoch 10/10, Train Accuracy: 65.08%

Training client

## Step 6b: (FL After Each Continual Learning Round (FL-CL)) - Apply Federated Learning (FL) After Each CL Round
- **Federated Averaging**: After each client completes a round of CL (training on a batch), we perform FL by averaging the client models.
- **Global Model Updates**: The aggregated global model is redistributed to all clients before continuing with the next batch.
- **Testing**: The global model is tested after each FL round to evaluate its performance.

In [16]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/'

# Model Preparation (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model().to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FL aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=10):  
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Fine-tune the model for the client
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()  # Return the fine-tuned state_dict


# Step 6: Federated Learning (FL) after each batch of Continual Learning (CL)
def apply_federated_learning_after_each_batch(client_splits_cl, val_loaders, test_loader, num_clients=4, num_epochs=10):
    # Initialize a global model
    global_model = prepare_model().to(device)

    # Iterate over each batch of CL for all clients
    for batch_idx in range(len(client_splits_cl[0])):  
        print(f'\n--- Training and Federated Learning after Batch {batch_idx + 1} ---')

        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} on Batch {batch_idx + 1}')

            # Create DataLoader for the current batch
            train_loader = DataLoader(client_splits_cl[client_idx][batch_idx], batch_size=128, shuffle=True, pin_memory=True)

            # Fine-tune client model with the current global model
            client_state_dict = fine_tune_client(global_model, train_loader, val_loaders[client_idx], num_epochs=num_epochs)
            client_state_dicts.append(client_state_dict)

        # Perform federated averaging after this batch for all clients
        avg_state_dict = federated_averaging(client_state_dicts)
        global_model.load_state_dict(avg_state_dict)

        # Optionally test the global model after each batch
        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Batch {batch_idx + 1}: {test_accuracy:.2f}%')

    return global_model

# Load models saved after Continual Learning (CL)
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 5)]

# Apply Federated Learning after each batch and test the global model
global_model = apply_federated_learning_after_each_batch(client_splits_cl, val_loaders, test_loader, num_clients=4, num_epochs=10)

Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_4clients/client_4_model.pth

--- Training and Federated Learning after Batch 1 ---

Training client 1 on Batch 1
Epoch 1/10, Train Accuracy: 10.21%
Epoch 2/10, Train Accuracy: 19.81%
Epoch 3/10, Train Accuracy: 27.90%
Epoch 4/10, Train Accuracy: 30.06%
Epoch 5/10, Train Accuracy: 34.17%
Epoch 6/10, Train Accuracy: 37.47%
Epoch 7/10, Train Accuracy: 40.29%
Epoch 8/10, Train Accuracy: 44.40%
Epoch 9/10, Train Accuracy: 45.30%
Epoch 10/10, Train Accuracy: 44.88%

Trainin

### Step6C - FedAdam

In [14]:
# import os
# import torch
# import torch.optim as optim
# import torch.nn as nn
# from torchvision import models
# from torch.cuda.amp import autocast, GradScaler

# # Global device setting
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Path where client models are saved
# saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/'

# # Model Preparation (ResNet18 for CIFAR-100)
# def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
#     """Load a pre-trained ResNet18 model and modify it for CIFAR-100 with optional dropout."""
#     model = models.resnet18(pretrained=True)
#     num_ftrs = model.fc.in_features
#     if use_dropout:
#         model.fc = nn.Sequential(
#             nn.Dropout(p=dropout_prob),
#             nn.Linear(num_ftrs, num_classes)
#         )
#     else:
#         model.fc = nn.Linear(num_ftrs, num_classes)
#     return model

# # Function to load a model from a file
# def load_client_model(client_idx, save_dir):
#     model = prepare_model().to(device)
#     model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
#     try:
#         model.load_state_dict(torch.load(model_path, map_location=device))
#         print(f'Loaded model for Client {client_idx} from {model_path}')
#     except FileNotFoundError:
#         print(f'Error: Model for Client {client_idx} not found at {model_path}')
#     return model

# # Initialize optimizer state variables for FedAdam
# optimizer_state = None
# optimizer_variance = None
# beta_1 = 0.8  # Adjusted momentum term
# beta_2 = 0.9  # Adjusted variance term
# epsilon = 1e-8
# eta = 0.01  # Increased learning rate for server-side optimization

# # Function for adaptive federated optimization (FedAdam)
# def adaptive_federated_optimization(state_dicts, global_model, optimizer_state, optimizer_variance):
#     global beta_1, beta_2, epsilon, eta

#     # Initialize optimizer state and variance if not already done
#     if optimizer_state is None or optimizer_variance is None:
#         optimizer_state = {key: torch.zeros_like(param, device=device) for key, param in global_model.state_dict().items()}
#         optimizer_variance = {key: torch.zeros_like(param, device=device) for key, param in global_model.state_dict().items()}

#     avg_state_dict = {}

#     # Compute federated optimization using the FedAdam algorithm
#     for key in state_dicts[0].keys():
#         # Compute the average of the updates (delta_w)
#         delta_w = sum([state_dict[key] - global_model.state_dict()[key] for state_dict in state_dicts]) / len(state_dicts)
        
#         # Update optimizer state (momentum) and variance
#         optimizer_state[key] = beta_1 * optimizer_state[key] + (1 - beta_1) * delta_w
#         optimizer_variance[key] = beta_2 * optimizer_variance[key] + (1 - beta_2) * (delta_w ** 2)
        
#         # Compute bias-corrected updates
#         m_hat = optimizer_state[key] / (1 - beta_1)
#         v_hat = optimizer_variance[key] / (1 - beta_2)
        
#         # Apply the FedAdam update rule with numerical stability
#         avg_state_dict[key] = global_model.state_dict()[key] - eta * m_hat / (torch.sqrt(v_hat) + epsilon)

#     return avg_state_dict, optimizer_state, optimizer_variance

# # Training function for each client after receiving the global model
# def fine_tune_client(global_model, train_loader, val_loader, num_epochs=3, use_mixed_precision=True):
#     model = global_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=0.001)
    
#     # Optional: Mixed precision training
#     scaler = GradScaler() if use_mixed_precision else None

#     for epoch in range(num_epochs):
#         model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)

#             optimizer.zero_grad()

#             if use_mixed_precision:
#                 with autocast():
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#             else:
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
            
#             if use_mixed_precision:
#                 scaler.scale(loss).backward()
#                 scaler.step(optimizer)
#                 scaler.update()
#             else:
#                 loss.backward()
#                 optimizer.step()

#             running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#         train_accuracy = 100 * correct / total
#         print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

#     return model.state_dict()

# # Step 6: Federated Learning (FL) after Continual Learning (CL)
# def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=8, num_epochs=3):
#     global global_model, optimizer_state, optimizer_variance

#     global_model = prepare_model().to(device)

#     # Collect state dicts from all clients after CL
#     state_dicts = [model.state_dict() for model in cl_models]

#     # Perform adaptive federated optimization (FedAdam)
#     avg_state_dict, optimizer_state, optimizer_variance = adaptive_federated_optimization(state_dicts, global_model, optimizer_state, optimizer_variance)
#     global_model.load_state_dict(avg_state_dict)

#     for round in range(num_epochs):
#         print(f'\n--- Federated Learning Round {round + 1} ---')
#         client_state_dicts = []

#         for client_idx in range(num_clients):
#             print(f'\nTraining client {client_idx + 1} with the global model')
#             client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=3)
#             client_state_dicts.append(client_state_dict)

#         # Federated optimization after each round using FedAdam
#         avg_state_dict, optimizer_state, optimizer_variance = adaptive_federated_optimization(client_state_dicts, global_model, optimizer_state, optimizer_variance)
#         global_model.load_state_dict(avg_state_dict)

#         test_accuracy = test_global_model(global_model, test_loader)
#         print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

#     return global_model

# # Function to test the global model
# def test_global_model(model, test_loader):
#     model.eval()
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_accuracy = 100 * test_correct / test_total
#     return test_accuracy

# # Load models saved after CL
# cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 9)]

# # Prepare DataLoader splits for each client with pin_memory=True and fewer workers
# train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8, batch_size=256)

# # Apply FL and test the global model
# global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader)



Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_4_model.pth
Loaded model for Client 5 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_5_model.pth
Loaded model for Client 6 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_6_model.pth
Loaded model for Client 7 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_7_model.pth
Loaded model for Client 8 from /raid/home/somayeh.shami/project/somayeh_workspace/federate

KeyboardInterrupt: 

### FCL with FedAvg, FedAdam, and Adaptive FedOpt

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import models
from torch.cuda.amp import autocast, GradScaler

# Global device setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/'

# Model Preparation (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model().to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FedAvg aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Function for adaptive federated optimization (FedAdam)
def fed_adam(state_dicts, global_model, optimizer_state, optimizer_variance, beta_1=0.9, beta_2=0.99, eta=0.01, epsilon=1e-8):
    if optimizer_state is None or optimizer_variance is None:
        optimizer_state = {key: torch.zeros_like(param) for key, param in global_model.state_dict().items()}
        optimizer_variance = {key: torch.zeros_like(param) for key, param in global_model.state_dict().items()}
    
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        delta_w = sum([state_dict[key] - global_model.state_dict()[key] for state_dict in state_dicts]) / len(state_dicts)
        optimizer_state[key] = beta_1 * optimizer_state[key] + (1 - beta_1) * delta_w
        optimizer_variance[key] = beta_2 * optimizer_variance[key] + (1 - beta_2) * (delta_w ** 2)

        m_hat = optimizer_state[key] / (1 - beta_1)
        v_hat = optimizer_variance[key] / (1 - beta_2)
        avg_state_dict[key] = global_model.state_dict()[key] - eta * m_hat / (torch.sqrt(v_hat) + epsilon)

    return avg_state_dict, optimizer_state, optimizer_variance

# Function for Adaptive Federated Optimization (AdapFedOpt)
def adaptive_fedopt(state_dicts, global_model, learning_rates, beta_1=0.9, beta_2=0.99, epsilon=1e-8):
    if learning_rates is None:
        learning_rates = {key: torch.ones_like(param) * 0.01 for key, param in global_model.state_dict().items()}  # Initial learning rates

    avg_state_dict = {}
    optimizer_state = {key: torch.zeros_like(param) for key, param in global_model.state_dict().items()}
    optimizer_variance = {key: torch.zeros_like(param) for key, param in global_model.state_dict().items()}
    
    for key in state_dicts[0].keys():
        delta_w = sum([state_dict[key] - global_model.state_dict()[key] for state_dict in state_dicts]) / len(state_dicts)
        optimizer_state[key] = beta_1 * optimizer_state[key] + (1 - beta_1) * delta_w
        optimizer_variance[key] = beta_2 * optimizer_variance[key] + (1 - beta_2) * (delta_w ** 2)

        m_hat = optimizer_state[key] / (1 - beta_1)
        v_hat = optimizer_variance[key] / (1 - beta_2)
        
        # Use adaptive learning rate for each parameter
        avg_state_dict[key] = global_model.state_dict()[key] - learning_rates[key] * m_hat / (torch.sqrt(v_hat) + epsilon)
        
        # Update learning rates for next round (can be adapted further as needed)
        learning_rates[key] *= 0.9  # Reduce the learning rate over time as an example

    return avg_state_dict, learning_rates

# Training function for each client
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=3, use_mixed_precision=True):
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    scaler = GradScaler() if use_mixed_precision else None

    for epoch in range(num_epochs):
        model.train()
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            if use_mixed_precision:
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            if use_mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()

# Federated Learning (FL) after Continual Learning (CL)
def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=8, num_epochs=3, method='fedavg'):
    global global_model, optimizer_state, optimizer_variance, learning_rates

    global_model = prepare_model().to(device)

    # Collect state dicts from all clients after CL
    state_dicts = [model.state_dict() for model in cl_models]

    # Initialize learning rates for Adaptive FedOpt
    learning_rates = None if method != 'adaptive_fedopt' else {key: torch.ones_like(param) * 0.01 for key, param in global_model.state_dict().items()}

    # Perform federated aggregation using the chosen method
    if method == 'fedavg':
        avg_state_dict = federated_averaging(state_dicts)
    elif method == 'fedadam':
        avg_state_dict, optimizer_state, optimizer_variance = fed_adam(state_dicts, global_model, optimizer_state, optimizer_variance)
    elif method == 'adaptive_fedopt':
        avg_state_dict, learning_rates = adaptive_fedopt(state_dicts, global_model, learning_rates)

    global_model.load_state_dict(avg_state_dict)

    for round in range(num_epochs):
        print(f'\n--- Federated Learning Round {round + 1} ---')
        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} with the global model')
            client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=3)
            client_state_dicts.append(client_state_dict)

        # Perform aggregation using the chosen method
        if method == 'fedavg':
            avg_state_dict = federated_averaging(client_state_dicts)
        elif method == 'fedadam':
            avg_state_dict, optimizer_state, optimizer_variance = fed_adam(client_state_dicts, global_model, optimizer_state, optimizer_variance)
        elif method == 'adaptive_fedopt':
            avg_state_dict, learning_rates = adaptive_fedopt(client_state_dicts, global_model, learning_rates)

        global_model.load_state_dict(avg_state_dict)

        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

    return global_model

# Function to test the global model
def test_global_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    return test_accuracy

# Load models saved after CL
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 9)]

# Prepare DataLoader splits for each client with pin_memory=True and fewer workers
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8, batch_size=256)

# Apply FL and test the global model using one of the methods: 'fedavg', 'fedadam'
global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, method='fedadam')
