In [1]:
import matplotlib
import os
import time
import random
import numpy as np
#from torchviz import make_dot

In [2]:
from network import *
from framework import *
from preprocessing import dataPreprocessing
##Networks
import monai

cpu


In [3]:
isles_data_root = '/Users/sebastianotalora/work/postdoc/data/ISLES/federated/'
exp_root = '/Users/sebastianotalora/work/tmi/fedem'
modality = 'Tmax'
batch_size = 2


In [4]:
clients=["center1", "center2", "center4"]
#from SCAFFOLD manuscript, global_learning_rate should be = sqrt(#Samples sites)
local_epoch, global_epoch = 1, 20
#no sampling
K=len(clients)

local_lr, global_lr = 0.00932, 1.7 #np.sqrt(K)

_, centers_data_loaders, all_test_loader, _ = dataPreprocessing(isles_data_root, modality, 4, 2)

#move center 3 at the end of the dataloaders
tmp = centers_data_loaders[2]
centers_data_loaders[2]=centers_data_loaders[3]
centers_data_loaders[3]=tmp

options = {'K': K, 'l_epoch': local_epoch, 'B': batch_size, 'g_epoch': global_epoch, 'clients': clients,
           'l_lr':local_lr, 'g_lr':global_lr, 'dataloader':centers_data_loaders, 'suffix': 'FedRod', 
           'scaffold_controls': False}

In [5]:
#network present in each client
class UNet_custom(monai.networks.nets.UNet):
    def __init__(self, spatial_dims, in_channels, out_channels, channels,
                 strides, kernel_size, num_res_units, name, scaff=False, fed_rod=False):
        #call parent constructor
        super(UNet_custom, self).__init__(spatial_dims=spatial_dims,
                                          in_channels=in_channels,
                                          out_channels=out_channels, 
                                          channels=channels,
                                          strides=strides,
                                          kernel_size=kernel_size, 
                                          num_res_units=num_res_units)

        self.name = name
        self.control = {}
        self.delta_control = {}
        self.delta_y = {}
        if fed_rod:
            #Unet params sets for FedRod
            self.encoder_generic = {}
            self.decoder_generic = {}
            self.decoder_personalized = {}

In [6]:
class FedRod(Fedem):
    def __init__(self, options):
        super(FedRod, self).__init__(options)
        self.writer = SummaryWriter(f"runs/llr{options['l_lr']}_glr{options['g_lr']}_le{options['l_epoch']}_ge{options['g_epoch']}_{options['K']}sites_"+"FEDROD"+options['suffix'])
        self.K = options['K']
        self.name_encoder_layers = ["model.0", "model.1.submodule.0", "model.1.submodule.1.submodule.2.0",
                                    "model.1.submodule.1.submodule.0", "model.1.submodule.1.submodule.1"]
        
        self.name_decoder_layers  = ['model.1.submodule.1.submodule.2.1',
                                    'model.1.submodule.2', 'model.2']

        #server model
        self.nn = UNet_custom(spatial_dims=2,
                             in_channels=1,
                             out_channels=1,
                             channels=(16, 32, 64, 128),
                             strides=(2, 2, 2),
                             kernel_size = (3,3),
                             num_res_units=2,
                             name='server',
                             scaff=False,
                             fed_rod=True).to(device)
        
        #Global encoder - decoder (inlcuding personalized) layers init
        for k, v in self.nn.named_parameters():
            for enc_layer in self.name_encoder_layers:
                if enc_layer in k:
                    self.nn.encoder_generic[k] = copy.deepcopy(v.data)
            for dec_layer in self.name_decoder_layers:
                if dec_layer in k:
                    self.nn.decoder_generic[k] = copy.deepcopy(v.data)
                    self.nn.decoder_personalized[k] = copy.deepcopy(v.data)
                    
        #print(self.nn.decoder_generic)
        #clients of the federation
        self.nns = []
        for i in range(len(options['clients'])):
            temp = copy.deepcopy(self.nn)
            temp.name = options['clients'][i]
            temp.encoder_generic = copy.deepcopy(self.nn.encoder_generic)
            temp.decoder_generic = copy.deepcopy(self.nn.decoder_generic)
            temp.decoder_personalized = copy.deepcopy(self.nn.decoder_personalized)            
            self.nns.append(temp)
            
    

    def aggregation(self, index, global_lr, **kwargs):
        s = 0.0
        for j in index:
            # normal
            s += self.nns[j].len
                
        # Agregating the generic encoder from clients encoders
        for j in index:
            for k, v in self.nn.named_parameters():
                for enc_layer in self.name_encoder_layers:
                    if enc_layer in k:
                        v.data += self.nns[j].encoder_generic[k]  / len(index) #check other weightings here

        # Agregating the generic decoder from clients decoders
        for j in index:
            for k, v in self.nn.named_parameters():
                for dec_layer in self.name_decoder_layers:
                    if dec_layer in k:
                        v.data += self.nns[j].decoder_generic[k]  / len(index)  #check other weightings here


    def train(self, ann, dataloader_train, local_epoch, local_lr):
        #First the generic encoder-decoder are updated       
        ann.train()
        ann.len = len(dataloader_train)
                
        x = copy.deepcopy(ann)
        loss_function = monai.losses.DiceLoss(sigmoid=True,include_background=False)
        #One option is to set here the weights to 0 before the optimizer receives the parametets
        optimizer = torch.optim.Adam(ann.parameters(), lr=local_lr)

        for epoch in range(local_epoch):
            for batch_data in dataloader_train:
                #(1)Optimization of the Generic path here equation (8) of the paper
                for k, v in ann.named_parameters(): #Transfering data from the generic head
                    for dec_layer in self.name_decoder_layers:
                        if dec_layer in k:
                            v.data = copy.deepcopy(ann.decoder_generic[k]) #"Swapping the heads"
                    v.requires_grad = True #deriving gradients to all the generic layers
                
                inputs, labels = batch_data[0][:,:,:,:,0].to(device), batch_data[1][:,:,:,:,0].to(device)
                y_pred_generic = ann(inputs)
                loss = loss_function(y_pred_generic, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #print("=====================================================")
                #for k, v in ann.named_parameters():
                #    print(v.requires_grad)
                    
                #(2)Optimization of the Perzonalized path here equation (9) of the paper
                for k, v in ann.named_parameters():
                    for dec_layer in self.name_decoder_layers:
                        if dec_layer in k:
                            ann.decoder_generic[k] = copy.deepcopy(v.data) #Keeping the generic decoder data
                
                #(3) Keeping the generic output to add it later to the personalized
                output_generic = copy.deepcopy(y_pred_generic.detach().numpy())

                for k,v in ann.named_parameters():
                    for enc_layer_name in self.name_encoder_layers:
                        if enc_layer_name in k:
                            v.requires_grad = False

                for k, v in ann.named_parameters():
                    for dec_layer in self.name_decoder_layers:
                        if dec_layer in k:
                            v.data = copy.deepcopy(ann.decoder_personalized[k]) #"Swapping the heads"
                            v.requires_grad = True #Deriving fradients only wrt to the personalized head
               
                #for k, v in ann.named_parameters():
                #    print(v.requires_grad)

                output_personalized = ann(inputs) + torch.tensor(output_generic) #regularized personalized output
                loss = loss_function(output_personalized, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #print("=====================================================")
                                            
        return ann, loss.item() 

    def global_test(self, aggreg_dataloader_test):
        model = self.nn
        model.eval()
        
        #test the global model on each individual dataloader
        for k, client in enumerate(self.nns):
            print("testing on", client.name, "dataloader")
            test(model, self.dataloaders[k][2])
        
        #test the global model on aggregated dataloaders
        print("testing on all the data")
        test(model, aggreg_dataloader_test)

In [7]:
centers_data_loaders[0]

(<torch.utils.data.dataloader.DataLoader at 0x7fa3db6c3520>,
 <torch.utils.data.dataloader.DataLoader at 0x7fa3d0e98d00>,
 <torch.utils.data.dataloader.DataLoader at 0x7fa3d0e86a00>)

In [8]:
fed_rod = FedRod(options)

In [9]:
net, loss_center =  fed_rod.train(fed_rod.nns[0], fed_rod.dataloaders[0][0], 10, 0.0001)
print(loss_center)



0.904084324836731
