In [None]:
import sys

sys.path.append('..')

In [None]:
import os

import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torchvision.transforms import Compose
from tqdm import tqdm

import datasets.multi_segment_dataset as ms
from datasets.transforms import multi_segment_transforms as MST
from nets.seg2d.multiviewunet import MultiViewUNet, UNet

%matplotlib inline

In [None]:
device = 'cuda:1'

In [None]:
trans = Compose([
    MST.FilterAndNormalizeByWL(250, 50),
    MST.Arr2Image(),
    MST.CenterCrop([288, 288]),
    MST.ToTensor()
])
dataset = ms.MultiSegmentDataset(
    data_root='/root/workspace/DeepRadiology/target_segmentation/',
    sample_records=[
        '../records/胶质恶性细胞瘤/new_paper/test_multi_front_records.csv'
    ],
    multi_fusion=True,
    transforms=trans)

segment_ckpt = torch.load(
    '../ckpts/new_paper/multi_segment/resnet50_mean_try2/95.pth',
    map_location=device)

cfg = {
    'data': {
        'num_classes': 2,
        'color_channels': 1,
    },
    'network': {
        'backbone': 'resnet50',
        'norm_type': 'batchnorm',
        'multi_view_unet': {
            'modality_num': 3,
            'operator': 'mean',
            'sum_weight': [0.9, 0.05, 0.05],
            'dropout': 0.1,
            'num_layers': 2
        }
    }
}
segment_net = MultiViewUNet(cfg)
segment_net.load_state_dict(segment_ckpt)
segment_net = segment_net.eval().to(device)

In [None]:
def segment_volume(probs, labels, metric_idxs):
    volumes = []
    for metric_idx in metric_idxs:
        mask_o = (probs == metric_idx)
        mask_y = (labels == metric_idx)
        inter = (mask_o * mask_y).sum()
        union = mask_o.sum() + mask_y.sum()
        v_y = mask_y.sum()
        v_o = mask_o.sum()
        volumes.append(np.array([union, inter, v_y, v_o]))
    return np.array(volumes)


def segment_metrics(volume):
    front_classes = volume.shape[0]
    metrics = []
    for class_idx in range(front_classes):
        union, inter, v_y, v_o = volume[class_idx, 0], volume[
            class_idx, 1], volume[class_idx, 2], volume[class_idx, 3]
        dice = 0 if union == 0 else float(2 * inter) / union
        TPVF = 0 if v_y == 0 else float(inter) / v_y
        PPV = 0 if v_o == 0 else float(inter) / v_o
        metrics.append(np.array([dice, TPVF, PPV]))
    return np.array(metrics)

In [None]:
dice_volume_meter = np.zeros([1, 4])
with torch.no_grad():
    for idx in range(len(dataset)):
        sample = dataset.__getitem__(idx)
        img = sample['image'].to(device)
        mask = sample['label']
        masked_contour_arr = np.ma.masked_where(mask == 0, mask)
        mask = mask.to(device)
        pred = np.zeros([288, 288], dtype='uint8')

        logit = segment_net(img)
        prob = torch.softmax(logit, dim=1)
        pred = torch.argmax(prob[0], 0).data.cpu().numpy()
        dice_volume = segment_volume(pred,
                                     mask.cpu().numpy().astype('uint8'), [1])
        dice_volume_meter += dice_volume

        pred_contour_arr = np.ma.masked_where(pred == 0, pred)
        img_arr = img.data.cpu().numpy()[0, 0]

#         plt.figure(figsize=(9, 9))
#         print("********** {} **********".format(idx))
#         plt.subplot(1, 3, 1)
#         plt.imshow(img_arr, cmap='gray', interpolation='none')
#         plt.subplot(1, 3, 2)
#         plt.imshow(img_arr, cmap='gray', interpolation='none')
#         plt.imshow(masked_contour_arr,
#                    cmap='cool',
#                    interpolation='none',
#                    alpha=0.7)
#         plt.subplot(1, 3, 3)
#         plt.imshow(img_arr, cmap='gray', interpolation='none')
#         plt.imshow(pred_contour_arr,
#                    cmap='cool',
#                    interpolation='none',
#                    alpha=0.7)
#         plt.show()
#         print('/'.join(dataset.records.iloc[idx].base_ct_path.split('/')[-3:]))
#         print(segment_metrics(dice_volume_meter))
dice_metrics = segment_metrics(dice_volume_meter)
dice_metrics