In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [2]:
# Step 1: Setting up FL with 8 Clients

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Transformations for CIFAR-100
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset (train and test)
dataset_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Total number of images after reserving 10% for validation (90% for training)
num_train_images = int(0.9 * len(dataset_train))  # 45,000 images for training
num_val_images = len(dataset_train) - num_train_images  # 5,000 images for validation

# Randomly split the dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset_train, [num_train_images, num_val_images])

# Create DataLoaders for training, validation, and test sets
batch_size = 256
num_workers = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to check the size of the training, validation, and test sets
print(f"Training Dataset Size: {len(train_loader.dataset)} images")
print(f"Validation Dataset Size: {len(val_loader.dataset)} images")
print(f"Test Dataset Size: {len(test_loader.dataset)} images")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Training Dataset Size: 45000 images
Validation Dataset Size: 5000 images
Test Dataset Size: 10000 images


In [3]:
# Step 2: Split the training dataset into 8 clients

# Each client will get 5625 images
client_splits = torch.utils.data.random_split(train_dataset, [5625] * 8)

# Verify that each client has 5625 images
for i, client_dataset in enumerate(client_splits):
    print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

Client 1 Dataset Size: 5625 images
Client 2 Dataset Size: 5625 images
Client 3 Dataset Size: 5625 images
Client 4 Dataset Size: 5625 images
Client 5 Dataset Size: 5625 images
Client 6 Dataset Size: 5625 images
Client 7 Dataset Size: 5625 images
Client 8 Dataset Size: 5625 images


In [4]:
# Step 3: Create DataLoaders for each client

batch_size = 256
num_workers = 4

# Create a DataLoader for each client
client_loaders = [DataLoader(client_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for client_dataset in client_splits]

# Also create a DataLoader for the validation set
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to verify one batch from Client 1
images, labels = next(iter(client_loaders[0]))
print(f"Client 1 First Batch - Images Shape: {images.shape}, Labels Shape: {labels.shape}")

Client 1 First Batch - Images Shape: torch.Size([256, 3, 224, 224]), Labels Shape: torch.Size([256])


In [9]:
import numpy as np
from torch.utils.data import random_split, ConcatDataset

# Split each client's 5625 images into 1 initial batch and 6 incremental batches
initial_split_size = 2625  # First batch size
num_incremental_batches = 6  # Number of incremental batches
new_data_per_batch = 500  # New images per incremental batch
replay_size = 500  # Replay size (previously seen data to include)

# Iterate over all clients to create incremental splits
client_splits_cl = []

for client_idx, client_dataset in enumerate(client_splits):
    # Initial split: 2625 images for the first batch, remaining for incremental batches
    remaining_split_size = len(client_dataset) - initial_split_size  # Should be 3000
    first_split, remaining_dataset = random_split(client_dataset, [initial_split_size, remaining_split_size])

    # Split the remaining dataset into 6 incremental batches of 500 new images each
    incremental_splits = random_split(remaining_dataset, [new_data_per_batch] * num_incremental_batches)
    
    # Store the initial split as the first batch
    client_splits_for_cl = [first_split]
    
    # Create incremental batches with replay technique
    for i in range(num_incremental_batches):
        # Get the previous seen data from all earlier batches
        previous_data = ConcatDataset(client_splits_for_cl)
        
        # Select 500 random previously seen images (for replay)
        replay_data, _ = random_split(previous_data, [replay_size, len(previous_data) - replay_size])
        
        # Create the new batch with 500 new images and 500 replay images
        new_batch = ConcatDataset([incremental_splits[i], replay_data])
        
        # Add this new batch to the client's splits
        client_splits_for_cl.append(new_batch)
    
    client_splits_cl.append(client_splits_for_cl)

# Verify the incremental splits for Client 1
for i, split in enumerate(client_splits_cl[0]):
    print(f"Client 1 - Batch {i+1}: {len(split)} images")

Client 1 - Batch 1: 2625 images
Client 1 - Batch 2: 1000 images
Client 1 - Batch 3: 1000 images
Client 1 - Batch 4: 1000 images
Client 1 - Batch 5: 1000 images
Client 1 - Batch 6: 1000 images
Client 1 - Batch 7: 1000 images


In [10]:
from torch.utils.data import DataLoader, random_split

# Function to create DataLoaders for each client
def create_client_loaders(dataset, num_clients=8, batch_size=128, val_split=0.1, num_workers=8):
    """Split dataset into `num_clients` parts and create train/validation loaders."""
    client_loaders = []
    val_loaders = []
    
    # Split the dataset randomly into `num_clients` equal parts
    client_datasets = random_split(dataset, [len(dataset) // num_clients] * num_clients)
    
    for client_dataset in client_datasets:
        # Further split each client's dataset into train and validation sets
        train_size = int((1 - val_split) * len(client_dataset))
        val_size = len(client_dataset) - train_size
        train_dataset, val_dataset = random_split(client_dataset, [train_size, val_size])
        
        # Create DataLoaders for the client with pin_memory and optimizations
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                  num_workers=num_workers, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                                num_workers=num_workers, pin_memory=True)
        
        client_loaders.append(train_loader)
        val_loaders.append(val_loader)
    
    return client_loaders, val_loaders

# Assuming `dataset_train` is your training dataset
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8, batch_size=128, num_workers=8)

In [12]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models
from torch.cuda.amp import autocast, GradScaler  # For mixed-precision training

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/'

# Model Preparation (ResNet18 for CIFAR-100)
def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model().to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Initialize optimizer state variables for FedOpt (e.g., FedAdam)
optimizer_state = None
beta_1 = 0.9  # FedAdam specific hyperparameters
beta_2 = 0.99
epsilon = 1e-8
eta = 0.001  # Learning rate for the server-side optimizer

# Function for adaptive federated optimization (FedAdam)
def federated_opt(state_dicts, global_model, optimizer_state):
    global beta_1, beta_2, epsilon, eta

    if optimizer_state is None:
        optimizer_state = {key: torch.zeros_like(param) for key, param in global_model.state_dict().items()}

    avg_state_dict = {}
    
    # Compute federated optimization using the FedAdam algorithm
    for key in state_dicts[0].keys():
        # Compute the difference between the clients' weights and the global model
        delta_w = sum([state_dict[key] - global_model.state_dict()[key] for state_dict in state_dicts]) / len(state_dicts)
        
        # Update optimizer state with momentum and variance terms
        optimizer_state[key] = beta_1 * optimizer_state[key] + (1 - beta_1) * delta_w
        optimizer_variance = beta_2 * optimizer_state[key] + (1 - beta_2) * (delta_w ** 2)
        
        # Compute the bias-corrected update for FedAdam
        m_hat = optimizer_state[key] / (1 - beta_1)
        v_hat = optimizer_variance / (1 - beta_2)
        
        # Apply the FedAdam update rule
        avg_state_dict[key] = global_model.state_dict()[key] - eta * m_hat / (torch.sqrt(v_hat) + epsilon)

    return avg_state_dict, optimizer_state

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=3, use_mixed_precision=True):
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    # Optional: Mixed precision training
    scaler = GradScaler() if use_mixed_precision else None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            if use_mixed_precision:
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            if use_mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()

# Step 6: Federated Learning (FL) after Continual Learning (CL)
def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=8, num_epochs=3):
    global global_model, optimizer_state

    global_model = prepare_model().to(device)

    # Collect state dicts from all clients after CL
    state_dicts = [model.state_dict() for model in cl_models]

    # Perform federated optimization using FedOpt (FedAdam)
    avg_state_dict, optimizer_state = federated_opt(state_dicts, global_model, optimizer_state)
    global_model.load_state_dict(avg_state_dict)

    for round in range(num_epochs):
        print(f'\n--- Federated Learning Round {round + 1} ---')
        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} with the global model')
            client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=3)
            client_state_dicts.append(client_state_dict)

        # Federated optimization after each round using FedOpt
        avg_state_dict, optimizer_state = federated_opt(client_state_dicts, global_model, optimizer_state)
        global_model.load_state_dict(avg_state_dict)

        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

    return global_model

# Function to test the global model
def test_global_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    return test_accuracy

# Load models saved after CL
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 9)]

# Prepare DataLoader splits for each client with pin_memory=True and fewer workers
train_loaders, val_loaders = create_client_loaders(dataset_train, num_clients=8, batch_size=128)

# Apply FL and test the global model
global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader)

Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_4_model.pth
Loaded model for Client 5 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_5_model.pth
Loaded model for Client 6 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_6_model.pth
Loaded model for Client 7 from /raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/client_7_model.pth
Loaded model for Client 8 from /raid/home/somayeh.shami/project/somayeh_workspace/federate

KeyboardInterrupt: 