In [None]:
import json
import os
import numpy as np
from pprint import pprint
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision.ops import box_iou
from torchvision.ops import nms

from tqdm import tqdm
import cv2
from ensemble_boxes import non_maximum_weighted
import visualization as gc1

In [None]:
font = cv2.FONT_HERSHEY_SIMPLEX

def make_show(img):
    img = (((img - img.min()) / (img.max() - img.min())) * 255).astype('int16')
    img = img[:,:,np.newaxis].repeat(3, 2)
    
    return img

def make_bbox(img, anno, TP=1, score=None):
    if TP == 1:
        c = (0, 255, 0)
    else:
        c = (0, 0, 255)
    pt1 = (int(anno[0]), int(anno[1]))
    pt2 = (int(anno[2]), int(anno[3]))
    
    img = cv2.rectangle(img, pt1, pt2, c, 1)
    if TP != 1:
        output_str = '%.4f'%score
        cv2.putText(img, output_str, (anno[2], anno[1]), font, 0.7, (255, 0, 0), 2)
    
    return img

In [None]:
def compute_ap(recall, precision):
    """ Compute the average precision, given the recall and precision curves.
    Code originally from https://github.com/rbgirshick/py-faster-rcnn.
    # Arguments
        recall:    The recall curve (list).
        precision: The precision curve (list).
    # Returns
        The average precision as computed in py-faster-rcnn.
    """
    # correct AP calculation
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], recall, [1.]))
    mpre = np.concatenate(([0.], precision, [0.]))

    # compute the precision envelope
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # to calculate area under PR curve, look for points
    # where X axis (recall) changes value
    i = np.where(mrec[1:] != mrec[:-1])[0]

    # and sum (\Delta recall) * prec
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [None]:
model_path = '../test_output/N4_All/'
pprint([each for each in sorted(os.listdir(model_path))])

In [None]:
key = 'mrDs_bs16_iou03_size892'

In [None]:
result_path = '../test_output/N4_All/' + key + '/'
print(result_path)
result_list = sorted([each for each in os.listdir(result_path)])

In [None]:
view_path = '../view/test_SEG/%s/'%key
if os.path.exists(view_path) == False:
    os.mkdir(view_path)

In [None]:
NMS_IOU = 0.1

result_list = sorted([each for each in os.listdir(result_path) if each[0] == 'K'])
K_pid_score = {}

for k_idx in range(0,5):
    K_pid_score[str(k_idx)] = {}
    
    tmp_result_npz = [each for each in result_list if each.startswith('K%s'%k_idx)][0]
    print(tmp_result_npz)
    tmp_result_path = result_path + tmp_result_npz
    tmp_result_file = np.load(tmp_result_path, allow_pickle=True)
    case_list = tmp_result_file['case']

    all_detections = tmp_result_file['det']
    all_annotations = tmp_result_file['anno']
    
    false_positives = np.zeros((0,))
    true_positives = np.zeros((0,))
    scores = np.zeros((0,))
    num_annotations = 0.0
    
    for i in range(len(case_list)):
        case = case_list[i]
        pid = '_'.join(case.split('_')[:-1])
        
        if pid not in K_pid_score[str(k_idx)]:
            K_pid_score[str(k_idx)][pid] = [0]
        
        detections = all_detections[i]
        annotations = all_annotations[i]
        num_annotations += annotations.shape[0]
        detected_annotations = []
        
        bb_scores = torch.tensor(detections[:,4])
        anchorBoxes = torch.tensor(detections[:,:4])
        anchors_nms_idx = nms(anchorBoxes, bb_scores, 0.1)
        anchors_nms_idx = anchors_nms_idx.numpy()
        detections = detections[anchors_nms_idx]
        
        for d in detections:
            det_score = d[4]
            K_pid_score[str(k_idx)][pid] += [det_score]
    
    for pid, score_list in K_pid_score[str(k_idx)].items():
        K_pid_score[str(k_idx)][pid] = np.mean(sorted(score_list)[-3:])

In [None]:
result_list = sorted([each for each in os.listdir(result_path) if each[0] == 'K'])

s_th = 0.05
NMS_IOU = 0.1
cls_th = 0

tmp_result_npz = [each for each in result_list if each.startswith('K0')][0]
tmp_result_path = result_path + tmp_result_npz
tmp_result_file = np.load(tmp_result_path, allow_pickle=True)
case_list = tmp_result_file['case']
all_annotations = tmp_result_file['anno']

for i in tqdm(range(len(case_list))):
    case_slice = case_list[i]
    seg_anno_name = '_'.join(case_slice.split('_')[:-1])
    case_slice_path = '../data/N4_All_img/%s.npz'%case_slice
    case_slice_array = np.load(case_slice_path)
    
    image_list = [case_slice_array['FL']]
    # draw image
    for idx, image in enumerate(image_list):
        image_list[idx] = make_show(image)
        
    annotations = all_annotations[i]

    # draw annos
    if len(annotations) != 0:
        for idx, anno in enumerate(annotations):
            for idx, image in enumerate(image_list):
                image_list[idx] = make_bbox(image, anno)
                
    draw_list = []
    
    for k_idx in range(5):
        tmp_image_list = deepcopy(image_list)
        tmp_result_npz = [each for each in result_list if each.startswith('K%s'%k_idx)][0]
        tmp_result_path = result_path + tmp_result_npz
        tmp_result_file = np.load(tmp_result_path, allow_pickle=True)
        all_detections = tmp_result_file['det']
        
        detections = all_detections[i]
        if len(detections) == 0:
            draw_list.append(tmp_image_list)
            continue
        
        boxes_list = [detections[:,:4] / 448]
        scores_list = [detections[:,-1]]
        labels_list = np.ones_like(scores_list)

        iou_thr = NMS_IOU
        skip_box_thr = 0.0001

        boxes, nms_scores, labels = non_maximum_weighted(boxes_list, 
                                            scores_list, 
                                            labels_list, 
                                            iou_thr=iou_thr,
                                            skip_box_thr=skip_box_thr)

        boxes = boxes * 448
        nms_scores = nms_scores[:,np.newaxis]
        detections = np.concatenate([boxes, nms_scores], axis=1)

        for d in detections:
            tmp_score = d[4] * K_pid_score[str(k_idx)][seg_anno_name]
            if tmp_score < s_th:
                continue

            for idx, image in enumerate(tmp_image_list):
                anno = d[:4].astype('int16')
                tmp_image_list[idx] = make_bbox(image, anno, TP=0, score=tmp_score)
                
        draw_list.append(tmp_image_list)
    for_show = [np.concatenate(each, axis=1) for each in draw_list]
    for_show_final = np.concatenate(for_show, axis=1)
    cv2.imwrite(view_path + '%s.png'%(case_slice), for_show_final)
#     break