In [1]:
import importlib
import os
import random
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets

importlib.import_module(".common", "flwr.dataset.utils")


  from .autonotebook import tqdm as notebook_tqdm


<module 'flwr.dataset.utils.common' from 'c:\\Users\\MetBook\\Documents\\UNIVERSITA\\MAGISTRALE\\SD21\\PROGETTO\\FLWR\\peer_reviewed_fl\\envs\\lib\\site-packages\\flwr\\dataset\\utils\\common.py'>

In [2]:
DATASET = "CIFAR10"  # possible values: "CIFAR10" or "CIFAR100"
NUM_CLIENTS = 50
NUM_ROUNDS = 10
LOCAL_EPOCHS = 2
BATCH_SIZE = 128
LR = 0.1
MILESTONES = [60, 120, 160]
LR_DECAY = 0.2
W_DECAY = 5e-04
FRACTION_FIT = 1 / 3
FRACTION_EVAL = 0
SEED = 0
# Set the start method for multiprocessing in case Python version is under 3.8.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")


Training on cuda


In [3]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [4]:
def load_datasets(num_clients: int, dataset: str = "CIFAR10", src: str = "."):
    if dataset not in ["CIFAR10", "CIFAR100"]:
        raise ValueError(
            "Unknown dataset! Admissible values are: 'CIFAR10' or 'CIFAR100'."
        )
    # Download and transform CIFAR dataset (train and test)
    augmentation = [
        transforms.Pad(4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
    ]
    transform = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    trainset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=True,
        download=True,
        transform=transforms.Compose([*augmentation, *transform]),
    )
    testset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=False,
        download=True,
        transform=transforms.Compose(transform),
    )

    # Split training set into `num_clients` partitions to simulate different local datasets
    targets = np.array(trainset.targets)
    idxs = np.array(range(len(targets)))
    dataset = [idxs, targets]
    train_partitions, _ = fl.dataset.utils.common.create_lda_partitions(
        dataset,
        num_partitions=NUM_CLIENTS,
        concentration=1 / 3,
        accept_imbalanced=False,
    )

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in map(lambda p: Subset(trainset, p[0]), train_partitions):
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator())
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader


In [5]:
class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int, lr: float = 0.1):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.SGD(
        net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4
    )
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net: nn.Module, testloader: DataLoader, return_dict=None):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy


In [6]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net.to(DEVICE)
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        current_round = config["current_round"]
        local_epochs = config["local_epochs"]
        lr = config["lr"]
        # Use values provided by the config
        print(f"[Client {self.cid}, round {current_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


In [8]:
set_seed(SEED)

# Load data
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS, DATASET, "../../../")
NUM_CLASSES = len(np.unique(testloader.dataset.targets))

# Create an instance of the model and get the parameters
net = Net(NUM_CLASSES).to(DEVICE)
params = get_parameters(net)
print(sum(p.numel() for p in net.parameters() if p.requires_grad))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../.././data\cifar-10-python.tar.gz


170499072it [00:30, 5608281.93it/s]                               


Extracting ../../.././data\cifar-10-python.tar.gz to ../../.././data
Files already downloaded and verified
62006


In [9]:
def client_fn(cid) -> FlowerClient:
    net = Net(NUM_CLASSES).to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)


def evaluate(
    weights: fl.common.Weights,
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net(NUM_CLASSES).to(DEVICE)
    set_parameters(net, weights)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    return loss, {"accuracy": accuracy}


def fit_config(rnd: int):
    lr = LR
    if MILESTONES is not None and LR_DECAY is not None:
        lr *= LR_DECAY ** sum([1 if rnd >= e else 0 for e in MILESTONES])
    config = {
        "current_round": rnd,
        "local_epochs": 1 if rnd < 2 else LOCAL_EPOCHS,
        "lr": lr,
    }
    return config


strategy = fl.server.strategy.FedAvg(
    fraction_fit=FRACTION_FIT,
    fraction_eval=FRACTION_EVAL,
    min_fit_clients=int(FRACTION_FIT * NUM_CLIENTS),
    min_eval_clients=int(FRACTION_EVAL * NUM_CLIENTS),
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.weights_to_parameters(get_parameters(net)),
    on_fit_config_fn=fit_config,
    eval_fn=evaluate,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=10,
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 1},
    ray_init_args={"local_mode": True, "include_dashboard": False},
)


INFO flower 2022-06-02 23:42:41,003 | app.py:155 | Ray initialized with resources: {'object_store_memory': 191444582.0, 'CPU': 8.0, 'memory': 382889166.0, 'GPU': 1.0, 'node:127.0.0.1': 1.0}
INFO flower 2022-06-02 23:42:41,007 | app.py:171 | Starting Flower simulation running: Config(num_rounds=10, round_timeout=None)
INFO flower 2022-06-02 23:42:41,016 | server.py:84 | Initializing global parameters
INFO flower 2022-06-02 23:42:41,018 | server.py:252 | Using initial parameters provided by strategy
INFO flower 2022-06-02 23:42:41,022 | server.py:86 | Evaluating initial parameters
