In [1]:
# TODO: Add noise
%load_ext autoreload
%autoreload 2


# Centralised training: the old way of doing ML

Let's begin by creating a simple (but complete) training loop as it is commonly done in centralised setups. Starting our tutorial in this way will allow us to very clearly identify which parts of a typical ML pipeline are common to both centralised and federated training.

For this tutorial we'll design a image classification pipeline for [MNIST digits](https://en.wikipedia.org/wiki/MNIST_database) and using a simple CNN model as the network to train. The MNIST dataset is comprised of `28x28` greyscale images with digits from 0 to 9 (i.e. 10 classes in total)


## A dataset

Let's begin by constructing the dataset.

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '5,7'
# we naturally first need to import torch and torchvision

import sys
sys.path.append('.')
from opt import get_opts
import torch
from collections import defaultdict

from torch.utils.data import DataLoader, random_split, Subset
from datasets import dataset_dict

# models
from models.nerf import *
from models.rendering import *

# optimizer, scheduler, visualization
from utils import *

# losses
from losses import loss_dict

# metrics
from metrics import *
import copy

In [3]:
## This code is to convert filenames which are step size
# import os

# # Set the directory path
# directory = "/home/zt16/code/priv-nerf/nerfw_pl_priv/data/lego/res400_360view"

# # Set the starting number for renaming
# counter = 6

# # Iterate through each file in the directory
# for filename in os.listdir(directory):
#     # Check if the file is a regular file (not a directory)
#     if os.path.isfile(os.path.join(directory, filename)) and counter<100:
#         # Get the file extension (if any)
#         print(filename)
#         ext = os.path.splitext(filename)[1]
#         # Generate the new filename
#         new_filename = f"r_{counter:d}{ext}"
#         print(new_filename)
#         # Rename the file
#         os.rename(os.path.join(directory, filename), os.path.join(directory, new_filename))
#         # Increment the counter
#         counter += 1


## Load datasets

In [4]:
def setup_dataloader(hparams):
    dataset = dataset_dict[hparams.dataset_name]
    kwargs = {'root_dir': hparams.root_dir}
    # import pdb; pdb.set_trace()
    if hparams.dataset_name == 'phototourism':
        kwargs['img_downscale'] = hparams.img_downscale
        kwargs['val_num'] = hparams.num_gpus
        kwargs['use_cache'] = hparams.use_cache
        kwargs['use_mask'] = hparams.use_mask
    elif hparams.dataset_name == 'blender':
        kwargs['img_wh'] = tuple(hparams.img_wh)
        kwargs['perturbation'] = hparams.data_perturb
        kwargs['random_occ'] = not hparams.nonrandom_occ
        kwargs['occ_yaw'] = hparams.occ_yaw
        kwargs['yaw_threshold'] = hparams.yaw_threshold
        kwargs['all_img_occ'] = hparams.all_img_occ
    
    train_dataset = dataset(split='train', **kwargs)
#     full_dataset = dataset(split='train', **kwargs)

    
#     train_datasets = random_split(train_dataset, lengths, torch.Generator().manual_seed(42))
    img_sample_size = hparams.img_wh[0] * hparams.img_wh[1]
#     import pdb; pdb.set_trace()
    if hparams.public_dataset:
        kwargs_public = copy.deepcopy(kwargs)
#         kwargs_public['random_occ'] = True 
        kwargs_public['root_dir'] = hparams.public_root_dir
        public_dataset = dataset(split='train',**kwargs_public)
        public_train_ray_idx = []
        for ind in range(0,100,5):
            public_train_ray_idx.extend(list(range(ind*img_sample_size,(ind+1)*img_sample_size)))    
        public_train_dataset = Subset(public_dataset,public_train_ray_idx)
        public_train_loaders = DataLoader(public_train_dataset,shuffle=True,
                              num_workers=4,
                              batch_size=hparams.batch_size,
                              pin_memory=True)
#         import pdb; pdb.set_trace()
#         full_idx = set(list(range(img_sample_size*100)))
#         public_idx = set(public_train_ray_idx)
#         remaining_idx = list(full_idx-public_idx)
#         remaining_idx = [x for x in full_idx if x not in public_train_ray_idx]
        train_dataset_remaining = train_dataset
#         train_dataset_remaining = Subset(train_dataset,remaining_idx)
    else:
        train_dataset_remaining = train_dataset
        public_train_loaders = None
    
    #splitting the dataset
    partition_size = len(train_dataset_remaining) // hparams.num_clients
    lengths = [partition_size] * (hparams.num_clients)
    
#     import pdb; pdb.set_trace()
    train_datasets = []
    for ind in range(hparams.num_clients):
        train_datasets.append(Subset(train_dataset_remaining,range(ind*partition_size,ind*partition_size+partition_size)))
    val_dataset = dataset(split='val', **kwargs)
    
    train_loaders = []
    val_loaders = []
    
    for trainset in train_datasets:
        train_loaders.append(DataLoader(trainset,shuffle=True,
                          num_workers=4,
                          batch_size=hparams.batch_size,
                          pin_memory=True))
        val_loaders.append(DataLoader(val_dataset,
                          shuffle=False,
                          num_workers=4,
                          batch_size=1, # validate one image (H*W rays) at a time
                          pin_memory=True))
        
    
    

    return train_loaders, val_loaders, train_dataset, public_train_loaders

# Preparing the experiment

This tutorial is not so much about novel architectural designs so we keep things simple and make use of a typical CNN that is adequate for the MNIST image classification task.



In [5]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, hparams,full_dataset) -> None:
        super(Net,self).__init__()
        self.hparams = hparams
        self.full_dataset = full_dataset
#         self.loss = loss_dict['nerfw'](coef=1)

        self.models_to_train = []
        self.embedding_xyz = PosEmbedding(hparams.N_emb_xyz-1, hparams.N_emb_xyz)
        self.embedding_dir = PosEmbedding(hparams.N_emb_dir-1, hparams.N_emb_dir)
        self.embeddings = {'xyz': self.embedding_xyz,
                           'dir': self.embedding_dir}

        if hparams.encode_a:
            self.embedding_a = torch.nn.Embedding(hparams.N_vocab, hparams.N_a)
            self.embeddings['a'] = self.embedding_a
            self.models_to_train += [self.embedding_a]
        if hparams.encode_t:
            self.embedding_t = torch.nn.Embedding(hparams.N_vocab, hparams.N_tau)
            self.embeddings['t'] = self.embedding_t
            self.models_to_train += [self.embedding_t]

        self.nerf_coarse = NeRF('coarse',
                                in_channels_xyz=6*hparams.N_emb_xyz+3,
                                in_channels_dir=6*hparams.N_emb_dir+3)
        self.models = {'coarse': self.nerf_coarse}
        if hparams.N_importance > 0:
            self.nerf_fine = NeRF('fine',
                                  in_channels_xyz=6*hparams.N_emb_xyz+3,
                                  in_channels_dir=6*hparams.N_emb_dir+3,
                                  encode_appearance=hparams.encode_a,
                                  in_channels_a=hparams.N_a,
                                  encode_transient=hparams.encode_t,
                                  in_channels_t=hparams.N_tau,
                                  beta_min=hparams.beta_min)
            self.models['fine'] = self.nerf_fine
        self.models_to_train += [self.models]

    def forward(self, rays, ts):
        B = rays.shape[0]
        results = defaultdict(list)
        for i in range(0, B, self.hparams.chunk):
            rendered_ray_chunks = \
                render_rays(self.models,
                            self.embeddings,
                            rays[i:i+self.hparams.chunk],
                            ts[i:i+self.hparams.chunk],
                            self.hparams.N_samples,
                            self.hparams.use_disp,
                            self.hparams.perturb,
                            self.hparams.noise_std,
                            self.hparams.N_importance,
                            self.hparams.chunk, # chunk size is effective in val mode
                            self.full_dataset.white_back)

            for k, v in rendered_ray_chunks.items():
                results[k] += [v]

        for k, v in results.items():
            results[k] = torch.cat(v, 0)
        return results

We'll be training the model in a Federated setting. In order to do that, we need to define two functions:

* `train()` that will train the model given a dataloader.
* `test()` that will be used to evaluate the performance of the model on held-out data, e.g., a training set.

In [6]:
def train(net,trainloader,optimizer,scheduler,epochs,device):    
    criterion = loss_dict['nerfw'](coef=1)
    net.train()
    
    for _ in range(epochs):
        itera=0
        for batch in trainloader:

#             if itera>10:
#                 break
                
            rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
            rays = rays.to(device)
            rgbs = rgbs.to(device)
            ts = ts.to(device)
            optimizer.zero_grad()
            results = net(rays,ts)
            loss_d = criterion(results,rgbs)
            loss = sum(l for l in loss_d.values())
            loss.backward()
            optimizer.step()
            itera+=1
            
                
        scheduler.step()
#         print('iteration done')
    return net

def test(net, valloader, device):
    criterion = loss_dict['nerfw'](coef=1)
    psnr_, loss = 0.0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in valloader:
#             batch.to(device)
            rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
            rays = rays.to(device)
            rgbs = rgbs.to(device)
            ts = ts.to(device)
            rays = rays.squeeze() # (H*W, 3)
            rgbs = rgbs.squeeze() # (H*W, 3)
            ts = ts.squeeze() # (H*W)
            results = net(rays, ts)
            loss_d = criterion(results, rgbs)
            loss += sum(l for l in loss_d.values())
            typ = 'fine' if 'rgb_fine' in results else 'coarse'

            psnr_ += psnr(results[f'rgb_{typ}'], rgbs)
    val_psnr = psnr_ / len(valloader.dataset)
            
    return loss, val_psnr

In [7]:
import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--root_dir', type=str, required=False,
                    help='root directory of dataset')
parser.add_argument('--dataset_name', type=str, default='blender',
                    choices=['blender', 'phototourism'],
                    help='which dataset to train/val')
# for blender
parser.add_argument('--data_perturb', nargs="+", type=str, default=[],
                    help='''what perturbation to add to data.
                            Available choices: [], ["color"], ["occ"] or ["color", "occ"]
                         ''')
parser.add_argument('--nonrandom_occ', default=False, action="store_true",
                    help='whether to use non-random occlusion')
parser.add_argument('--all_img_occ', default=False, action="store_true",
                    help='whether to add black occlusion to all images')
parser.add_argument('--occ_yaw', type=float, default=0.0,
                    help='yaw angle for selecting images for non-random occlusion')
parser.add_argument('--yaw_threshold', type=float, default=0.0,
                    help='threshold for selecting images for non-random occlusion')
parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800],
                    help='resolution (img_w, img_h) of the image')

# for phototourism
parser.add_argument('--img_downscale', type=int, default=1,
                    help='how much to downscale the images for phototourism dataset')
parser.add_argument('--use_cache', default=False, action="store_true",
                    help='whether to use ray cache (make sure img_downscale is the same)')
parser.add_argument('--use_mask', default=False, action="store_true",
                    help='use masked images')

# original NeRF parameters
parser.add_argument('--N_emb_xyz', type=int, default=10,
                    help='number of xyz embedding frequencies')
parser.add_argument('--N_emb_dir', type=int, default=4,
                    help='number of direction embedding frequencies')
parser.add_argument('--N_samples', type=int, default=64,
                    help='number of coarse samples')
parser.add_argument('--N_importance', type=int, default=128,
                    help='number of additional fine samples')
parser.add_argument('--use_disp', default=False, action="store_true",
                    help='use disparity depth sampling')
parser.add_argument('--perturb', type=float, default=1.0,
                    help='factor to perturb depth sampling points')
parser.add_argument('--noise_std', type=float, default=1.0,
                    help='std dev of noise added to regularize sigma')

# NeRF-W parameters
parser.add_argument('--N_vocab', type=int, default=100,
                    help='''number of vocabulary (number of images) 
                            in the dataset for nn.Embedding''')
parser.add_argument('--encode_a', default=False, action="store_true",
                    help='whether to encode appearance (NeRF-A)')
parser.add_argument('--N_a', type=int, default=48,
                    help='number of embeddings for appearance')
parser.add_argument('--encode_t', default=False, action="store_true",
                    help='whether to encode transient object (NeRF-U)')
parser.add_argument('--N_tau', type=int, default=16,
                    help='number of embeddings for transient objects')
parser.add_argument('--beta_min', type=float, default=0.1,
                    help='minimum color variance for each ray')

parser.add_argument('--batch_size', type=int, default=1024,
                    help='batch size')
parser.add_argument('--chunk', type=int, default=32*1024,
                    help='chunk size to split the input to avoid OOM')
parser.add_argument('--num_epochs', type=int, default=16,
                    help='number of training epochs')
parser.add_argument('--num_gpus', type=int, default=1,
                    help='number of gpus')

parser.add_argument('--ckpt_path', type=str, default=None,
                    help='pretrained checkpoint path to load')
parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'],
                    help='the prefixes to ignore in the checkpoint state dict')

parser.add_argument('--optimizer', type=str, default='adam',
                    help='optimizer type',
                    choices=['sgd', 'adam', 'radam', 'ranger'])
parser.add_argument('--lr', type=float, default=5e-4,
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='learning rate momentum')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--lr_scheduler', type=str, default='steplr',
                    help='scheduler type',
                    choices=['steplr', 'cosine', 'poly'])
#### params for warmup, only applied when optimizer == 'sgd' or 'adam'
parser.add_argument('--warmup_multiplier', type=float, default=1.0,
                    help='lr is multiplied by this factor after --warmup_epochs')
parser.add_argument('--warmup_epochs', type=int, default=0,
                    help='Gradually warm-up(increasing) learning rate in optimizer')
###########################
#### params for steplr ####
parser.add_argument('--decay_step', nargs='+', type=int, default=[20],
                    help='scheduler decay step')
parser.add_argument('--decay_gamma', type=float, default=0.1,
                    help='learning rate decay amount')
###########################
#### params for poly ####
parser.add_argument('--poly_exp', type=float, default=0.9,
                    help='exponent for polynomial learning rate decay')
###########################

parser.add_argument('--exp_name', type=str, default='exp',
                    help='experiment name')
parser.add_argument('--refresh_every', type=int, default=1,
                    help='print the progress bar every X steps')
# Federated Learning parameters
parser.add_argument('--num_clients', type=int, default=1,help='number of clients')
parser.add_argument('--num_rounds', type=int, default=10,help='number of rounds')
parser.add_argument('--public_dataset', default=False, action="store_true",
                    help='whether to use public dataset for training')
hparams = parser.parse_args(args=[])

In [8]:
hparams.dataset_name

'blender'

The code we have written so far is not specific to Federated Learning. Then, what are the key differences between Federated Learning and Centralised Training? If you could only pick you, probably you'd say:
* Federated Learning is distributed -- the model is trained on-device by the participating clients.
* Data remains private and is owned by a specific _client_ -- the data is never sent to the central server.

The are several more differences. But the above two are the main ones to always consider and that are common to all flavours of Federated Learning (e.g. _cross-device_ or _cross-silo_). The remaining of this tutorial is going to focus in transforming the code we have written so far for the centralised setting and construct a Federated Learning pipeline using Flower and PyTorch.

Let's begin! 🚀

In [9]:
import copy
hparams.dataset_name = 'blender'
hparams.root_dir = '/home/zt16/code/priv-nerf/nerfw_pl_priv/data/lego/res800_360view_IID_vertical_random'#'/home/zt16/code/priv-nerf/nerfw_pl_priv/data/lego/res400_360view'
hparams.public_root_dir = '/home/zt16/code/priv-nerf/nerfw_pl_priv/data/lego/res400_360view_random'
hparams.N_importance = 64
hparams.img_wh = [400, 400]
hparams.noise_std = 0
hparams.num_epochs = 5
hparams.batch_size = 1024 
hparams.optimizer = 'adam'
hparams.lr = 5e-4
hparams.lr_scheduler = 'cosine'
hparams.exp_name = 'lego_nerfU_nonIID_split_20clients_newdataset'
hparams.encode_t = True #False #True
hparams.beta_min = 0.1 
hparams.data_perturb = [] #["occ"]
hparams.num_clients = 20
hparams.num_rounds = 30
hparams.nonrandom_occ = True
hparams.all_img_occ = True
hparams.public_dataset = False
NUM_CLIENTS = hparams.num_clients
    


trainloaders,valloaders, full_dataset_central, public_train_loader = setup_dataloader(hparams)
hparams_public = copy.deepcopy(hparams)
hparams_central = copy.deepcopy(hparams)
hparams_public.encode_t = True
hparams_central.encode_t = False

add [] perturbation!


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model_central = Net(hparams,full_dataset_central).to(device)
# # # Define the optimizer and schedulter
# optim = get_optimizer(hparams, model_central.models_to_train)
# scheduler = get_scheduler(hparams, optim)



In [11]:
# # do local training
# train(model_central, public_train_loader, optim, scheduler, epochs=20, device=device)
# loss, accuracy = test(model_central, valloaders[0], device=device)

In [12]:
# print(loss,accuracy)

In [13]:
from collections import OrderedDict
# _state_dict_all = model_central.state_dict()
# _mydict = model_central.state_dict().keys()

# _mystr_arr = ['nerf_fine.transient_','embedding_t']
# _transient_keys = []
# #         import pdb; pdb.set_trace()
# for mystr in _mystr_arr:
# #             avoid_keys.append(key.startswith(mystr) for key in mydict)
#     for key in _mydict:
#         if key.startswith(mystr):
#             _transient_keys.append(key)



# _state_dict_transient = OrderedDict({k: v for k, v in _state_dict_all.items() if k in _transient_keys })
# _state_dict_static = OrderedDict({k: v for k, v in _state_dict_all.items() if k not in _transient_keys })

# os.system(f'mkdir -p ckpts/{hparams.exp_name}')
# _ckpt_file = os.path.join(f'ckpts/{hparams.exp_name}/central_model_transient40.pth')
# _ckpt_file_static = os.path.join(f'ckpts/{hparams.exp_name}/central_model_static40.pth')
# torch.save(_state_dict_transient,_ckpt_file)
# torch.save(_state_dict_static,_ckpt_file_static)

In [14]:
# public_ckpt_file_static = os.path.join(f'ckpts/{hparams.exp_name}/central_model_static40.pth')
# public_static_state_dict = torch.load(public_ckpt_file_static)

# public_model_static_params = [val.cpu().numpy() for _,val in public_static_state_dict.items()]

In [15]:
# state_dict2=model_central2.state_dict()
# # for k, v in state_dict2.items():
# #     jk=1
# state_dict = OrderedDict({k: v for k, v in state_dict2.items() if k in ['ghgjhj'] })

In [16]:
len(trainloaders[0])

782

As you can see, the histogram of this partition is a bit different from the one we obtained at the beginning where we took the entire dataset into consideration. Because our data partitions are artificially constructed by sampling the MNIST dataset in an IID fashion, our Federated Learning example will not face sever _data heterogeneity_ issues (which is a fairly [active research topic](https://arxiv.org/abs/1912.04977)).

Let's next define how our FL clients will behave

## Defining a Flower Client

You can think of a client in FL as an entity that owns some data and trains a model using this data. The caveat is that the model is being trained _collaboratively_ in Federation by multiple clients (sometimes up to hundreds of thousands) and, in most instances of FL, is sent by a central server.

A Flower Client is a simple Python class with four distinct methods:

* `fit()`: With this method, the client does on-device training for a number of epochs using its own data. At the end, the resulting model is sent back to the server for aggregation.

* `evaluate()`: With this method, the server can evaluate the performance of the global model on the local validation set of a client. This can be used for instance when there is no centralised dataset on the server for validation/test. Also, this method can be use to asses the degree of personalisation of the model being federated.

* `set_parameters()`: This method takes the parameters sent by the server and uses them to initialise the parameters of the local model that is ML framework specific (e.g. TF, Pytorch, etc).

* `get_parameters()`: It extract the parameters from the local model and transforms them into a list of NumPy arrays. This ML framework-agnostic representation of the model will be sent to the server.

Let's start by importing Flower!

In [17]:
import flwr as fl

2023-09-19 11:03:46,068	INFO util.py:159 -- Outdated packages:
  ipywidgets==7.6.5 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Now let's defice our Flower Client class:

In [18]:

from typing import Dict, List, Tuple

import torch
from flwr.common import NDArrays, Scalar
import glob


class FlowerClient(fl.client.NumPyClient):
    def __init__(self, hparams, trainloader, vallodaer,full_dataset,client_number) -> None:
        super().__init__()

        self.trainloader = trainloader
        self.valloader = vallodaer
        self.full_dataset = full_dataset
        self.hparams = hparams
        self.model = Net(self.hparams,self.full_dataset)
        self.client_num = client_number
        # Determine device
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # figure out Transient params
        mydict = self.model.state_dict().keys()
#         avoid_transient = ['transient_rgb','transient_encoding','transient_sigma','transient_beta','embedding_t']
        mystr_arr = ['nerf_fine.transient_','embedding_t']
        transient_keys = []
#         import pdb; pdb.set_trace()
        for mystr in mystr_arr:
#             avoid_keys.append(key.startswith(mystr) for key in mydict)
            for key in mydict:
                if key.startswith(mystr):
                    transient_keys.append(key)
        self.transient_keys = transient_keys
                    
        self.model.to(self.device)  # send model to device
        self.epoch_num =1

    def set_parameters(self, parameters):
        """With the model paramters received from the server,
        overwrite the uninitialise model in this class with them."""
#         import pdb; pdb.set_trace()
        
#         print(avoid_keys)
        mykeys = self.model.state_dict().keys()
#         import pdb;pdb.set_trace()
        mykeys = [key for key in mykeys if key != 'embedding_t.weight'] #embedding_t is non learnable
        params_dict = zip(mykeys, parameters)
        state_dict_static = OrderedDict({k: torch.Tensor(v) for k, v in params_dict if k not in self.transient_keys })
        ckpt_file_dir = os.path.join(f'ckpts/{self.hparams.exp_name}/clients/client_{self.client_num:0>2d}/**')
        files = glob.glob(ckpt_file_dir)
        if len(files)>0:
            
            ckpt_file = max(files)
#             print("loading from saved ckpt: ",ckpt_file)
            state_dict_transient = torch.load(ckpt_file)
            state_dict_transient =  OrderedDict({k: v for k, v in state_dict_transient.items() if k in self.transient_keys })

            state_dict_static.update(state_dict_transient)
        
        self.model.load_state_dict(state_dict_static, strict=False)

    def get_parameters(self, config: Dict[str, Scalar]):
        """Extract all model parameters and conver them to a list of
        NumPy arryas. The server doesn't work with PyTorch/TF/etc."""
        return [val.cpu().numpy() for key, val in self.model.state_dict().items() if key not in self.transient_keys]

    def fit(self, parameters, config):
        """This method train the model using the parameters sent by the
        server on the dataset of this client. At then end, the parameters
        of the locally trained model are communicated back to the server"""
        
        # read from config
        server_round_, lr, epochs = config["round"], config["lr"], config["epochs"]
        
        # copy parameters sent by the server into client's local model
        self.set_parameters(parameters)

        # Define the optimizer and schedulter
        optim = get_optimizer(self.hparams, self.model.models_to_train)
        scheduler = get_scheduler(self.hparams, optim)

        # do local training
        train(self.model, self.trainloader, optim, scheduler, epochs=epochs, device=self.device)
                
        state_dict_all = self.model.state_dict()
        
        if hparams.encode_t:
            state_dict_transient = OrderedDict({k: v for k, v in state_dict_all.items() if k in self.transient_keys })
            ckpt_file = os.path.join(f'ckpts/{self.hparams.exp_name}/clients/client_{self.client_num:0>2d}/epoch_{server_round_:0>2d}.pth')
            torch.save(state_dict_transient,ckpt_file)
        
        state_dict_static = OrderedDict({k: v for k, v in state_dict_all.items() if k not in self.transient_keys })        
        ckpt_file_static = os.path.join(f'ckpts/{self.hparams.exp_name}/clients-static/client_{self.client_num:0>2d}/epoch_{server_round_:0>2d}.pth')        
        torch.save(state_dict_static,ckpt_file_static)

        # return the model parameters to the server as well as extra info (number of training examples in this case)
        return self.get_parameters({}), len(self.trainloader), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        """Evaluate the model sent by the server on this client's
        local validation set. Then return performance metrics."""
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.valloader, device=self.device)
        # send statistics back to the server
        return float(loss), len(self.valloader), {"val_psnr": accuracy}

In [19]:
hparams.encode_t

True

In [20]:
# client1 = FlowerClient(hparams, trainloaders[0], valloaders[0],full_dataset_central,client_number=1) 

Spend a few minutes to inspect the `FlowerClient` class above. Please ask questions if there is something unclear !

Then keen-eyed among you might have realised that if we were to fuse the client's `fit()` and `evaluate()` methods, we'll end up with essentially the same as in the `run_centralised()` function we used in the Centralised Training part of this tutorial. And it is true!! In Federated Learning, the way clients perform local training makes use of the same principles as more traditional centralised setup. The key difference is that the dataset now is much smaller and it's never _"seen"_ by the entity running the FL workload (i.e. the central server).


Talking about the central server... we should define what strategy we want to make use of so the updated models sent from the clients back to the server at the end of the `fit()` method are aggregate.


## Choosing a Flower Strategy


A strategy sits at the core of the Federated Learning experiment. It is involved in all stages of a FL pipeline: sampling clients; sending the _global model_ to the clients so they can do `fit()`; receive the updated models from the clients and **aggregate** these to construct a new _global model_; define and execute global or federated evaluation; and more.

Flower comes with [many strategies built-in](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy) and more to be available in the next release (`1.5` already!). For this tutorial, let's use what is arguable the most popular strategy out there: `FedAvg`.

The way `FedAvg` works is simple but performs surprisingly well in practice. It is therefore one good strategy to start your experimentation. `FedAvg`, as its name implies, derives a new version of the _global model_ by taking the average of all the models sent by clients participating in the round. You can read all the details [in the paper](https://arxiv.org/abs/1602.05629).

Let's see how we can define `FedAvg` using Flower. We use one of the callbacks called `evaluate_fn` so we can easily evaluate the state of the global model using a small centralised testset. Note this functionality is user-defined since it requires a choice in terms of ML-framework. (if you recall, Flower is framework agnostic).

> This being said, centralised evaluation of the global model is only possible if there exists a centralised dataset that somewhat follows a similar distribution as the data that's spread across clients. In some cases having such centralised dataset for validation is not possible, so the only solution is to federate the evaluation of the _global model_. This is the default behaviour in Flower. If you don't specify teh `evaluate_fn` argument in your strategy, then, centralised global evaluation won't be performed.

In [21]:
def get_evaluate_fn(testloader):
    """This is a function that returns a function. The returned
    function (i.e. `evaluate_fn`) will be executed by the strategy
    at the end of each round to evaluate the stat of the global
    model."""

    def evaluate_fn(server_round: int, parameters, config):
        """This function is executed by the strategy it will instantiate
        a model and replace its parameters with those from the global model.
        The, the model will be evaluate on the test set (recall this is the
        whole MNIST test set)."""

        central_model = Net(hparams_central,full_dataset_central)

        # Determine device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        central_model.to(device)  # send model to device

#         print("global model evaluation started")
        # set parameters to the model
#         print(central_model.state_dict().keys())
        params_dict = zip(central_model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        central_model.load_state_dict(state_dict, strict=True)
        
#         print("global model loaded")
        

        # call test
        loss, accuracy = test(central_model, testloader, device)
        return loss, {"val_psnr": accuracy}

    return evaluate_fn


# now we can define the strategy
# strategy = fl.server.strategy.FedAvg(
#     fraction_fit=0.1,
#     fraction_evaluate=0.1,
#     min_available_clients=100,
#     evaluate_fn=get_evaluate_fn(testloader), # Even this is not required
# )

We could now define a strategy just as shown (commented) above. Instead, let's see how additional (but entirely optional) functionality can be easily added to our strategy. We are going to define two additional auxiliary functions to: (1) be able to configure how clients do local training; and (2) define a function to aggregate the metrics that clients return after running their `evaluate` methods:

1. `fit_config()`. This is a function that will be executed inside the strategy when configuring a new `fit` round. This function is relatively simple and only requires as input argument the round at which the FL experiment is at. In this example we simply return a Python dictionary to specify the number of epochs and learning rate each client should made use of inside their `fit()` methods. A more versatile implementation would add more hyperparameters (e.g. the learning rate) and adjust them as the FL process advances (e.g. reducing the learning rate in later FL rounds).
2. `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`.

In [22]:
from flwr.common import Metrics


def fit_config(server_round: int) -> Dict[str, Scalar]:
    """Return a configuration with static batch size and (local) epochs."""
    config = {
        "round": server_round,
        "epochs": 1,  # Number of local epochs done by clients
        "lr": 0.01,  # Learning rate to use by clients during fit()
    }
    return config


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregation function for (federated) evaluation metrics, i.e. those returned by
    the client's evaluate() method."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["val_psnr"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"val_psnr": sum(accuracies) / sum(examples)}

Now we can define our strategy:

In [23]:
from typing import Callable, Union, Dict, List, Optional, Tuple
from flwr.server.client_proxy import ClientProxy
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            model = Net(hparams_central,full_dataset_central)
            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(model.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            model.load_state_dict(state_dict, strict=True)

            # Save the model
            torch.save(model.state_dict(), f"ckpts/{hparams_central.exp_name}/server/round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

In [24]:
if not hparams.public_dataset:
    strategy = SaveModelStrategy(
        fraction_fit=1.,  # Sample 10% of available clients for training
        fraction_evaluate=0.1,  # Sample 5% of available clients for evaluation
        min_fit_clients=4,  # Never sample less than 10 clients for training
        min_evaluate_clients=1,  # Never sample less than 5 clients for evaluation
        min_available_clients=int(
            NUM_CLIENTS * 0.75
        ),  # Wait until at least 75 clients are available
        on_fit_config_fn=fit_config,
        evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics
        evaluate_fn=get_evaluate_fn(valloaders[0]),  # global evaluation function
    )
else:
    print("Initial Parameters given")
    strategy = SaveModelStrategy(
        fraction_fit=1.,  # Sample 10% of available clients for training
        fraction_evaluate=0.1,  # Sample 5% of available clients for evaluation
        min_fit_clients=4,  # Never sample less than 10 clients for training
        min_evaluate_clients=1,  # Never sample less than 5 clients for evaluation
        min_available_clients=int(
            NUM_CLIENTS * 0.75
        ),  # Wait until at least 75 clients are available
        on_fit_config_fn=fit_config,
        evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics
        evaluate_fn=get_evaluate_fn(valloaders[0]),  # global evaluation function
        initial_parameters=fl.common.ndarrays_to_parameters(public_model_static_params)#NOTE: finish this
    )

So far we have:
* created the dataset partitions (one for each client)
* defined the client class
* decided on a strategy to use

Now we just need to launch the Flower FL experiment... not so fast! just one final function: let's create another callback that the Simulation Engine will use in order to span VirtualClients. As you can see this is really simple: construct a FlowerClient object, assigning each their own data partition.

In [25]:
def generate_client_fn(trainloaders, valloaders,hparams,full_dataset):
    def client_fn(cid: str):
        """Returns a FlowerClient containing the cid-th data partition"""

        return FlowerClient(
            hparams=hparams, trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)],full_dataset=full_dataset,client_number=int(cid)
        )

    return client_fn


client_fn_callback = generate_client_fn(trainloaders, valloaders,hparams,full_dataset_central)

In [None]:
logger_filename = f'{hparams.exp_name}.txt'
fl.common.logger.configure(identifier="myFlowerExperiment", filename=logger_filename)
# With a dictionary, you tell Flower's VirtualClientEngine that each
# client needs exclusive access to these many resources in order to run
client_resources = {"num_cpus": 8, "num_gpus": 0.5}
os.system(f'mkdir -p ckpts/{hparams.exp_name}/server')
for i in range(NUM_CLIENTS):
    os.system(f'mkdir -p ckpts/{hparams.exp_name}/clients/client_{i:0>2d}')
    os.system(f'mkdir -p ckpts/{hparams.exp_name}/clients-static/client_{i:0>2d}')

history = fl.simulation.start_simulation(
    client_fn=client_fn_callback,  # a callback to construct a client
    num_clients=NUM_CLIENTS,  # total number of clients in the experiment
    config=fl.server.ServerConfig(num_rounds=hparams.num_rounds),  # let's run for 10 rounds
    strategy=strategy,  # the strategy that will orchestrate the whole FL pipeline
    client_resources=client_resources,
)

INFO flwr 2023-09-19 11:03:46,461 | app.py:175 | Starting Flower simulation, config: ServerConfig(num_rounds=30, round_timeout=None)
2023-09-19 11:03:49,006	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2023-09-19 11:03:50,198 | app.py:210 | Flower VCE: Ray initialized with resources: {'node:__internal_head__': 1.0, 'memory': 861846494208.0, 'object_store_memory': 200000000000.0, 'CPU': 128.0, 'node:10.129.96.28': 1.0, 'accelerator_type:A100': 1.0, 'GPU': 2.0}
INFO flwr 2023-09-19 11:03:50,199 | app.py:224 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 8, 'num_gpus': 0.5}
INFO flwr 2023-09-19 11:03:50,217 | app.py:270 | Flower VCE: Creating VirtualClientEngineActorPool with 4 actors
INFO flwr 2023-09-19 11:03:50,218 | server.py:89 | Initializing global parameters
INFO flwr 2023-09-19 11:03:50,220 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-09-19 11:03:55,450 | server.py:280 | Received initial parameters from one 

Saving round 1 aggregated_parameters...


INFO flwr 2023-09-19 11:16:59,636 | server.py:125 | fit progress: (1, tensor(0.7406, device='cuda:0'), {'val_psnr': tensor(9.6018, device='cuda:0')}, 751.9026063049969)
DEBUG flwr 2023-09-19 11:16:59,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 11:18:09,654 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-09-19 11:18:09,657 | server.py:222 | fit_round 2: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 11:30:07,823 | server.py:236 | fit_round 2 received 20 results and 0 failures


Saving round 2 aggregated_parameters...


INFO flwr 2023-09-19 11:30:36,899 | server.py:125 | fit progress: (2, tensor(0.4134, device='cuda:0'), {'val_psnr': tensor(13.7340, device='cuda:0')}, 1569.1650289289973)
DEBUG flwr 2023-09-19 11:30:36,908 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 11:31:48,227 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-09-19 11:31:48,229 | server.py:222 | fit_round 3: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 11:43:46,446 | server.py:236 | fit_round 3 received 20 results and 0 failures


Saving round 3 aggregated_parameters...


INFO flwr 2023-09-19 11:44:15,750 | server.py:125 | fit progress: (3, tensor(0.1978, device='cuda:0'), {'val_psnr': tensor(16.3960, device='cuda:0')}, 2388.016537931995)
DEBUG flwr 2023-09-19 11:44:15,757 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 11:45:25,972 | server.py:187 | evaluate_round 3 received 2 results and 0 failures
DEBUG flwr 2023-09-19 11:45:25,977 | server.py:222 | fit_round 4: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 11:57:33,827 | server.py:236 | fit_round 4 received 20 results and 0 failures


Saving round 4 aggregated_parameters...


INFO flwr 2023-09-19 11:58:02,998 | server.py:125 | fit progress: (4, tensor(0.1377, device='cuda:0'), {'val_psnr': tensor(17.7429, device='cuda:0')}, 3215.264642119)
DEBUG flwr 2023-09-19 11:58:03,008 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 11:59:14,262 | server.py:187 | evaluate_round 4 received 2 results and 0 failures
DEBUG flwr 2023-09-19 11:59:14,265 | server.py:222 | fit_round 5: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 12:11:14,452 | server.py:236 | fit_round 5 received 20 results and 0 failures


Saving round 5 aggregated_parameters...


INFO flwr 2023-09-19 12:11:43,538 | server.py:125 | fit progress: (5, tensor(0.1058, device='cuda:0'), {'val_psnr': tensor(19.0254, device='cuda:0')}, 4035.8048292739986)
DEBUG flwr 2023-09-19 12:11:43,547 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 12:12:54,287 | server.py:187 | evaluate_round 5 received 2 results and 0 failures
DEBUG flwr 2023-09-19 12:12:54,290 | server.py:222 | fit_round 6: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 12:24:52,502 | server.py:236 | fit_round 6 received 20 results and 0 failures


Saving round 6 aggregated_parameters...


INFO flwr 2023-09-19 12:25:22,023 | server.py:125 | fit progress: (6, tensor(0.0852, device='cuda:0'), {'val_psnr': tensor(20.3616, device='cuda:0')}, 4854.28932542399)
DEBUG flwr 2023-09-19 12:25:22,028 | server.py:173 | evaluate_round 6: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 12:26:33,186 | server.py:187 | evaluate_round 6 received 2 results and 0 failures
DEBUG flwr 2023-09-19 12:26:33,189 | server.py:222 | fit_round 7: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 12:38:33,379 | server.py:236 | fit_round 7 received 20 results and 0 failures


Saving round 7 aggregated_parameters...


INFO flwr 2023-09-19 12:39:02,970 | server.py:125 | fit progress: (7, tensor(0.0720, device='cuda:0'), {'val_psnr': tensor(21.4874, device='cuda:0')}, 5675.236932414002)
DEBUG flwr 2023-09-19 12:39:02,976 | server.py:173 | evaluate_round 7: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 12:40:14,021 | server.py:187 | evaluate_round 7 received 2 results and 0 failures
DEBUG flwr 2023-09-19 12:40:14,023 | server.py:222 | fit_round 8: strategy sampled 20 clients (out of 20)
DEBUG flwr 2023-09-19 12:52:11,725 | server.py:236 | fit_round 8 received 20 results and 0 failures


Saving round 8 aggregated_parameters...


INFO flwr 2023-09-19 12:52:41,040 | server.py:125 | fit progress: (8, tensor(0.0634, device='cuda:0'), {'val_psnr': tensor(22.3006, device='cuda:0')}, 6493.306253212999)
DEBUG flwr 2023-09-19 12:52:41,046 | server.py:173 | evaluate_round 8: strategy sampled 2 clients (out of 20)
DEBUG flwr 2023-09-19 12:53:51,959 | server.py:187 | evaluate_round 8 received 2 results and 0 failures
DEBUG flwr 2023-09-19 12:53:51,962 | server.py:222 | fit_round 9: strategy sampled 20 clients (out of 20)


Doing 10 rounds should take less than 2 minutes on a CPU-only Colab instance <-- Flower Simulation is fast! 🚀

You can then use the resturned `History` object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the _global model_. This is want the function `evaluate_fn()` that we passed to the strategy reports.

In [None]:
# print(f"{history.metrics_centralized = }")

# global_accuracy_centralised = history.metrics_centralized["accuracy"]
# round = [data[0] for data in global_accuracy_centralised]
# acc = [100.0 * data[1] for data in global_accuracy_centralised]
# plt.plot(round, acc)
# plt.grid()
# plt.ylabel("Accuracy (%)")
# plt.xlabel("Round")
# plt.title("MNIST - IID - 100 clients with 10 clients per round")