# FedLib: Simulating Multi-task Federated Learning using FedLib virtual Federated environment

## Importing supportive libaries
This notebook shows a demo on PyTorch back-end model impelementation.

In the very begining, we import the supporting libraries.

In [None]:
import torch
import numpy as np
import copy
from fedlib.utils import get_logger
from fedlib.ve.mtfl import MTFLEnv
from fedlib.lib import Server, Client
from fedlib.networks import resnet20
from fedlib.lib.sampler import random_sampler
from fedlib.lib.algo.torch.mtfl import Trainer
from fedlib.datasets import partition_data, get_dataloader,get_client_dataloader


## Define arguments
Here we define arguments. To show an intuitive example, we show the demo store all the parameters in a dictionary in the following code block.
We also provide APIs for you create your arguments in a `*.yaml` file.

In [None]:
logger = get_logger()
args = {}
args["n_clients"] = 10
args["device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args['sample_fn'] = random_sampler
args['trainer'] = Trainer(logger)
args['communicator'] = None
args["test_dataset"] = None
args["partition"] = "noniid-labeldir"
args["dataset"] = "mnist"
args["datadir"] = "./data"
args["beta"] = 0.5
args["batch_size"] = 64
args["lr"] = 0.01
args["optimizer"] = "SGD"
args["lr_scheduler"] = "ExponentialLR"

Load test dataset for server, and passing it as an argument

In [None]:
X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
    args["dataset"], args["datadir"], args['partition'], args['n_clients'], beta=args['beta'])
n_classes = len(np.unique(y_train))
train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args["dataset"],
                                                                                    args["datadir"],
                                                                                      args["batch_size"],
                                                                                      32)
args["test_dataset"] = test_dl_global

## Define Model Arc
Model must contains encoder, decoder, predictor

In [None]:
from torch import nn

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.predictor = nn.Linear(in_features=32, out_features=10, bias=True)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh()
        )


    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.predictor(x)
        return x



## test the model 

In [None]:
model = autoencoder()
x = torch.rand([10,1,28,28])
representation = model.encoder(x)
x_ = model.decoder(representation)
pred = model(x)
print(x.shape,x_.shape,pred.shape)

## Create server and clients objects
Here we use the arguments we defined before, and create server and clients.

In [None]:

args["global_model"] = model.encoder
server = Server(**args)
clients = {}

data_loaders = get_client_dataloader(args["dataset"], args["datadir"], args['batch_size'], 32, net_dataidx_map)

criterion_pred = torch.nn.CrossEntropyLoss()
criterion_rep = torch.nn.MSELoss()

args["criterion"]={
    "criterion_rep": criterion_rep,
    "criterion_pred": criterion_pred
    }

for id in range(args["n_clients"]):
    # dataidxs = net_dataidx_map[id]
    args["id"] = id
    # args["trainloader"], _, _, _ = get_dataloader(args["dataset"], args["datadir"], args['batch_size'], 32, dataidxs)
    args["trainloader"] = data_loaders[id]
    args["model"] = copy.deepcopy(model)
    clients[id] = Client(**args)




## Create simulator

Simulator simulates the virtual federated learning environments, and run server and clients on single device.

In [None]:
simulator = MTFLEnv(server=server, clients=clients, communication_rounds=10,n_clients= 10,sample_rate=.1)

## Run simulator
User API Simulator.run

In [None]:
simulator.run(local_epochs=2)

In [None]:
a =[(1,0.010016),(2,0.010016),(3,0.010617),(4,0.019431),(5,0.023638),(6,0.034355),(7,0.034555),(8,0.037560),(9,0.041667),(10,0.046274),(11,0.043069),(12,0.055889),(13,0.061899),(14,0.063802),(15,0.068910),(16,0.075921),(17,0.075521),(18,0.073117),(19,0.078025),(20,0.087740),(21,0.082332),(22,0.086839),(23,0.084836),(24,0.086438),(25,0.086038),(26,0.103265),(27,0.102664),(28,0.102865),(29,0.096454),(30,0.098357),(31,0.106771),(32,0.108273),(33,0.107472),(34,0.108173),(35,0.110477),(36,0.116086),(37,0.121695),(38,0.118590),(39,0.121494),(40,0.121394),(41,0.124700),(42,0.124800),(43,0.128105),(44,0.127103),(45,0.133514),(46,0.125701),(47,0.135317),(48,0.130609),(49,0.133313),(50,0.144631),(51,0.144431),(52,0.140224),(53,0.143530),(54,0.138822),(55,0.141126),(56,0.144732),(57,0.150942),(58,0.150841),(59,0.151342),(60,0.157151),(61,0.159655),(62,0.155849),(63,0.158454),(64,0.155349),(65,0.158353),(66,0.165465),(67,0.161759),(68,0.166266),(69,0.162460),(70,0.165665),(71,0.166066),(72,0.168570),(73,0.164363),(74,0.161558),(75,0.170172),(76,0.170873),(77,0.174980),(78,0.173478),(79,0.170873),(80,0.173277),(81,0.182692),(82,0.179688),(83,0.176182),(84,0.174379),(85,0.178686),(86,0.180889),(87,0.179487),(88,0.174780),(89,0.184595),(90,0.178285),(91,0.181090),(92,0.182692),(93,0.179988),(94,0.186098),(95,0.184696),(96,0.179587),(97,0.185196),(98,0.189002),(99,0.183293),(100,0.188101),(101,0.189002),(102,0.182091),(103,0.185196),(104,0.187400),(105,0.184595),(106,0.196414),(107,0.189103),(108,0.191707),(109,0.193910),(110,0.192808),(111,0.195613),(112,0.193610),(113,0.183594),(114,0.195813),(115,0.198217),(116,0.195713),(117,0.197416),(118,0.193309),(119,0.203726),(120,0.198518),(121,0.201422),(122,0.207632),(123,0.201322),(124,0.196815),(125,0.201522),(126,0.202123),(127,0.203425),(128,0.206030),(129,0.203225),(130,0.205529),(131,0.206530),(132,0.208033),(133,0.206731),(134,0.204127),(135,0.206731),(136,0.205529),(137,0.209034),(138,0.207732),(139,0.207732),(140,0.210938),(141,0.210537),(142,0.209936),(143,0.210437),(144,0.213942),(145,0.211438),(146,0.212440),(147,0.212440),(148,0.216346),(149,0.215645),(150,0.215745),(151,0.211739),(152,0.212941),(153,0.213742),(154,0.213842),(155,0.208934),(156,0.212139),(157,0.210136),(158,0.214543),(159,0.219451),(160,0.217147),(161,0.209836),(162,0.215244),(163,0.214944),(164,0.218249),(165,0.217648),(166,0.218249),(167,0.220152),(168,0.218650),(169,0.217748),(170,0.217648),(171,0.210437),(172,0.217548),(173,0.219050),(174,0.222957),(175,0.224960),(176,0.217448),(177,0.221454),(178,0.216747),(179,0.222957),(180,0.220553),(181,0.225060),(182,0.223858),(183,0.223458),(184,0.216647),(185,0.222556),(186,0.217248),(187,0.224259),(188,0.221554),(189,0.221254),(190,0.220052),(191,0.223157),(192,0.220453),(193,0.224659),(194,0.224960),(195,0.219050),(196,0.215745),(197,0.216446),(198,0.223357),(199,0.223458),(200,0.220853),(201,0.224159),(202,0.221354),(203,0.223958),(204,0.224058),(205,0.217949),(206,0.221154),(207,0.223958),(208,0.216947),(209,0.222456),(210,0.228165),(211,0.230369),(212,0.223357),(213,0.218950),(214,0.221354),(215,0.225761),(216,0.221655),(217,0.225461),(218,0.226963),(219,0.229167),(220,0.227764),(221,0.225861),(222,0.230268),(223,0.222857),(224,0.224058),(225,0.224459),(226,0.231571),(227,0.229267),(228,0.226863),(229,0.232071),(230,0.227865),(231,0.228065),(232,0.224259),(233,0.229968),(234,0.228466),(235,0.230569),(236,0.231671),(237,0.235276),(238,0.228866),(239,0.230369),(240,0.229768),(241,0.227664),(242,0.225761),(243,0.227264),(244,0.235877),(245,0.233073),(246,0.231771),(247,0.231671),(248,0.236478),(249,0.229067),(250,0.224259),(251,0.230469),(252,0.229768),(253,0.228466),(254,0.237580),(255,0.238081),(256,0.235276),(257,0.239483),(258,0.232372),(259,0.234375),(260,0.236278),(261,0.229267),(262,0.234375),(263,0.233674),(264,0.238081),(265,0.233974),(266,0.227965),(267,0.236078),(268,0.237380),(269,0.233474),(270,0.239083),(271,0.235777),(272,0.237179),(273,0.239083),(274,0.236478),(275,0.234675),(276,0.236078),(277,0.236679),(278,0.235477),(279,0.240585),(280,0.232772),(281,0.237580),(282,0.234976),(283,0.237580),(284,0.234776),(285,0.238482),(286,0.236979),(287,0.244191),(288,0.238682),(289,0.234375),(290,0.242588),(291,0.242688),(292,0.237079),(293,0.240284),(294,0.245393),(295,0.240084),(296,0.243690),(297,0.231470),(298,0.244591),(299,0.241787),(300,0.242288),(301,0.238281),(302,0.241486),(303,0.235877),(304,0.237480),(305,0.238081),(306,0.238582),(307,0.241486),(308,0.242188),(309,0.240385),(310,0.235978),(311,0.242688),(312,0.241486),(313,0.238782),(314,0.238782),(315,0.238081),(316,0.239083),(317,0.242188),(318,0.241086),(319,0.241587),(320,0.243089),(321,0.247596),(322,0.245593),(323,0.242989),(324,0.239083),(325,0.242588),(326,0.239483),(327,0.243790),(328,0.241186),(329,0.248297),(330,0.246595),(331,0.244992),(332,0.242388),(333,0.244191),(334,0.243590),(335,0.243590),(336,0.250901),(337,0.250501),(338,0.244591),(339,0.248898),(340,0.248498),(341,0.246695),(342,0.246394),(343,0.246194),(344,0.246695),(345,0.243490),(346,0.242388),(347,0.242388),(348,0.244391),(349,0.244191),(350,0.247296),(351,0.246294),(352,0.250401),(353,0.239784),(354,0.249099),(355,0.246595),(356,0.252103),(357,0.245192),(358,0.250401),(359,0.251402),(360,0.248698),(361,0.246094),(362,0.248898),(363,0.247296),(364,0.244892),(365,0.244391),(366,0.246494),(367,0.253205),(368,0.245893),(369,0.248598),(370,0.245893),(371,0.244792),(372,0.243089),(373,0.253606),(374,0.250801),(375,0.252905),(376,0.246895),(377,0.245493),(378,0.252003),(379,0.250701),(380,0.251803),(381,0.252804),(382,0.253706),(383,0.253906),(384,0.249399),(385,0.244291),(386,0.246194),(387,0.246695),(388,0.248998),(389,0.252204),(390,0.256110),(391,0.256611),(392,0.253405),(393,0.250901),(394,0.251803),(395,0.251202),(396,0.250501),(397,0.250100),(398,0.256410),(399,0.249900),(400,0.255509),(401,0.245292),(402,0.256410),(403,0.252905),(404,0.255008),(405,0.254808),(406,0.252604),(407,0.251502),(408,0.253906),(409,0.254808),(410,0.255809),(411,0.253906),(412,0.253105),(413,0.246595),(414,0.251803),(415,0.251603),(416,0.252204),(417,0.254107),(418,0.257712),(419,0.253305),(420,0.260116),(421,0.258714),(422,0.251903),(423,0.252604),(424,0.254607),(425,0.258814),(426,0.255208),(427,0.256711),(428,0.260317),(429,0.255809),(430,0.252504),(431,0.259014),(432,0.259916),(433,0.254607),(434,0.256711),(435,0.254507),(436,0.251903),(437,0.245893),(438,0.261518),(439,0.256210),(440,0.256210),(441,0.256611),(442,0.255308),(443,0.251002),(444,0.259215),(445,0.258213),(446,0.254207),(447,0.252404),(448,0.256811),(449,0.259315),(450,0.257011),(451,0.257712),(452,0.259215),(453,0.257512),(454,0.252804),(455,0.260317),(456,0.258714),(457,0.259716),(458,0.261518),(459,0.263522),(460,0.258514),(461,0.255909),(462,0.256210),(463,0.256510),(464,0.261318),(465,0.260417),(466,0.252704),(467,0.256310),(468,0.260417),(469,0.258113),(470,0.257512),(471,0.253706),(472,0.250601),(473,0.258013),(474,0.259215),(475,0.259315),(476,0.253405),(477,0.260016),(478,0.257412),(479,0.254307),(480,0.257412),(481,0.258814),(482,0.256110),(483,0.267628),(484,0.262620),(485,0.257512),(486,0.256110),(487,0.258313),(488,0.255709),(489,0.260016),(490,0.257913),(491,0.261619),(492,0.266526),(493,0.263221),(494,0.263722),(495,0.262220),(496,0.263522),(497,0.260016),(498,0.258914),(499,0.264924),(500,0.261018)]



for (i,k) in a:
    print(k)

In [None]:

l = [12.133472, 11.051832, 14.131328, 14.266, 13.902304, 13.031408, 12.846048, 14.419544, 14.98216, 14.098432, 13.224656, 13.642752, 12.103536, 13.72976, 14.13528, 12.516888, 14.698832, 13.97496, 14.78284, 15.706896, 15.643472, 14.685232, 12.855568, 12.993528, 15.17564, 12.004216, 12.179992, 15.179248, 13.713744, 16.09056, 15.052384, 13.680304, 14.986608, 13.562448, 16.899536, 14.891776, 13.510264, 15.33912, 15.10424, 13.361208, 12.860784, 13.34252, 15.169296, 14.427352, 12.432376, 14.166144, 14.353368, 13.100632, 13.174512, 14.316856, 13.071, 12.27112, 12.698736, 15.270704, 11.845976, 12.505952, 11.827648, 16.17976, 15.293376, 13.908896, 14.55068, 12.569904, 14.696304, 14.809072, 13.670112, 14.40056, 14.148088, 14.937336, 14.249408, 12.347672, 13.689808, 15.529656, 13.90156, 14.52136, 15.370016, 12.95068, 13.760152, 13.640736, 13.384896, 13.940432, 12.288096, 16.022944, 13.67304, 16.3048, 14.555064, 15.112504, 13.870744, 14.602264, 13.578128, 15.910176, 13.262008, 13.884064, 14.88256, 13.72896, 12.365976, 12.424472, 11.237752, 13.702424, 15.667224, 14.390168]
sum(l)/len(l)

In [None]:
from torch.utils.data import Dataset

class TinyImageNetDataset(Dataset):
  def __init__(self, root_dir, mode='train', preload=True, load_transform=None,
               transform=None, download=False, max_samples=None):
    tinp = TinyImageNetPaths(root_dir, download)
    self.mode = mode
    self.label_idx = 1  # from [image, id, nid, box]
    self.preload = preload
    self.transform = transform
    self.transform_results = dict()

    self.IMAGE_SHAPE = (64, 64, 3)

    self.img_data = []
    self.label_data = []

    self.max_samples = max_samples
    self.samples = tinp.paths[mode]
    self.samples_num = len(self.samples)

    if self.max_samples is not None:
      self.samples_num = min(self.max_samples, self.samples_num)
      self.samples = np.random.permutation(self.samples)[:self.samples_num]

    if self.preload:
      load_desc = "Preloading {} data...".format(mode)
      self.img_data = np.zeros((self.samples_num,) + self.IMAGE_SHAPE,
                               dtype=np.float32)
      self.label_data = np.zeros((self.samples_num,), dtype=np.int)
      for idx in tqdm(range(self.samples_num), desc=load_desc):
        s = self.samples[idx]
        img = imageio.imread(s[0])
        img = _add_channels(img)
        self.img_data[idx] = img
        if mode != 'test':
          self.label_data[idx] = s[self.label_idx]

      if load_transform:
        for lt in load_transform:
          result = lt(self.img_data, self.label_data)
          self.img_data, self.label_data = result[:2]
          if len(result) > 2:
            self.transform_results.update(result[2])

  def __len__(self):
    return self.samples_num

  def __getitem__(self, idx):
    if self.preload:
      img = self.img_data[idx]
      lbl = None if self.mode == 'test' else self.label_data[idx]
    else:
      s = self.samples[idx]
      img = imageio.imread(s[0])
      lbl = None if self.mode == 'test' else s[self.label_idx]
    sample = {'image': img, 'label': lbl}

    if self.transform:
      sample = self.transform(sample)
    return sample

In [None]:
import torchvision
from torchvision import transforms, datasets
train_ds = torchvision.datasets.ImageFolder('data/test',
                                    transform=transforms.Compose([
                                    transforms.Resize(32), 
                                    transforms.ToTensor(),
                                    ]))

In [None]:
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST


In [None]:
import torch
train_ds.__len__()
train_ds.__getitem__(0)[0].shape
dl = torch.utils.data.DataLoader(train_ds,batch_size=2)

In [None]:
for x,y in dl:
    print(x.shape)
    break

In [None]:
transform_train = transforms.Compose([
                transforms.ToTensor()])
cifar = CIFAR10("../src/data", True, transform_train, None)


In [None]:
print(cifar.__getitem__(0)[0])
print(cifar.data[0].shape)

In [None]:
dl_cifar = torch.utils.data.DataLoader(cifar,batch_size=2)

In [None]:
for x,y in dl_cifar:
    print(x.shape)
    break

In [None]:
import numpy as np
a = np.array([1,2,3,4,5,9])
idx  = [1,3,5]
b = a[idx]
b

In [None]:
import torch.utils.data as data
import os
class TinyImageNet_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.data, self.target = self.__build_truncated_dataset__()

    def __build_truncated_dataset__(self):

        # cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
        if self.train:
            data_dir = os.path.join(self.root,"train")
        else:
            data_dir = os.path.join(self.root,"val")
        ti_dataobj = torchvision.datasets.ImageFolder(data_dir,
                                    transform=self.transform)
        
        target = ti_dataobj.targets
        
        data = []
        

        
        if self.dataidxs is not None:
            for i in self.dataidxs:
                data.append(ti_dataobj.__getitem__(i)[0])

        else:
            for i in range(ti_dataobj.__len__()):
                data.append(ti_dataobj.__getitem__(i)[0])

        return data, target

    def truncate_channel(self, index):
        for i in range(index.shape[0]):
            gs_index = index[i]
            self.data[gs_index, :, :, 1] = 0.0
            self.data[gs_index, :, :, 2] = 0.0

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.target[index]


        # if self.transform is not None:
        #     img = self.transform(img)

        # if self.target_transform is not None:
        #     target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)



In [None]:
ti_data = TinyImageNet_truncated("data",train=True,dataidxs=[127,236,888], transform=transforms.Compose([
                                    transforms.Resize(32), 
                                    transforms.ToTensor(),
                                    ]))

In [None]:
a = np.array(ti_data.target)
np.where(a == 0)[0]

In [None]:
dl_ti = torch.utils.data.DataLoader(ti_data,batch_size=2)

In [None]:
import torchvision
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler

train_ds = torchvision.datasets.ImageFolder('data/train',
                                    transform=transforms.Compose([
                                    transforms.Resize(32), 
                                    transforms.ToTensor(),
                                    ]))
dl_ti2 = torch.utils.data.DataLoader(train_ds,batch_size=2,sampler=SubsetRandomSampler([888,236,127]))               

In [None]:
print(len(dl_ti),len(dl_ti2))

In [None]:
for x,y in dl_ti:
    print(x)

In [None]:
for x,y in dl_ti2:
    print(x)