In [3]:
from model import UNet
from dataset2d import Dataset2D
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from imageio import imwrite
from PIL import Image
import model_hyper_parameters as config
from tqdm.auto import tqdm

In [4]:
class DiceScore(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()
        self.normalization=nn.Softmax(dim=1)

    def forward(self, inputs, targets, smooth=1e-4):
        inputs = self.normalization(inputs)

        targets = targets[:, 1:2, ...]
        inputs = torch.where(inputs[:, 1:2, ...] > 0.5, 1.0, 0.0)

        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return dice

In [5]:
DEFAULT_KIDNEY_COLOR = [255, 0, 0]
DEFAULT_PRED_COLOR = [0, 0, 255]
ALPHA = 0.3
dicescore=DiceScore()
os.makedirs('pred_img',exist_ok=True)

In [None]:
unet = UNet(64,5,use_xavier=True,use_batchNorm=True,dropout=0.5,retain_size=True,nbCls=2)
devices = 'cpu'
device_num = 0
if torch.cuda.is_available():
    devices = 'gpu'
    device_num = torch.cuda.device_count()
unet = torch.nn.DataParallel(unet)
unet.to(config.DEVICE)
if torch.cuda.is_available():
    print('CUDA Available!')
    unet.load_state_dict(torch.load('./final_result/unet.pt'))
else:
    print('CUDA is unavailable, using CPU instead!')
    print('Warning: using CPU might require several hours')
    unet.load_state_dict(torch.load('./final_result/unet.pt', map_location=torch.device('cpu')))

In [None]:
total_dice=0
root_dir=os.path.join('.','data_npy','valid')
softmax=nn.Softmax(dim=1)
for iters in tqdm(range(7922)):
    unet.eval()
    
    cid_name='{:05d}.npy'.format(iters)
    img_np=np.load(os.path.join(root_dir,'image',cid_name))
    seg_np=np.load(os.path.join(root_dir,'segmentation',cid_name))
    
    pred=unet(torch.tensor(img_np.reshape((1,1,512,512)),dtype=torch.float32))
    pred=pred.cpu()
    dice=dicescore(pred.clone(),torch.tensor(seg_np.reshape((1,2,512,512))))
    total_dice+=dice
    
    pred=softmax(pred)
    pred=np.where(pred[:,1,...].cpu().detach().numpy()>0.5,1,0)
    
    img_np=img_np.reshape((1,512,512))
    seg_np=seg_np[1,...]
    seg_np=seg_np.reshape((1,512,512))
    
    img=255*img_np
    img=np.stack((img,img,img),axis=-1)
    
    shp=seg_np.shape
    seg_color=np.zeros((shp[0],shp[1],shp[2],3),dtype=np.float32)
    seg_color[np.equal(seg_np,1)]=DEFAULT_KIDNEY_COLOR
    seg_color[np.equal(pred,1)]=DEFAULT_PRED_COLOR
    
    img.astype(np.uint8)
    seg_color.astype(np.uint8)
    seg_np.astype(np.uint8)
    
    segbin1=np.greater(seg_np,0)
    segbin2=np.greater(pred,0)
    
    segbin=segbin1*0.5+segbin2*0.5
    
    r_segbin=np.stack((segbin,segbin,segbin),axis=-1)
    overlayed=np.where(
        r_segbin,
        np.round(ALPHA*seg_color+(1-ALPHA)*img).astype(np.uint8),
        np.round(img).astype(np.uint8)
        
    )
    imwrite('./pred_img/'+'{:05d}_{:.2f}%.png'.format(iters,dice*100),overlayed[0])
print('Image Generated Finished, Average F1 Score: {:.3f}%'.format((total_dice/7922)*100))