In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [2]:
# Import necessary libraries
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
import torch.nn as nn
from torchvision import models
import torch.optim as optim
import random

# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
# Transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset
dataset_train = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)

# Define superclasses and subclasses
superclasses = {
    1: [4, 30, 55, 72, 95],  # aquatic mammals
    2: [1, 32, 67, 73, 91],  # fish
    3: [54, 62, 70, 82, 92],  # flowers
    4: [9, 10, 16, 28, 61],  # food containers
    5: [0, 51, 53, 57, 83],  # fruit and vegetables
    6: [22, 39, 40, 86, 87],  # household electrical devices
    7: [5, 20, 25, 84, 94],  # household furniture
    8: [6, 7, 14, 18, 24],  # insects
    9: [3, 42, 43, 88, 97],  # large carnivores
    10: [12, 17, 37, 68, 76],  # large man-made outdoor things
    11: [23, 33, 49, 60, 71],  # large natural outdoor scenes
    12: [15, 19, 21, 31, 38],  # medium-sized mammals
    13: [34, 63, 64, 66, 75],  # non-insect invertebrates
    14: [26, 45, 77, 79, 99],  # people
    15: [2, 11, 35, 46, 98],  # reptiles
    16: [27, 29, 44, 78, 93],  # small mammals
    17: [36, 50, 65, 74, 80],  # trees
    18: [8, 13, 48, 58, 90],  # vehicles 1
    19: [41, 66, 69, 81, 89],  # vehicles 2
    20: [47, 50, 52, 56, 59],  # household furniture
}

Files already downloaded and verified
Files already downloaded and verified


### 1. Preprocessing

In [4]:
# # Function to map subclass to its corresponding superclass
# def get_superclass(subclass, superclasses):
#     for superclass in superclasses.keys():
#         subclasses=superclasses[superclass]
#         if subclass in subclasses:
#             return superclass
#     return None

# # Function to filter dataset by superclasses using dataset.targets directly
# def filter_dataset_by_superclass(dataset, superclasses, selected_superclasses):
#     selected_indices = []
    
#     # Directly access dataset.targets, assuming the dataset is labeled (like CIFAR-100)
#     for idx, target in enumerate(dataset.targets):  # Access only targets, not full data
#         subclass = target
#         superclass = get_superclass(subclass, superclasses)
        
#         if superclass in selected_superclasses:
#             selected_indices.append(idx)

#     return Subset(dataset, selected_indices)

# # Define the superclasses for each client
# client1_superclasses = list(range(1, 11))
# client2_superclasses = list(range(11, 21))

# # Filter the dataset for each client
# client1_dataset = filter_dataset_by_superclass(dataset_train, superclasses, client1_superclasses)
# client2_dataset = filter_dataset_by_superclass(dataset_train, superclasses, client2_superclasses)

# # Print dataset sizes to verify
# print(f'Client 1 dataset size: {len(client1_dataset)}')
# print(f'Client 2 dataset size: {len(client2_dataset)}')

# # Create DataLoaders for each client
# batch_size = 256
# num_workers = 4

# client_loaders = [
#     DataLoader(client1_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers),
#     DataLoader(client2_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# ]

# # Test DataLoader (common for all clients)
# test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# # Checking the first client loader to see if everything works correctly
# for images, labels in client_loaders[0]:
#     print(f'Client 1 - Batch size: {images.size(0)}')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels shape: {labels.shape}')
#     print(f'Labels: {labels}')
#     break

# # Checking the second client loader to see if everything works correctly
# for images, labels in client_loaders[1]:
#     print(f'Client 2 - Batch size: {images.size(0)}')
#     print(f'Images shape: {images.shape}')
#     print(f'Labels shape: {labels.shape}')
#     print(f'Labels: {labels}')
#     break

# print("Preprocessing for federated learning with specific superclasses completed successfully.")

In [5]:
# Function to randomly split subclasses between two clients, ensuring all superclasses are present in both
def split_subclasses_for_both_clients(superclasses):
    client1_subclasses = []
    client2_subclasses = []
    
    for superclass, subclasses in superclasses.items():
        random.shuffle(subclasses)
        half = len(subclasses) // 2
        # Both clients get some subclasses from each superclass
        client1_subclasses.extend(subclasses[:half])
        client2_subclasses.extend(subclasses[half:])
        
    return client1_subclasses, client2_subclasses

# Function to filter dataset by subclasses
def filter_dataset_by_subclasses(dataset, subclasses):
    selected_indices = []
    for idx, target in enumerate(dataset.targets):
        if target in subclasses:
            selected_indices.append(idx)
    return Subset(dataset, selected_indices)

# Split subclasses between the two clients (ensuring both clients have all superclasses)
client1_subclasses, client2_subclasses = split_subclasses_for_both_clients(superclasses)

# Filter dataset for each client based on the subclasses they received
client1_dataset = filter_dataset_by_subclasses(dataset_train, client1_subclasses)
client2_dataset = filter_dataset_by_subclasses(dataset_train, client2_subclasses)

# Function to create train/validation split using PyTorch's random_split
def split_train_val(dataset, val_size=0.1):
    val_len = int(len(dataset) * val_size)
    train_len = len(dataset) - val_len
    return random_split(dataset, [train_len, val_len])

# Create train/validation split for each client using PyTorch's random_split
client1_train, client1_val = split_train_val(client1_dataset, val_size=0.1)
client2_train, client2_val = split_train_val(client2_dataset, val_size=0.1)

# Print dataset sizes to verify
print(f'Client 1 training set size: {len(client1_train)}')
print(f'Client 1 validation set size: {len(client1_val)}')
print(f'Client 2 training set size: {len(client2_train)}')
print(f'Client 2 validation set size: {len(client2_val)}')

# Set client_data_sizes based on the training set sizes
client_data_sizes = [len(client1_train), len(client2_train)]
print(f"Client data sizes: {client_data_sizes}")  # This should print [17550, 26550]


# Create DataLoaders for each client
batch_size = 128
num_workers = 4

client1_loader = DataLoader(client1_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
client1_val_loader = DataLoader(client1_val, batch_size=batch_size, shuffle=False, num_workers=num_workers)

client2_loader = DataLoader(client2_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
client2_val_loader = DataLoader(client2_val, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Test DataLoader (common for all clients)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Create a list of DataLoaders for both clients
client_loaders = [client1_loader, client2_loader]

# Similarly, create a list of validation loaders
val_loaders = [client1_val_loader, client2_val_loader]


# Test DataLoader (common for all clients)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Client 1 training set size: 18000
Client 1 validation set size: 2000
Client 2 training set size: 26100
Client 2 validation set size: 2900
Client data sizes: [18000, 26100]


### 2. Model Preparation

In [6]:
# Function to map subclasses back to their superclasses
def get_superclass_mapping(superclasses):
    subclass_to_superclass = {}
    for superclass, subclasses in superclasses.items():
        for subclass in subclasses:
            subclass_to_superclass[subclass] = superclass
    return subclass_to_superclass

# Check if all superclasses are present in the dataset
def check_superclasses_present(dataset, subclass_to_superclass):
    superclasses_present = set()
    for idx in dataset.indices:
        subclass = dataset.dataset.targets[idx]
        superclass = subclass_to_superclass[subclass]
        superclasses_present.add(superclass)
    return superclasses_present

# Get mapping from subclasses to superclasses
subclass_to_superclass = get_superclass_mapping(superclasses)

# Check for Client 1
client1_superclasses_present = check_superclasses_present(client1_dataset, subclass_to_superclass)
print(f"Client 1 superclasses present: {sorted(client1_superclasses_present)}")

# Check for Client 2
client2_superclasses_present = check_superclasses_present(client2_dataset, subclass_to_superclass)
print(f"Client 2 superclasses present: {sorted(client2_superclasses_present)}")

Client 1 superclasses present: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
Client 2 superclasses present: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]


In [7]:
import torch.nn as nn
from torchvision import models

def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

### 3. Early Stopping

In [8]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

### 4. Federated Training Functions (Federated Averaging (FedAvg))

In [9]:
import torch
import torch.optim as optim

def train_client(model, train_loader, criterion, optimizer, epochs=1):
    model.to(device)  
    model.train()
    for _ in range(epochs):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

In [10]:
# Weighted Federated Averaging function
def weighted_federated_averaging(state_dicts, client_data_sizes):
    total_data_points = sum(client_data_sizes)  # Total number of data points across all clients
    print(f"Total data points: {total_data_points}")
    
    avg_state_dict = {}

    # Initialize with zeros
    for key in state_dicts[0].keys():
        avg_state_dict[key] = torch.zeros_like(state_dicts[0][key])

    # Weighted sum of client updates
    for i, state_dict in enumerate(state_dicts):
        client_weight = client_data_sizes[i] / total_data_points
        for key in state_dict.keys():
            # Only perform weighted sum in float for float-like tensors
            if avg_state_dict[key].dtype in [torch.float32, torch.float64]:
                avg_state_dict[key] += state_dict[key].float() * client_weight
            else:  # For integer types (e.g., LongTensors), just add them directly
                avg_state_dict[key] += state_dict[key]

    return avg_state_dict

In [11]:

# def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation phase
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for val_loader in val_loaders:
#                 for inputs, labels in val_loader:
#                     inputs, labels = inputs.to(device), labels.to(device)
#                     outputs = model(inputs)
#                     loss = criterion(outputs, labels)
#                     val_running_loss += loss.item()
#                     _, predicted = torch.max(outputs.data, 1)
#                     val_total += labels.size(0)
#                     val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

In [12]:
#client_data_sizes = [len(client1_train), len(client2_train)]


# Modify train_federated_model to accept client_data_sizes and use weighted averaging
def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, client_data_sizes, learning_rate=0.0001, patience=5, min_delta=0):
    
    # Debugging print to ensure the function is called and client_data_sizes is passed
    print(f'Train Federated Model function called with client_data_sizes: {client_data_sizes}')
    
    model = prepare_model().to(device)
    criterion = torch.nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for round in range(num_epochs):
        print(f"Starting federated learning round {round+1}/{num_epochs}...")
        state_dicts = []
        for i, client_loader in enumerate(client_loaders):
            print(f"Training model for client {i+1}...")
            client_model = prepare_model().to(device)
            client_model.load_state_dict(model.state_dict())
            optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
            client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
            state_dicts.append(client_state_dict)
            
        # Check client_data_sizes before passing it to weighted_federated_averaging
        print(f'Client data sizes: {client_data_sizes}')

        # Use the new weighted federated averaging
        avg_state_dict = weighted_federated_averaging(state_dicts, client_data_sizes)
        model.load_state_dict(avg_state_dict)
        model.to(device)
        
        # Validation phase
        print("Validating model...")
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_loader in val_loaders:
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
        val_accuracy = 100 * val_correct / val_total

        print(f'Federated Round {round+1}/{num_epochs}, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    # Test phase
    print("Testing model...")
    model.eval()
    test_running_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss = test_running_loss / len(test_loader)
    test_accuracy = 100 * test_correct / test_total

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4. Federated Training Functions (Adaptive Federated Optimization)

In [13]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)  
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)  
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts, global_model, beta=0.9):
#     avg_state_dict = global_model.state_dict()
#     for key in avg_state_dict.keys():
#         if avg_state_dict[key].dtype == torch.long:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key], dtype=torch.float32)
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key].float() / len(state_dicts)
#             avg_state_dict[key] = avg_state_dict[key].long()  # Convert back to long if necessary
#         else:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key])
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key] / len(state_dicts)
#             avg_state_dict[key] = beta * avg_state_dict[key] + (1 - beta) * global_model.state_dict()[key]
#     return avg_state_dict

# def train_federated_model(client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)

#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts, model)  # Call to modified function
#         model.load_state_dict(avg_state_dict)
#         model.to(device)

#         # Validation phase
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs, labels = inputs.to(device), labels.to(device)
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
#                 val_running_loss += loss.item()
#                 _, predicted = torch.max(outputs.data, 1)
#                 val_total += labels.size(0)
#                 val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / len(val_loader)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 5. Define the Logging Function

In [14]:
import csv
import os


# Parameters
num_clients = 2
num_epochs = 40
learning_rate = 0.0001
patience = 5
min_delta = 0.01
batch_size = 128
num_workers = 4
log_file = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/experiment_results.csv'


def log_experiment_result(filename, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy):
    file_exists = os.path.isfile(filename)
    
    with open(filename, 'a', newline='') as csvfile:
        fieldnames = ['num_clients', 'num_epochs', 'learning_rate', 'patience', 'min_delta', 'val_loss', 'val_accuracy', 'test_loss', 'test_accuracy']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()  # Write header only once
        
        writer.writerow({
            'num_clients': num_clients,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'patience': patience,
            'min_delta': min_delta,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        })

In [15]:
# Correctly define client_data_sizes based on the size of each client's dataset
client_data_sizes = [len(client1_train), len(client2_train)]  # Size of the training sets for both clients|

print(client_data_sizes)

[18000, 26100]


### 6. Federated Learning Experiment Configuration

In [16]:
# Federated training
print("Starting federated training...")

# Call train_federated_model with correct parameters
model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model(
    client_loaders, val_loaders, test_loader, num_clients, num_epochs, client_data_sizes, learning_rate, patience, min_delta
)

print("Federated training complete.")

# Log the results
log_experiment_result(
    log_file, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy
)

print("Experiment results logged.")

Starting federated training...
Train Federated Model function called with client_data_sizes: [18000, 26100]




Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Client data sizes: [18000, 26100]
Total data points: 44100
Validating model...
Federated Round 1/40, Val Loss: 3.3013, Val Accuracy: 31.16%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Client data sizes: [18000, 26100]
Total data points: 44100
Validating model...
Federated Round 2/40, Val Loss: 2.7653, Val Accuracy: 36.14%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Client data sizes: [18000, 26100]
Total data points: 44100
Validating model...
Federated Round 3/40, Val Loss: 2.5211, Val Accuracy: 39.41%
Starting federated learning round 4/40...
Training model for client 1...
Training model for client 2...
Client data sizes: [18000, 26100]
Total data points: 44100
Validating model...
Federated Round 4/40, Val Loss: 2.4000, Val Accuracy: 41.80%
Starting federated learn