In [1]:
import matplotlib
import os
import time
import random
import numpy as np
from torchviz import make_dot

In [2]:
from network import *
from framework import *
from preprocessing import dataPreprocessing
##Networks
import monai

cuda


In [3]:
isles_data_root = '/str/data/ASAP/miccai22_data/isles/federated/'
exp_root = '/home/otarola/miccai22/fedem/'
modality = 'Tmax'
batch_size = 2


In [4]:
clients=["center1", "center2", "center4"]
#from SCAFFOLD manuscript, global_learning_rate should be = sqrt(#Samples sites)
local_epoch, global_epoch = 2, 20
#no sampling
K=len(clients)

local_lr, global_lr = 0.00932, 1.7 #np.sqrt(K)

_, centers_data_loaders, all_test_loader, _ = dataPreprocessing(isles_data_root, modality, 4, 2)

#move center 3 at the end of the dataloaders
tmp = centers_data_loaders[2]
centers_data_loaders[2]=centers_data_loaders[3]
centers_data_loaders[3]=tmp

options = {'K': K, 'l_epoch': local_epoch, 'B': batch_size, 'g_epoch': global_epoch, 'clients': clients,
           'l_lr':local_lr, 'g_lr':global_lr, 'dataloader':centers_data_loaders, 'suffix': 'FedRod', 
           'scaffold_controls': False}

In [5]:
#network present in each client
class UNet_custom(monai.networks.nets.UNet):
    def __init__(self, spatial_dims, in_channels, out_channels, channels,
                 strides, kernel_size, num_res_units, name, scaff=False, fed_rod=False):
        #call parent constructor
        super(UNet_custom, self).__init__(spatial_dims=spatial_dims,
                                          in_channels=in_channels,
                                          out_channels=out_channels, 
                                          channels=channels,
                                          strides=strides,
                                          kernel_size=kernel_size, 
                                          num_res_units=num_res_units)

        self.name = name
        if scaff:
            #control variables for SCAFFOLD
            self.control = {}
            self.delta_control = {}
            self.delta_y = {}
        if fed_rod:
            #Unet params sets for FedRod
            self.encoder_generic = {}
            self.decoder_generic = {}
            self.decoder_personalized = {}
            if options['scaffold_controls']:
                #In case we want FedRod + Scaffold controls
                self.control = {}
                self.delta_control = {}
                self.delta_y = {}

In [6]:
class FedRod(Fedem):
    def __init__(self, options):
        super(FedRod, self).__init__(options)
        self.writer = SummaryWriter(f"runs/llr{options['l_lr']}_glr{options['g_lr']}_le{options['l_epoch']}_ge{options['g_epoch']}_{options['K']}sites_"+"FEDROD"+options['suffix'])
        self.K = options['K']
        self.name_encoder_layers = ["model.0", "model.1.submodule.0", "model.1.submodule.1.submodule.2.0",
                                    "model.1.submodule.1.submodule.0", "model.1.submodule.1.submodule.1"]
        
        self.name_decoder_layers = name_decoder_layers = ['model.1.submodule.1.submodule.2.1',
                                                          'model.1.submodule.2', 'model.2']

        #server model
        self.nn = UNet_custom(spatial_dims=2,
                             in_channels=1,
                             out_channels=1,
                             channels=(16, 32, 64, 128),
                             strides=(2, 2, 2),
                             kernel_size = (3,3),
                             num_res_units=2,
                             name='server',
                             scaff=False,
                             fed_rod=True).to(device)
        
        #Global encoder - decoder (inlcuding personalized) layers init
        for k, v in self.nn.named_parameters():
            for enc_layer in self.name_encoder_layers:
                if enc_layer in k:
                    self.nn.encoder_generic[k] = v.data
            for dec_layer in self.name_decoder_layers:
                if dec_layer in k:
                    self.nn.decoder_generic[k] = v.data
                    self.nn.decoder_personalized[k] = v.data
                    
        #print(self.nn.decoder_generic)
        #clients of the federation
        self.nns = []
        for i in range(len(options['clients'])):
            temp = copy.deepcopy(self.nn)
            temp.name = options['clients'][i]
            temp.encoder_generic = copy.deepcopy(self.nn.encoder_generic)
            temp.decoder_generic = copy.deepcopy(self.nn.decoder_generic)
            temp.decoder_personalized = copy.deepcopy(self.nn.decoder_personalized)            
            self.nns.append(temp)
            
    

    def aggregation(self, index, global_lr, **kwargs):
        s = 0.0
        for j in index:
            # normal
            s += self.nns[j].len
                
        # Agregating the generic encoder from clients encoders
        for j in index:
            for k, v in self.nn.named_parameters():
                for enc_layer in self.name_encoder_layers:
                    if enc_layer in k:
                        v.data += self.nns[j].encoder_generic[k]  / len(index) #check other weightings here

        # Agregating the generic decoder from clients decoders
        for j in index:
            for k, v in self.nn.named_parameters():
                for dec_layer in self.name_decoder_layers:
                    if dec_layer in k:
                        v.data += self.nns[j].decoder_generic[k]  / len(index)  #check other weightings here


    def train(self, ann, dataloader_train, local_epoch, local_lr):
        #First the generic encoder-decoder are updated       
        ann.train()
        ann.len = len(dataloader_train)
                
        x = copy.deepcopy(ann)
        loss_function = monai.losses.DiceLoss(sigmoid=True,include_background=False)
        #One option is to set here the weights to 0 before the optimizer receives the parametets
        optimizer = torch.optim.Adam(ann.parameters(), lr=local_lr)

        for epoch in range(local_epoch):
            for batch_data in dataloader_train:
                #Optimization of the Generic path here
                inputs, labels = batch_data[0][:,:,:,:,0].to(device), batch_data[1][:,:,:,:,0].to(device)
                y_pred = ann(inputs)
                loss = loss_function(y_pred, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step(self.nn.control, ann.control) #performing SGD on the control variables
                
                #Optimization of the Personalized path here. No gradient derivation w.r.t. generic enc. and dec.
                optimizer.zero_grad()
                y_pred_personalized = ann_two_branch(inputs)
                loss = loss_function(y_pred, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step(self.nn.control, ann.control) #performing SGD on the control variables
                            
        # update c
        # c+ <- ci - c + 1/(E * lr) * (x-yi)
        temp = {}
        for k, v in ann.named_parameters():
            temp[k] = v.data.clone()
        for k, v in x.named_parameters():
            ann.control[k] = ann.control[k] - self.nn.control[k] + (v.data - temp[k]) / (local_epoch * local_lr)
            ann.delta_y[k] = temp[k] - v.data
            ann.delta_control[k] = ann.control[k] - x.control[k]
        return theta_m, psi_m, loss.item()

    def global_test(self, aggreg_dataloader_test):
        model = self.nn
        model.eval()
        
        #test the global model on each individual dataloader
        for k, client in enumerate(self.nns):
            print("testing on", client.name, "dataloader")
            test(model, self.dataloaders[k][2])
        
        #test the global model on aggregated dataloaders
        print("testing on all the data")
        test(model, aggreg_dataloader_test)

In [7]:
fed_rod = FedRod(options)

In [8]:
fed_rod.nn.named_parameters

<bound method Module.named_parameters of UNet_custom(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(

In [9]:
yhat = fed_rod.nn(torch.zeros([1, 1, 56, 224]).to(device))

In [10]:
for k,v in fed_rod.nn.named_parameters():
#    if "model.1.submodule.1.submodule.1" in k or "model.0" in k:
    print(k)
    print(v.shape)
    print("=====================")

model.0.conv.unit0.conv.weight
torch.Size([16, 1, 3, 3])
model.0.conv.unit0.conv.bias
torch.Size([16])
model.0.conv.unit0.adn.A.weight
torch.Size([1])
model.0.conv.unit1.conv.weight
torch.Size([16, 16, 3, 3])
model.0.conv.unit1.conv.bias
torch.Size([16])
model.0.conv.unit1.adn.A.weight
torch.Size([1])
model.0.residual.weight
torch.Size([16, 1, 3, 3])
model.0.residual.bias
torch.Size([16])
model.1.submodule.0.conv.unit0.conv.weight
torch.Size([32, 16, 3, 3])
model.1.submodule.0.conv.unit0.conv.bias
torch.Size([32])
model.1.submodule.0.conv.unit0.adn.A.weight
torch.Size([1])
model.1.submodule.0.conv.unit1.conv.weight
torch.Size([32, 32, 3, 3])
model.1.submodule.0.conv.unit1.conv.bias
torch.Size([32])
model.1.submodule.0.conv.unit1.adn.A.weight
torch.Size([1])
model.1.submodule.0.residual.weight
torch.Size([32, 16, 3, 3])
model.1.submodule.0.residual.bias
torch.Size([32])
model.1.submodule.1.submodule.0.conv.unit0.conv.weight
torch.Size([64, 32, 3, 3])
model.1.submodule.1.submodule.0.conv

In [11]:

name_encoder_layers = ["model.0", "model.1.submodule.0", "model.1.submodule.1.submodule.0","model.1.submodule.1.submodule.1"]

for k,v in fed_rod.nn.named_parameters():
    for enc_layer_name in name_encoder_layers:
        if enc_layer_name in k:
            print(k)
            print(v.shape)
            print("=====================")


model.0.conv.unit0.conv.weight
torch.Size([16, 1, 3, 3])
model.0.conv.unit0.conv.bias
torch.Size([16])
model.0.conv.unit0.adn.A.weight
torch.Size([1])
model.0.conv.unit1.conv.weight
torch.Size([16, 16, 3, 3])
model.0.conv.unit1.conv.bias
torch.Size([16])
model.0.conv.unit1.adn.A.weight
torch.Size([1])
model.0.residual.weight
torch.Size([16, 1, 3, 3])
model.0.residual.bias
torch.Size([16])
model.1.submodule.0.conv.unit0.conv.weight
torch.Size([32, 16, 3, 3])
model.1.submodule.0.conv.unit0.conv.bias
torch.Size([32])
model.1.submodule.0.conv.unit0.adn.A.weight
torch.Size([1])
model.1.submodule.0.conv.unit1.conv.weight
torch.Size([32, 32, 3, 3])
model.1.submodule.0.conv.unit1.conv.bias
torch.Size([32])
model.1.submodule.0.conv.unit1.adn.A.weight
torch.Size([1])
model.1.submodule.0.residual.weight
torch.Size([32, 16, 3, 3])
model.1.submodule.0.residual.bias
torch.Size([32])
model.1.submodule.1.submodule.0.conv.unit0.conv.weight
torch.Size([64, 32, 3, 3])
model.1.submodule.1.submodule.0.conv

In [12]:
model1 = copy.deepcopy(fed_rod.nn)
model2 = copy.deepcopy(fed_rod.nn)

In [13]:
fed_rod.nn

UNet_custom(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequential(
          

In [14]:
name_encoder_layers = ["model.0", "model.1.submodule.0", "model.1.submodule.1.submodule.2.0",
                            "model.1.submodule.1.submodule.0", "model.1.submodule.1.submodule.1"]

name_decoder_layers = name_decoder_layers = ['model.1.submodule.1.submodule.2.1',
                                                  'model.1.submodule.2', 'model.2']



for k,v in model2.named_parameters():
    for enc_layer_name in name_encoder_layers:
        if enc_layer_name in k:
            v.data = torch.zeros(v.shape).to(device)
            v.requires_grad = False

for k,v in model2.named_parameters():
    for dec_layer_name in name_decoder_layers:
        if dec_layer_name in k:
            v.data = torch.ones(v.shape).to(device)



In [15]:
for k,v in fed_rod.nn.named_parameters():
    print(k)
    print(v.shape)
    print(v.sum())    
    print("=====================")


model.0.conv.unit0.conv.weight
torch.Size([16, 1, 3, 3])
tensor(1.3818, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit0.conv.bias
torch.Size([16])
tensor(0.1102, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit0.adn.A.weight
torch.Size([1])
tensor(0.2500, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.conv.weight
torch.Size([16, 16, 3, 3])
tensor(1.5985, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.conv.bias
torch.Size([16])
tensor(0.0059, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.adn.A.weight
torch.Size([1])
tensor(0.2500, device='cuda:0', grad_fn=<SumBackward0>)
model.0.residual.weight
torch.Size([16, 1, 3, 3])
tensor(4.1116, device='cuda:0', grad_fn=<SumBackward0>)
model.0.residual.bias
torch.Size([16])
tensor(-0.7261, device='cuda:0', grad_fn=<SumBackward0>)
model.1.submodule.0.conv.unit0.conv.weight
torch.Size([32, 16, 3, 3])
tensor(-1.9756, device='cuda:0', grad_fn=<SumBackward0>)
model.1.submodule.0.conv.unit0.c

In [16]:
for k,v in model1.named_parameters():
    print(k)
    print(v.shape)
    print(v.sum())    
    print("=====================")

model.0.conv.unit0.conv.weight
torch.Size([16, 1, 3, 3])
tensor(1.3818, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit0.conv.bias
torch.Size([16])
tensor(0.1102, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit0.adn.A.weight
torch.Size([1])
tensor(0.2500, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.conv.weight
torch.Size([16, 16, 3, 3])
tensor(1.5985, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.conv.bias
torch.Size([16])
tensor(0.0059, device='cuda:0', grad_fn=<SumBackward0>)
model.0.conv.unit1.adn.A.weight
torch.Size([1])
tensor(0.2500, device='cuda:0', grad_fn=<SumBackward0>)
model.0.residual.weight
torch.Size([16, 1, 3, 3])
tensor(4.1116, device='cuda:0', grad_fn=<SumBackward0>)
model.0.residual.bias
torch.Size([16])
tensor(-0.7261, device='cuda:0', grad_fn=<SumBackward0>)
model.1.submodule.0.conv.unit0.conv.weight
torch.Size([32, 16, 3, 3])
tensor(-1.9756, device='cuda:0', grad_fn=<SumBackward0>)
model.1.submodule.0.conv.unit0.c

In [17]:
name_decoder_layers = ['model.1.submodule.1.submodule.2.1','model.1.submodule.2', 'model.2']

In [18]:
for k,v in model2.named_parameters():
    print(k)
    print(v.shape)
    print(v.sum())    
    print("=====================")

model.0.conv.unit0.conv.weight
torch.Size([16, 1, 3, 3])
tensor(0., device='cuda:0')
model.0.conv.unit0.conv.bias
torch.Size([16])
tensor(0., device='cuda:0')
model.0.conv.unit0.adn.A.weight
torch.Size([1])
tensor(0., device='cuda:0')
model.0.conv.unit1.conv.weight
torch.Size([16, 16, 3, 3])
tensor(0., device='cuda:0')
model.0.conv.unit1.conv.bias
torch.Size([16])
tensor(0., device='cuda:0')
model.0.conv.unit1.adn.A.weight
torch.Size([1])
tensor(0., device='cuda:0')
model.0.residual.weight
torch.Size([16, 1, 3, 3])
tensor(0., device='cuda:0')
model.0.residual.bias
torch.Size([16])
tensor(0., device='cuda:0')
model.1.submodule.0.conv.unit0.conv.weight
torch.Size([32, 16, 3, 3])
tensor(0., device='cuda:0')
model.1.submodule.0.conv.unit0.conv.bias
torch.Size([32])
tensor(0., device='cuda:0')
model.1.submodule.0.conv.unit0.adn.A.weight
torch.Size([1])
tensor(0., device='cuda:0')
model.1.submodule.0.conv.unit1.conv.weight
torch.Size([32, 32, 3, 3])
tensor(0., device='cuda:0')
model.1.submod