# FedLib: Simulating FedAvg using FedLib virtual Federated environment

## Importing supportive libaries
This notebook shows a demo on PyTorch back-end model impelementation.

In the very begining, we import the supporting libraries.

In [1]:
import torch
import numpy as np
import copy
from fedlib.utils import get_logger
from fedlib.ve.mtfl import MTFLEnv
from fedlib.lib import Server, Client
from fedlib.networks import resnet20
from fedlib.lib.sampler import random_sampler
from fedlib.lib.algo.torch.mtfl import Trainer
from fedlib.datasets import partition_data, get_dataloader,get_client_dataloader


## Define arguments
Here we define arguments. To show an intuitive example, we show the demo store all the parameters in a dictionary in the following code block.
We also provide APIs for you create your arguments in a `*.yaml` file.

In [2]:
logger = get_logger()
args = {}
args["n_clients"] = 10
args["device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args['sample_fn'] = random_sampler
args['trainer'] = Trainer(logger)
args['communicator'] = None
args["test_dataset"] = None
args["partition"] = "noniid-labeldir"
args["dataset"] = "mnist"
args["datadir"] = "./data"
args["beta"] = 0.5
args["batch_size"] = 64
args["lr"] = 0.01
args["optimizer"] = "SGD"
args["lr_scheduler"] = "ExponentialLR"

Load test dataset for server, and passing it as an argument

In [3]:
X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
    args["dataset"], args["datadir"], args['partition'], args['n_clients'], beta=args['beta'])
n_classes = len(np.unique(y_train))
train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args["dataset"],
                                                                                    args["datadir"],
                                                                                      args["batch_size"],
                                                                                      32)
args["test_dataset"] = test_dl_global

INFO:root:Data statistics: {0: {0: 1511, 1: 1125, 2: 1939, 3: 159, 4: 1941}, 1: {0: 1378, 1: 1366, 2: 737, 3: 2768}, 2: {0: 267, 1: 378, 2: 2572, 3: 307, 4: 588, 5: 20, 6: 50, 7: 780, 8: 44, 9: 226}, 3: {0: 206, 1: 130, 2: 216, 3: 527, 4: 1110, 5: 1342, 6: 36, 7: 2066, 8: 317, 9: 55}, 4: {0: 176, 1: 759, 2: 107, 3: 48, 4: 89, 5: 726, 6: 770, 7: 496, 8: 1510, 9: 177}, 5: {0: 876, 2: 87, 3: 518, 4: 1295, 5: 52, 6: 12, 7: 737, 8: 2233, 9: 2996}, 6: {0: 233, 1: 688, 2: 39, 3: 549, 4: 158, 5: 318, 6: 593, 7: 680, 8: 365, 9: 1812}, 7: {0: 476, 1: 175, 3: 2, 5: 690, 6: 949, 7: 433, 8: 992, 9: 633}, 8: {0: 405, 1: 279, 2: 253, 3: 63, 4: 248, 5: 1397, 6: 1642, 7: 1073, 8: 389, 9: 50}, 9: {0: 395, 1: 1842, 2: 8, 3: 1190, 4: 413, 5: 876, 6: 1866, 8: 1}}


## Define Model Arc
Model must contains encoder, decoder, predictor

In [4]:
from torch import nn
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.predictor = nn.Linear(in_features=32, out_features=10, bias=True)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh()
        )


    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.predictor(x)
        return x

## test the model 

In [5]:
model = autoencoder()
x = torch.rand([10,1,28,28])
representation = model.encoder(x)
x_ = model.decoder(representation)
pred = model(x)
print(x.shape,x_.shape,pred.shape)

torch.Size([10, 1, 28, 28]) torch.Size([10, 1, 28, 28]) torch.Size([10, 10])


## Create server and clients objects
Here we use the arguments we defined before, and create server and clients.

In [6]:

args["global_model"] = model.encoder
server = Server(**args)
clients = {}

data_loaders = get_client_dataloader(args["dataset"], args["datadir"], args['batch_size'], 32, net_dataidx_map)

criterion_pred = torch.nn.CrossEntropyLoss()
criterion_rep = torch.nn.MSELoss()

args["criterion"]={
    "criterion_rep": criterion_rep,
    "criterion_pred": criterion_pred
    }

for id in range(args["n_clients"]):
    # dataidxs = net_dataidx_map[id]
    args["id"] = id
    # args["trainloader"], _, _, _ = get_dataloader(args["dataset"], args["datadir"], args['batch_size'], 32, dataidxs)
    args["trainloader"] = data_loaders[id]
    args["model"] = copy.deepcopy(model)
    clients[id] = Client(**args)




## Create simulator

Simulator simulates the virtual federated learning environments, and run server and clients on single device.

In [7]:
simulator = MTFLEnv(server=server, clients=clients, communication_rounds=10,n_clients= 10,sample_rate=.1)

## Run simulator
User API Simulator.run

In [8]:
simulator.run(local_epochs=2)

INFO:root:*******starting rounds 1 optimization******
INFO:root:optimize the 4-th clients
INFO:root:Update Epoch: 0 	Loss: 2.458877
INFO:root:Update Epoch: 0 	Loss: 2.433209
INFO:root:Update Epoch: 0 	Loss: 2.365509
INFO:root:Update Epoch: 0 	Loss: 2.318931
INFO:root:Update Epoch: 0 	Loss: 2.289380
INFO:root:Update Epoch: 0 	Loss: 2.270249
INFO:root:Update Epoch: 0 	Loss: 2.171733
INFO:root:Update Epoch: 0 	Loss: 2.149680
INFO:root:Epoch: 0	Loss: 2.291434
INFO:root:Update Epoch: 1 	Loss: 2.060631
INFO:root:Update Epoch: 1 	Loss: 2.088791
INFO:root:Update Epoch: 1 	Loss: 2.019820
INFO:root:Update Epoch: 1 	Loss: 1.965572
INFO:root:Update Epoch: 1 	Loss: 2.101911
INFO:root:Update Epoch: 1 	Loss: 1.849362
INFO:root:Update Epoch: 1 	Loss: 1.978142
INFO:root:Update Epoch: 1 	Loss: 2.040374
INFO:root:Epoch: 1	Loss: 2.177208
INFO:root:*******starting rounds 2 optimization******
INFO:root:optimize the 9-th clients
INFO:root:Update Epoch: 0 	Loss: 2.450994
INFO:root:Update Epoch: 0 	Loss: 2.417

KeyboardInterrupt: 