In [1]:
import sys

from utils_fedavg import get_optimizer
from weighting_schemes import average_weights, average_weights_beta, average_weights_softmax

import torch
from torch import nn, optim

import monai
import numpy as np
import nibabel as nib
from glob import glob
from matplotlib import pyplot as plt
import copy
from scipy.spatial import distance_matrix
from monai.transforms import (
    Activations,
    AsChannelFirstD,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
    EnsureType,
    Resized
)

from monai.data import (
    ArrayDataset, GridPatchDataset, create_test_image_3d, PatchIter)
from monai.utils import first
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import DataLoader, decollate_batch
from natsort import natsorted
import umap

from torch.utils.tensorboard import SummaryWriter
from sklearn.preprocessing import StandardScaler
import sklearn


In [2]:
LOCATION = 'scan' #laptop
if LOCATION == 'scan':
    isles_data_root = '/str/data/ASAP/miccai22_data/isles/federated/'
    exp_root = '/home/otarola/miccai22/fedem/'

if LOCATION == 'laptop':
    isles_data_root = '/data/ASAP/miccai22_data/isles/federated/'

In [3]:
#Hyperparams cell
modality = 'Tmax'
batch_size = 2
num_epochs = 300
learning_rate = 0.000932#lrs[0] #To comment in the loop
weighting_scheme = 'FEDAVG'
beta_val=0.9

In [4]:
def get_train_valid_test_partitions(modality, isles_data_root, num_centers=4):
    centers_partitions = [[] for i in range(num_centers)]
    for center_num in range(1,num_centers+1):
        center_paths_train  = sorted(glob(isles_data_root+'center'+str(center_num)+'/train'+'/**/*'+modality+'*/*.nii'))
        center_paths_valid  = sorted(glob(isles_data_root+'center'+str(center_num)+'/valid'+'/**/*'+modality+'*/*.nii'))
        center_paths_test   = sorted(glob(isles_data_root+'center'+str(center_num)+'/test'+'/**/*'+modality+'*/*.nii'))
        center_lbl_paths_train  = sorted(glob(isles_data_root+'center'+str(center_num)+'/train'+'/**/*OT*/*nii'))
        center_lbl_paths_valid  = sorted(glob(isles_data_root+'center'+str(center_num)+'/valid'+'/**/*OT*/*nii'))
        center_lbl_paths_test  = sorted(glob(isles_data_root+'center'+str(center_num)+'/test'+'/**/*OT*/*nii'))
        centers_partitions[center_num-1] = [[center_paths_train,center_paths_valid,center_paths_test],[center_lbl_paths_train,center_lbl_paths_valid,center_lbl_paths_test]]
    return centers_partitions

In [5]:
partitions_paths = get_train_valid_test_partitions(modality, isles_data_root, 4)

In [6]:
len(partitions_paths[0][0][2]),len(partitions_paths[0][1][2]) #idx_order: center,img_label,partition

(4, 4)

In [7]:
#creating the dataloader for 10 ISLES volumes using the T_max and the CBF
#For cbf we are windowing 1-1024
#For tmax we'll window 0-60
#For CBV we'll window 0-200
if modality =='CBF':
    max_intensity = 1200
if modality =='CBV':
    max_intensity = 200
if modality =='Tmax' or modality =='MTT':
    max_intensity = 30

In [8]:
imtrans = Compose(
    [   LoadImage(image_only=True),
        #RandScaleIntensity( factors=0.1, prob=0.5),
        ScaleIntensity(minv=0.0, maxv=max_intensity),
        AddChannel(),
        RandRotate90( prob=0.5, spatial_axes=[0, 1]),
        RandSpatialCrop((224, 224,1), random_size=False),
        EnsureType(),
        #Resized
    ]
)

segtrans = Compose(
    [   LoadImage(image_only=True),
        AddChannel(),
        RandRotate90( prob=0.5, spatial_axes=[0, 1]),
        RandSpatialCrop((224, 224,1), random_size=False),
        EnsureType(),
        #Resized
    ]
)


imtrans_neutral = Compose(
    [   LoadImage(image_only=True),
        #RandScaleIntensity( factors=0.1, prob=0.5),
        ScaleIntensity(minv=0.0, maxv=max_intensity),
        AddChannel(),
        RandSpatialCrop((224, 224,1), random_size=False),
        EnsureType(),
        #Resized
    ]
)

segtrans_neutral = Compose(
    [   LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((224, 224,1), random_size=False),
        EnsureType(),
        #Resized
    ]
)

imtrans_test = Compose(
    [   LoadImage(image_only=True),
        ScaleIntensity(minv=0.0, maxv=max_intensity),
        AddChannel(),
        #RandSpatialCrop((224, 224,1), random_size=False), In test we would like to process ALL slices
        EnsureType(),
        #Resized
    ]
)

segtrans_test = Compose(
    [   LoadImage(image_only=True),
        AddChannel(),
        #RandSpatialCrop((224, 224,1), random_size=False),
        EnsureType(),
        #Resized
    ]
)


In [9]:
def center_dataloaders(partitions_paths_center, batch_size=2):#
    center_ds_train = ArrayDataset(partitions_paths_center[0][0], imtrans, partitions_paths_center[1][0], segtrans)
    center_train_loader   = torch.utils.data.DataLoader(
        center_ds_train, batch_size=batch_size, num_workers=1, pin_memory=torch.cuda.is_available()
    )

    center_ds_valid = ArrayDataset(partitions_paths_center[0][1], imtrans, partitions_paths_center[1][1], segtrans)
    center_valid_loader   = torch.utils.data.DataLoader(
        center_ds_valid, batch_size=batch_size, num_workers=1, pin_memory=torch.cuda.is_available()
    )

    center_ds_test = ArrayDataset(partitions_paths_center[0][2], imtrans_test, partitions_paths_center[1][2], segtrans_test)
    center_test_loader   = torch.utils.data.DataLoader(
        center_ds_test, batch_size=batch_size, num_workers=1, pin_memory=torch.cuda.is_available()
    )
    return center_train_loader, center_valid_loader, center_test_loader

In [10]:
centers_data_loaders = []
for i in range(len(partitions_paths)):#Adding all the centers data loaders
    centers_data_loaders.append(center_dataloaders(partitions_paths[i],batch_size))

In [11]:
len(centers_data_loaders)

4

In [12]:
partitions_test_imgs = [partitions_paths[i][0][2] for i in range(len(partitions_paths))]
partitions_test_lbls = [partitions_paths[i][1][2] for i in range(len(partitions_paths))]

partitions_valid_imgs = [partitions_paths[i][0][1] for i in range(len(partitions_paths))]
partitions_valid_lbls = [partitions_paths[i][1][1] for i in range(len(partitions_paths))]

#For selecting the model and testing in the heldout partition we collect the valid and test data from ALL centers
all_ds_test = ArrayDataset([i for l in partitions_test_imgs for i in l],
                            imtrans, [i for l in partitions_test_lbls for i in l],
                            segtrans)
all_test_loader   = torch.utils.data.DataLoader(
    all_ds_test, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()
)


all_ds_valid = ArrayDataset([i for l in partitions_valid_imgs for i in l],
                            imtrans, [i for l in partitions_valid_lbls for i in l],
                            segtrans)
all_valid_loader   = torch.utils.data.DataLoader(
    all_ds_valid, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()
)


In [13]:
trainloaders_lengths = [len(centers_data_loaders[i][0].dataset) for i in [0,1,3]] #We don't take the Siemens training case

beta = 0.9

weight_classes = [(1-beta)/(1-np.power(beta,length)) for length in trainloaders_lengths]
print(trainloaders_lengths)
print(weight_classes)

[15, 41, 9]
[0.1259273180814282, 0.10134821448516575, 0.16324411477092318]


In [14]:
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

global_model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128),
        strides=(2, 2, 2),
        kernel_size = (3,3),
        #dropout = 0.2,
        num_res_units=2,
).to(device)

loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(global_model.parameters(), lr=learning_rate)



val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter(comment=modality+"_"+weighting_scheme+"_LR_"+str(learning_rate)+"_BATCH_"+str(batch_size))


batch_data = next(iter(centers_data_loaders[0][0]))



cuda


In [15]:
inputs, labels = batch_data[0][:,:,:,:,0].to(device),batch_data[1][:,:,:,:,0].to(device)
#torch.swapaxes(batch_data[0][0], 1, -1).to(device), torch.swapaxes(batch_data[1][0], 1, -1).to(device).to(device)
print(inputs.shape,labels.shape)


global_model.train()

# copy weights
global_weights = global_model.state_dict()

#Testing that the model works in one iteration
optimizer.zero_grad()
outputs = global_model(inputs)
loss    = loss_function(outputs, labels)
loss.backward()
optimizer.step()

torch.Size([2, 1, 224, 224]) torch.Size([2, 1, 224, 224])


In [16]:
global_model.train()
epoch_loss = 0
train_loss, train_dice = [], []

In [17]:
def perform_one_local_epoch(train_loader, local_model, local_optimizer, loss_function, client_idx=0):
    batch_loss_client = []
    for batch_data in train_loader:
        inputs, labels = batch_data[0][:,:,:,:,0].to(device), batch_data[1][:,:,:,:,0].to(device)
        local_model.zero_grad()        
        outputs = local_model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        local_optimizer.step()
        batch_loss_client.append(loss.item())
    avg_loss_client = copy.deepcopy(sum(batch_loss_client) / len(batch_loss_client))
    print("Loss for client: " +str(client_idx)+" :" +str(avg_loss_client))
    return avg_loss_client   

In [18]:
for epoch in range(num_epochs):
    print("-" * 10)
    print(f"local epoch {epoch + 1}/{num_epochs}")

    local_weights, local_losses = [], []
    global_model.train()

    global_weights = global_model.state_dict()
    modelc1, modelc2, modelc4 = copy.deepcopy(global_model), copy.deepcopy(global_model), copy.deepcopy(global_model)
    optimizerc1 = torch.optim.Adam(modelc1.parameters(), learning_rate)
    optimizerc2 = torch.optim.Adam(modelc2.parameters(), learning_rate)
    optimizerc4 = torch.optim.Adam(modelc4.parameters(), learning_rate)

    modelc1.train()
    modelc2.train()
    modelc4.train()

    print(f"local epoch for train_loader 1: {epoch + 1}/{num_epochs}")
    loss_c1 = perform_one_local_epoch(centers_data_loaders[0][0], modelc1, optimizerc1, loss_function, client_idx=1)
    local_losses.append(loss_c1)
    print("Loss C1: " + str(local_losses[-1]))

    print(f"local epoch for train_loader 2: {epoch + 1}/{num_epochs}")
    loss_c2 = perform_one_local_epoch(centers_data_loaders[1][0], modelc2, optimizerc2, loss_function, client_idx=2)
    local_losses.append(loss_c2)
    print("Loss C2: " + str(local_losses[-1]))

    #C3 is the Siemens data loader for which we have only one data point
    print(f"local epoch for train_loader 4: {epoch + 1}/{num_epochs}")
    loss_c4 = perform_one_local_epoch(centers_data_loaders[3][0], modelc4, optimizerc4, loss_function, client_idx=4)
    local_losses.append(loss_c4)
    print("Loss C4: " + str(local_losses[-1]))

    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)


    print(f"train_loss: {train_loss[-1]:.4f}")
    writer.add_scalar("train_loss", loss_avg, epoch)

    #Agregating the weights with the selected weighting scheme
    if weighting_scheme =='FEDAVG':
        global_weights = average_weights([copy.deepcopy(modelc1.state_dict()),copy.deepcopy(modelc2.state_dict()),copy.deepcopy(modelc4.state_dict())])
    if weighting_scheme =='BETA':
        global_weights = average_weights_beta([copy.deepcopy(modelc1.state_dict()),copy.deepcopy(modelc2.state_dict()),copy.deepcopy(modelc4.state_dict())],trainloaders_lengths,beta_val)
    if weighting_scheme =='SOFTMAX':
        global_weights = average_weights_softmax([copy.deepcopy(modelc1.state_dict()),copy.deepcopy(modelc2.state_dict()),copy.deepcopy(modelc4.state_dict())],trainloaders_lengths)

    
    # Update global weights with the averaged model weights.
    global_model.load_state_dict(global_weights)

    if (epoch + 1) % val_interval == 0:
        global_model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in all_valid_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = global_model(val_images[:,:,:,:,0])
                val_outputs = val_outputs>0.5 #This assumes one slice in the last dim
                dice_metric(y_pred=val_outputs, y=val_labels[:,:,:,:,0])
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(global_model.state_dict(), modality+'_beta_'+str(beta)+'_'+weighting_scheme+'_best_metric_model_segmentation2d_array.pth')
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)


----------
local epoch 1/300
local epoch for train_loader 1: 1/300
Loss for client: 1 :0.9819399565458298
Loss C1: 0.9819399565458298
local epoch for train_loader 2: 1/300
Loss for client: 2 :0.9443082894597735
Loss C2: 0.9443082894597735
local epoch for train_loader 4: 1/300
Loss for client: 4 :0.9876788020133972
Loss C4: 0.9876788020133972
train_loss: 0.9713
----------
local epoch 2/300
local epoch for train_loader 1: 2/300
Loss for client: 1 :0.975497879087925
Loss C1: 0.975497879087925
local epoch for train_loader 2: 2/300
Loss for client: 2 :0.9328742594945998
Loss C2: 0.9328742594945998
local epoch for train_loader 4: 2/300
Loss for client: 4 :0.9819965481758117
Loss C4: 0.9819965481758117
train_loss: 0.9635
saved new best metric model
current epoch: 2 current mean dice: 0.3677 best mean dice: 0.3677 at epoch 2
----------
local epoch 3/300
local epoch for train_loader 1: 3/300
Loss for client: 1 :0.9712964072823524
Loss C1: 0.9712964072823524
local epoch for train_loader 2: 3/300

Loss for client: 1 :0.9608909636735916
Loss C1: 0.9608909636735916
local epoch for train_loader 2: 21/300
Loss for client: 2 :0.9040542131378537
Loss C2: 0.9040542131378537
local epoch for train_loader 4: 21/300
Loss for client: 4 :0.9698583722114563
Loss C4: 0.9698583722114563
train_loss: 0.9449
----------
local epoch 22/300
local epoch for train_loader 1: 22/300
Loss for client: 1 :0.9606833234429359
Loss C1: 0.9606833234429359
local epoch for train_loader 2: 22/300
Loss for client: 2 :0.9032550227074396
Loss C2: 0.9032550227074396
local epoch for train_loader 4: 22/300
Loss for client: 4 :0.9694042325019836
Loss C4: 0.9694042325019836
train_loss: 0.9444
current epoch: 22 current mean dice: 0.4684 best mean dice: 0.4715 at epoch 14
----------
local epoch 23/300
local epoch for train_loader 1: 23/300
Loss for client: 1 :0.960291288793087
Loss C1: 0.960291288793087
local epoch for train_loader 2: 23/300
Loss for client: 2 :0.9005788621448335
Loss C2: 0.9005788621448335
local epoch for 

Loss for client: 2 :0.8724982766878038
Loss C2: 0.8724982766878038
local epoch for train_loader 4: 41/300
Loss for client: 4 :0.9605247735977173
Loss C4: 0.9605247735977173
train_loss: 0.9274
----------
local epoch 42/300
local epoch for train_loader 1: 42/300
Loss for client: 1 :0.9484355375170708
Loss C1: 0.9484355375170708
local epoch for train_loader 2: 42/300
Loss for client: 2 :0.8701164694059462
Loss C2: 0.8701164694059462
local epoch for train_loader 4: 42/300
Loss for client: 4 :0.9599639654159546
Loss C4: 0.9599639654159546
train_loss: 0.9262
current epoch: 42 current mean dice: 0.4668 best mean dice: 0.4954 at epoch 30
----------
local epoch 43/300
local epoch for train_loader 1: 43/300
Loss for client: 1 :0.9475731924176216
Loss C1: 0.9475731924176216
local epoch for train_loader 2: 43/300
Loss for client: 2 :0.8686673385756356
Loss C2: 0.8686673385756356
local epoch for train_loader 4: 43/300
Loss for client: 4 :0.9592228651046752
Loss C4: 0.9592228651046752
train_loss: 0.

Loss for client: 4 :0.944844913482666
Loss C4: 0.944844913482666
train_loss: 0.8988
----------
local epoch 62/300
local epoch for train_loader 1: 62/300
Loss for client: 1 :0.9272659495472908
Loss C1: 0.9272659495472908
local epoch for train_loader 2: 62/300
Loss for client: 2 :0.8221344153086344
Loss C2: 0.8221344153086344
local epoch for train_loader 4: 62/300
Loss for client: 4 :0.943990957736969
Loss C4: 0.943990957736969
train_loss: 0.8978
current epoch: 62 current mean dice: 0.4612 best mean dice: 0.4954 at epoch 30
----------
local epoch 63/300
local epoch for train_loader 1: 63/300
Loss for client: 1 :0.9258951172232628
Loss C1: 0.9258951172232628
local epoch for train_loader 2: 63/300
Loss for client: 2 :0.8166812360286713
Loss C2: 0.8166812360286713
local epoch for train_loader 4: 63/300
Loss for client: 4 :0.9429759025573731
Loss C4: 0.9429759025573731
train_loss: 0.8952
----------
local epoch 64/300
local epoch for train_loader 1: 64/300
Loss for client: 1 :0.92447542399168

Loss for client: 1 :0.8960410133004189
Loss C1: 0.8960410133004189
local epoch for train_loader 2: 82/300
Loss for client: 2 :0.7534278773126148
Loss C2: 0.7534278773126148
local epoch for train_loader 4: 82/300
Loss for client: 4 :0.9215640187263489
Loss C4: 0.9215640187263489
train_loss: 0.8570
current epoch: 82 current mean dice: 0.4898 best mean dice: 0.5056 at epoch 76
----------
local epoch 83/300
local epoch for train_loader 1: 83/300
Loss for client: 1 :0.8943749815225601
Loss C1: 0.8943749815225601
local epoch for train_loader 2: 83/300
Loss for client: 2 :0.7490462831088475
Loss C2: 0.7490462831088475
local epoch for train_loader 4: 83/300
Loss for client: 4 :0.9198361873626709
Loss C4: 0.9198361873626709
train_loss: 0.8544
----------
local epoch 84/300
local epoch for train_loader 1: 84/300
Loss for client: 1 :0.8924264907836914
Loss C1: 0.8924264907836914
local epoch for train_loader 2: 84/300
Loss for client: 2 :0.7458713139806475
Loss C2: 0.7458713139806475
local epoch fo

Loss for client: 2 :0.6728959168706622
Loss C2: 0.6728959168706622
local epoch for train_loader 4: 102/300
Loss for client: 4 :0.8929425001144409
Loss C4: 0.8929425001144409
train_loss: 0.8072
current epoch: 102 current mean dice: 0.4749 best mean dice: 0.5056 at epoch 76
----------
local epoch 103/300
local epoch for train_loader 1: 103/300
Loss for client: 1 :0.853816457092762
Loss C1: 0.853816457092762
local epoch for train_loader 2: 103/300
Loss for client: 2 :0.6724902093410492
Loss C2: 0.6724902093410492
local epoch for train_loader 4: 103/300
Loss for client: 4 :0.8911624550819397
Loss C4: 0.8911624550819397
train_loss: 0.8058
----------
local epoch 104/300
local epoch for train_loader 1: 104/300
Loss for client: 1 :0.8514696732163429
Loss C1: 0.8514696732163429
local epoch for train_loader 2: 104/300
Loss for client: 2 :0.6654774617581141
Loss C2: 0.6654774617581141
local epoch for train_loader 4: 104/300
Loss for client: 4 :0.8901310443878174
Loss C4: 0.8901310443878174
train_

Loss for client: 2 :0.5938899900232043
Loss C2: 0.5938899900232043
local epoch for train_loader 4: 122/300
Loss for client: 4 :0.8629769802093505
Loss C4: 0.8629769802093505
train_loss: 0.7559
current epoch: 122 current mean dice: 0.4764 best mean dice: 0.5056 at epoch 76
----------
local epoch 123/300
local epoch for train_loader 1: 123/300
Loss for client: 1 :0.8054928705096245
Loss C1: 0.8054928705096245
local epoch for train_loader 2: 123/300
Loss for client: 2 :0.5868817454292661
Loss C2: 0.5868817454292661
local epoch for train_loader 4: 123/300
Loss for client: 4 :0.8614255547523498
Loss C4: 0.8614255547523498
train_loss: 0.7513
----------
local epoch 124/300
local epoch for train_loader 1: 124/300
Loss for client: 1 :0.8102381341159344
Loss C1: 0.8102381341159344
local epoch for train_loader 2: 124/300
Loss for client: 2 :0.5803637249129159
Loss C2: 0.5803637249129159
local epoch for train_loader 4: 124/300
Loss for client: 4 :0.8588380217552185
Loss C4: 0.8588380217552185
trai

Loss for client: 2 :0.49983293811480206
Loss C2: 0.49983293811480206
local epoch for train_loader 4: 142/300
Loss for client: 4 :0.8316573023796081
Loss C4: 0.8316573023796081
train_loss: 0.6953
current epoch: 142 current mean dice: 0.4211 best mean dice: 0.5056 at epoch 76
----------
local epoch 143/300
local epoch for train_loader 1: 143/300
Loss for client: 1 :0.771959375590086
Loss C1: 0.771959375590086
local epoch for train_loader 2: 143/300
Loss for client: 2 :0.4962914543492453
Loss C2: 0.4962914543492453
local epoch for train_loader 4: 143/300
Loss for client: 4 :0.8406022548675537
Loss C4: 0.8406022548675537
train_loss: 0.7030
----------
local epoch 144/300
local epoch for train_loader 1: 144/300
Loss for client: 1 :0.7482394315302372
Loss C1: 0.7482394315302372
local epoch for train_loader 2: 144/300
Loss for client: 2 :0.4909639528819493
Loss C2: 0.4909639528819493
local epoch for train_loader 4: 144/300
Loss for client: 4 :0.8297368228435517
Loss C4: 0.8297368228435517
trai

Loss for client: 2 :0.3369400856040773
Loss C2: 0.3369400856040773
local epoch for train_loader 4: 162/300
Loss for client: 4 :0.7853735089302063
Loss C4: 0.7853735089302063
train_loss: 0.5877
current epoch: 162 current mean dice: 0.2684 best mean dice: 0.5056 at epoch 76
----------
local epoch 163/300
local epoch for train_loader 1: 163/300
Loss for client: 1 :0.6374366320669651
Loss C1: 0.6374366320669651
local epoch for train_loader 2: 163/300
Loss for client: 2 :0.32393640137854074
Loss C2: 0.32393640137854074
local epoch for train_loader 4: 163/300
Loss for client: 4 :0.7777342796325684
Loss C4: 0.7777342796325684
train_loss: 0.5797
----------
local epoch 164/300
local epoch for train_loader 1: 164/300
Loss for client: 1 :0.6217965818941593
Loss C1: 0.6217965818941593
local epoch for train_loader 2: 164/300
Loss for client: 2 :0.33183020920980544
Loss C2: 0.33183020920980544
local epoch for train_loader 4: 164/300
Loss for client: 4 :0.7752050399780274
Loss C4: 0.7752050399780274


Loss for client: 2 :0.19001045113518125
Loss C2: 0.19001045113518125
local epoch for train_loader 4: 182/300
Loss for client: 4 :0.7325782895088195
Loss C4: 0.7325782895088195
train_loss: 0.4782
current epoch: 182 current mean dice: 0.0906 best mean dice: 0.5056 at epoch 76
----------
local epoch 183/300
local epoch for train_loader 1: 183/300
Loss for client: 1 :0.4611539654433727
Loss C1: 0.4611539654433727
local epoch for train_loader 2: 183/300
Loss for client: 2 :0.18590626688230605
Loss C2: 0.18590626688230605
local epoch for train_loader 4: 183/300
Loss for client: 4 :0.7342528402805328
Loss C4: 0.7342528402805328
train_loss: 0.4604
----------
local epoch 184/300
local epoch for train_loader 1: 184/300
Loss for client: 1 :0.4732143208384514
Loss C1: 0.4732143208384514
local epoch for train_loader 2: 184/300
Loss for client: 2 :0.19013670086860657
Loss C2: 0.19013670086860657
local epoch for train_loader 4: 184/300
Loss for client: 4 :0.7291061520576477
Loss C4: 0.729106152057647

Loss for client: 2 :0.12681683188393003
Loss C2: 0.12681683188393003
local epoch for train_loader 4: 202/300
Loss for client: 4 :0.7166641116142273
Loss C4: 0.7166641116142273
train_loss: 0.4026
current epoch: 202 current mean dice: 0.0665 best mean dice: 0.5056 at epoch 76
----------
local epoch 203/300
local epoch for train_loader 1: 203/300
Loss for client: 1 :0.4306057542562485
Loss C1: 0.4306057542562485
local epoch for train_loader 2: 203/300
Loss for client: 2 :0.120218745299748
Loss C2: 0.120218745299748
local epoch for train_loader 4: 203/300
Loss for client: 4 :0.716621595621109
Loss C4: 0.716621595621109
train_loss: 0.4225
----------
local epoch 204/300
local epoch for train_loader 1: 204/300
Loss for client: 1 :0.3700650632381439
Loss C1: 0.3700650632381439
local epoch for train_loader 2: 204/300
Loss for client: 2 :0.12133794455301194
Loss C2: 0.12133794455301194
local epoch for train_loader 4: 204/300
Loss for client: 4 :0.7155116438865662
Loss C4: 0.7155116438865662
trai

Loss for client: 2 :0.10204102595647176
Loss C2: 0.10204102595647176
local epoch for train_loader 4: 222/300
Loss for client: 4 :0.7098575532436371
Loss C4: 0.7098575532436371
train_loss: 0.3854
current epoch: 222 current mean dice: 0.2366 best mean dice: 0.5056 at epoch 76
----------
local epoch 223/300
local epoch for train_loader 1: 223/300
Loss for client: 1 :0.327693197876215
Loss C1: 0.327693197876215
local epoch for train_loader 2: 223/300
Loss for client: 2 :0.10523437744095213
Loss C2: 0.10523437744095213
local epoch for train_loader 4: 223/300
Loss for client: 4 :0.70665642619133
Loss C4: 0.70665642619133
train_loss: 0.3799
----------
local epoch 224/300
local epoch for train_loader 1: 224/300
Loss for client: 1 :0.362260315567255
Loss C1: 0.362260315567255
local epoch for train_loader 2: 224/300
Loss for client: 2 :0.10465047614915031
Loss C2: 0.10465047614915031
local epoch for train_loader 4: 224/300
Loss for client: 4 :0.7104801416397095
Loss C4: 0.7104801416397095
train_

Loss for client: 2 :0.08901903459003993
Loss C2: 0.08901903459003993
local epoch for train_loader 4: 242/300
Loss for client: 4 :0.7070875883102417
Loss C4: 0.7070875883102417
train_loss: 0.3716
current epoch: 242 current mean dice: 0.1652 best mean dice: 0.5056 at epoch 76
----------
local epoch 243/300
local epoch for train_loader 1: 243/300
Loss for client: 1 :0.3049960844218731
Loss C1: 0.3049960844218731
local epoch for train_loader 2: 243/300
Loss for client: 2 :0.08757954693975903
Loss C2: 0.08757954693975903
local epoch for train_loader 4: 243/300
Loss for client: 4 :0.7077804625034332
Loss C4: 0.7077804625034332
train_loss: 0.3668
----------
local epoch 244/300
local epoch for train_loader 1: 244/300
Loss for client: 1 :0.3159152753651142
Loss C1: 0.3159152753651142
local epoch for train_loader 2: 244/300
Loss for client: 2 :0.09088447973841712
Loss C2: 0.09088447973841712
local epoch for train_loader 4: 244/300
Loss for client: 4 :0.7052481174468994
Loss C4: 0.705248117446899

Loss for client: 2 :0.07866448589733668
Loss C2: 0.07866448589733668
local epoch for train_loader 4: 262/300
Loss for client: 4 :0.7048829972743988
Loss C4: 0.7048829972743988
train_loss: 0.3626
current epoch: 262 current mean dice: 0.1542 best mean dice: 0.5056 at epoch 76
----------
local epoch 263/300
local epoch for train_loader 1: 263/300
Loss for client: 1 :0.2992759384214878
Loss C1: 0.2992759384214878
local epoch for train_loader 2: 263/300
Loss for client: 2 :0.08518570093881517
Loss C2: 0.08518570093881517
local epoch for train_loader 4: 263/300
Loss for client: 4 :0.7039399147033691
Loss C4: 0.7039399147033691
train_loss: 0.3628
----------
local epoch 264/300
local epoch for train_loader 1: 264/300
Loss for client: 1 :0.3040863238275051
Loss C1: 0.3040863238275051
local epoch for train_loader 2: 264/300
Loss for client: 2 :0.08834247220130194
Loss C2: 0.08834247220130194
local epoch for train_loader 4: 264/300
Loss for client: 4 :0.7055727183818817
Loss C4: 0.705572718381881

Loss for client: 2 :0.07481138479141962
Loss C2: 0.07481138479141962
local epoch for train_loader 4: 282/300
Loss for client: 4 :0.703129380941391
Loss C4: 0.703129380941391
train_loss: 0.3594
current epoch: 282 current mean dice: 0.2201 best mean dice: 0.5056 at epoch 76
----------
local epoch 283/300
local epoch for train_loader 1: 283/300
Loss for client: 1 :0.296327106654644
Loss C1: 0.296327106654644
local epoch for train_loader 2: 283/300
Loss for client: 2 :0.07697542224611555
Loss C2: 0.07697542224611555
local epoch for train_loader 4: 283/300
Loss for client: 4 :0.7040723741054535
Loss C4: 0.7040723741054535
train_loss: 0.3591
----------
local epoch 284/300
local epoch for train_loader 1: 284/300
Loss for client: 1 :0.295732282102108
Loss C1: 0.295732282102108
local epoch for train_loader 2: 284/300
Loss for client: 2 :0.07435358705974761
Loss C2: 0.07435358705974761
local epoch for train_loader 4: 284/300
Loss for client: 4 :0.7045413792133332
Loss C4: 0.7045413792133332
trai

In [20]:
checkpoint = torch.load('/home/otarola/miccai22/fedem/'+modality+'_beta_'+str(beta)+'_'+weighting_scheme+'_best_metric_model_segmentation2d_array.pth')
global_model.load_state_dict(checkpoint)
outputs = global_model(inputs)
print(modality)

count_volume = 0
dice_metric.reset()
metric_values_test = []
for test_data in all_test_loader:
    count_volume = count_volume+1
    cur_image, cur_label = test_data
    cur_outputs = []
    cur_labels  = []
    labels   = torch.tensor(cur_label).to(device)
    for ct_slice in range(cur_image.shape[-1]):
        cur_ct_slice = torch.tensor(cur_image[:,:,:,:,ct_slice]).to(device)        
        label    = labels[:,:,:,:,ct_slice]
        outputs = global_model(cur_ct_slice)

        cur_outputs.append(outputs.cpu().detach().numpy()>0.5)
        cur_labels.append(label.cpu().detach().numpy()>0.5)
    #print(torch.tensor(cur_outputs[-1]).shape)
    #print(torch.tensor(cur_labels[-1]).shape)
        dice_metric(y_pred=torch.tensor(cur_outputs[-1]), y=torch.tensor(cur_labels[-1]))

    # aggregate the final mean dice result
    metric = dice_metric.aggregate().item()
    dice_metric.reset()
    metric_values_test.append(metric)
print("AVG TEST DICE SCORE FOR LEARNING RATE "+str(learning_rate) + ": " + str(np.mean(metric_values_test)) + " - STD: " + str(np.std(metric_values_test)))
print(metric_values_test)

Tmax


  labels   = torch.tensor(cur_label).to(device)
  cur_ct_slice = torch.tensor(cur_image[:,:,:,:,ct_slice]).to(device)


AVG TEST DICE SCORE FOR LEARNING RATE 0.000932: 0.35972078268726665 - STD: 0.3079528126439487
[0.07033639401197433, 0.4547461271286011, 0.0, 0.04846785590052605, 0.8387826681137085, 0.8354261517524719, 0.24476110935211182, 0.8222819566726685, 0.36120402812957764, 0.8022418022155762, 0.33318814635276794, 0.11142061650753021, 0.2721012532711029, 0.0, 0.200853630900383]
