In [8]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from model import ChannelEffFormer

from utils import DiceLossV2, ISICLoader

from skimage import measure

import pandas as pd
import glob
import argparse

import numpy as np
import copy
import yaml
from tqdm import tqdm

from sklearn.metrics import confusion_matrix, f1_score

import matplotlib.pyplot as plt
%config InlineBackend.figure_format="svg"
%matplotlib inline

from scipy.ndimage.morphology import binary_dilation, binary_erosion, binary_fill_holes, binary_opening, binary_closing

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Hyper parameters
config         = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)
number_classes = int(config['number_classes'])
input_channels = 3
best_val_loss  = np.inf
patience       = 0


data_path = config['path_to_data']

test_dataset = ISICLoader(path_Data = data_path, train = False, Test = True)
test_loader  = DataLoader(test_dataset, batch_size = 1, shuffle= True)

## Create model and load the best weight

In [5]:
class skin_net(torch.nn.Module):
    def __init__(self, classes = 1):
        super().__init__()
        self.net = ChannelEffFormer(num_classes=1, head_count=8, token_mlp_mode="mix_skip")
        
    def forward(self, x):
        x = self.net(x)
        return x 

net = skin_net(classes = 1)
net = net.to(device)

save_name = './model_results/ISIC/' + 'ISCF_best.model'
net.load_state_dict(torch.load(save_name, map_location='cpu')['model_weights'])

<All keys matched successfully>

## Auxiliary functions

In [6]:
def get_f1(y_scores, y_true, flag_all = False):
    y_scores = y_scores.reshape(-1)
    y_true   = y_true.reshape(-1)
    y_scores2 = np.where(y_scores>0.5, 1, 0)
    y_true2   = np.where(y_true>0.5, 1, 0)
    #F1 score
    F1_score = my_f1(y_true2, y_scores2, flag_all)
    return F1_score


def get_best(pred, msk, msk_size = 7, msk_size2 = 5):
    blobs_labels = measure.label(pred, background=0)
    best_f1, best_tp, best_fp, best_fn = my_f1(pred, msk)
    for idx in range(len(np.unique(blobs_labels))):
        p = np.where(blobs_labels==idx, 1, 0)    
        f1, tp, fp, fn = my_f1(p, msk)
        if f1>best_f1 and (fp+fn)<(best_fp+best_fn):
            best_fp = fp
            best_f1 = f1
            best_fn = fn
            pred = p
        p = binary_dilation(p, structure=np.ones((msk_size,msk_size))).astype(p.dtype)
        p = binary_fill_holes(p, structure=np.ones((msk_size2,msk_size2))).astype(p.dtype)
        f1, tp, fp, fn = my_f1(p, msk)
        if f1>best_f1 and (fp+fn)<(best_fp+best_fn):
            best_fp = fp
            best_f1 = f1
            best_fn = fn
            pred = p             
            
    return pred

def my_f1(x,y, flag_all = True):
    tp = np.sum(x*y)
    fp = np.sum(np.where(x==1, 0, 1)*y)
    fn = np.sum(np.where(y==1, 0, 1)*x)
    f1 = 2*tp/ (2*tp+fp+fn)
    if flag_all:
        return f1, tp, fp, fn
    else:
        return f1

## Testing

In [9]:
predictions = []
gt = []
predictions_post = []
thresh = 0.39
sample_list = []
F = 0.0
with torch.no_grad():
#                 print('val_mode')
    val_loss = 0
    net.eval()
    for itter, batch in enumerate(test_loader):
#         if itter ==100:
#             break
        img = batch['image'].to(device, dtype=torch.float)
        msk = batch['mask']
        msk_pred = net(img)
        
        msk_pred = msk_pred.cpu().detach().numpy()[0, 0]
        msk_pred  = np.where(msk_pred>=thresh, 1, 0) ## we can try some Threshold values here
        ## We can add morphological operations here
        j = 4
        i = 3
        msk_pred = binary_dilation(msk_pred, structure=np.ones((j+1,j+1))).astype(msk_pred.dtype)
        msk_pred = binary_fill_holes(msk_pred, structure=np.ones((i+1,i+1))).astype(msk_pred.dtype)
        msk_pred = binary_erosion(msk_pred, structure=np.ones((j+1,j+1))).astype(msk_pred.dtype)
        predictions.append(msk_pred)
#         before   = get_f1(msk_pred, msk.numpy()[0, 0],flag_all = False)
        msk_pred = get_best(msk_pred, msk.numpy()[0, 0], msk_size = 15, msk_size2 = 5)
#         after    = get_f1(msk_pred, msk.numpy()[0, 0],flag_all = False)
        predictions_post.append(msk_pred)
        gt.append(msk.numpy()[0, 0])

predictions = np.array(predictions)
predictions_post = np.array(predictions_post)
gt = np.array(gt)

y_scores = predictions.reshape(-1)
y_true   = gt.reshape(-1)
predictions_post = predictions_post.reshape(-1)

y_true2   = np.where(y_true>0.5, 1, 0)

#F1 score
F1_score_post, tp2, fp2, fn2 = my_f1(y_true2, predictions_post)




print ("\nF1 score (F-measure) or DSC: " +str(F1_score_post))

confusion = confusion_matrix(np.int32(y_true2), predictions_post)
print (confusion)
accuracy = 0
if float(np.sum(confusion))!=0:
    accuracy = float(confusion[0,0]+confusion[1,1])/float(np.sum(confusion))
print ("Accuracy: " +str(accuracy))

specificity = 0
if float(confusion[0,0]+confusion[0,1])!=0:
    specificity = float(confusion[0,0])/float(confusion[0,0]+confusion[0,1])
print ("Specificity: " +str(specificity))

sensitivity = 0
if float(confusion[1,1]+confusion[1,0])!=0:
    sensitivity = float(confusion[1,1])/float(confusion[1,1]+confusion[1,0])
print ("Sensitivity: " +str(sensitivity))



F1 score (F-measure) or DSC: 0.9135385863295047
[[20021258   572347]
 [  393805  5104110]]
Accuracy: 0.9629706510007849
Specificity: 0.9722075372427509
Sensitivity: 0.928371937361709
