In [1]:
import numpy as np
from dpipe.medim.visualize import slice3d
import torch
import os
import sys
import cv2
import copy
from matplotlib import pyplot as plt
sys.path.append('/nmnt/media/home/alex_samoylenko/Federated/FederatedUNet')

from FederatedUNet.model.model import UNet

In [2]:
model = UNet(n_channels=1, n_classes=1).float()

In [3]:
WEIGHTS_PATH = '/nmnt/media/home/alex_samoylenko/experiments/Federated/no_federated/model.pth'
weights = torch.load(WEIGHTS_PATH)
model.load_state_dict(weights)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [15]:
EXP_NAME = 'no_federated'
TARGET_PATH = '/nmnt/x3-hdd/data/DA/CC359/Silver-standard-MLScaled'

VALID_PRED_PATH = f'/nmnt/x3-hdd/data/Federated/{EXP_NAME}/valid'
TEST_PRED_PATH = f'/nmnt/x3-hdd/data/Federated/{EXP_NAME}/test'
VALID_IMG_NAMES = [img_name[:-4] for img_name in os.listdir(VALID_PRED_PATH)]
TEST_IMG_NAMES = [img_name[:-4] for img_name in os.listdir(TEST_PRED_PATH)]

targets, preds = dict(), dict()

for img_name in VALID_IMG_NAMES:
    target = np.load(os.path.join(TARGET_PATH, img_name + '_ss.npy'))
    target = cv2.resize(target, (256, 170))
    targets[img_name] = target
    
    pred = np.load(os.path.join(VALID_PRED_PATH, img_name + '.npy'))
    pred = torch.sigmoid(torch.tensor(pred))
    preds[img_name] = pred.numpy()

In [5]:
def dice_score_slice(pred_slice, target_slice, threshold):
    diff = abs(pred_slice - target_slice)
    return diff[diff == 0].shape[0] / (pred_slice.shape[0] * pred_slice.shape[1])

def dice_score(pred, target, threshold=0.5):
    pred, target = copy.deepcopy(pred), copy.deepcopy(target)
    pred, target = np.moveaxis(pred, -1, 0), np.moveaxis(target, -1, 0)
    pred[pred <= threshold] = 0
    pred[pred > threshold] = 1
    return [dice_score_slice(pred_slice, target_slice, threshold) for pred_slice, target_slice in zip(pred, target)]

In [6]:
for threshold in [0.05, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7]:
    dices = []
    for img_name in VALID_IMG_NAMES:
        pred, target = preds[img_name], targets[img_name]
        dices.extend(dice_score(pred, target, threshold=threshold))
    print(f"THRESHOLD: {threshold}, MEAN_DICE: {np.mean(dices)}") # best is 0.6

THRESHOLD: 0.05, MEAN_DICE: 0.98740112901572
THRESHOLD: 0.3, MEAN_DICE: 0.992320739015047
THRESHOLD: 0.35, MEAN_DICE: 0.9926142829019886
THRESHOLD: 0.4, MEAN_DICE: 0.9928397616161893
THRESHOLD: 0.45, MEAN_DICE: 0.9930134392042578
THRESHOLD: 0.5, MEAN_DICE: 0.993140172099998
THRESHOLD: 0.55, MEAN_DICE: 0.9932250270157905
THRESHOLD: 0.6, MEAN_DICE: 0.9932658510995367
THRESHOLD: 0.65, MEAN_DICE: 0.9932597880672315
THRESHOLD: 0.7, MEAN_DICE: 0.9931984017426285


# Separate classes

In [7]:
from collections import defaultdict

In [8]:
class_to_img_name = defaultdict(list)
for img_name in VALID_IMG_NAMES:
    class_name = img_name.split('_')[1] + '-' + img_name.split('_')[4][0]
    class_to_img_name[class_name].append(img_name)

In [9]:
class_to_dice = {}
for _class in class_to_img_name:
    class_img_names = class_to_img_name[_class]
    dices = []
    for img_name in class_img_names:
        pred, target = preds[img_name], targets[img_name]
        dices.extend(dice_score(pred, target, threshold=0.6))
    class_to_dice[_class] = np.mean(dices)

In [10]:
class_to_dice

{'siemens-F': 0.99598435345818,
 'ge-M': 0.9921803193933824,
 'philips-F': 0.9914452278056657,
 'philips-M': 0.9910250209602188,
 'siemens-M': 0.9951132680855546,
 'ge-F': 0.9937355340695849}

# Test score

In [23]:
for img_name in TEST_IMG_NAMES:
    target = np.load(os.path.join(TARGET_PATH, img_name + '_ss.npy'))
    target = cv2.resize(target, (256, 170))
    targets[img_name] = target
    
    pred = np.load(os.path.join(TEST_PRED_PATH, img_name + '.npy'))
    pred = torch.sigmoid(torch.tensor(pred))
    preds[img_name] = pred.numpy()

In [24]:
class_to_img_name_test = defaultdict(list)
for img_name in TEST_IMG_NAMES:
    class_name = img_name.split('_')[1] + '-' + img_name.split('_')[4][0]
    class_to_img_name_test[class_name].append(img_name)

In [27]:
class_to_dice_test = {}
for _class in class_to_img_name_test:
    class_img_names = class_to_img_name_test[_class]
    dices = []
    for img_name in class_img_names:
        pred, target = preds[img_name], targets[img_name]
        dices.extend(dice_score(pred, target, threshold=0.6))
    class_to_dice_test[_class] = np.mean(dices)

In [28]:
class_to_dice_test

{'ge-M': 0.9922594599363183,
 'siemens-F': 0.9955436856138941,
 'philips-F': 0.9919339302956031,
 'philips-M': 0.9854507754090844,
 'siemens-M': 0.994167472490298,
 'ge-F': 0.9926221286549288}