# FLSim Tutorial: Image classification with CIFAR-10

## Introduction

In this tutorial, we will train a simple CNN image classifier on CIFAR-10 with federated learning using FLSim.

### Prerequisites

To get the most of this tutorial, you should be comfortable with training machine learning models with **PyTorch** and familiar with the concept of **federated learning (FL)**. If you are unfamiliar with either of them or could use a refresher, please take a look at the following resources before proceeding with the tutorial:

- McMahan & Ramage (2017): [Federated Learning: Collaborative Machine Learning without Centralized Training Data](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html). A short blog post from Google AI introducing the main idea of FL in a beginner-friendly way.
- McMahan et al. (2017): [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/pdf/1602.05629.pdf). This paper first proposes the approach of federated learning. The described algorithm is now known as federated averaging (or FedAvg for short).
- PyTorch has [extensive tutorials](https://pytorch.org/tutorials/) on their website. In particular, take a look at their [image classification tutorial using CIFAR-10](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).

Now that you're familiar with PyTorch and FL, let's move on!

### Objectives 

By the end of this tutorial, we will have learnt how to

1. Build a data pipeline for federated learning with FLSim,
2. Create an image classification model compatible with FL training,
3. Create a metrics reporter to collect and report metrics,
4. Set hyperparameters for FL training, and
5. Launch an FL training flow using FLSim.

## Training an image classifier with FLSim

### Prerequisites
First, let us install flsim via pip with the command below:

In [1]:
!pip install --quiet flsim

Some useful parameters for later - no need to change these.

In [2]:
USE_CUDA = True
LOCAL_BATCH_SIZE = 32
EXAMPLES_PER_USER = 500
IMAGE_SIZE = 32

# suppress large outputs
VERBOSE = False

### 0. About the dataset

For this tutorial, we will use the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). The CIFAR-10 dataset consists of 60k 3x32x32 3-channel color images with 32x32 pixels from 10 classes, with 6k images per class. 
There are 50k training images (5k training images per class) and 10k test images (1k test images per class).
The classes are ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, and ‘truck’.

![img](https://pytorch.org/tutorials/_images/cifar10.png)

We can get the CIFAR-10 dataset from `torchvision.datasets`.

In [3]:
from torchvision.datasets.cifar import CIFAR10

### 1. Data pipeline

First, let's define how to build the data pipeline for federated learning:

1. We create data transforms and training, eval, and test datasets. This step is identical to preparing data in non-federated learning.

In [4]:
from torchvision import transforms

# 1. Create training, eval, and test datasets like in non-federated learning.
transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), 
            (0.2023, 0.1994, 0.2010)
        ),
    ]
)
train_dataset = CIFAR10(
    root="./cifar10", train=True, download=True, transform=transform
)
test_dataset = CIFAR10(
    root="./cifar10", train=False, download=True, transform=transform
)

Files already downloaded and verified
Files already downloaded and verified



There are a few extra steps to enable training with federated learning. In particular, we need to

2. Create a sharder, which defines a mapping from examples in the training data to clients. In other words, a sharder groups rows of data into client datasets and returns a list of list of examples. FLSim provides a number of sharding strategies such as random or column-based sharding. 
In this tutorial, we use sequential sharding, which assigns the first `examples_per_user` rows to user 0, the second `examples_per_user` rows to user 1, etc. 

3. Create a data loader, which will shard and batchify training, eval, and test data. For each dataset, the data loader first assigns rows to clients using the sharder and then splits each client's data into batches of size `batch_size`. We choose not to drop the last batch.

4. Lastly, wrap the data loader with a data provider and return it. The data provider creates clients from the groupings in the data loader and adds metadata (e.g. number of examples/batches). Our data is now formatted such that the trainer will accept it.

Note that the concept of a client or device only applies to the training data, the eval and test set data are identical to non-federated learning.

In [5]:
from flsim.data.data_sharder import SequentialSharder
from flsim.utils.example_utils import DataLoader, DataProvider

# 2. Create a sharder, which maps samples in the training data to clients.
sharder = SequentialSharder(examples_per_shard=EXAMPLES_PER_USER)

# 3. Shard and batchify training, eval, and test data.
fl_data_loader = DataLoader(
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    test_dataset=test_dataset,
    sharder=sharder,
    batch_size=LOCAL_BATCH_SIZE,
    drop_last=False,
)

# 4. Wrap the data loader with a data provider.
data_provider = DataProvider(fl_data_loader)
print(f"\nClients in total: {data_provider.num_train_users()}")

Creating FL User: 100user [00:11,  8.64user/s]
Creating FL User: 20user [00:02,  9.36user/s]
Creating FL User: 20user [00:02,  9.54user/s]


Clients in total: 100





### 2. Create the model

Now, let's see how we can create a model that is compatible with FL-training.

1. First, we define a standard, non-FL image classification PyTorch `nn.Module.` In this tutorial we use a simple CNN with 4 convolutional layers, a group norm, and a linear layer. 

2. Create a `torch.device` and choose where the model will be allocated (CUDA or CPU).

As with the data pipeline, these steps are identical to creating a model in non-federated learning. Note that in contrast to non-FL learning, we haven't moved the model to device yet.

In [6]:
import torch
from flsim.utils.example_utils import SimpleConvNet

# 1. Define our model, a simple CNN.
model = SimpleConvNet(in_channels=3, num_classes=10)

# 2. Choose where the model will be allocated.
cuda_enabled = torch.cuda.is_available() and USE_CUDA
device = torch.device(f"cuda:{0}" if cuda_enabled else "cpu")

model, device

(SimpleConvNet(
   (layers): ModuleList(
     (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
     (1-3): 3 x Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
   )
   (gn_relu): Sequential(
     (0): GroupNorm(32, 32, eps=1e-05, affine=True)
     (1): ReLU()
     (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   )
   (dropout): Dropout(p=0, inplace=False)
   (fc): Linear(in_features=288, out_features=10, bias=True)
 ),
 device(type='cpu'))

As with the data pipeline, there are a few extra steps that we need to take to make sure that our model is compatible with FL. In particular, we need to

3. Wrap the PyTorch module with the FLSim `FLModel`, an abstracted version of a FL-friendly model class that is accepted by the trainer and handles metric collection, as well as the forward pass for both training and evaluation. We can recover our `nn.Module` by calling `FLModel.fl_get_module()`

4. Move the model to GPU and enable CUDA if desired. `FLModel.fl_cuda()` internally calls `model.to(device)` to move the model to GPU.

In [7]:
from flsim.utils.example_utils import FLModel

# 3. Wrap the model with FLModel.
global_model = FLModel(model, device)
assert(global_model.fl_get_module() == model)

# 4. Move the model to GPU and enable CUDA if desired.
if cuda_enabled:
    global_model.fl_cuda()

### 3. Metrics Reporting

After having created our data pipeline and FL model, we will now create our metrics reporter. 
The metrics reporter allows us to collect, evaluate, and report relevant training, aggregation, and evaluation/test metrics as well as log them onto TensorBoard.



In [8]:
from flsim.interfaces.metrics_reporter import Channel
from flsim.utils.example_utils import MetricsReporter

# Create a metric reporter.
metrics_reporter = MetricsReporter([Channel.TENSORBOARD, Channel.STDOUT])

There are three functions that are of particular interest:

1. `compute_scores` computes the metrics of interest for both training and aggregation (if desired) as well as evaluation/test.

2. `create_eval_metrics` creates a dictionary that stores the value for each eval metric. 

3. `compare_metrics` compares the current eval metrics that are returned by `create_eval_metrics` to the best eval metrics so far.


For this tutorial, our only metric of interest is top-1 accuracy. In general, as with the data loading and model, you should write your own metrics reporter depending on the task. For example, if you are running an NLP task you may want to have your metrics reporter track perplexity as well.

In [9]:
import inspect

if VERBOSE:
    print(inspect.getsource(MetricsReporter.compute_scores))
    print(inspect.getsource(MetricsReporter.create_eval_metrics))
    print(inspect.getsource(MetricsReporter.compare_metrics))

### 4. Hyperparameters

We can represent the hyperparameters for FL training in a JSON config for ease of representation and we convert the JSON config to OmegaConf before passing it to the FL trainer.

In particular, we specify a FedAvg implementation with 10 users per round.

In [10]:
import flsim.configs
from flsim.utils.config_utils import fl_config_from_json
from omegaconf import OmegaConf

json_config = {
    "trainer": {
        "_base_": "base_sync_trainer",
        # there are different types of aggregator
        # fed avg doesn't require lr, while others such as fed_avg_with_lr or fed_adam do
        "_base_": "base_sync_trainer",
        "server": {
            "_base_": "base_sync_server",
            "server_optimizer": {
                "_base_": "base_fed_avg_with_lr",
                "lr": 2.13,
                "momentum": 0.9
            },
            # type of user selection sampling
            "active_user_selector": {"_base_": "base_uniformly_random_active_user_selector"},
        },
        "client": {
            # number of client's local epoch
            "epochs": 1,
            "optimizer": {
                "_base_": "base_optimizer_sgd",
                # client's local learning rate
                "lr": 0.01,
                # client's local momentum
                "momentum": 0,
            },
        },
        # number of users per round for aggregation
        "users_per_round": 5,
        # total number of global epochs
        # total #rounds = ceil(total_users / users_per_round) * epochs
        "epochs": 1,
        # frequency of reporting train metrics
        "train_metrics_reported_per_epoch": 100,
        # frequency of evaluation per epoch
        "eval_epoch_frequency": 1,
        "do_eval": True,
        # should we report train metrics after global aggregation
        "report_train_metrics_after_aggregation": True,
    }
}
cfg = fl_config_from_json(json_config)
if VERBOSE: print(OmegaConf.to_yaml(cfg))

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path=None):


### 5. Training
Recall that we already built the data provider and created a model compatible with FL training. 
We also initialized a metrics reporter and set our desired hyperparameters.

Now, we only need to instantiate the trainer with the model and hyperparameter config we defined earlier to launch the FL training flow. We run FL training with the above JSON config and utilize `eval_score` to store the final evaluation metrics.

In [11]:
from hydra.utils import instantiate

# Instantiate the trainer.
trainer = instantiate(cfg.trainer, model=global_model, cuda_enabled=cuda_enabled)   

# Launch FL training.
final_model, eval_score = trainer.train(
    data_provider=data_provider,
    metrics_reporter=metrics_reporter,
    num_total_users=data_provider.num_train_users(),
    distributed_world_size=1
)

Round:   0%|          | 0/20 [00:00<?, ?round/s]

*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd236ffbb20>
*** computing delta! ***
delta norm before: 0.2731243073940277
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237edd790>
*** computing delta! ***
delta norm before: 0.2843219041824341
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237126dc0>
*** computing delta! ***
delta norm before: 0.27967569231987
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2372af2e0>
*** computing delta! ***
delta norm before: 0.282144695520401
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23744f550>
*** computing delta! ***
delta nor

Round:   5%|▌         | 1/20 [00:10<03:17, 10.41s/round]

(epoch = 1, round = 1, global round = 1), Loss/Aggregation: 3.599871537089348
(epoch = 1, round = 1, global round = 1), Accuracy/Aggregation: 10.48
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2378e6f10>
*** computing delta! ***
delta norm before: 0.3977954387664795
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237457a90>
*** computing delta! ***
delta norm before: 0.3669048547744751
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2372af730>
*** computing delta! ***
delta norm before: 0.4014224112033844
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237026130>
*** computing delta! ***
delta norm before: 0.431845486164093
*** normalizing delta! ***
delt

Round:  10%|█         | 2/20 [00:17<02:32,  8.45s/round]

(epoch = 1, round = 2, global round = 2), Loss/Aggregation: 10.437564879655838
(epoch = 1, round = 2, global round = 2), Accuracy/Aggregation: 9.12
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238354df0>
*** computing delta! ***
delta norm before: 0.816408097743988
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2383f1910>
*** computing delta! ***
delta norm before: 0.8474814295768738
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237437a00>
*** computing delta! ***
delta norm before: 0.8466973900794983
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238354c10>
*** computing delta! ***
delta norm before: 0.8294088840484619
*** normalizing 

Round:  15%|█▌        | 3/20 [00:23<02:07,  7.51s/round]

(epoch = 1, round = 3, global round = 3), Loss/Aggregation: 2.6442373767495155
(epoch = 1, round = 3, global round = 3), Accuracy/Aggregation: 22.84
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23744f160>
*** computing delta! ***
delta norm before: 0.2965182065963745
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2386e2430>
*** computing delta! ***
delta norm before: 0.2853381633758545
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2374579a0>
*** computing delta! ***
delta norm before: 0.3111647963523865
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238390fa0>
*** computing delta! ***
delta norm before: 0.31004780530929565
*** normalizing delta! ***
delta norm after

Round:  20%|██        | 4/20 [00:32<02:08,  8.01s/round]

(epoch = 1, round = 4, global round = 4), Loss/Aggregation: 10.234052801132203
(epoch = 1, round = 4, global round = 4), Accuracy/Aggregation: 10.0
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237fa0b80>
*** computing delta! ***
delta norm before: 1.003581166267395
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f71e50>
*** computing delta! ***
delta norm before: 0.9818763732910156
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd25e2a1a90>
*** computing delta! ***
delta norm before: 0.9684560894966125
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237a14100>
*** computing delta! ***
delta norm before: 0.9754949808120728
*** normalizing 

Round:  25%|██▌       | 5/20 [00:40<01:57,  7.83s/round]

(epoch = 1, round = 5, global round = 5), Loss/Aggregation: 3.218816262483597
(epoch = 1, round = 5, global round = 5), Accuracy/Aggregation: 11.0
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2383f1a00>
*** computing delta! ***
delta norm before: 0.3648887276649475
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f10910>
*** computing delta! ***
delta norm before: 0.3271062672138214
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23839b250>
*** computing delta! ***
delta norm before: 0.3592316508293152
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f21e20>
*** computing delta! ***
delta norm before: 0.3592480719089508
*** normalizing delta! ***
delta norm after: 0

Round:  30%|███       | 6/20 [00:46<01:40,  7.15s/round]

(epoch = 1, round = 6, global round = 6), Loss/Aggregation: 2.4727012321352957
(epoch = 1, round = 6, global round = 6), Accuracy/Aggregation: 19.28
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2381e99d0>
*** computing delta! ***
delta norm before: 0.11781445145606995
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2378c8e20>
*** computing delta! ***
delta norm before: 0.12550689280033112
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237d1d760>
*** computing delta! ***
delta norm before: 0.09930417686700821
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2381f1280>
*** computing delta! ***
delta norm before: 0.12168338894844055
*** normal

Round:  35%|███▌      | 7/20 [00:51<01:27,  6.74s/round]

(epoch = 1, round = 7, global round = 7), Loss/Aggregation: 2.263486534357071
(epoch = 1, round = 7, global round = 7), Accuracy/Aggregation: 14.76
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2383dcbb0>
*** computing delta! ***
delta norm before: 0.03898811712861061
*** normalizing delta! ***
delta norm after: 0.9999997615814209
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237587b20>
*** computing delta! ***
delta norm before: 0.03752143308520317
*** normalizing delta! ***
delta norm after: 0.9999997615814209
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237d7a190>
*** computing delta! ***
delta norm before: 0.03711140528321266
*** normalizing delta! ***
delta norm after: 0.9999996423721313
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f23880>
*** computing delta! ***
delta norm before: 0.03420818597078

Round:  40%|████      | 8/20 [00:56<01:14,  6.17s/round]

(epoch = 1, round = 8, global round = 8), Loss/Aggregation: 2.36442711353302
(epoch = 1, round = 8, global round = 8), Accuracy/Aggregation: 12.76
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237ec3610>
*** computing delta! ***
delta norm before: 0.11741168797016144
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237ea6d30>
*** computing delta! ***
delta norm before: 0.11119589954614639
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2383fd190>
*** computing delta! ***
delta norm before: 0.11636430770158768
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2b04ff190>
*** computing delta! ***
delta norm before: 0.121664263308048

Round:  45%|████▌     | 9/20 [01:06<01:18,  7.13s/round]

(epoch = 1, round = 9, global round = 9), Loss/Aggregation: 2.4869251787662505
(epoch = 1, round = 9, global round = 9), Accuracy/Aggregation: 10.36
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2381f1190>
*** computing delta! ***
delta norm before: 0.14319711923599243
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2381f6d00>
*** computing delta! ***
delta norm before: 0.13947251439094543
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2381f1a00>
*** computing delta! ***
delta norm before: 0.1542637199163437
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd268735f40>
*** computing delta! ***
delta norm before: 0.14724873006343

Round:  50%|█████     | 10/20 [01:09<00:58,  5.90s/round]

(epoch = 1, round = 10, global round = 10), Loss/Aggregation: 2.3377160131931305
(epoch = 1, round = 10, global round = 10), Accuracy/Aggregation: 12.0
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2375b58e0>
*** computing delta! ***
delta norm before: 0.06314309686422348
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237569fa0>
*** computing delta! ***
delta norm before: 0.06305715441703796
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23702e550>
*** computing delta! ***
delta norm before: 0.049515094608068466
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237198880>
*** computing delta! ***
delta norm before: 0.060427941

Round:  55%|█████▌    | 11/20 [01:12<00:45,  5.04s/round]

(epoch = 1, round = 11, global round = 11), Loss/Aggregation: 2.4670733541250227
(epoch = 1, round = 11, global round = 11), Accuracy/Aggregation: 14.6
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2383fd730>
*** computing delta! ***
delta norm before: 0.1428319662809372
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238396fa0>
*** computing delta! ***
delta norm before: 0.15864033997058868
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237564550>
*** computing delta! ***
delta norm before: 0.15332239866256714
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237d1d760>
*** computing delta! ***
delta norm before: 0.15044586360

Round:  60%|██████    | 12/20 [01:14<00:33,  4.20s/round]

(epoch = 1, round = 12, global round = 12), Loss/Aggregation: 2.4311916500329973
(epoch = 1, round = 12, global round = 12), Accuracy/Aggregation: 10.24
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237864790>
*** computing delta! ***
delta norm before: 0.09312263876199722
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238227820>
*** computing delta! ***
delta norm before: 0.10086008161306381
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f71d30>
*** computing delta! ***
delta norm before: 0.10316474735736847
*** normalizing delta! ***
delta norm after: 1.0
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23789d1f0>
*** computing delta! ***
delta norm before: 0.08866924792528152
*** normalizing delta

Round:  65%|██████▌   | 13/20 [01:19<00:30,  4.41s/round]

(epoch = 1, round = 13, global round = 13), Loss/Aggregation: 2.4580227971076964
(epoch = 1, round = 13, global round = 13), Accuracy/Aggregation: 11.96
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238407670>
*** computing delta! ***
delta norm before: 0.1144479513168335
*** normalizing delta! ***
delta norm after: 0.9999998211860657
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2378fc910>
*** computing delta! ***
delta norm before: 0.10646335780620575
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237597310>
*** computing delta! ***
delta norm before: 0.10798005759716034
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237f101f0>
*** computing delta! ***
delta norm before: 0.0962655916

Round:  70%|███████   | 14/20 [01:24<00:27,  4.66s/round]

(epoch = 1, round = 14, global round = 14), Loss/Aggregation: 2.341620495915413
(epoch = 1, round = 14, global round = 14), Accuracy/Aggregation: 13.0
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2370ba730>
*** computing delta! ***
delta norm before: 0.08755127340555191
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238412070>
*** computing delta! ***
delta norm before: 0.07435301691293716
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237864d30>
*** computing delta! ***
delta norm before: 0.09211733192205429
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23823b640>
*** computing delta! ***
delta norm before: 0.07266527414

Round:  75%|███████▌  | 15/20 [01:27<00:20,  4.15s/round]

(epoch = 1, round = 15, global round = 15), Loss/Aggregation: 2.285178118944168
(epoch = 1, round = 15, global round = 15), Accuracy/Aggregation: 14.32
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2378effa0>
*** computing delta! ***
delta norm before: 0.04833580181002617
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd238334070>
*** computing delta! ***
delta norm before: 0.042075980454683304
*** normalizing delta! ***
delta norm after: 0.9999997615814209
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237fc1580>
*** computing delta! ***
delta norm before: 0.04291263967752457
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237414f70>
*** computing delta! ***
delta norm before: 0.043392188

Round:  80%|████████  | 16/20 [01:31<00:16,  4.05s/round]

(epoch = 1, round = 16, global round = 16), Loss/Aggregation: 2.2919827133417128
(epoch = 1, round = 16, global round = 16), Accuracy/Aggregation: 14.68
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237d87130>
*** computing delta! ***
delta norm before: 0.04135385528206825
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237a06250>
*** computing delta! ***
delta norm before: 0.03461717814207077
*** normalizing delta! ***
delta norm after: 0.9999996423721313
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2372ac460>
*** computing delta! ***
delta norm before: 0.03300275281071663
*** normalizing delta! ***
delta norm after: 0.9999996423721313
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2372ac160>
*** computing delta! ***
delta norm before: 0.038001488

Round:  85%|████████▌ | 17/20 [01:34<00:10,  3.65s/round]

(epoch = 1, round = 17, global round = 17), Loss/Aggregation: 2.800315809249878
(epoch = 1, round = 17, global round = 17), Accuracy/Aggregation: 15.84
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237a14670>
*** computing delta! ***
delta norm before: 0.210496187210083
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23823b820>
*** computing delta! ***
delta norm before: 0.21036532521247864
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237ab8f70>
*** computing delta! ***
delta norm before: 0.2233106642961502
*** normalizing delta! ***
delta norm after: 0.9999998807907104
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237aa4df0>
*** computing delta! ***
delta norm before: 0.2241872549057

Round:  90%|█████████ | 18/20 [01:38<00:07,  3.78s/round]

(epoch = 1, round = 18, global round = 18), Loss/Aggregation: 2.4768665820360183
(epoch = 1, round = 18, global round = 18), Accuracy/Aggregation: 12.28
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237984c10>
*** computing delta! ***
delta norm before: 0.1356615424156189
*** normalizing delta! ***
delta norm after: 0.9999999403953552
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd2379b90d0>
*** computing delta! ***
delta norm before: 0.12988266348838806
*** normalizing delta! ***
delta norm after: 0.9999998211860657
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237a1c3d0>
*** computing delta! ***
delta norm before: 0.15974557399749756
*** normalizing delta! ***
delta norm after: 0.9999997615814209
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237d8ecd0>
*** computing delta! ***
delta norm before: 0.1194238886

Round:  95%|█████████▌| 19/20 [01:43<00:04,  4.27s/round]

(epoch = 1, round = 19, global round = 19), Loss/Aggregation: 2.322485500574112
(epoch = 1, round = 19, global round = 19), Accuracy/Aggregation: 15.32
*** clients ***
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237414e50>
*** computing delta! ***
delta norm before: 0.037589170038700104
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237ae9940>
*** computing delta! ***
delta norm before: 0.03508024290204048
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd237aa4820>
*** computing delta! ***
delta norm before: 0.031514812260866165
*** normalizing delta! ***
delta norm after: 0.9999997019767761
*** FL updating a client! ***
client: <flsim.clients.base_client.Client object at 0x7fd23841f040>
*** computing delta! ***
delta norm before: 0.04476442

Round:  95%|█████████▌| 19/20 [01:54<00:06,  6.03s/round]
Epoch:   0%|          | 0/1 [01:54<?, ?epoch/s]


After training finishes, we evaluate the model and report the accuracy on the test set before finishing this tutorial.


In [11]:
# We can now test our trained model.
trainer.test(
    data_provider=data_provider,
    metrics_reporter=MetricsReporter([Channel.STDOUT]),
)

Running (epoch = 1, round = 1, global round = 1) for Test
(epoch = 1, round = 1, global round = 1), Loss/Test: 1.4767778711393476
(epoch = 1, round = 1, global round = 1), Accuracy/Test: 46.96


{'Accuracy': 46.96}

## Summary

In this tutorial, we first showed how to get the data. We then built a data provider by sharding the data to simulate multiple client devices, each with their own data, and splitting each client's data into batches. 
We defined a simple CNN as our model, wrapped it with a model compatible with FL training, and moved it to GPU. 
Lastly, we set the hyperparameters for FL training, launched the training flow, and evaluated our model.

### Additional resources

- For a more in-depth understanding of this tutorial, check out [example_utils.py](https://github.com/facebookresearch/FLSim/blob/main/flsim/utils/example_utils.py) where we define the data loader, data provider, simple CNN, `FLModel`, and metrics reporter that we use in this tutorial.

- [FLSim tutorials](https://github.com/facebookresearch/FLSim/tree/main/tutorials) - check out our other tutorial on sentiment classification.

- Kairouz et al. (2021): [Advances and Open Problems in Federated Learning](https://arxiv.org/pdf/1912.04977.pdf). As the title suggests, an in-depth overview of advances and open problems in FL.

- If you're interested in federated learning with differential privacy, take a look at [Opacus](https://opacus.ai/), a library that enables training PyTorch models with differential privacy. 
You can find a blog post introducing Opacus [here](https://ai.facebook.com/blog/introducing-opacus-a-high-speed-library-for-training-pytorch-models-with-differential-privacy/).

