In [1]:
import pandas as pd
import numpy  as np
import scipy.io
import pickle
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, accuracy_score
import copy
import math
from tqdm import notebook
import matplotlib.pyplot as plt

# GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# reproducibility
np.random.seed(0)
torch.manual_seed(0)

# tensorboard
# %load_ext tensorboard
%reload_ext tensorboard
# if running in linux, please replace "rmdir /s /q" with "rm -rf" 
!rmdir /s /q runs
writer = SummaryWriter()

cuda


## Data preprocessing

In [2]:
# which data to use
use_affectnet = True
use_openface  = True

# read fold_ids.csv
fold_ids_df = pd.read_csv("../fold_ids.csv")

user_ids = fold_ids_df["participant_id"]
uuids    = fold_ids_df["uuid"]
mws      = fold_ids_df["MW"]

all_ids    = []
all_labels = []
all_data   = []

# not using data from all frames 
seq_length = 300               # 10 seconds * 30 FPS
sampling   = 10                # only 30 frames are used out of 300 frames

for user_id, uuid, mw in zip(user_ids, uuids, mws):
    if use_affectnet and not use_openface: # using affectnet features
        affectnet_features = scipy.io.loadmat("../features/affectnet/" + uuid + "_faces.mat")["features"]
        if affectnet_features.shape[0] == seq_length:
            affectnet_features = affectnet_features[range(0, seq_length, sampling)]
            all_ids.append(user_id)
            all_labels.append(1.0 if mw == 'n' else 0.0)
            all_data.append(affectnet_features)
        
    elif use_openface and not use_affectnet: # using openface features
        openface_pd        = pd.read_csv("../features/Openface_features/" + uuid + ".csv")
        openface_features  = openface_pd.drop(openface_pd.columns[:5], axis=1)      # drop first five columns as they are not needed
        if openface_features.shape[0] == seq_length:
            openface_features = openface_features.iloc[range(0, seq_length, sampling)]
            all_ids.append(user_id)
            all_labels.append(1.0 if mw == 'n' else 0.0)
            all_data.append(openface_features)
    
    else: # using both features
        affectnet_features = scipy.io.loadmat("../features/affectnet/" + uuid + "_faces.mat")["features"]
        openface_pd        = pd.read_csv("../features/Openface_features/" + uuid + ".csv")
        openface_features  = openface_pd.drop(openface_pd.columns[:5], axis=1)      # drop first five columns as they are not needed
        if affectnet_features.shape[0] == openface_features.shape[0]:               # drop samples that have different number of rows in affectnet features and in openface features
            combined_features = np.concatenate((affectnet_features, openface_features), axis = 1)
            combined_features = combined_features[range(0, seq_length, sampling)]
            all_ids.append(user_id)
            all_labels.append(1.0 if mw == 'n' else 0.0)
            all_data.append(combined_features)
            
# list to torch tensors
all_labels = torch.tensor(all_labels)
all_data   = torch.tensor(np.array(all_data)).float()       # (df -> np array -> torch tensor) faster than (df -> torch tensor)

# data normalization
data_mean = all_data.mean(axis=(0,1))
data_std  = all_data.std (axis=(0,1))
all_data  = (all_data - data_mean) / (data_std + 1e-6)

# some prints
print(all_data.shape)
print(all_labels.unique(return_counts=True))

torch.Size([1209, 30, 2757])
(tensor([0., 1.]), tensor([1006,  203]))


### Help functions

In [3]:
# average model parameters
def weighted_avg_params(params, weights = None):
    if weights == None:
        weights = [1.0] * len(params)
        
    params_avg = copy.deepcopy(params[0])
    for key in params_avg.keys():
        params_avg[key] *= weights[0]
        for i in range(1, len(params)):
            params_avg[key] += params[i][key] * weights[i]
        params_avg[key] = torch.div(params_avg[key], sum(weights))
    return params_avg

# get weights for each class
def get_weight_dict(labels):
    unique_labels, unique_counts = labels.unique(return_counts=True)
    weight_dict = {}
    sum_counts = sum(unique_counts)
    for label, counts in zip(unique_labels, unique_counts):
        weight_dict[label.item()] = sum_counts / counts
    sum_weights = sum(weight_dict.values())
    
    # normalize weights 
    for _, weight in weight_dict.items():
        weight /= sum_weights
    return weight_dict

# for a vector of labels, return a vector of same length of weights
def get_weight_vector(weight_dict, labels):
    weight_vec = [weight_dict[label.item()] for label in labels]
    weight_vec = torch.tensor(weight_vec, device = labels.device)
    return weight_vec

def weighted_avg(values, weights):
    sum_values = 0
    for v, w in zip(values, weights):
        sum_values += v *w
    return sum_values / sum(weights)

### Dataset class

In [4]:
class DatasetMW(torch.utils.data.Dataset):
    def __init__(self, labels, weight_dict, data, batch_size):
        self.labels      = labels
        self.data        = data
        self.weight_dict = weight_dict
        # normalize again, not necessary
        for _, weight in self.weight_dict.items():
            weight *= batch_size
        self.weights = get_weight_vector(self.weight_dict, labels)
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label   = self.labels [idx]
        weight  = self.weights[idx]
        feature = self.data   [idx]     
        return label, weight, feature

## LSTM model

In [5]:
# LSTM model 
class LSTM(torch.nn.Module):
    def __init__(self, 
                 i_size  = all_data[0].shape[1] , # input dimension, for this project it should be 709 (openface features) or 2048 (affectnet features) or 2757 (both)
                 h_size  = 128                  , # hidden state dimension
                 n_layer = 1                    , # number of hidden layers
                 bi_dir  = True                 , # bi-directional LSTM or not
                 dropout = 0.5                  , # dropout rate for LSTM and MLP
                 batch_size = 4                 ,
                 using_hn   = True              , # use hidden state or output of LSTM as input for mlp
                ):
        super(LSTM, self).__init__()
        
        D = 2 if bi_dir else 1
        
        self.h0 = torch.nn.Parameter(torch.randn(D * n_layer, batch_size, h_size), requires_grad=True)
        self.c0 = torch.nn.Parameter(torch.randn(D * n_layer, batch_size, h_size), requires_grad=True)
        
        self.lstm = torch.nn.LSTM(input_size = i_size, hidden_size = h_size, num_layers = n_layer, bidirectional = bi_dir, dropout = dropout, batch_first = True)
        self.mlp  = MLP(D * n_layer * h_size) if using_hn else MLP(D * h_size)
        
        self.using_hn = using_hn
    
    def forward(self, x):
        output, (hn, cn) = self.lstm(x, (self.h0, self.c0))
        
        # using hidden state of LSTM as input to MLP
        if self.using_hn:
            hn = hn.permute(1, 0, 2).flatten(1)
            hn = self.mlp(hn)
            return hn
        # using output of LSTM as input to MLP
        else:
            output = output[:, -1, :]          # many-to-one LSTM
            output = self.mlp(output)
            return output
    
# MLP model
class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.layer0  = torch.nn.Linear(input_size, 16)
        self.layer1  = torch.nn.Linear(16, 1)
        self.relu    = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.dropout = torch.nn.Dropout(p = 0.5)
        
    def forward(self, x):
        x = self.layer0(x) ; x = self.dropout(x) ; x = self.relu(x)
        x = self.layer1(x) ; x = self.sigmoid(x)
        return x

## Client

### Client class

In [6]:
# client is an object, which contains its own dataset, dataloader, local model copy (for simplification, only local parameters) and training loop fucntion
class Client(object):
    def __init__(self, client_id, client_labels, client_data, weight_dict, batch_size, train_test_split, global_model, learning_rate, l2_regular, client_epoch, reuse_optimizer = True):
        super(Client, self).__init__()
        self.client_id     = client_id
        self.client_labels = client_labels
        self.client_data   = client_data
        self.weight_dict   = weight_dict
        self.batch_size    = batch_size
        self.client_epoch  = client_epoch
        self.learning_rate = learning_rate
        self.l2_regular    = l2_regular
        self.optim_state   = None
        self.reuse_optimizer   = reuse_optimizer
        self.train_epoch_count = 0
        self.test_epoch_count  = 0
        
        # dataset and data loader
        dataset = DatasetMW(self.client_labels, self.weight_dict, self.client_data, self.batch_size)

        # train test split
        self.client_split = 0 < train_test_split < 1
        if self.client_split:
            train_test_split = max(train_test_split, 1 - train_test_split)
            train_size = int(train_test_split * len(dataset))
            test_size  = len(dataset) - train_size
            self.train_dataset, self.test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
        else:
            self.train_dataset = dataset
            self.test_dataset  = None

        # dataloader
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size = batch_size, shuffle = True, drop_last=True)
        self.test_loader  = torch.utils.data.DataLoader(self.test_dataset , batch_size = batch_size, shuffle = True, drop_last=True)
        
        # client model
        if global_model is None:
            self.client_model = LSTM(batch_size = batch_size)
        else:
            self.client_model = copy.deepcopy(global_model)
        
        # recording loss, accuracy, and f1 score during training and test
        self.train_loss = []
        self.train_accu = []
        self.train_f1   = []
        self.test_loss  = []
        self.test_accu  = []
        self.test_f1    = []
        
    ############################### local train ###################################
    def local_train(self):
        """ local training
        
        @ return      : avg training loss per epoch, avg training accuracy per epoch, avg training f1 score per epoch
        """
        
        # model to GPU
        self.client_model.to(device)
        self.client_model.train()
        
        # optimizer, and load previous optimizer state
        optimizer = torch.optim.Adam(self.client_model.parameters(), lr = self.learning_rate, weight_decay = self.l2_regular)
        if self.reuse_optimizer and self.optim_state is not None:
            optimizer.load_state_dict(self.optim_state)
        
        for epoch in range(self.client_epoch):
            sum_batch_loss = 0
            sum_batch_len  = 0
            epoch_labels   = []
            epoch_predicts = []
            
            for batch_id, (labels, weights, sequences) in enumerate(self.train_loader):
                sum_batch_len += len(labels)
                optimizer.zero_grad()
                
                labels    = labels   .to(device)
                weights   = weights  .to(device)
                sequences = sequences.to(device)
                predicts  = self.client_model(sequences).view(-1)
                
                loss = F.binary_cross_entropy(predicts, labels, reduction = 'mean', weight = weights)
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.client_model.parameters(), 1.0)          # clip gradient
                optimizer.step()
                
                sum_batch_loss += loss.item() * len(labels)
                predicts = predicts.round()
                epoch_labels  .append(labels)
                epoch_predicts.append(predicts)
            
            # calculate loss, accuracy, and f1 score
            epoch_label   = torch.concatenate(epoch_labels  ).detach().to('cpu').numpy()
            epoch_predict = torch.concatenate(epoch_predicts).detach().to('cpu').numpy()
            epoch_loss    = sum_batch_loss / sum_batch_len
            epoch_accu    = accuracy_score(epoch_label, epoch_predict)
            epoch_f1      = f1_score      (epoch_label, epoch_predict, zero_division = 0)         
            
            # add to tensorboard
            writer.add_scalar(self.client_id + '/train/loss', epoch_loss, self.train_epoch_count)
            writer.add_scalar(self.client_id + '/train/accu', epoch_accu, self.train_epoch_count)
            writer.add_scalar(self.client_id + '/train/f1'  , epoch_f1  , self.train_epoch_count)
            writer.flush()
            self.train_epoch_count += 1
        
            self.train_loss.append(epoch_loss)
            self.train_accu.append(epoch_accu)
            self.train_f1  .append(epoch_f1  )
            
        # model to CPU and save optimizer state
        self.client_model.to('cpu')
        if self.reuse_optimizer:
            self.optim_state = optimizer.state_dict()
            
        avg_epoch_loss = sum(self.train_loss[-self.client_epoch:]) / self.client_epoch
        avg_epoch_accu = sum(self.train_accu[-self.client_epoch:]) / self.client_epoch
        avg_epoch_f1   = sum(self.train_f1  [-self.client_epoch:]) / self.client_epoch

        return avg_epoch_loss, avg_epoch_accu, avg_epoch_f1
    
    ############################### local test ###################################
    def local_test(self):
        """ local test
        
        @ return      : avg test loss, avg test accuracy, avg test f1 score
        """
        
        assert(self.client_split)
        
        self.client_model.to(device)
        self.client_model.eval()
        sum_batch_loss = 0
        sum_batch_len  = 0
        epoch_labels   = []
        epoch_predicts = []
        
        with torch.no_grad():
            for batch_id, (labels, weights, sequences) in enumerate(self.test_loader):
                sum_batch_len += len(labels)
                labels    = labels   .to(device)
                weights   = weights  .to(device)
                sequences = sequences.to(device)
                predicts  = self.client_model(sequences).view(-1)

                loss = F.binary_cross_entropy(predicts, labels, reduction = 'mean', weight = weights)
                sum_batch_loss += loss.item() * len(labels)

                predicts = predicts.round()
                epoch_labels  .append(labels)
                epoch_predicts.append(predicts)

        epoch_label   = torch.concatenate(epoch_labels  ).detach().to('cpu').numpy()
        epoch_predict = torch.concatenate(epoch_predicts).detach().to('cpu').numpy()
        epoch_loss    = sum_batch_loss / sum_batch_len
        epoch_accu    = accuracy_score(epoch_label, epoch_predict)
        epoch_f1      = f1_score      (epoch_label, epoch_predict, zero_division = 0)         

        self.test_loss.append(epoch_loss)
        self.test_accu.append(epoch_accu)
        self.test_f1  .append(epoch_f1)

        writer.add_scalar(self.client_id + '/test/loss', epoch_loss, self.test_epoch_count)
        writer.add_scalar(self.client_id + '/test/accu', epoch_accu, self.test_epoch_count)
        writer.add_scalar(self.client_id + '/test/f1'  , epoch_f1  , self.test_epoch_count)
        writer.flush()
        self.test_epoch_count += 1
        
        self.client_model.to('cpu')
        
        return epoch_loss, epoch_accu, epoch_f1

### get_clients function

In [7]:
def get_clients(batch_size   , 
                global_model , 
                learning_rate, 
                l2_regular   , 
                client_epoch ,
                reuse_optimizer ,
                train_test_split, 
                all_ids    = all_ids, 
                all_labels = all_labels, 
                all_data   = all_data,
                global_preprocessing = False):

    clients = []
    
    # global data imbalance and global normalization
    if global_preprocessing:
        data_mean   = all_data.mean(axis=(0,1))
        data_std    = all_data.std (axis=(0,1))
        all_data    = (all_data - data_mean) / (data_std + 1e-6)
        weight_dict = get_weight_dict(all_labels)
    
    
    # data entries for each client
    unique_ids, unique_counts = np.unique(all_ids, return_counts=True)
    global_counts = 0
    for client_id, client_counts in zip(unique_ids, unique_counts):
        
        client_labels = all_labels[global_counts : global_counts + client_counts]
        client_data   = all_data  [global_counts : global_counts + client_counts]
        
        # local data imbalance and local normalization
        if not global_preprocessing:
            client_mean = client_data.mean(axis=(0,1))
            client_std  = client_data.std (axis=(0,1))
            client_data = (client_data - client_mean) / (client_std + 1e-6)
            weight_dict = get_weight_dict(client_labels)
            
        client = Client(client_id     = client_id     , 
                        client_labels = client_labels , 
                        client_data   = client_data   ,
                        weight_dict   = weight_dict   ,
                        batch_size    = batch_size    , 
                        global_model  = global_model  , 
                        learning_rate = learning_rate , 
                        l2_regular    = l2_regular    , 
                        client_epoch  = client_epoch  ,
                        reuse_optimizer  = reuse_optimizer  ,
                        train_test_split = train_test_split ,
                       )
                
        clients.append(client)
        global_counts += client_counts
            
    return clients

## Global update

In [8]:
def federated_learning(global_epoch     = 500   ,
                       client_epoch     = 10    ,
                       update_C         = 0.1   ,
                       batch_size       = 4     ,
                       learning_rate    = 1e-5  ,
                       l2_regular       = 1e-1  ,
                       reuse_optimizer  = True  ,
                       train_test_split = 0.8   ,
                       global_init      = True ,
                       global_preprocessing = False,
                       known_client_data_size = True,
                      ):
    """ simulation of federated learning
    
    @ global_epoch    : number of rounds for global update
    @ client_epoch    : number of rounds for local  update (parameter E in paper)
    @ update_C        : proportion of clients that are updated in each global round (parameter C in paper)
    @ batch_size      : local batch size
    @ learning_rate   : local learning rate
    @ l2_regular      : local l2 regularization
    @ reuse_optimizer : whether the state of local optimizer is re-used 
    @ train_test_split: local train test split, any value outside (0, 1) means not doing split
    
    @ global_init           : whether all local models have the same initilization
    @ global_preprocessing  : whether data is normalized globally or locally
    @ known_client_data_size: whether the global center knows how much data each client has
    
    """
    
    # global model
    global_model = LSTM(batch_size = batch_size)
    
    # determine how many clients are updated per global round
    num_client        = len(np.unique(all_ids, return_counts=True)[0])
    num_update_client = min(max(math.ceil(update_C * num_client), 1), num_client) # number of clients to update per round
    print("total number of clients:", num_client, " number of clients to be updated in each global round:", num_update_client)
    print()
    
    clients = get_clients(batch_size       = batch_size    , 
                          learning_rate    = learning_rate , 
                          l2_regular       = l2_regular    , 
                          client_epoch     = client_epoch  ,
                          reuse_optimizer  = reuse_optimizer ,
                          train_test_split = train_test_split,
                          global_model     = global_model if global_init else None,
                          global_preprocessing = global_preprocessing
                         )
    
    # global test dataset and data loader
    if 0 < train_test_split < 1:
        test_datasets = [c.test_dataset for c in clients]
        global_test_dataset = torch.utils.data.ConcatDataset(test_datasets)
        global_test_loader  = torch.utils.data.DataLoader(global_test_dataset, batch_size = batch_size, shuffle = True, drop_last=True)
        
    # global update loop
    for epoch in notebook.tqdm(range(global_epoch)):
        client_ids     = np.random.choice(num_client, num_update_client, replace=False)
        client_weights = []
        client_params  = []
        client_train_losses  = []
        client_train_accus   = []
        client_train_f1s     = []
        client_test_losses   = []
        client_test_accus    = []
        client_test_f1s      = []
        
        # local update
        for client_id in client_ids:
            client = clients[client_id]
            client_weights.append(len(client.client_data) if known_client_data_size else 1.0)
            
            # local train
            loss, accu, f1 = client.local_train()
            client_params.append(client.client_model.state_dict())
            client_train_losses.append(loss)
            client_train_accus .append(accu)
            client_train_f1s   .append(f1)
            
            # local test
            loss, accu, f1 = client.local_test()
            client_test_losses.append(loss)
            client_test_accus .append(accu)
            client_test_f1s   .append(f1)
        
        # print
        avg_client_train_loss = weighted_avg(values = client_train_losses, weights = client_weights)
        avg_client_train_accu = weighted_avg(values = client_train_accus , weights = client_weights)
        avg_client_train_f1   = weighted_avg(values = client_train_f1s   , weights = client_weights)
        avg_client_test_loss  = weighted_avg(values = client_test_losses , weights = client_weights)
        avg_client_test_accu  = weighted_avg(values = client_test_accus  , weights = client_weights)
        avg_client_test_f1    = weighted_avg(values = client_test_f1s    , weights = client_weights)
        print("global epoch: {:02d}".format(epoch), " updating clients:", client_ids)
        print("avg client train loss: {:.4f}".format(avg_client_train_loss), " avg client train accu: {:.4f}".format(avg_client_train_accu), " avg client train f1: {:.4f}".format(avg_client_train_f1))
        print("avg client test  loss: {:.4f}".format(avg_client_test_loss ), " avg client test  accu: {:.4f}".format(avg_client_test_accu ), " avg client test  f1: {:.4f}".format(avg_client_test_f1 ))
        
        # add to tensorboard
        writer.add_scalar('global/train/avg_loss', avg_client_train_loss, epoch)
        writer.add_scalar('global/train/avg_accu', avg_client_train_accu, epoch)
        writer.add_scalar('global/train/avg_f1'  , avg_client_train_f1  , epoch)
        writer.add_scalar('global/test/avg_loss' , avg_client_test_loss , epoch)
        writer.add_scalar('global/test/avg_accu' , avg_client_test_accu , epoch)
        writer.add_scalar('global/test/avg_f1'   , avg_client_test_f1   , epoch)
        writer.flush()
        
        # global model aggregation
        client_params .append(global_model.state_dict())
        client_weights.append(len(all_data) if known_client_data_size else 1.0)
        new_global_params = weighted_avg_params(params = client_params, weights = client_weights)
        
        # update global model
        global_model.load_state_dict(new_global_params)
        
        # update local models
        for client in clients:
            client.client_model.load_state_dict(new_global_params)
            
        # global test
        global_model.to(device)
        global_model.eval()
        sum_batch_loss = 0
        sum_batch_len  = 0
        global_test_labels   = []
        global_test_predicts = []
        
        with torch.no_grad():
            for batch_id, (labels, weights, sequences) in enumerate(global_test_loader):
                sum_batch_len += len(labels)
                labels    = labels   .to(device)
                weights   = weights  .to(device)
                sequences = sequences.to(device)
                predicts  = global_model(sequences).view(-1)

                loss = F.binary_cross_entropy(predicts, labels, reduction = 'mean', weight = weights)
                sum_batch_loss += loss.item() * len(labels)

                predicts = predicts.round()
                global_test_labels  .append(labels)
                global_test_predicts.append(predicts)

        global_test_label   = torch.concatenate(global_test_labels  ).detach().to('cpu').numpy()
        global_test_predict = torch.concatenate(global_test_predicts).detach().to('cpu').numpy()
        global_test_loss    = sum_batch_loss / sum_batch_len
        global_test_accu    = accuracy_score(global_test_label, global_test_predict)
        global_test_f1      = f1_score      (global_test_label, global_test_predict, zero_division = 0)
        
        # print
        print("global test loss: {:.4f}".format(global_test_loss), " global test accu: {:.4f}".format(global_test_accu), " global test f1: {:.4f}".format(global_test_f1))
        print("global test predicts:", np.unique(global_test_predict, return_counts=True), "global test truth:", np.unique(global_test_label, return_counts=True))
        
        # add to tensorboard
        writer.add_scalar('global/test/global_loss', avg_client_train_loss, epoch)
        writer.add_scalar('global/test/global_accu', avg_client_train_accu, epoch)
        writer.add_scalar('global/test/global_f1'  , avg_client_train_f1  , epoch)
        
        print("======================================================================================================================================================================================================")
        
        global_model.to('cpu')

## Run and plots

In [9]:
%reload_ext tensorboard
!rmdir /s /q runs
writer = SummaryWriter()
%tensorboard --logdir runs

federated_learning(update_C     = 0.1,
                   client_epoch = 50, 
                   global_epoch = 1000,
                   
                   learning_rate   = 1e-5,
                   l2_regular      = 1e-1,
                   reuse_optimizer = True,
                   
                   global_init = True,
                   global_preprocessing = False,
                   known_client_data_size = True,
                  )

Reusing TensorBoard on port 6006 (pid 11308), started 4 days, 19:30:23 ago. (Use '!kill 11308' to kill it.)



total number of clients: 15  number of clients to be updated in each global round: 2



  0%|          | 0/1000 [00:00<?, ?it/s]

global epoch: 00  updating clients: [1 6]
avg client train loss: 0.4281  avg client train accu: 0.9067  avg client train f1: 0.1071
avg client test  loss: 0.5120  avg client test  accu: 0.8558  avg client test  f1: 0.0000
global test loss: 0.8273  global test accu: 0.8361  global test f1: 0.0000
global test predicts: (array([0.], dtype=float32), array([244], dtype=int64)) global test truth: (array([0., 1.], dtype=float32), array([204,  40], dtype=int64))
global epoch: 01  updating clients: [2 4]
avg client train loss: 0.7090  avg client train accu: 0.8563  avg client train f1: 0.2093
avg client test  loss: 0.8678  avg client test  accu: 0.7802  avg client test  f1: 0.0000
global test loss: 0.8232  global test accu: 0.8361  global test f1: 0.0000
global test predicts: (array([0.], dtype=float32), array([244], dtype=int64)) global test truth: (array([0., 1.], dtype=float32), array([204,  40], dtype=int64))
global epoch: 02  updating clients: [ 4 12]
avg client train loss: 1.0471  avg cli