In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [4]:
# Step 1: Setting up FL with 8 Clients

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Transformations for CIFAR-100
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset (train and test)
dataset_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Total number of images after reserving 10% for validation (90% for training)
num_train_images = int(0.9 * len(dataset_train))  # 45,000 images for training
num_val_images = len(dataset_train) - num_train_images  # 5,000 images for validation

# Randomly split the dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset_train, [num_train_images, num_val_images])

# Create DataLoaders for training, validation, and test sets
batch_size = 256
num_workers = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to check the size of the training, validation, and test sets
print(f"Training Dataset Size: {len(train_loader.dataset)} images")
print(f"Validation Dataset Size: {len(val_loader.dataset)} images")
print(f"Test Dataset Size: {len(test_loader.dataset)} images")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Training Dataset Size: 45000 images
Validation Dataset Size: 5000 images
Test Dataset Size: 10000 images


In [7]:
# Step 2: Split the training dataset into 8 clients

# Each client will get 5625 images
client_splits = torch.utils.data.random_split(train_dataset, [5625] * 8)

# Verify that each client has 5625 images
for i, client_dataset in enumerate(client_splits):
    print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

Client 1 Dataset Size: 5625 images
Client 2 Dataset Size: 5625 images
Client 3 Dataset Size: 5625 images
Client 4 Dataset Size: 5625 images
Client 5 Dataset Size: 5625 images
Client 6 Dataset Size: 5625 images
Client 7 Dataset Size: 5625 images
Client 8 Dataset Size: 5625 images


In [8]:
# Step 3: Create DataLoaders for each client

batch_size = 256
num_workers = 4

# Create a DataLoader for each client
client_loaders = [DataLoader(client_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for client_dataset in client_splits]

# Also create a DataLoader for the validation set
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to verify one batch from Client 1
images, labels = next(iter(client_loaders[0]))
print(f"Client 1 First Batch - Images Shape: {images.shape}, Labels Shape: {labels.shape}")

# Each client will get 5625 images
client_splits = torch.utils.data.random_split(train_dataset, [5625] * 8)

# Verify that each client has 5625 images
for i, client_dataset in enumerate(client_splits):
    print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

Client 1 First Batch - Images Shape: torch.Size([256, 3, 224, 224]), Labels Shape: torch.Size([256])
Client 1 Dataset Size: 5625 images
Client 2 Dataset Size: 5625 images
Client 3 Dataset Size: 5625 images
Client 4 Dataset Size: 5625 images
Client 5 Dataset Size: 5625 images
Client 6 Dataset Size: 5625 images
Client 7 Dataset Size: 5625 images
Client 8 Dataset Size: 5625 images


In [6]:
# # Split the training dataset into 16 clients

# # Each client gets 2813 or 2812 images, reduce clients with 2813 images to 2812
# for i, client_dataset in enumerate(client_splits):
#     # If client has 2813 images, reduce it to 2812
#     if len(client_dataset) == 2813:
#         client_splits[i] = Subset(client_dataset, range(2812))

# # Verify that each client now has exactly 2812 images
# for i, client_dataset in enumerate(client_splits):
#     print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

In [12]:
# import numpy as np

# # New settings for 16 clients with 2812 images each
# initial_split_size = 1312  # First batch size
# num_incremental_batches = 3  # Number of incremental batches
# new_data_per_batch = 500  # New images per incremental batch
# replay_data_per_batch = 500  # Replay images per incremental batch

# # Iterate over all clients to create incremental splits
# client_splits_cl = []

# for client_idx, client_dataset in enumerate(client_splits):
#     # Initial split: 1312 images for the first batch, remaining for incremental batches
#     remaining_split_size = len(client_dataset) - initial_split_size  # Should be 1500
#     first_split, remaining_dataset = random_split(client_dataset, [initial_split_size, remaining_split_size])

#     # Split the remaining dataset into 3 incremental batches of 500 new images each
#     incremental_splits = random_split(remaining_dataset, [new_data_per_batch] * num_incremental_batches)
    
#     # Prepare the combined splits (replay strategy)
#     seen_indices_list = list(first_split.indices)
#     client_splits_for_cl = [first_split]  # First split goes in directly

#     # Create the 3 subsequent batches with 500 new + 500 replay images
#     for i, current_split in enumerate(incremental_splits):
#         current_split_indices = list(current_split.indices)
        
#         # Select 500 previously seen images (replay strategy)
#         previous_seen_indices = list(set(seen_indices_list) - set(current_split_indices))
#         additional_data_indices = np.random.choice(previous_seen_indices, replay_data_per_batch, replace=False)
        
#         # Combine current split (500 new) and additional data (500 replay)
#         combined_indices = np.concatenate([current_split_indices, additional_data_indices])
        
#         # Add current batch indices to seen indices list
#         seen_indices_list.extend(current_split_indices)
        
#         # Create a subset for the current batch with new + replay data
#         combined_split = Subset(client_dataset, combined_indices)
#         client_splits_for_cl.append(combined_split)
    
#     client_splits_cl.append(client_splits_for_cl)

# # Verify the incremental splits for Client 1
# for i, split in enumerate(client_splits_cl[0]):
#     print(f"Client 1 - Batch {i+1}: {len(split)} images")


Client 1 - Batch 1: 1312 images
Client 1 - Batch 2: 1000 images
Client 1 - Batch 3: 1000 images
Client 1 - Batch 4: 1000 images


In [None]:
# import os

# # Step 5: Train each client with Continual Learning (CL), test after training, and save models

# from torchvision import models
# import torch.optim as optim
# import torch.nn as nn
# from torch.utils.data import DataLoader

# # Model Preparation Function (ResNet18 for CIFAR-100)
# def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
#     """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
#     model = models.resnet18(pretrained=True)
#     num_ftrs = model.fc.in_features
#     if use_dropout:
#         model.fc = nn.Sequential(
#             nn.Dropout(p=dropout_prob),
#             nn.Linear(num_ftrs, num_classes)
#         )
#     else:
#         model.fc = nn.Linear(num_ftrs, num_classes)
#     return model

# # EarlyStopping class
# class EarlyStopping:
#     def __init__(self, patience=5, min_delta=0):
#         self.patience = patience
#         self.min_delta = min_delta
#         self.counter = 0
#         self.best_loss = None
#         self.early_stop = False

#     def __call__(self, val_loss):
#         if self.best_loss is None:
#             self.best_loss = val_loss
#         elif val_loss > self.best_loss - self.min_delta:
#             self.counter += 1
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_loss = val_loss
#             self.counter = 0

# # Training function for Continual Learning
# def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, patience=5, min_delta=0):
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for epoch in range(num_epochs):
#         model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)

#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#         train_loss = running_loss / len(train_loader)
#         train_accuracy = 100 * correct / total

#         # Validation phase
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs, labels = inputs.to(device), labels.to(device)
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
#                 val_running_loss += loss.item()
#                 _, predicted = torch.max(outputs.data, 1)
#                 val_total += labels.size(0)
#                 val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / len(val_loader)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Epoch {epoch+1}/{num_epochs}, '
#               f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping triggered")
#             break
    
#     return model

# # Testing function to evaluate the model on the test set
# def test_model(model, test_loader):
#     model.eval()
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()
    
#     test_accuracy = 100 * test_correct / test_total
#     print(f'Test Accuracy: {test_accuracy:.2f}%')

# # Save function
# def save_model(model, client_idx, save_dir):
#     if not os.path.exists(save_dir):
#         os.makedirs(save_dir)
#     model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
#     torch.save(model.state_dict(), model_path)
#     print(f'Model for Client {client_idx} saved at {model_path}')

# # Prepare the training and validation loaders for each client
# batch_size = 256
# num_workers = 4
# num_epochs = 5  # Set the number of epochs for each CL round

# # Define the directory to save the models (for 16 clients)
# save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/'

# # Iterate over all clients and perform incremental training (CL)
# for client_idx, client_batches in enumerate(client_splits_cl):
#     print(f"\nStarting Continual Learning for Client {client_idx + 1}...")
    
#     # Initialize a new model for each client
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=0.0001)

#     # Train on each batch incrementally
#     for batch_idx, batch in enumerate(client_batches):
#         print(f"Training on Batch {batch_idx + 1}/{len(client_batches)}")
        
#         # Create DataLoader for the current batch
#         train_loader = DataLoader(batch, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        
#         # Use the same validation set for all batches
#         model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)
    
#     # After training on all batches, test the model on the test set
#     print(f"Testing model for Client {client_idx + 1} after Continual Learning...")
#     test_model(model, test_loader)

#     # Save the model for each client in the 16 clients directory
#     save_model(model, client_idx + 1, save_dir)

# print("Continual Learning, testing, and saving models for all clients completed.")



Starting Continual Learning for Client 1...




Training on Batch 1/4
Epoch 1/5, Train Loss: 4.6841, Train Accuracy: 1.37%, Val Loss: 4.6223, Val Accuracy: 1.86%
Epoch 2/5, Train Loss: 4.3622, Train Accuracy: 5.95%, Val Loss: 4.4780, Val Accuracy: 4.16%
Epoch 3/5, Train Loss: 4.1142, Train Accuracy: 11.13%, Val Loss: 4.3373, Val Accuracy: 6.86%
Epoch 4/5, Train Loss: 3.9236, Train Accuracy: 17.91%, Val Loss: 4.1707, Val Accuracy: 9.42%
Epoch 5/5, Train Loss: 3.6981, Train Accuracy: 20.43%, Val Loss: 4.0286, Val Accuracy: 11.66%
Training on Batch 2/4
Epoch 1/5, Train Loss: 3.6990, Train Accuracy: 21.60%, Val Loss: 3.9337, Val Accuracy: 13.26%
Epoch 2/5, Train Loss: 3.5556, Train Accuracy: 24.70%, Val Loss: 3.8664, Val Accuracy: 14.02%
Epoch 3/5, Train Loss: 3.4217, Train Accuracy: 28.40%, Val Loss: 3.8002, Val Accuracy: 14.96%
Epoch 4/5, Train Loss: 3.2608, Train Accuracy: 32.90%, Val Loss: 3.7161, Val Accuracy: 15.52%
Epoch 5/5, Train Loss: 3.1224, Train Accuracy: 34.80%, Val Loss: 3.6615, Val Accuracy: 16.12%
Training on Batch 3/4


In [9]:
from torch.utils.data import DataLoader, random_split

# Function to create DataLoaders for each client
def create_client_loaders(dataset, num_clients=16, batch_size=256, val_split=0.1):
    """Split dataset into `num_clients` parts and create train/validation loaders."""
    client_loaders = []
    val_loaders = []
    
    # Split the dataset randomly into `num_clients` equal parts
    client_datasets = random_split(dataset, [len(dataset) // num_clients] * num_clients)
    
    for client_dataset in client_datasets:
        # Further split each client's dataset into train and validation sets
        train_size = int((1 - val_split) * len(client_dataset))
        val_size = len(client_dataset) - train_size
        train_dataset, val_dataset = random_split(client_dataset, [train_size, val_size])
        
        # Create DataLoaders for the client
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        client_loaders.append(train_loader)
        val_loaders.append(val_loader)
    
    return client_loaders, val_loaders

# Assuming `dataset_train` is your training dataset
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=16)

In [10]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/'

# Model Preparation (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model().to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FL aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=5):  # Reduced to 5 epochs to save time
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

        # Validation phase after each epoch
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total
        print(f'Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {val_accuracy:.2f}%')

    return model.state_dict()  # Return the fine-tuned state_dict

# Step 6: Federated Learning (FL) after Continual Learning (CL)
def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=8, num_epochs=10):  # Updated for 16 clients
    global_model = prepare_model().to(device)

    # Collect state dicts (trained weights) from all client models
    state_dicts = [model.state_dict() for model in cl_models]

    # Perform federated averaging to create a global model
    avg_state_dict = federated_averaging(state_dicts)
    global_model.load_state_dict(avg_state_dict)

    for round in range(num_epochs):
        print(f'\n--- Federated Learning Round {round + 1} ---')
        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} with the global model')
            client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=5)
            client_state_dicts.append(client_state_dict)

        avg_state_dict = federated_averaging(client_state_dicts)
        global_model.load_state_dict(avg_state_dict)

        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

    return global_model

# Function to test the global model
def test_global_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    return test_accuracy

# Load the models saved after Continual Learning (CL)
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 17)]  #  clients

# Prepare the dataset splits for each client
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8)

# Apply Federated Learning and test the global model
global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader)



Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_4_model.pth
Loaded model for Client 5 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_5_model.pth
Loaded model for Client 6 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_6_model.pth
Loaded model for Client 7 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models_16clients/client_7_model.pth
Loaded model for Cli

KeyboardInterrupt: 

In [None]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/'

# Model Preparation (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model().to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FL aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=10):
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    # Fine-tune the model for the client
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()  # Return the fine-tuned state_dict

# Step 6: Federated Learning (FL) after each round of Continual Learning (CL)
def apply_federated_learning_after_each_batch(client_splits_cl, num_clients=8, num_epochs=10):
    # Initialize a global model
    global_model = prepare_model().to(device)

    for batch_idx in range(len(client_splits_cl[0])):  # Iterate over batches
        print(f'\n--- Training and Federated Learning after Batch {batch_idx + 1} ---')
        
        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} on Batch {batch_idx + 1}')
            
            # Create DataLoader for the current batch
            train_loader = DataLoader(client_splits_cl[client_idx][batch_idx], batch_size=256, shuffle=True)
            
            # Fine-tune client model with the current global model
            client_state_dict = fine_tune_client(global_model, train_loader, val_loader, num_epochs=10)
            client_state_dicts.append(client_state_dict)

        # Perform federated averaging after this batch for all clients
        avg_state_dict = federated_averaging(client_state_dicts)
        global_model.load_state_dict(avg_state_dict)

        # Optionally test the global model after each batch
        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Batch {batch_idx + 1}: {test_accuracy:.2f}%')

    return global_model


# Function to test the global model
def test_global_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    return test_accuracy

# Load the models saved after Continual Learning (CL)
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 9)]

# Prepare the dataset splits for each client
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8)

# Apply Federated Learning and test the global model
global_model = apply_federated_learning_after_each_batch(client_splits_cl, num_clients=8, num_epochs=10)