In [1]:
from model import UNet
from makedataset import makeDataset
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from imageio import imwrite
from PIL import Image
import config
from tqdm.auto import tqdm
from eff_unet import EffUNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DiceScore(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()
        self.normalization=nn.Softmax(dim=1)

    def forward(self, inputs, targets, smooth=1e-4):
        inputs = self.normalization(inputs)

        targets = targets[:, 1:2, ...]
        inputs = torch.where(inputs[:, 1:2, ...] > 0.5, 1.0, 0.0)

        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return dice

In [3]:
DEFAULT_KIDNEY_COLOR = [255, 0, 0]
DEFAULT_PRED_COLOR = [0, 0, 255]
ALPHA = 0.3
dicescore=DiceScore()
os.makedirs('pred_img3',exist_ok=True)

In [4]:
from monai.networks.nets.segresnet import SegResNet
# unet = UNet(64,5,use_xavier=True,use_batchNorm=True,dropout=0.5,retain_size=True,nbCls=2)
# unet = EffUNet(1, 5, use_xavier=True, use_batchNorm=True, dropout=0.5, retain_size=True, nbCls=2)
unet = EffUNet(1, 5, use_xavier=True, use_batchNorm=True, dropout=0.2, retain_size=True, nbCls=2)
# unet = SegResNet(
#     spatial_dims=2,
#     init_filters=16,
#     in_channels=1,
#     out_channels=2,
#     dropout_prob=0.2,
#     )
devices = 'cpu'
device_num = 0
if torch.cuda.is_available():
    devices = 'gpu'
    device_num = torch.cuda.device_count()
unet = torch.nn.DataParallel(unet)
unet.to(config.DEVICE)
if torch.cuda.is_available():
    print('CUDA Available!')
    unet.load_state_dict(torch.load('./final_result13/unet_40.pt'))
else:
    print('CUDA is unavailable, using CPU instead!')
    print('Warning: using CPU might require several hours')
    unet.load_state_dict(torch.load('./final_result13/unet_40.pt', map_location=torch.device('cpu')))

2024-02-06 16:12:05.805021: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


CUDA Available!


In [5]:
total_dice=0
root_dir=os.path.join('.','kits19','valid')
softmax=nn.Softmax(dim=1)
test_dataset =  makeDataset(kind='valid', location='kits19')
len_test = len(test_dataset)
for iters in tqdm(range(len_test)):
    unet.eval()
    
    cid_name='{:05d}.npy'.format(iters)
    img_np=np.load(os.path.join(root_dir,'image',cid_name))
    seg_np=np.load(os.path.join(root_dir,'segmentation',cid_name))
    
    pred=unet(torch.tensor(img_np.reshape((1,1,512,512)),dtype=torch.float32))
    pred=pred.cpu()
    dice=dicescore(pred.clone(),torch.tensor(seg_np.reshape((1,2,512,512))))
    total_dice+=dice
    
    # pred=softmax(pred)
    pred=torch.sigmoid(pred)
    pred=np.where(pred[:,1,...].cpu().detach().numpy()>0.7,1,0)
    
    img_np=img_np.reshape((1,512,512))
    seg_np=seg_np[1,...]
    seg_np=seg_np.reshape((1,512,512))
    
    img=255*img_np
    img=np.stack((img,img,img),axis=-1)
    
    shp=seg_np.shape
    seg_color=np.zeros((shp[0],shp[1],shp[2],3),dtype=np.float32)
    seg_color[np.equal(seg_np,1)]=DEFAULT_KIDNEY_COLOR
    seg_color[np.equal(pred,1)]=DEFAULT_PRED_COLOR
    
    img.astype(np.uint8)
    seg_color.astype(np.uint8)
    seg_np.astype(np.uint8)
    
    segbin1=np.greater(seg_np,0)
    segbin2=np.greater(pred,0)
    
    segbin=segbin1*0.5+segbin2*0.5
    
    r_segbin=np.stack((segbin,segbin,segbin),axis=-1)
    overlayed=np.where(
        r_segbin,
        np.round(ALPHA*seg_color+(1-ALPHA)*img).astype(np.uint8),
        np.round(img).astype(np.uint8)
        
    )
    imwrite('./pred_img3/'+'{:05d}_{:.2f}%.png'.format(iters,dice*100),overlayed[0])
print('Image Generated Finished, Average F1 Score: {:.3f}%'.format((total_dice/7922)*100))

 14%|█▍        | 1133/7922 [01:08<06:53, 16.43it/s]


KeyboardInterrupt: 