In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
# Import necessary libraries
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.nn as nn
from torchvision import models
import torch.optim as optim

# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
import numpy as np
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

# Transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-100 dataset
dataset_train = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
dataset_test = datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)

# Define superclasses and subclasses
superclasses = {
    1: [4, 30, 55, 72, 95],  # aquatic mammals
    2: [1, 32, 67, 73, 91],  # fish
    3: [54, 62, 70, 82, 92],  # flowers
    4: [9, 10, 16, 28, 61],  # food containers
    5: [0, 51, 53, 57, 83],  # fruit and vegetables
    6: [22, 39, 40, 86, 87],  # household electrical devices
    7: [5, 20, 25, 84, 94],  # household furniture
    8: [6, 7, 14, 18, 24],  # insects
    9: [3, 42, 43, 88, 97],  # large carnivores
    10: [12, 17, 37, 68, 76],  # large man-made outdoor things
    11: [23, 33, 49, 60, 71],  # large natural outdoor scenes
    12: [15, 19, 21, 31, 38],  # medium-sized mammals
    13: [34, 63, 64, 66, 75],  # non-insect invertebrates
    14: [26, 45, 77, 79, 99],  # people
    15: [2, 11, 35, 46, 98],  # reptiles
    16: [27, 29, 44, 78, 93],  # small mammals
    17: [36, 50, 65, 74, 80],  # trees
    18: [8, 13, 48, 58, 90],  # vehicles 1
    19: [41, 66, 69, 81, 89],  # vehicles 2
    20: [47, 50, 52, 56, 59],  # household furniture
}

# Function to map subclass to its corresponding superclass
def get_superclass(subclass, superclasses):
    for superclass, subclasses in superclasses.items():
        if subclass in subclasses:
            return superclass
    return None

# Function to filter dataset by superclasses
def filter_dataset_by_superclass(dataset, superclasses, selected_superclasses):
    selected_indices = []
    for idx, (data, target) in enumerate(dataset):
        subclass = target
        superclass = get_superclass(subclass, superclasses)
        if superclass in selected_superclasses:
            selected_indices.append(idx)
    return Subset(dataset, selected_indices)

# # Define the superclasses for each client
# client1_superclasses = list(range(1, 13))  # Superclasses 1-12 for Client 1
# client2_superclasses = list(range(8, 21))  # Superclasses 9-20 for Client 2


# Superclass ranges for each client
client1_superclasses = list(range(1, 10))   # Superclasses 1 to 9
client2_superclasses = list(range(7, 16))   # Superclasses 7 to 15
client3_superclasses = list(range(13, 21))  # Superclasses 13 to 20

# Filter the dataset for each client
client1_dataset = filter_dataset_by_superclass(dataset_train, superclasses, client1_superclasses)
client2_dataset = filter_dataset_by_superclass(dataset_train, superclasses, client2_superclasses)
client3_dataset = filter_dataset_by_superclass(dataset_train, superclasses, client3_superclasses)

# Print dataset sizes to verify
print(f'Client 1 dataset size: {len(client1_dataset)}')
print(f'Client 2 dataset size: {len(client2_dataset)}')
print(f'Client 3 dataset size: {len(client3_dataset)}')

# Create DataLoaders for each client and the validation set
batch_size = 256
num_workers = 4


# Split each client dataset into training and validation sets
val_split = 0.1  # 10% for validation

# For client1
val_size_client1 = int(val_split * len(client1_dataset))
train_size_client1 = len(client1_dataset) - val_size_client1
train_dataset_client1, val_dataset_client1 = random_split(client1_dataset, [train_size_client1, val_size_client1])

# For client2
val_size_client2 = int(val_split * len(client2_dataset))
train_size_client2 = len(client2_dataset) - val_size_client2
train_dataset_client2, val_dataset_client2 = random_split(client2_dataset, [train_size_client2, val_size_client2])

# For client3
val_size_client3 = int(val_split * len(client3_dataset))
train_size_client3 = len(client3_dataset) - val_size_client3
train_dataset_client3, val_dataset_client3 = random_split(client3_dataset, [train_size_client3, val_size_client3])

# Create DataLoaders for training and validation
client_loaders = [
    DataLoader(train_dataset_client1, batch_size=batch_size, shuffle=True, num_workers=num_workers),
    DataLoader(train_dataset_client2, batch_size=batch_size, shuffle=True, num_workers=num_workers),
    DataLoader(train_dataset_client3, batch_size=batch_size, shuffle=True, num_workers=num_workers),   
]

val_loaders = [
    DataLoader(val_dataset_client1, batch_size=batch_size, shuffle=False, num_workers=num_workers),
    DataLoader(val_dataset_client2, batch_size=batch_size, shuffle=False, num_workers=num_workers),
    DataLoader(val_dataset_client3, batch_size=batch_size, shuffle=False, num_workers=num_workers)
]

# Test DataLoader (common for all clients)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Checking the first client loader to see if everything works correctly
for images, labels in client_loaders[0]:
    print(f'Client 1 - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

# Checking the second client loader to see if everything works correctly
for images, labels in client_loaders[1]:
    print(f'Client 2 - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

# Checking the third client loader to see if everything works correctly
for images, labels in client_loaders[2]:
    print(f'Client 2 - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

# Checking the first validation loader
for images, labels in val_loaders[0]:
    print(f'Client 1 Validation - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

# Checking the second validation loader
for images, labels in val_loaders[1]:
    print(f'Client 2 Validation - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

# Checking the third validation loader
for images, labels in val_loaders[2]:
    print(f'Client 3 Validation - Batch size: {images.size(0)}')
    print(f'Images shape: {images.shape}')
    print(f'Labels shape: {labels.shape}')
    print(f'Labels: {labels}')
    break

print("Preprocessing for federated learning with specific superclasses and validation sets completed successfully.")

Files already downloaded and verified
Files already downloaded and verified
Client 1 dataset size: 22500
Client 2 dataset size: 22500
Client 3 dataset size: 19000
Client 1 - Batch size: 256
Images shape: torch.Size([256, 3, 224, 224])
Labels shape: torch.Size([256])
Labels: tensor([73, 82,  1,  5, 40, 28,  9, 86, 28, 73, 20, 24, 97, 51, 91, 16, 22, 30,
        40, 55,  5, 18, 70, 42, 83, 57, 25, 18, 54,  9,  7,  4, 94, 73, 94, 82,
         4,  1, 83, 70, 62,  1, 54, 83,  6, 67, 25,  5, 67, 97, 86, 53, 16,  6,
        91, 54,  4, 53, 28, 83,  7, 67,  9, 25, 30, 28, 24, 92,  9, 97, 94, 55,
         1, 30, 18, 16, 95, 70, 70,  5, 42, 20, 32, 86, 83, 82, 67, 82, 43, 28,
        32,  9, 51,  4, 10,  0, 43,  3,  0, 30, 67, 67, 43,  5, 14,  1,  0, 91,
        22, 32, 84,  9, 24,  7, 54, 95, 32, 32, 88,  9, 57, 32, 18, 91,  4, 73,
        16,  4, 28, 54, 16, 24, 92, 32, 53,  3,  5, 73, 25, 22, 67, 82, 40, 73,
        72,  1, 86, 57,  4, 24, 20, 25, 73, 61,  3, 92, 73, 39, 82, 61,  5,  4,
     

In [4]:
# Function to count occurrences of each superclass in a dataset
def count_superclasses(dataset, superclasses):
    superclass_counts = {k: 0 for k in superclasses.keys()}
    
    for _, target in dataset:
        subclass = target
        superclass = get_superclass(subclass, superclasses)
        if superclass is not None:
            superclass_counts[superclass] += 1
            
    return superclass_counts

# After filtering the dataset for each client
client1_superclass_counts = count_superclasses(client1_dataset, superclasses)
client2_superclass_counts = count_superclasses(client2_dataset, superclasses)
client3_superclass_counts = count_superclasses(client3_dataset, superclasses)

# Print the distribution of superclasses for each client
print("Client 1 superclass distribution:")
for superclass, count in client1_superclass_counts.items():
    print(f"Superclass {superclass}: {count} images")

print("\nClient 2 superclass distribution:")
for superclass, count in client2_superclass_counts.items():
    print(f"Superclass {superclass}: {count} images")
    
print("\nClient 3 superclass distribution:")
for superclass, count in client3_superclass_counts.items():
    print(f"Superclass {superclass}: {count} images")
    
    
# Count and print the distribution of superclasses for a given dataset
def print_superclass_distribution(dataset_name, dataset, superclasses):
    counts = count_superclasses(dataset, superclasses)
    print(f"{dataset_name} superclass distribution:")
    for superclass, count in counts.items():
        print(f"Superclass {superclass}: {count} images")
    print()

# After filtering the dataset for each client and validation sets
print_superclass_distribution("Client 1 Training", client1_dataset, superclasses)
print_superclass_distribution("Client 1 Validation", val_dataset_client1, superclasses)
print_superclass_distribution("Client 2 Training", client2_dataset, superclasses)
print_superclass_distribution("Client 2 Validation", val_dataset_client2, superclasses)
print_superclass_distribution("Client 3 Training", client3_dataset, superclasses)
print_superclass_distribution("Client 3 Validation", val_dataset_client3, superclasses)

Client 1 superclass distribution:
Superclass 1: 2500 images
Superclass 2: 2500 images
Superclass 3: 2500 images
Superclass 4: 2500 images
Superclass 5: 2500 images
Superclass 6: 2500 images
Superclass 7: 2500 images
Superclass 8: 2500 images
Superclass 9: 2500 images
Superclass 10: 0 images
Superclass 11: 0 images
Superclass 12: 0 images
Superclass 13: 0 images
Superclass 14: 0 images
Superclass 15: 0 images
Superclass 16: 0 images
Superclass 17: 0 images
Superclass 18: 0 images
Superclass 19: 0 images
Superclass 20: 0 images

Client 2 superclass distribution:
Superclass 1: 0 images
Superclass 2: 0 images
Superclass 3: 0 images
Superclass 4: 0 images
Superclass 5: 0 images
Superclass 6: 0 images
Superclass 7: 2500 images
Superclass 8: 2500 images
Superclass 9: 2500 images
Superclass 10: 2500 images
Superclass 11: 2500 images
Superclass 12: 2500 images
Superclass 13: 2500 images
Superclass 14: 2500 images
Superclass 15: 2500 images
Superclass 16: 0 images
Superclass 17: 0 images
Supercl

### 2. Model Preparation

In [5]:
import torch.nn as nn
from torchvision import models

def prepare_model(num_classes=100, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for CIFAR100 with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

### 3. Early Stopping

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

### 4. Federated Training Functions (Federated Averaging (FedAvg))

In [7]:
import torch
import torch.optim as optim

def train_client(model, train_loader, criterion, optimizer, epochs=1):
    model.to(device)  
    model.train()
    for _ in range(epochs):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

def train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
    model = prepare_model().to(device)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for round in range(num_epochs):
        print(f"Starting federated learning round {round+1}/{num_epochs}...")
        state_dicts = []
        for i, client_loader in enumerate(client_loaders):
            print(f"Training model for client {i+1}...")
            client_model = prepare_model().to(device)
            client_model.load_state_dict(model.state_dict())
            optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
            client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
            state_dicts.append(client_state_dict)

        avg_state_dict = federated_averaging(state_dicts)
        model.load_state_dict(avg_state_dict)
        model.to(device)
        
        # Validation phase (handling multiple validation loaders)
        print("Validating model...")
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_loader in val_loaders:
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / sum(len(loader) for loader in val_loaders)
        val_accuracy = 100 * val_correct / val_total

        print(f'Federated Round {round+1}/{num_epochs}, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    # Test phase
    print("Testing model...")
    model.eval()
    test_running_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss = test_running_loss / len(test_loader)
    test_accuracy = 100 * test_correct / test_total

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    return model, val_loss, val_accuracy, test_loss, test_accuracy

### 4. Federated Training Functions (Adaptive Federated Optimization)

In [8]:
# import torch
# import torch.optim as optim

# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)  
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)  
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts, global_model, beta=0.9):
#     avg_state_dict = global_model.state_dict()
#     for key in avg_state_dict.keys():
#         if avg_state_dict[key].dtype == torch.long:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key], dtype=torch.float32)
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key].float() / len(state_dicts)
#             avg_state_dict[key] = avg_state_dict[key].long()  # Convert back to long if necessary
#         else:
#             avg_state_dict[key] = torch.zeros_like(avg_state_dict[key])
#             for state_dict in state_dicts:
#                 avg_state_dict[key] += state_dict[key] / len(state_dicts)
#             avg_state_dict[key] = beta * avg_state_dict[key] + (1 - beta) * global_model.state_dict()[key]
#     return avg_state_dict

# def train_federated_model(client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate=0.001, patience=5, min_delta=0):
#     model = prepare_model().to(device)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)

#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model().to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         avg_state_dict = federated_averaging(state_dicts, model)  # Call to modified function
#         model.load_state_dict(avg_state_dict)
#         model.to(device)

#         # Validation phase
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs, labels = inputs.to(device), labels.to(device)
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
#                 val_running_loss += loss.item()
#                 _, predicted = torch.max(outputs.data, 1)
#                 val_total += labels.size(0)
#                 val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / len(val_loader)
#         val_accuracy = 100 * val_correct / val_total

#         print(f'Federated Round {round+1}/{num_epochs}, '
#               f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total

#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

### 5. Define the Logging Function

In [9]:
import csv
import os

def log_experiment_result(filename, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy):
    file_exists = os.path.isfile(filename)
    
    with open(filename, 'a', newline='') as csvfile:
        fieldnames = ['num_clients', 'num_epochs', 'learning_rate', 'patience', 'min_delta', 'val_loss', 'val_accuracy', 'test_loss', 'test_accuracy']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()  # Write header only once
        
        writer.writerow({
            'num_clients': num_clients,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'patience': patience,
            'min_delta': min_delta,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        })

In [10]:
print("Validation Loaders: ", val_loaders)


Validation Loaders:  [<torch.utils.data.dataloader.DataLoader object at 0x7f5b744e4c40>, <torch.utils.data.dataloader.DataLoader object at 0x7f5b744e4b20>, <torch.utils.data.dataloader.DataLoader object at 0x7f5b744dcd90>]


### 6. Federated Learning Experiment Configuration

In [11]:
# Parameters
num_clients = 3
num_epochs = 40
learning_rate = 0.001
patience = 5
min_delta = 0.01
batch_size = 256
num_workers = 4
log_file = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/experiment_results.csv'


# Initialize Data Loaders
#client_loaders, val_loader = get_data_loaders(batch_size, num_workers, num_clients)
#test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print("Starting federated training...")
model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model(client_loaders, val_loaders, test_loader, num_clients, num_epochs, learning_rate, patience, min_delta)
print("Federated training complete.")

# Log the results
log_experiment_result(log_file, num_clients, num_epochs, learning_rate, patience, min_delta, val_loss, val_accuracy, test_loss, test_accuracy)
print("Experiment results logged.")

Starting federated training...




Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Validating model...
Federated Round 1/40, Val Loss: 3.7057, Val Accuracy: 20.50%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Validating model...
Federated Round 2/40, Val Loss: 3.4868, Val Accuracy: 26.89%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Validating model...
Federated Round 3/40, Val Loss: 3.5577, Val Accuracy: 28.75%
Starting federated learning round 4/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Validating model...
Federated Round 4/40, Val Loss: 3.3289, Val Accuracy: 31.38%
Starting federated learning round 5/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
V