In [1]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

In [3]:
import os
import cv2
import glob
import numpy as np
from tqdm import tqdm

import torch

In [4]:
'''
The following metric code is heavily borrow from https://github.com/Mhaiyang/ICCV2019_MirrorNet, 
and computes the comparable result with "Where Is My Mirror? (2019, ICCV)", "Don't Hit Me! Glass Detection in Real-world Scenes (2020, CVPR)" and "Progressive Mirror Detection (2020, CVPR)"
'''
def compute_iou(predict_mask, gt_mask):
    """
    (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii))
    Here, n_cl = 1 as we have only one class (mirror).
    """
    if np.sum(predict_mask) == 0 or np.sum(gt_mask) == 0:
        iou_ = 0
        return iou_

    n_ii = np.sum(np.logical_and(predict_mask, gt_mask))
    t_i = np.sum(gt_mask)
    n_ij = np.sum(predict_mask)

    iou_ = n_ii / (t_i + n_ij - n_ii)
    return iou_

def compute_acc_image(predict_mask, gt_mask):

    N_p = np.sum(gt_mask)
    N_n = np.sum(np.logical_not(gt_mask))

    TP = np.sum(np.logical_and(predict_mask, gt_mask))
    TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask)))

    accuracy_ = (TP + TN) / (N_p + N_n)

    return accuracy_

def compute_mae(predict_mask, gt_mask):

    mae_ = np.mean(abs(predict_mask - gt_mask)).item()

    return mae_

def compute_ber(predict_mask, gt_mask):

    N_p = np.sum(gt_mask)
    N_n = np.sum(np.logical_not(gt_mask))

    TP = np.sum(np.logical_and(predict_mask, gt_mask))
    TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask)))

    ber_ = 1 - (1 / 2) * ((TP / N_p) + (TN / N_n))

    return ber_

class FMeasure():
    def __init__(self, div_n=1, count_per_div=256):
        self.div_n, self.count_per_div = div_n, count_per_div

        self.prec = torch.zeros(256,dtype=torch.float32).cuda()
        self.recall = torch.zeros(256,dtype=torch.float32).cuda()
        self.count = torch.zeros(256,dtype=torch.float32).cuda()
        self.th = torch.arange(start=0.0,end=1.0,step=1/256).float().cuda()
        self.th = self.th.view(256,1,1)

    def update(self,predicted,gt):
        with torch.no_grad():
            p_gpu = torch.from_numpy(predicted).cuda()
            p_gpu = p_gpu - torch.min(p_gpu)
            p_gpu = p_gpu / torch.max(p_gpu)
            p_gpu = torch.stack([p_gpu]*self.count_per_div,dim=0) # [count_per_div,h,w]

            gt_gpu = torch.from_numpy(gt).cuda()
            gt_gpu = torch.unsqueeze(gt_gpu,dim=0)

            for i in range(self.div_n):
                p_temp = (p_gpu > self.th[self.count_per_div*i:self.count_per_div*i+self.count_per_div]).float()
                tp = torch.sum( p_temp * gt_gpu, dim=(1,2) )
                self.prec[self.count_per_div*i:self.count_per_div*i+self.count_per_div] += tp / (torch.sum(p_temp,dim=(1,2))+1e-31)
                self.recall[self.count_per_div*i:self.count_per_div*i+self.count_per_div] += tp / (torch.sum(gt_gpu,dim=(1,2))+1e-31)
                self.count[self.count_per_div*i:self.count_per_div*i+self.count_per_div] += 1

    def get(self,beta2=0.3):
        avg_p = self.prec / (self.count + 1e-31)
        avg_r = self.recall / (self.count + 1e-31)
        fmeasures = (1+beta2)*avg_p*avg_r / (beta2*avg_p + avg_r + 1e-31)
        return fmeasures.cpu().numpy()
    
def crf_refine(img, annos):
    import pydensecrf.densecrf as crf
    assert img.dtype == np.uint8
    assert annos.dtype == np.uint8
    assert img.shape[:2] == annos.shape

    # img and annos should be np array with data type uint8
    def _sigmoid(x):
        return 1 / (1 + np.exp(-x))

    EPSILON = 1e-8

    M = 2  # salient or not
    tau = 1.05
    # Setup the CRF model
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

    anno_norm = annos / 255.

    n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
    p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

    U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
    U[0, :] = n_energy.flatten()
    U[1, :] = p_energy.flatten()

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

    # Do the inference
    infer = np.array(d.inference(1)).astype('float32')
    res = infer[1, :]

    res = res * 255
    res = res.reshape(img.shape[:2])
    return res.astype('uint8')

In [7]:
use_crf = False
datacategory = "msd"
# datacategory = "glass"

# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/ICCV2019_MirrorNet/ckpt/MirrorNet/MirrorNet_160_nocrf/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/ICCV2019_MirrorNet/ckpt/MirrorNet/MirrorNet_160_crf/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/demo/msd/output/20201016-1-3/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-12/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/fanet_merged_for_released/result/msd_with_fakemix/mirrornet_test/*.png")
paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/fanet_merged_for_released/result/msd_with_fakemix/pmd_test/*.png")

# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/PMD/result/MSD_nocrf/*.png")
# paths = glob.glob("./result/MSD_crf/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8-1/*.png") # test size = (416,416)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8-1-1/*.png") # test size = (416,416)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8-2/*.png") # test size = (416,416)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8-3/*.png") # test size = (416,416)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-8-2-1/*.png") # test size = (416,416)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd_with_rotate/20201016-1-8-1/*.png") # test size = (384,384)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd_with_rotate/20201016-1-8-2/*.png") # test size = (384,384)
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-9/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-10/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-14/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-15/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-15-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-16/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-16-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-16-1-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/msd/20201016-1-16-2/*.png")

# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/GDD/GDNet_200/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/demo/glass/output/20201016-2-8/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-22/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-23/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-23-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-24/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-24-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-25/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-26/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-3/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-1-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-2/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-2-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-3/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-4/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-4-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-5/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-6/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-7/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-7-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-8/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-8-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-9/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-9-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-10/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-10-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-2-11/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-2/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-3-1/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/transsegmoredatasets_v2/results/glass/20201016-2-27-1-3-2/*.png")
# paths = glob.glob("/workspace/cpfs-data/transparent_material_detection/fanet_merged_for_released/result/glass_with_fakemix/test/*.png")

accu_mirror,accu,miou,mae,mber,fscore = 0,0,0,0,0,0
fmeasure = FMeasure(div_n=8,count_per_div=32)
for single_path in tqdm(paths):
    img_name = os.path.basename(single_path)
    predicted = cv2.imread(single_path,cv2.IMREAD_GRAYSCALE)
    if datacategory == "msd":
        gt_path = os.path.join("/workspace","cpfs-data","datas","MSD","test","mask",img_name)
    else:
        gt_path = os.path.join("/workspace","cpfs-data","datas","glass_detection","test","mask",img_name)
    gt = cv2.imread(gt_path,cv2.IMREAD_GRAYSCALE)
    if gt.shape[0] != predicted.shape[0] or gt.shape[1] != predicted.shape[1]:
        predicted = cv2.resize( predicted, (gt.shape[1],gt.shape[0]), interpolation=cv2.INTER_LINEAR)
    
    if use_crf is True:
        if datacategory == "msd":
            image_path = os.path.join("/workspace","cpfs-data","datas","MSD","test","image",img_name.replace(".png",".jpg"))
        else:
            image_path = os.path.join("/workspace","cpfs-data","datas","glass_detection","test","image",img_name.replace(".png",".jpg"))
        image = cv2.imread(image_path,cv2.IMREAD_COLOR)
        image = image[...,::-1]
        crf_res = crf_refine( np.copy(image,order="C"), np.copy(predicted,order="C") )
        predicted_p = (crf_res / 255.0).astype(np.float32)
    else:
        predicted_p = (predicted / 255.0).astype(np.float32)
    predicted_mask = (predicted_p >= 0.5).astype(np.int32)
    gt_p = (gt / 255.0).astype(np.float32)
    gt_mask = (gt_p >= 0.5).astype(np.int32)
    
    accu += compute_acc_image(predicted_mask,gt_mask)
    miou += compute_iou(predicted_mask,gt_mask)
    mber += compute_ber(predicted_mask,gt_mask)
    
    mae += compute_mae(predicted_p,gt_p)

    fmeasure.update(predicted_p,gt_p)

print(f"accu:{accu/len(paths):5.4}")
print(f"mIoU:{miou/len(paths):5.4}")
print(f"mae:{mae/len(paths):5.4}")
print(f"mber:{mber/len(paths):5.4}")
print(f"max fmeasure:{np.max(fmeasure.get())}")

100%|██████████| 955/955 [00:38<00:00, 24.69it/s]

accu:0.9539
mIoU:0.8194
mae:0.04781
mber:0.06576
max fmeasure:0.8908935785293579



