<a href="https://colab.research.google.com/github/tushar-semwal/fedperf/blob/main/Santiago/Shakespeare/FedMed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FedPerf - Shakespeare + FedMed algorithm

## Setup & Dependencies Installation

In [None]:
%%capture
!pip install torchsummaryX unidecode

In [None]:
%load_ext tensorboard

import copy
from functools import reduce
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import time
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import Sampler
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from torchsummaryX import summary as summaryx
from torchvision import transforms, utils, datasets
from tqdm.notebook import tqdm
from unidecode import unidecode

%matplotlib inline

# Check assigned GPU
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

# set manual seed for reproducibility
RANDOM_SEED = 42

# general reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Mon Apr 26 02:09:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Mount GDrive

In [None]:
BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/FedMed'

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    os.makedirs(BASE_DIR, exist_ok=True)
except:
    print("WARNING: Results won't be stored on GDrive")
    BASE_DIR = './'



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Loading Dataset

In [None]:
!rm -Rf data
!mkdir -p data scripts

In [None]:
GENERATE_DATASET = False  # If False, download the dataset provided by the q-FFL paper
DATA_DIR = 'data/'
# Dataset generation params
SAMPLES_FRACTION = 1.  # If using an already generated dataset
# SAMPLES_FRACTION = 0.2  # Fraction of total samples in the dataset - FedProx default script
# SAMPLES_FRACTION = 0.05  # Fraction of total samples in the dataset - qFFL
TRAIN_FRACTION = 0.8  # Train set size
MIN_SAMPLES = 0  # Min samples per client (for filtering purposes) - FedProx
# MIN_SAMPLES = 64  # Min samples per client (for filtering purposes) - qFFL

In [None]:
# Download raw dataset
# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt
!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt

--2021-04-26 02:09:39--  http://www.gutenberg.org/files/100/100-0.txt
Resolving www.gutenberg.org (www.gutenberg.org)... 152.19.134.47, 2610:28:3090:3000:0:bad:cafe:47
Connecting to www.gutenberg.org (www.gutenberg.org)|152.19.134.47|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5757108 (5.5M) [text/plain]
Saving to: ‘data/shakespeare.txt’


2021-04-26 02:09:39 (28.5 MB/s) - ‘data/shakespeare.txt’ saved [5757108/5757108]



In [None]:
if not GENERATE_DATASET:
    !rm -Rf data/train data/test
    !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD  # Download Shakespeare dataset used by the FedProx paper
    !unzip shakespeare.zip
    !mv -f shakespeare_paper/train data/
    !mv -f shakespeare_paper/test data/
    !rm -R shakespeare_paper/ shakespeare.zip


Downloading...
From: https://drive.google.com/uc?id=1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD
To: /content/shakespeare.zip
0.00B [00:00, ?B/s]2.96MB [00:00, 42.7MB/s]
Archive:  shakespeare.zip
   creating: shakespeare_paper/
   creating: shakespeare_paper/test/
  inflating: shakespeare_paper/test/all_data_niid_2_keep_0_test_8.json  
   creating: shakespeare_paper/train/
  inflating: shakespeare_paper/train/all_data_niid_2_keep_0_train_8.json  


In [None]:
corpus = []
with open('data/shakespeare.txt', 'r') as f:
    data = list(unidecode(f.read()))
    corpus = list(set(list(data)))
print('Corpus Length:', len(corpus))

Corpus Length: 90


#### Dataset Preprocessing script

In [None]:
%%capture
if GENERATE_DATASET:
    # Download dataset generation scripts
    !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py
    !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py
    !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py

    # Download data preprocessing scripts
    !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py
    !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py

In [None]:
# Running scripts
if GENERATE_DATASET:
    !mkdir -p data/raw_data data/all_data data/train data/test
    !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data
    !python scripts/gen_all_data.py

#### Dataset class

In [None]:
class ShakespeareDataset(Dataset):
    def __init__(self, x, y, corpus, seq_length):
        self.x = x
        self.y = y
        self.corpus = corpus
        self.corpus_size = len(self.corpus)
        super(ShakespeareDataset, self).__init__()

    def __len__(self):
        return len(self.x)

    def __repr__(self):
        return f'{self.__class__} - (length: {self.__len__()})'

    def __getitem__(self, i):
        input_seq = self.x[i]
        next_char = self.y[i]
        # print('\tgetitem', i, input_seq, next_char)
        input_value = self.text2charindxs(input_seq)
        target_value = self.get_label_from_char(next_char)
        return input_value, target_value

    def text2charindxs(self, text):
        tensor = torch.zeros(len(text), dtype=torch.int32)
        for i, c in enumerate(text):
            tensor[i] = self.get_label_from_char(c)
        return tensor

    def get_label_from_char(self, c):
        return self.corpus.index(c)

    def get_char_from_label(self, l):
        return self.corpus[l]

##### Federated Dataset

In [None]:
class ShakespeareFedDataset(ShakespeareDataset):
    def __init__(self, x, y, corpus, seq_length):
        super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)

    def dataloader(self, batch_size, shuffle=True):
        return DataLoader(self,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          num_workers=0)


## Partitioning & Data Loaders

### IID

In [None]:
def iid_partition_(dataset, clients):
  """
  I.I.D paritioning of data over clients
  Shuffle the data
  Split it between clients
  
  params:
    - dataset (torch.utils.Dataset): Dataset
    - clients (int): Number of Clients to split the data between

  returns:
    - Dictionary of image indexes for each client
  """

  num_items_per_client = int(len(dataset)/clients)
  client_dict = {}
  image_idxs = [i for i in range(len(dataset))]

  for i in range(clients):
    client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))
    image_idxs = list(set(image_idxs) - client_dict[i])

  return client_dict

In [None]:
def iid_partition(corpus, seq_length=80, val_split=False):

    train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]
    test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]

    with open(train_file, 'r') as file:
        data_train = json.loads(unidecode(file.read()))

    with open(test_file, 'r') as file:
        data_test = json.loads(unidecode(file.read()))

    
    total_samples_train = sum(data_train['num_samples'])

    data_dict = {}

    x_train, y_train = [], []
    x_test, y_test = [], []
    # x_val, y_val = [], []

    users = list(zip(data_train['users'], data_train['num_samples']))
    # random.shuffle(users)



    total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)
    print('Objective', total_samples, '/', sum(data_train['num_samples']))
    sample_count = 0
    
    for i, (author_id, samples) in enumerate(users):

        if sample_count >= total_samples:
            print('Max samples reached', sample_count, '/', total_samples)
            break

        if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:
            print('SKIP', author_id, samples)
            continue
        else:
            udata_train = data_train['user_data'][author_id]
            max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) 
            
            sample_count += max_samples
            # print('sample_count', sample_count)

            x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])
            y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])

            author_data = data_test['user_data'][author_id]
            test_size = int(len(author_data['x']) * SAMPLES_FRACTION)

            if val_split:
                x_test.extend(author_data['x'][:int(test_size / 2)])
                y_test.extend(author_data['y'][:int(test_size / 2)])
                # x_val.extend(author_data['x'][int(test_size / 2):])
                # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])

            else:
                x_test.extend(author_data['x'][:int(test_size)])
                y_test.extend(author_data['y'][:int(test_size)])

    train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)
    test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)
    # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)

    data_dict = iid_partition_(train_ds, clients=len(users))

    return train_ds, data_dict, test_ds

### Non-IID

In [None]:
def noniid_partition(corpus, seq_length=80, val_split=False):

    train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]
    test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]

    with open(train_file, 'r') as file:
        data_train = json.loads(unidecode(file.read()))

    with open(test_file, 'r') as file:
        data_test = json.loads(unidecode(file.read()))

    
    total_samples_train = sum(data_train['num_samples'])

    data_dict = {}

    x_test, y_test = [], []

    users = list(zip(data_train['users'], data_train['num_samples']))
    # random.shuffle(users)

    total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)
    print('Objective', total_samples, '/', sum(data_train['num_samples']))
    sample_count = 0
    
    for i, (author_id, samples) in enumerate(users):

        if sample_count >= total_samples:
            print('Max samples reached', sample_count, '/', total_samples)
            break

        if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:
            print('SKIP', author_id, samples)
            continue
        else:
            udata_train = data_train['user_data'][author_id]
            max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) 
            
            sample_count += max_samples
            # print('sample_count', sample_count)

            x_train = data_train['user_data'][author_id]['x'][:max_samples]
            y_train = data_train['user_data'][author_id]['y'][:max_samples]

            train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)

            x_val, y_val = None, None
            val_ds = None
            author_data = data_test['user_data'][author_id]
            test_size = int(len(author_data['x']) * SAMPLES_FRACTION)
            if val_split:
                x_test += author_data['x'][:int(test_size / 2)]
                y_test += author_data['y'][:int(test_size / 2)]
                x_val = author_data['x'][int(test_size / 2):]
                y_val = author_data['y'][int(test_size / 2):int(test_size)]

                val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)

            else:
                x_test += author_data['x'][:int(test_size)]
                y_test += author_data['y'][:int(test_size)]

            data_dict[author_id] = {
                'train_ds': train_ds,
                'val_ds': val_ds
            }

    test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)

    return data_dict, test_ds

## Models

### Shakespeare LSTM

In [None]:
class ShakespeareLSTM(nn.Module):
    """
    """

    def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):
        super(ShakespeareLSTM, self).__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.classes = classes
        self.no_layers = lstm_layers
        
        self.embedding = nn.Embedding(num_embeddings=self.classes,
                                      embedding_dim=self.embedding_dim)
        self.lstm = nn.LSTM(input_size=self.embedding_dim, 
                            hidden_size=self.hidden_dim,
                            num_layers=self.no_layers,
                            batch_first=batch_first, 
                            dropout=dropout if self.no_layers > 1 else 0.)
        self.fc = nn.Linear(hidden_dim, self.classes)

    def forward(self, x, hc=None):
        batch_size = x.size(0)
        x_emb = self.embedding(x)
        out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)
        dense = self.fc(ht[-1])
        return dense
    
    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),
                Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))


#### Model Summary

In [None]:
batch_size = 10
seq_length = 80 # mcmahan17a, fedprox, qFFL

shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,  
                                   embedding_dim=8,  # mcmahan17a, fedprox, qFFL
                                   hidden_dim=256,  # mcmahan17a, fedprox impl
                                #    hidden_dim=100,  # fedprox paper
                                   classes=len(corpus),
                                   lstm_layers=2,
                                   dropout=0.1,  # TODO:
                                   batch_first=True
                                   )

if torch.cuda.is_available():
  shakespeare_lstm.cuda()



hc = shakespeare_lstm.init_hidden(batch_size)

x_sample = torch.zeros((batch_size, seq_length),
                       dtype=torch.long,
                       device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))

x_sample[0][0] = 1
x_sample

print("\nShakespeare LSTM SUMMARY")
print(summaryx(shakespeare_lstm, x_sample))


Shakespeare LSTM SUMMARY
            Kernel Shape   Output Shape  Params  Mult-Adds
Layer                                                     
0_embedding      [8, 90]    [10, 80, 8]     720        720
1_lstm                 -  [10, 80, 256]  798720     794624
2_fc           [256, 90]       [10, 90]   23130      23040
----------------------------------------------------------
                      Totals
Total params          822570
Trainable params      822570
Non-trainable params       0
Mult-Adds             818384
            Kernel Shape   Output Shape  Params  Mult-Adds
Layer                                                     
0_embedding      [8, 90]    [10, 80, 8]     720        720
1_lstm                 -  [10, 80, 256]  798720     794624
2_fc           [256, 90]       [10, 90]   23130      23040


## FedMed Algorithm

### Plot Utils

In [None]:
from sklearn.metrics import f1_score

In [None]:
def plot_scores(history, exp_id, title, suffix):
    accuracies = [x['accuracy'] for x in history]
    f1_macro = [x['f1_macro'] for x in history]
    f1_weighted = [x['f1_weighted'] for x in history]

    fig, ax = plt.subplots()
    ax.plot(accuracies, 'tab:orange')
    ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)
    ax.grid()
    fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)
    plt.show()

    fig, ax = plt.subplots()
    ax.plot(f1_macro, 'tab:orange')
    ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)
    ax.grid()
    fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')
    plt.show()

    fig, ax = plt.subplots()
    ax.plot(f1_weighted, 'tab:orange')
    ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)
    ax.grid()
    fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')
    plt.show()


def plot_losses(history, exp_id, title, suffix):
    val_losses = [x['loss'] for x in history]
    train_losses = [x['train_loss'] for x in history]

    fig, ax = plt.subplots()
    ax.plot(train_losses, 'tab:orange')
    ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)
    ax.grid()
    fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')
    plt.show()

    fig, ax = plt.subplots()
    ax.plot(val_losses, 'tab:orange')
    ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)
    ax.grid()
    fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')
    plt.show()


### Local Training (Client Update)

In [None]:
class CustomDataset(Dataset):
  def __init__(self, dataset, idxs):
      
      self.dataset = dataset
      self.idxs = list(idxs)

  def __len__(self):
      return len(self.idxs)

  def __getitem__(self, item):
      data, label = self.dataset[self.idxs[item]]
      return data, label

In [None]:
class ClientUpdate(object):
  def __init__(self, dataset, batchSize, learning_rate, epochs, idxs):
    # self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)
    if hasattr(dataset, 'dataloader'):
        self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)
    else:
        self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)

    self.learning_rate = learning_rate
    self.epochs = epochs

  def train(self, model):
    # print("Client training for {} epochs.".format(self.epochs))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)

    # use the weights of global model for proximal term calculation
    global_model = copy.deepcopy(model)

    # calculate local training time
    start_time = time.time()


    e_loss = []
    for epoch in range(1, self.epochs+1):

      train_loss = 0.0

      model.train()
      for data, labels in self.train_loader:

        if torch.cuda.is_available():
          data, labels = data.cuda(), labels.cuda()

        # clear the gradients
        optimizer.zero_grad()
        # make a forward pass
        output = model(data)

        loss = criterion(output, labels)
        # do a backwards pass
        loss.backward()
        # perform a single optimization step
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)

      # average losses
      train_loss = train_loss/len(self.train_loader.dataset)
      e_loss.append(train_loss)

    total_loss = sum(e_loss)/len(e_loss)

    return model.state_dict(), total_loss

### Server Side Training

In [None]:
def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, M, plt_title, plt_color, classes, history=[], eval_every=1, tb_logger=None):
  """
  Function implements the Federated Averaging Algorithm from the FedAvg paper.
  Specifically, this function is used for the server side training and weight update

  Params:
    - model:           PyTorch model to train
    - rounds:          Number of communication rounds for the client update
    - batch_size:      Batch size for client update training
    - lr:              Learning rate used for client update training
    - ds:              Dataset used for training
    - data_dict:       Type of data partition used for training (IID or non-IID)
    - test_data_dict:  Data used for testing the model
    - C:               Fraction of clients randomly chosen to perform computation on each round
    - K:               Total number of clients
    - E:               Number of training passes each client makes over its local dataset per round
    - mu:              proximal term constant
    - percentage:      percentage of selected client to have fewer than E epochs
  Returns:
    - model:           Trained model on the server
  """

  start = time.time()

  # global model weights
  global_weights = model.state_dict()

#   pk = np.ones(K) * 1.0 / K

  # training loss
  train_loss = []
  # test accuracy
  test_acc = []

  users_id = list(data_dict.keys())

  for curr_round in range(1, rounds+1):
    w, local_loss = [], []

    m = max(int(C*K), 1)
    newM = max(int(M*m), 1)

    c = 0
    S_t = np.random.choice(range(K), m, replace=False)
    print('Clients: {}/{} -> {}'.format(len(S_t), K, S_t))

    # print('Robustness..')
    # for i in range(len(S_t)):
    #   if c == newM:
    #     break
    #   c += 1

    #   k = S_t[i]
    #   key = users_id[k]
    #   ds_ = ds if ds else data_dict[key]['train_ds']
    #   idxs = data_dict[key] if ds else None
    #   print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')
    #   local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=E, idxs=idxs)
    #   weights, loss = local_update.train(model=copy.deepcopy(model))

    #   for k in weights.keys():
    #     t = torch.Tensor(weights[k].shape).cuda()
    #     t.fill_(0.1)
    #     weights[k] = t

    #   w.append(copy.deepcopy(weights))
    #   local_loss.append(copy.deepcopy(loss))
    # print('###############################')

    # for i in tqdm(range(newM, len(S_t))): # if robustness is included
    for i in tqdm(range(len(S_t))):
      k = S_t[i]
      key = users_id[k]
      ds_ = ds if ds else data_dict[key]['train_ds']
      idxs = data_dict[key] if ds else None
      print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')
      local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=E, idxs=idxs)
      weights, loss = local_update.train(model=copy.deepcopy(model))
      
      w.append(copy.deepcopy(weights))
      local_loss.append(copy.deepcopy(loss))

    print('Computing median...')

    target = copy.deepcopy(w[0]);
    weights_med = copy.deepcopy(w[0]);
    for k in weights_med.keys():
      tmp = copy.deepcopy(torch.median(torch.stack([w[i][k].data for i in range(0, len(w))]), dim=0))[0]
      target[k].data = tmp


    global_weights = target
    # move the updated weights to our model state dict
    model.load_state_dict(global_weights)

    # loss
    loss_avg = sum(local_loss) / len(local_loss)
    print('Round: {}... \tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))
    train_loss.append(loss_avg)
    if tb_logger:
        tb_logger.add_scalar(f'Train/Loss', loss_avg, curr_round)

    # testing
    # if curr_round % eval_every == 0:
    test_scores = testing(model, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(classes), classes)
    test_scores['train_loss'] = loss_avg
    test_loss, test_accuracy = test_scores['loss'], test_scores['accuracy']
    history.append(test_scores)
    
    # print('Round: {}... \tAverage Loss: {} \tTest Loss: {} \tTest Acc: {}'.format(curr_round, round(loss_avg, 3), round(test_loss, 3), round(test_accuracy, 3)))

    if tb_logger:
        tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)
        tb_logger.add_scalars(f'Test/Scores', {
            'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']
        }, curr_round)

    test_acc.append(test_accuracy)


  end = time.time()
  
  # plot train loss
  fig, ax = plt.subplots()
  x_axis = np.arange(1, rounds+1)
  y_axis = np.array(train_loss)
  ax.plot(x_axis, y_axis)

  ax.set(xlabel='Number of Rounds', ylabel='Train Loss', title=plt_title)
  ax.grid()
  # fig.savefig(plt_title+'.jpg', format='jpg')

  # plot test accuracy
  fig1, ax1 = plt.subplots()
  x_axis1 = np.arange(1, rounds+1)
  y_axis1 = np.array(test_acc)
  ax1.plot(x_axis1, y_axis1)

  ax1.set(xlabel='Number of Rounds', ylabel='Test Accuracy', title=plt_title)
  ax1.grid()
  # fig1.savefig(plt_title+'-test.jpg', format='jpg')
  
  print("Training Done! Total time taken to Train: {}".format(end-start))

  return model, history

### Testing Loop

In [None]:
def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):
  #test loss 
  test_loss = 0.0
  correct_class = list(0. for i in range(num_classes))
  total_class = list(0. for i in range(num_classes))

  test_loader = DataLoader(dataset, batch_size=bs)
  l = len(test_loader)
  model.eval()
  print('running validation...')
  for i, (data, labels) in enumerate(tqdm(test_loader)):

    if torch.cuda.is_available():
      data, labels = data.cuda(), labels.cuda()

    output = model(data)
    loss = criterion(output, labels)
    test_loss += loss.item()*data.size(0)

    _, pred = torch.max(output, 1)

    # For F1Score
    y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()
    y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()

    correct_tensor = pred.eq(labels.data.view_as(pred))
    correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())

    for i, lbl in enumerate(labels.data):
    #   print('lbl', i, lbl)
      correct_class[lbl] += correct.data[i]
      total_class[lbl] += 1
    
  # avg test loss
  test_loss = test_loss/len(test_loader.dataset)
  print("Test Loss: {:.6f}\n".format(test_loss))

  # Avg F1 Score
  f1_macro = f1_score(y_true, y_hat, average='macro')
  # F1-Score -> weigthed to consider class imbalance
  f1_weighted =  f1_score(y_true, y_hat, average='weighted')
  print("F1 Score: {:.6f} (macro) {:.6f} (weighted) %\n".format(f1_macro, f1_weighted))

  # print test accuracy
  if print_all:
    for i in range(num_classes):
        if total_class[i]>0:
            print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % 
                    (classes[i], 100 * correct_class[i] / total_class[i],
                    np.sum(correct_class[i]), np.sum(total_class[i])))
        else:
            print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))

  overall_accuracy = np.sum(correct_class) / np.sum(total_class)

  print('\nFinal Test  Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))

  return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}

## Experiments

In [None]:
FAIL-ON-PURPOSE

NameError: ignored

In [None]:
seq_length = 80  # mcmahan17a, fedprox, qFFL
embedding_dim = 8  # mcmahan17a, fedprox, qFFL
# hidden_dim = 100  # fedprox paper
hidden_dim = 256  # mcmahan17a, fedprox impl
num_classes = len(corpus)
classes = list(range(num_classes))
lstm_layers = 2  # mcmahan17a, fedprox, qFFL
dropout = 0.1  # TODO


In [None]:
class Hyperparameters():

    def __init__(self, total_clients):
        # number of training rounds
        self.rounds = 50
        # client fraction
        self.C = 0.5
        # number of clients
        self.K = total_clients
        # number of training passes on local dataset for each roung
        # self.E = 20
        self.E = 1
        self.M = 0.01  # FedMed
        # batch size
        self.batch_size = 10
        # learning Rate
        self.lr = 0.8

In [None]:
exp_log = dict()

### IID

In [None]:
train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True)  # Not using val_ds but makes train eval periods faster

total_clients = len(data_dict.keys())
'Total users:', total_clients

Objective 413629 / 413629


('Total users:', 143)

In [None]:
hparams = Hyperparameters(total_clients)
hparams.__dict__

{'C': 0.5,
 'E': 1,
 'K': 143,
 'M': 0.01,
 'batch_size': 10,
 'lr': 0.8,
 'rounds': 50}

In [None]:
# Sweeping parameter
PARAM_NAME = 'clients_fraction'
PARAM_VALUE = hparams.C
exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'
exp_id

'clients_fraction/0.5'

In [None]:
EXP_DIR = f'{BASE_DIR}/{exp_id}'
os.makedirs(EXP_DIR, exist_ok=True)

# tb_logger = SummaryWriter(log_dir)
# print(f'TBoard logger created at: {log_dir}')

title = 'LSTM FedMed on IID'

In [None]:
def run_experiment(run_id):

    shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,  
                                   embedding_dim=embedding_dim,  
                                   hidden_dim=hidden_dim,
                                   classes=num_classes,
                                   lstm_layers=lstm_layers,
                                   dropout=dropout,
                                   batch_first=True
                                   )

    if torch.cuda.is_available():
        shakespeare_lstm.cuda()
    
    test_history = []

    lstm_iid_trained, test_history = training(shakespeare_lstm,
                                            hparams.rounds, hparams.batch_size, hparams.lr,
                                            train_ds,
                                            data_dict,
                                            test_ds,
                                            hparams.C, hparams.K, hparams.E, hparams.M,
                                            title, "green",
                                            corpus, # classes
                                            history=test_history,
                                            # tb_logger=tb_writer
                                            )
    

    final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)
    print(f'\n\n========================================================\n\n')
    print(f'Final scores for Exp {run_id} \n {final_scores}')

    log = {
        'history': test_history,
        'hyperparams': hparams.__dict__
    }

    with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:
        pickle.dump(log, file)

    return test_history


In [None]:
exp_history = list()
for run_id in range(2):  # TOTAL RUNS
    print(f'============== RUNNING EXPERIMENT #{run_id} ==============')
    exp_history.append(run_experiment(run_id))
    print(f'\n\n========================================================\n\n')

Clients: 71/143 -> [  4  12 100  50  46  39 127 112   8  93  58 125  66  78  79  27 142 115
 124  52  49  20  86  14  57  22  17 120  83 116  28  77  40  54  26  59
 134  23  98 139  72  32  64  70 107 122 135   9 141  84  44 138  92 130
 128 114  67 102 108 129 104  81 103  42   7   0  43 101  18  36  37]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 4: 2892 samples


  self.dropout, self.training, self.bidirectional, self.batch_first)


Client 12: 2892 samples
Client 100: 2892 samples
Client 50: 2892 samples
Client 46: 2892 samples
Client 39: 2892 samples
Client 127: 2892 samples
Client 112: 2892 samples
Client 8: 2892 samples
Client 93: 2892 samples
Client 58: 2892 samples
Client 125: 2892 samples
Client 66: 2892 samples
Client 78: 2892 samples
Client 79: 2892 samples
Client 27: 2892 samples
Client 142: 2892 samples
Client 115: 2892 samples
Client 124: 2892 samples
Client 52: 2892 samples
Client 49: 2892 samples
Client 20: 2892 samples
Client 86: 2892 samples
Client 14: 2892 samples
Client 57: 2892 samples
Client 22: 2892 samples
Client 17: 2892 samples
Client 120: 2892 samples
Client 83: 2892 samples
Client 116: 2892 samples
Client 28: 2892 samples
Client 77: 2892 samples
Client 40: 2892 samples
Client 54: 2892 samples
Client 26: 2892 samples
Client 59: 2892 samples
Client 134: 2892 samples
Client 23: 2892 samples
Client 98: 2892 samples
Client 139: 2892 samples
Client 72: 2892 samples
Client 32: 2892 samples
Client

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.928030

F1 Score: 0.015211 (macro) 0.119871 (weighted) %


Final Test  Accuracy: 0.230 (11898.0/51704.0)
Clients: 71/143 -> [ 37  96 117  38  66  56  67  31  17  43  25  72  49  77  60  14  22  97
 109  98 140  87  13 129 119 112  51 138  23  89  24  11 134 105   9  71
  58  91  85  42  55  29 128  39 120  36  46  70  64  93 113 141  65   2
  50  33 123   3  62  95 107  30  76 118 137  27  28  54 126 101 125]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 37: 2892 samples
Client 96: 2892 samples
Client 117: 2892 samples
Client 38: 2892 samples
Client 66: 2892 samples
Client 56: 2892 samples
Client 67: 2892 samples
Client 31: 2892 samples
Client 17: 2892 samples
Client 43: 2892 samples
Client 25: 2892 samples
Client 72: 2892 samples
Client 49: 2892 samples
Client 77: 2892 samples
Client 60: 2892 samples
Client 14: 2892 samples
Client 22: 2892 samples
Client 97: 2892 samples
Client 109: 2892 samples
Client 98: 2892 samples
Client 140: 2892 samples
Client 87: 2892 samples
Client 13: 2892 samples
Client 129: 2892 samples
Client 119: 2892 samples
Client 112: 2892 samples
Client 51: 2892 samples
Client 138: 2892 samples
Client 23: 2892 samples
Client 89: 2892 samples
Client 24: 2892 samples
Client 11: 2892 samples
Client 134: 2892 samples
Client 105: 2892 samples
Client 9: 2892 samples
Client 71: 2892 samples
Client 58: 2892 samples
Client 91: 2892 samples
Client 85: 2892 samples
Client 42: 2892 samples
Client 55: 2892 samples
Client 2

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.642471

F1 Score: 0.032418 (macro) 0.185810 (weighted) %


Final Test  Accuracy: 0.277 (14344.0/51704.0)
Clients: 71/143 -> [ 49  28  12 133   4  90  27  89  77  65 136  81  39 103 121  67  32  13
   0 116  47  58  37 139  24 130  19  25  76  66  96  61   2 126  53  71
   5  42 101  46 129  91  93 107 132  56  80  29  55  92 102  85 111 138
  30  26   1 104  44  88  31  94  33  22 100  34  69  14  79  84 106]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 49: 2892 samples
Client 28: 2892 samples
Client 12: 2892 samples
Client 133: 2892 samples
Client 4: 2892 samples
Client 90: 2892 samples
Client 27: 2892 samples
Client 89: 2892 samples
Client 77: 2892 samples
Client 65: 2892 samples
Client 136: 2892 samples
Client 81: 2892 samples
Client 39: 2892 samples
Client 103: 2892 samples
Client 121: 2892 samples
Client 67: 2892 samples
Client 32: 2892 samples
Client 13: 2892 samples
Client 0: 2892 samples
Client 116: 2892 samples
Client 47: 2892 samples
Client 58: 2892 samples
Client 37: 2892 samples
Client 139: 2892 samples
Client 24: 2892 samples
Client 130: 2892 samples
Client 19: 2892 samples
Client 25: 2892 samples
Client 76: 2892 samples
Client 66: 2892 samples
Client 96: 2892 samples
Client 61: 2892 samples
Client 2: 2892 samples
Client 126: 2892 samples
Client 53: 2892 samples
Client 71: 2892 samples
Client 5: 2892 samples
Client 42: 2892 samples
Client 101: 2892 samples
Client 46: 2892 samples
Client 129: 2892 samples
Client 91:

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.487344

F1 Score: 0.048434 (macro) 0.233902 (weighted) %


Final Test  Accuracy: 0.320 (16528.0/51704.0)
Clients: 71/143 -> [ 61  34  24  18 122  91  12  46  16  21   5  93  72 137  26 133   3  66
  44 142  78  47  23  19   7  27 120 131 128  57  65 109  92  82 112  54
 139 101 124  10  43  89  14 141 105  84 125  71  48   8  52  50  28  69
  83  90  22  33  51  29 123   9  99 114  11  13 115  39  45  79  59]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 61: 2892 samples
Client 34: 2892 samples
Client 24: 2892 samples
Client 18: 2892 samples
Client 122: 2892 samples
Client 91: 2892 samples
Client 12: 2892 samples
Client 46: 2892 samples
Client 16: 2892 samples
Client 21: 2892 samples
Client 5: 2892 samples
Client 93: 2892 samples
Client 72: 2892 samples
Client 137: 2892 samples
Client 26: 2892 samples
Client 133: 2892 samples
Client 3: 2892 samples
Client 66: 2892 samples
Client 44: 2892 samples
Client 142: 2892 samples
Client 78: 2892 samples
Client 47: 2892 samples
Client 23: 2892 samples
Client 19: 2892 samples
Client 7: 2892 samples
Client 27: 2892 samples
Client 120: 2892 samples
Client 131: 2892 samples
Client 128: 2892 samples
Client 57: 2892 samples
Client 65: 2892 samples
Client 109: 2892 samples
Client 92: 2892 samples
Client 82: 2892 samples
Client 112: 2892 samples
Client 54: 2892 samples
Client 139: 2892 samples
Client 101: 2892 samples
Client 124: 2892 samples
Client 10: 2892 samples
Client 43: 2892 samples
Client 

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.399509

F1 Score: 0.060108 (macro) 0.253967 (weighted) %


Final Test  Accuracy: 0.334 (17284.0/51704.0)
Clients: 71/143 -> [ 23  79  75   4 106 133 102  44 131  51  91  52  16   7 137  10  13  20
  61  95  64  83  73  63   6  62  12  81  65  28 134 111 100  57   1  96
  71  38 116 140   3  42   9  31  21 119  25  55  18 115  80  98  27  72
 126  74 103 138  45  82 130  66   2  59   8 109  22  88  46  53  19]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 23: 2892 samples
Client 79: 2892 samples
Client 75: 2892 samples
Client 4: 2892 samples
Client 106: 2892 samples
Client 133: 2892 samples
Client 102: 2892 samples
Client 44: 2892 samples
Client 131: 2892 samples
Client 51: 2892 samples
Client 91: 2892 samples
Client 52: 2892 samples
Client 16: 2892 samples
Client 7: 2892 samples
Client 137: 2892 samples
Client 10: 2892 samples
Client 13: 2892 samples
Client 20: 2892 samples
Client 61: 2892 samples
Client 95: 2892 samples
Client 64: 2892 samples
Client 83: 2892 samples
Client 73: 2892 samples
Client 63: 2892 samples
Client 6: 2892 samples
Client 62: 2892 samples
Client 12: 2892 samples
Client 81: 2892 samples
Client 65: 2892 samples
Client 28: 2892 samples
Client 134: 2892 samples
Client 111: 2892 samples
Client 100: 2892 samples
Client 57: 2892 samples
Client 1: 2892 samples
Client 96: 2892 samples
Client 71: 2892 samples
Client 38: 2892 samples
Client 116: 2892 samples
Client 140: 2892 samples
Client 3: 2892 samples
Client 42: 

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.332980

F1 Score: 0.072108 (macro) 0.279542 (weighted) %


Final Test  Accuracy: 0.349 (18059.0/51704.0)
Clients: 71/143 -> [ 59  13  85 116  88 126  71 119 137  68  36 104  58 135  69  27  37   3
 107  79 130  22  14 136  28  29  20  35  99  30 100  42  73  54  24 101
  51 108  94 118   1   5  91  64  65  46  34  74 113 112  67  15  86  50
  78  87  19  33  41  63   4  26  92 109 129  16 110 142 103  90  44]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 59: 2892 samples
Client 13: 2892 samples
Client 85: 2892 samples
Client 116: 2892 samples
Client 88: 2892 samples
Client 126: 2892 samples
Client 71: 2892 samples
Client 119: 2892 samples
Client 137: 2892 samples
Client 68: 2892 samples
Client 36: 2892 samples
Client 104: 2892 samples
Client 58: 2892 samples
Client 135: 2892 samples
Client 69: 2892 samples
Client 27: 2892 samples
Client 37: 2892 samples
Client 3: 2892 samples
Client 107: 2892 samples
Client 79: 2892 samples
Client 130: 2892 samples
Client 22: 2892 samples
Client 14: 2892 samples
Client 136: 2892 samples
Client 28: 2892 samples
Client 29: 2892 samples
Client 20: 2892 samples
Client 35: 2892 samples
Client 99: 2892 samples
Client 30: 2892 samples
Client 100: 2892 samples
Client 42: 2892 samples
Client 73: 2892 samples
Client 54: 2892 samples
Client 24: 2892 samples
Client 101: 2892 samples
Client 51: 2892 samples
Client 108: 2892 samples
Client 94: 2892 samples
Client 118: 2892 samples
Client 1: 2892 samples
Clien

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.274479

F1 Score: 0.078728 (macro) 0.298398 (weighted) %


Final Test  Accuracy: 0.361 (18663.0/51704.0)
Clients: 71/143 -> [ 85  65  34  71  31  28  88  13  48  21  53  87 114 100  35 130  24   5
  92  15  49  19  16 111  57  33  91  98  42  14  26   8 101  10 131 117
   1 133 105 137  58  32  41   0  12  73   9  82 103  76   7  44 140  90
 132  56 102  81 125 122  55 108  59   3 126  43   4 115  54  95  30]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 85: 2892 samples
Client 65: 2892 samples
Client 34: 2892 samples
Client 71: 2892 samples
Client 31: 2892 samples
Client 28: 2892 samples
Client 88: 2892 samples
Client 13: 2892 samples
Client 48: 2892 samples
Client 21: 2892 samples
Client 53: 2892 samples
Client 87: 2892 samples
Client 114: 2892 samples
Client 100: 2892 samples
Client 35: 2892 samples
Client 130: 2892 samples
Client 24: 2892 samples
Client 5: 2892 samples
Client 92: 2892 samples
Client 15: 2892 samples
Client 49: 2892 samples
Client 19: 2892 samples
Client 16: 2892 samples
Client 111: 2892 samples
Client 57: 2892 samples
Client 33: 2892 samples
Client 91: 2892 samples
Client 98: 2892 samples
Client 42: 2892 samples
Client 14: 2892 samples
Client 26: 2892 samples
Client 8: 2892 samples
Client 101: 2892 samples
Client 10: 2892 samples
Client 131: 2892 samples
Client 117: 2892 samples
Client 1: 2892 samples
Client 133: 2892 samples
Client 105: 2892 samples
Client 137: 2892 samples
Client 58: 2892 samples
Client 32

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.221982

F1 Score: 0.085083 (macro) 0.313551 (weighted) %


Final Test  Accuracy: 0.374 (19360.0/51704.0)
Clients: 71/143 -> [  8 101  75  73  29  62  94  26 140 108  43  83   7 112  63  33 132  23
  25 124  22 135  52 118   9  11 115 138  35  77  64 114 105  21 128 126
  51  91 123  95  28  70 127   4  60  39  16 122  46  90 141  81 133  88
  31  74 129 100  41  67  40 107  36  66  38  12 116  30  61 131  24]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 8: 2892 samples
Client 101: 2892 samples
Client 75: 2892 samples
Client 73: 2892 samples
Client 29: 2892 samples
Client 62: 2892 samples
Client 94: 2892 samples
Client 26: 2892 samples
Client 140: 2892 samples
Client 108: 2892 samples
Client 43: 2892 samples
Client 83: 2892 samples
Client 7: 2892 samples
Client 112: 2892 samples
Client 63: 2892 samples
Client 33: 2892 samples
Client 132: 2892 samples
Client 23: 2892 samples
Client 25: 2892 samples
Client 124: 2892 samples
Client 22: 2892 samples
Client 135: 2892 samples
Client 52: 2892 samples
Client 118: 2892 samples
Client 9: 2892 samples
Client 11: 2892 samples
Client 115: 2892 samples
Client 138: 2892 samples
Client 35: 2892 samples
Client 77: 2892 samples
Client 64: 2892 samples
Client 114: 2892 samples
Client 105: 2892 samples
Client 21: 2892 samples
Client 128: 2892 samples
Client 126: 2892 samples
Client 51: 2892 samples
Client 91: 2892 samples
Client 123: 2892 samples
Client 95: 2892 samples
Client 28: 2892 samples
Clie

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.178158

F1 Score: 0.088131 (macro) 0.324965 (weighted) %


Final Test  Accuracy: 0.383 (19809.0/51704.0)
Clients: 71/143 -> [ 39  10  19 113  77  50 100 117 142  96   0 141  61  84  43 122   9  98
  80  73 112  71  32 106  24 128  35  78  33  63 108  66 111  53 125  46
 126 138  85 123 127  74  89  95 109 132 139  21  51  48  58  62 129 136
  49   5 103  94 124  93  99  69  25   3  34   4 104 102  72  56 119]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 39: 2892 samples
Client 10: 2892 samples
Client 19: 2892 samples
Client 113: 2892 samples
Client 77: 2892 samples
Client 50: 2892 samples
Client 100: 2892 samples
Client 117: 2892 samples
Client 142: 2892 samples
Client 96: 2892 samples
Client 0: 2892 samples
Client 141: 2892 samples
Client 61: 2892 samples
Client 84: 2892 samples
Client 43: 2892 samples
Client 122: 2892 samples
Client 9: 2892 samples
Client 98: 2892 samples
Client 80: 2892 samples
Client 73: 2892 samples
Client 112: 2892 samples
Client 71: 2892 samples
Client 32: 2892 samples
Client 106: 2892 samples
Client 24: 2892 samples
Client 128: 2892 samples
Client 35: 2892 samples
Client 78: 2892 samples
Client 33: 2892 samples
Client 63: 2892 samples
Client 108: 2892 samples
Client 66: 2892 samples
Client 111: 2892 samples
Client 53: 2892 samples
Client 125: 2892 samples
Client 46: 2892 samples
Client 126: 2892 samples
Client 138: 2892 samples
Client 85: 2892 samples
Client 123: 2892 samples
Client 127: 2892 samples
Cl

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.141748

F1 Score: 0.090447 (macro) 0.330370 (weighted) %


Final Test  Accuracy: 0.387 (20010.0/51704.0)
Clients: 71/143 -> [ 80  93 129  65   1   7  74  35 130  53  17 104  95  58  96  75  28   0
 128  40  89   6  48  18  87 131  14  70 114  79  47  27  26 123  77   8
  67  29  73 108  92  59 122  52 142  33  15  45 126  36  88 115  25 109
  44  78 133 141 135  54  71 121  11  12  23 111  85 118 119   4 125]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 80: 2892 samples
Client 93: 2892 samples
Client 129: 2892 samples
Client 65: 2892 samples
Client 1: 2892 samples
Client 7: 2892 samples
Client 74: 2892 samples
Client 35: 2892 samples
Client 130: 2892 samples
Client 53: 2892 samples
Client 17: 2892 samples
Client 104: 2892 samples
Client 95: 2892 samples
Client 58: 2892 samples
Client 96: 2892 samples
Client 75: 2892 samples
Client 28: 2892 samples
Client 0: 2892 samples
Client 128: 2892 samples
Client 40: 2892 samples
Client 89: 2892 samples
Client 6: 2892 samples
Client 48: 2892 samples
Client 18: 2892 samples
Client 87: 2892 samples
Client 131: 2892 samples
Client 14: 2892 samples
Client 70: 2892 samples
Client 114: 2892 samples
Client 79: 2892 samples
Client 47: 2892 samples
Client 27: 2892 samples
Client 26: 2892 samples
Client 123: 2892 samples
Client 77: 2892 samples
Client 8: 2892 samples
Client 67: 2892 samples
Client 29: 2892 samples
Client 73: 2892 samples
Client 108: 2892 samples
Client 92: 2892 samples
Client 59: 28

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.108632

F1 Score: 0.098451 (macro) 0.344847 (weighted) %


Final Test  Accuracy: 0.397 (20541.0/51704.0)
Clients: 71/143 -> [137  17  51  74  47 113 101   8  40 132  60 120 107  43  96  48 131   7
  33 139  16  31 125  12   0  58   6  22 140  45  94  98  59  97   5 112
 128  23 118  93  34  63  64  28 110  83  70 138  26  19  75 102  72 105
 116  10  35  25  20  30  76 109  82 103  67  91 130  86   3  95  71]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 137: 2892 samples
Client 17: 2892 samples
Client 51: 2892 samples
Client 74: 2892 samples
Client 47: 2892 samples
Client 113: 2892 samples
Client 101: 2892 samples
Client 8: 2892 samples
Client 40: 2892 samples
Client 132: 2892 samples
Client 60: 2892 samples
Client 120: 2892 samples
Client 107: 2892 samples
Client 43: 2892 samples
Client 96: 2892 samples
Client 48: 2892 samples
Client 131: 2892 samples
Client 7: 2892 samples
Client 33: 2892 samples
Client 139: 2892 samples
Client 16: 2892 samples
Client 31: 2892 samples
Client 125: 2892 samples
Client 12: 2892 samples
Client 0: 2892 samples
Client 58: 2892 samples
Client 6: 2892 samples
Client 22: 2892 samples
Client 140: 2892 samples
Client 45: 2892 samples
Client 94: 2892 samples
Client 98: 2892 samples
Client 59: 2892 samples
Client 97: 2892 samples
Client 5: 2892 samples
Client 112: 2892 samples
Client 128: 2892 samples
Client 23: 2892 samples
Client 118: 2892 samples
Client 93: 2892 samples
Client 34: 2892 samples
Client 6

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.081649

F1 Score: 0.100926 (macro) 0.348966 (weighted) %


Final Test  Accuracy: 0.403 (20846.0/51704.0)
Clients: 71/143 -> [  8 141  73 116  15   0  18   3 122 100  50  13 102  53  91  23  42 111
  36  98  72  88 103  30  75  48  85  31  52 120   5  41 101  84 128 131
 105 104  82  34  60 117  19  24  61  58 109  86  62 133 132  96  35   2
  43 118  38  57  76  11  20  68 115  74 108 142 124 130 134  55 123]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 8: 2892 samples
Client 141: 2892 samples
Client 73: 2892 samples
Client 116: 2892 samples
Client 15: 2892 samples
Client 0: 2892 samples
Client 18: 2892 samples
Client 3: 2892 samples
Client 122: 2892 samples
Client 100: 2892 samples
Client 50: 2892 samples
Client 13: 2892 samples
Client 102: 2892 samples
Client 53: 2892 samples
Client 91: 2892 samples
Client 23: 2892 samples
Client 42: 2892 samples
Client 111: 2892 samples
Client 36: 2892 samples
Client 98: 2892 samples
Client 72: 2892 samples
Client 88: 2892 samples
Client 103: 2892 samples
Client 30: 2892 samples
Client 75: 2892 samples
Client 48: 2892 samples
Client 85: 2892 samples
Client 31: 2892 samples
Client 52: 2892 samples
Client 120: 2892 samples
Client 5: 2892 samples
Client 41: 2892 samples
Client 101: 2892 samples
Client 84: 2892 samples
Client 128: 2892 samples
Client 131: 2892 samples
Client 105: 2892 samples
Client 104: 2892 samples
Client 82: 2892 samples
Client 34: 2892 samples
Client 60: 2892 samples
Client 

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.052279

F1 Score: 0.105824 (macro) 0.358565 (weighted) %


Final Test  Accuracy: 0.407 (21037.0/51704.0)
Clients: 71/143 -> [ 55 118  20  49 140  15 115  22  88 105   3  78  73  35 125  64 122 129
 119  89 108 128  43  16  21  57  13 139  32  48 138 131  42 100  91  40
   4  30  37   1  18   9  17  46  93 127  79  70  74  92  59  72 112  87
  58  50  80 130  27  14 113  60   2  63  54 123  76 103  69  77 114]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 55: 2892 samples
Client 118: 2892 samples
Client 20: 2892 samples
Client 49: 2892 samples
Client 140: 2892 samples
Client 15: 2892 samples
Client 115: 2892 samples
Client 22: 2892 samples
Client 88: 2892 samples
Client 105: 2892 samples
Client 3: 2892 samples
Client 78: 2892 samples
Client 73: 2892 samples
Client 35: 2892 samples
Client 125: 2892 samples
Client 64: 2892 samples
Client 122: 2892 samples
Client 129: 2892 samples
Client 119: 2892 samples
Client 89: 2892 samples
Client 108: 2892 samples
Client 128: 2892 samples
Client 43: 2892 samples
Client 16: 2892 samples
Client 21: 2892 samples
Client 57: 2892 samples
Client 13: 2892 samples
Client 139: 2892 samples
Client 32: 2892 samples
Client 48: 2892 samples
Client 138: 2892 samples
Client 131: 2892 samples
Client 42: 2892 samples
Client 100: 2892 samples
Client 91: 2892 samples
Client 40: 2892 samples
Client 4: 2892 samples
Client 30: 2892 samples
Client 37: 2892 samples
Client 1: 2892 samples
Client 18: 2892 samples
Clien

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.028951

F1 Score: 0.110174 (macro) 0.363117 (weighted) %


Final Test  Accuracy: 0.413 (21360.0/51704.0)
Clients: 71/143 -> [  6  68  65  63  18  98  72  27  20  45 137  31  51  71  32   8  50 105
  66 142 141 112 118 109 110 130 101  29  56  28 127  62  23  58 116  78
  52  12  69  30  49 123 131  25  53 124   9  59  38  24  61 122  84  33
  85  60  14 102  95 115 100 132  54  90  94  42   5   4  48  81 121]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 6: 2892 samples
Client 68: 2892 samples
Client 65: 2892 samples
Client 63: 2892 samples
Client 18: 2892 samples
Client 98: 2892 samples
Client 72: 2892 samples
Client 27: 2892 samples
Client 20: 2892 samples
Client 45: 2892 samples
Client 137: 2892 samples
Client 31: 2892 samples
Client 51: 2892 samples
Client 71: 2892 samples
Client 32: 2892 samples
Client 8: 2892 samples
Client 50: 2892 samples
Client 105: 2892 samples
Client 66: 2892 samples
Client 142: 2892 samples
Client 141: 2892 samples
Client 112: 2892 samples
Client 118: 2892 samples
Client 109: 2892 samples
Client 110: 2892 samples
Client 130: 2892 samples
Client 101: 2892 samples
Client 29: 2892 samples
Client 56: 2892 samples
Client 28: 2892 samples
Client 127: 2892 samples
Client 62: 2892 samples
Client 23: 2892 samples
Client 58: 2892 samples
Client 116: 2892 samples
Client 78: 2892 samples
Client 52: 2892 samples
Client 12: 2892 samples
Client 69: 2892 samples
Client 30: 2892 samples
Client 49: 2892 samples
Client

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 2.007242

F1 Score: 0.117510 (macro) 0.366815 (weighted) %


Final Test  Accuracy: 0.419 (21644.0/51704.0)
Clients: 71/143 -> [114  87 138  14  57  22  68  18  51  72  40  79  65  39  46  67  83  16
 139 104  52  80  20 106   1 105  26  33  53  94 102  44 133  43  37  12
  69   4  30 135 100  27   9 132  59 125  99  34 103  36 107  54  58 101
  56  63 115  32 112  60 109  88  45  89  82   3  13   8  70  74 131]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 114: 2892 samples
Client 87: 2892 samples
Client 138: 2892 samples
Client 14: 2892 samples
Client 57: 2892 samples
Client 22: 2892 samples
Client 68: 2892 samples
Client 18: 2892 samples
Client 51: 2892 samples
Client 72: 2892 samples
Client 40: 2892 samples
Client 79: 2892 samples
Client 65: 2892 samples
Client 39: 2892 samples
Client 46: 2892 samples
Client 67: 2892 samples
Client 83: 2892 samples
Client 16: 2892 samples
Client 139: 2892 samples
Client 104: 2892 samples
Client 52: 2892 samples
Client 80: 2892 samples
Client 20: 2892 samples
Client 106: 2892 samples
Client 1: 2892 samples
Client 105: 2892 samples
Client 26: 2892 samples
Client 33: 2892 samples
Client 53: 2892 samples
Client 94: 2892 samples
Client 102: 2892 samples
Client 44: 2892 samples
Client 133: 2892 samples
Client 43: 2892 samples
Client 37: 2892 samples
Client 12: 2892 samples
Client 69: 2892 samples
Client 4: 2892 samples
Client 30: 2892 samples
Client 135: 2892 samples
Client 100: 2892 samples
Client 2

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 1.983722

F1 Score: 0.122407 (macro) 0.375355 (weighted) %


Final Test  Accuracy: 0.423 (21859.0/51704.0)
Clients: 71/143 -> [105 107  65   1  95  34  10 120  76  22  59  75  30  98  17  91 117  33
  81  58  20 123   4  71  35   2  92  41  77  12 140  63 103  94  21  83
  73 126 138 113 101  99  32  18 121  25 104   8 137  43  78  60  42  28
  15  55  46  38 131  57   9 111 116 132 112 135  64  39   6  88  40]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 105: 2892 samples
Client 107: 2892 samples
Client 65: 2892 samples
Client 1: 2892 samples
Client 95: 2892 samples
Client 34: 2892 samples
Client 10: 2892 samples
Client 120: 2892 samples
Client 76: 2892 samples
Client 22: 2892 samples
Client 59: 2892 samples
Client 75: 2892 samples
Client 30: 2892 samples
Client 98: 2892 samples
Client 17: 2892 samples
Client 91: 2892 samples
Client 117: 2892 samples
Client 33: 2892 samples
Client 81: 2892 samples
Client 58: 2892 samples
Client 20: 2892 samples
Client 123: 2892 samples
Client 4: 2892 samples
Client 71: 2892 samples
Client 35: 2892 samples
Client 2: 2892 samples
Client 92: 2892 samples
Client 41: 2892 samples
Client 77: 2892 samples
Client 12: 2892 samples
Client 140: 2892 samples
Client 63: 2892 samples
Client 103: 2892 samples
Client 94: 2892 samples
Client 21: 2892 samples
Client 83: 2892 samples
Client 73: 2892 samples
Client 126: 2892 samples
Client 138: 2892 samples
Client 113: 2892 samples
Client 101: 2892 samples
Client 9

HBox(children=(FloatProgress(value=0.0, max=2586.0), HTML(value='')))


Test Loss: 1.964554

F1 Score: 0.130020 (macro) 0.382373 (weighted) %


Final Test  Accuracy: 0.428 (22134.0/51704.0)
Clients: 71/143 -> [ 12  21  76   0  65 135   8  61  78  11  74  50 101 122  90 107  91 130
  42 133  17  66  46  60  57 141 102  34  85 103  56 124  49 105  98  82
  47   4 138   3 109 112 113  97  86  95  87  58  13  38  72  80 119 139
  89  24  25 114 132 123  37  41 140 127 128  68   1 129   6  77 110]


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

Client 12: 2892 samples
Client 21: 2892 samples
Client 76: 2892 samples
Client 0: 2892 samples
Client 65: 2892 samples
Client 135: 2892 samples
Client 8: 2892 samples
Client 61: 2892 samples
Client 78: 2892 samples
Client 11: 2892 samples
Client 74: 2892 samples
Client 50: 2892 samples
Client 101: 2892 samples
Client 122: 2892 samples
Client 90: 2892 samples
Client 107: 2892 samples
Client 91: 2892 samples
Client 130: 2892 samples
Client 42: 2892 samples
Client 133: 2892 samples
Client 17: 2892 samples
Client 66: 2892 samples
Client 46: 2892 samples
Client 60: 2892 samples
Client 57: 2892 samples
Client 141: 2892 samples
Client 102: 2892 samples
Client 34: 2892 samples
Client 85: 2892 samples
Client 103: 2892 samples
Client 56: 2892 samples
Client 124: 2892 samples
Client 49: 2892 samples
Client 105: 2892 samples
Client 98: 2892 samples
Client 82: 2892 samples
Client 47: 2892 samples
Client 4: 2892 samples
Client 138: 2892 samples
Client 3: 2892 samples
Client 109: 2892 samples
Client 

In [None]:
exp_log[title] = {
    'history': exp_history,
    'hyperparams': hparams.__dict__
}

In [None]:
df = None
for i, e in enumerate(exp_history):
    if i == 0:
        df = pd.json_normalize(e)
        continue
    df = df + pd.json_normalize(e)
    
df_avg = df / len(exp_history)
avg_history = df_avg.to_dict(orient='records')

In [None]:
plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')

In [None]:
plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')

In [None]:
with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:
    pickle.dump(exp_log, file)

### Non-IID

In [None]:
exp_log = dict()

In [None]:
data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)

total_clients = len(data_dict.keys())
'Total users:', total_clients

In [None]:
hparams = Hyperparameters(total_clients)
hparams.__dict__

In [None]:
# Sweeping parameter
PARAM_NAME = 'clients_fraction'
PARAM_VALUE = hparams.C
exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'
exp_id

In [None]:
EXP_DIR = f'{BASE_DIR}/{exp_id}'
os.makedirs(EXP_DIR, exist_ok=True)

# tb_logger = SummaryWriter(log_dir)
# print(f'TBoard logger created at: {log_dir}')

title = 'LSTM FedMed on Non-IID'

In [None]:
def run_experiment(run_id):

    shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,  
                                   embedding_dim=embedding_dim,  
                                   hidden_dim=hidden_dim,
                                   classes=num_classes,
                                   lstm_layers=lstm_layers,
                                   dropout=dropout,
                                   batch_first=True
                                   )

    if torch.cuda.is_available():
        shakespeare_lstm.cuda()
    
    test_history = []

    lstm_iid_trained, test_history = training(shakespeare_lstm,
                                            hparams.rounds, hparams.batch_size, hparams.lr,
                                            None,
                                            data_dict,
                                            test_ds,
                                            hparams.C, hparams.K, hparams.E, hparams.M,
                                            title, "green",
                                            corpus, # classes
                                            history=test_history,
                                            # tb_logger=tb_writer
                                            )
    

    final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)
    print(f'\n\n========================================================\n\n')
    print(f'Final scores for Exp {run_id} \n {final_scores}')

    log = {
        'history': test_history,
        'hyperparams': hparams.__dict__
    }

    with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:
        pickle.dump(log, file)

    return test_history


In [None]:
exp_history = list()
for run_id in range(2):  # TOTAL RUNS
    print(f'============== RUNNING EXPERIMENT #{run_id} ==============')
    exp_history.append(run_experiment(run_id))
    print(f'\n\n========================================================\n\n')

In [None]:
exp_log[title] = {
    'history': exp_history,
    'hyperparams': hparams.__dict__
}

In [None]:
df = None
for i, e in enumerate(exp_history):
    if i == 0:
        df = pd.json_normalize(e)
        continue
    df = df + pd.json_normalize(e)
    
df_avg = df / len(exp_history)
avg_history = df_avg.to_dict(orient='records')

In [None]:
plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')

In [None]:
plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')

### Pickle Experiment Results

In [None]:
with open(f'{EXP_DIR}/results_niid.pkl', 'wb') as file:
    pickle.dump(exp_log, file)