In [1]:
import sys
sys.path.append('../data/cifar100/')
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
import random
from torch.utils.data import Subset
from statistics import mean
#from cifar100_loader import load_cifar100
from models.model import LeNet5 #import the model
from Server import Server
import numpy as np
from torchvision import datasets, transforms

### Constants for FL training

In [2]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 64  # Batch size for local training
LR = 1e-2  # Initial learning rate for local optimizers
MOMENTUM = 0.9  # Momentum for SGD optimizer
WEIGHT_DECAY = 1e-4  # Regularization term for local training

LOG_FREQUENCY = 10  # Frequency of logging training progress

cpu


# Loaders

In [3]:
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

class CIFAR100DataLoader:
    def __init__(self, batch_size=32, validation_split=0.1, download=True, num_workers=4, pin_memory=True):
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.download = download
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(24, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        self.test_transform = transforms.Compose([
            transforms.CenterCrop(24),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        # Load datasets
        self.train_loader, self.val_loader, self.test_loader = self._prepare_loaders()

    def _prepare_loaders(self):
        # Load the full training dataset
        full_trainset = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.train_transform)

        # Split indices for training and validation
        indexes = list(range(len(full_trainset)))
        train_indexes, val_indexes = train_test_split(
            indexes,
            train_size=1 - self.validation_split,
            test_size=self.validation_split,
            random_state=42,
            stratify=full_trainset.targets,
            shuffle=True
        )

        # Create training and validation subsets
        train_dataset = Subset(full_trainset, train_indexes)
        train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        full_trainset_val = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.test_transform)
        val_dataset = Subset(full_trainset_val, val_indexes)
        val_loader = DataLoader(
            val_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        # Load the test dataset
        testset = datasets.CIFAR100(root='./data', train=False, download=self.download, transform=self.test_transform)
        test_loader = DataLoader(
            testset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        return train_loader, val_loader, test_loader

    def __iter__(self):
        """Allows iteration over all loaders for unified access."""
        return iter([self.train_loader, self.val_loader, self.test_loader])

In [4]:
#10% of the dataset kept for validation
from cifar100_loader import CIFAR100DataLoader


data_loader = CIFAR100DataLoader(batch_size=32, validation_split=0.1, download=True, num_workers=4, pin_memory=True)
trainloader, validloader, testloader = data_loader.train_loader, data_loader.val_loader, data_loader.test_loader

print("Dimension of the training dataset:", len(trainloader.dataset))
print("Dimension of the validation dataset:", len(validloader.dataset))
print("Dimension of the test dataset:", len(testloader.dataset))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Dimension of the training dataset: 45000
Dimension of the validation dataset: 5000
Dimension of the test dataset: 10000


# Training

### Initialize Model & Loss

In [5]:
global_model = LeNet5()
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100

### Run the training

In [6]:
#just for now
from Server import Server
lr = LR
wd = WEIGHT_DECAY
CHECKPOINT_DIR = '../checkpoints/'
#delete_existing_checkpoints("Federated/")
# Run Federated Learning
# Instantiate the server
server = Server(global_model, DEVICE, CHECKPOINT_DIR)
#run federeted learning
global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(
    criterion=criterion,
    trainloader=trainloader,
    validloader=validloader,
    num_clients=NUM_CLIENTS,
    num_classes=100,
    rounds=GLOBAL_ROUNDS,
    lr=lr,
    momentum=MOMENTUM,
    batchsize=BATCH_SIZE,
    wd=wd,
    C=FRACTION_CLIENTS,
    local_steps=LOCAL_STEPS,
    log_freq=LOG_FREQUENCY,
    detailed_print=False
)

No checkpoint found, starting from epoch 1.
------------------------------------- Round 10 ------------------------------------------------
--> best validation accuracy: 0.01
--> training accuracy: 0.55
--> validation loss: 4.6051
--> training loss: 4.6067
Checkpoint saved: ../checkpoints/Federated/model_epoch_9_params_LR0.01_WD0.0001.pth
------------------------------ Round 9 terminated: model updated -----------------------------


------------------------------------- Round 20 ------------------------------------------------
--> best validation accuracy: 0.01
--> training accuracy: 0.98
--> validation loss: 4.6042
--> training loss: 4.6046
Checkpoint saved: ../checkpoints/Federated/model_epoch_19_params_LR0.01_WD0.0001.pth
------------------------------ Round 19 terminated: model updated -----------------------------




KeyboardInterrupt: 

# Validation

### Run the test



In [None]:
accuracy = evaluate(global_model, testloader, criterion)[0]
print('\nTest Accuracy: {}'.format(accuracy))