In [None]:
import os
import csv
import cv2
import math
import time
import json
import glob
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm as tqdm


In [None]:
# Input
# TODO 預測的 mask 結果資料夾
pred_all_crown_list = glob.glob('./result/*_pred_raw/*.PNG')
pred_all_crown_list.sort()
assert len(pred_all_crown_list) != 0, 'wrong pred path'
# TODO 要比較的 label 資料夾
# label_folder = '../data/label_CEJline_mask_385_512/'
label_folder = '../3-Unet-mask-process/385_cej_4_pixel_512/CEJ_line_mask/'
# label_folder = '../data/7class_generate_tooth_mask/'
assert len(glob.glob(label_folder + '/*.PNG')) != 0, 'wrong label path'

In [None]:
# Output
# TODO 輸出 csv 設定，會覆蓋原始檔案
output_csv_path = "./IOU_csv/CEJ_385.csv"

In [None]:
def show_pixel_set(img_nparray):
    a = img_nparray
    unique, counts = np.unique(a, return_counts=True)
    return dict(zip(unique, counts))
def count_iou(label_img_gray, predict_img_gray, val):
    label_target = np.where(label_img_gray==val, 1, 0)
    predict_target = np.where(predict_img_gray==val, 1, 0)
    intersection = np.logical_and(label_target, predict_target)
    union = np.logical_or(label_target, predict_target)
    iou_score = np.sum(intersection) / np.sum(union)
#     print('IOU val:',iou_score)
    return iou_score
def count_pixel_acc(label_img_gray, predict_img_gray, val):
    label_target = np.where(label_img_gray==val, 1, 0)
    predict_target = np.where(predict_img_gray==val, 1, 0)
    different = np.logical_xor(label_target, predict_target)
    diff_score = np.sum(different) / np.sum(np.ones_like(label_img_gray))
#     print('IOU val:',iou_score)
    return 1 - diff_score

In [None]:
record_iou = pd.DataFrame()
for pred_img_path in tqdm(pred_all_crown_list, total = len(pred_all_crown_list)):
    record_dict = {}
    img_set_type = pred_img_path.split('\\')[-2]
    img_name = os.path.basename(pred_img_path)
    record_dict['image'] = img_name
    pre_img = cv2.imread(pred_img_path, 0)
    label_img = cv2.imread(label_folder + '/' + img_name, 0)
    if not os.path.exists(label_folder + '/' + img_name):
        print('no label image ', label_folder + '/' + img_name)
        continue
    post_img_pth = pred_img_path.replace('result', 'result_post_process')
    assert post_img_pth != pred_img_path, 'wrong post_img_pth' + post_img_pth
    if os.path.exists(post_img_pth):
        post_img = cv2.imread(post_img_pth, 0)
        post_img = cv2.resize(post_img.astype(np.uint8), label_img.shape, interpolation=cv2.INTER_NEAREST)
        record_dict['post_iou'] = count_iou(post_img, label_img, 255)
        record_dict['post_PA'] = count_pixel_acc(post_img, label_img, 255)
    record_dict['iou'] = count_iou(pre_img, label_img, 255)
    record_dict['ratio'] = np.where(label_img != 0, 1, 0).sum() / (label_img.shape[0] ** 2)
    record_dict['PA'] = count_pixel_acc(pre_img, label_img, 255)
    record_dict['image_set'] = img_set_type
    record_iou = record_iou.append(record_dict, ignore_index=True)
record_iou.to_csv(output_csv_path)

In [None]:
def show_metrix(df):
    print(df['PA'].describe())
    print(df['iou'].describe())
    plt.figure(figsize=(15, 3))
    plt.subplot(1, 2, 1)
    plt.hist(df['PA'], bins=100)
    plt.title('PA')
    plt.subplot(1, 2, 2)
    plt.hist(df['iou'], bins=100)
    plt.title('iou')
    plt.show()
def calculate_metrics(df, type=""):
    if type == 'post_':
        ratio = np.nan
    else:
        ratio = np.mean(df[type+'ratio'])
    metrics = {
        'iou_mean': np.mean(df[type+'iou']),
        'iou_std': np.std(df[type+'iou']),
        'PA_mean': np.mean(df[type+'PA']),
        'PA_std': np.std(df[type+'PA']),
        'ratio': ratio,
        'z_count': df[type+'iou'].notna().sum()
    }
    return metrics
result_df = pd.DataFrame()
for img_type_set in ['test', 'val', 'train']:
    # show_metrix(df)
    df = record_iou[record_iou.image_set == f'{img_type_set}_pred_raw']
    metrics = calculate_metrics(df)
    metrics['1_name'] = img_type_set
    result_df = result_df.append(metrics, ignore_index=True)
    if 'post_iou' in df.columns:
        metrics = calculate_metrics(df, 'post_')
        metrics['1_name'] = 'post_' + img_type_set
        result_df = result_df.append(metrics, ignore_index=True)
print(result_df.round(4))

In [None]:
result_df.to_csv('tmp.csv', index=False)