In [1]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.data_loading import BasicDataset
from utils.dice_score import multiclass_dice_coeff, dice_coeff
from torch.utils.data import DataLoader, random_split
from unet import UNet
import os

In [2]:
def evaluate(net, dataloader, device):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']
        # move images and labels to correct device and type
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)

        mask_true = F.one_hot(mask_true, net.n_classes).permute(0,3, 1, 2).float()
        with torch.no_grad():
            # predict the mask
            mask_pred = net(image)
            # convert to one-hot format
            if net.n_classes == 1:
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0,3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)

           

    net.train()

    # Fixes a potential division by zero error
    if num_val_batches == 0:
        return dice_score
    return dice_score / num_val_batches

In [3]:
def evaluateloss(net, dataloader, device):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']
        # move images and labels to correct device and type
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)
        mask_true = F.one_hot(mask_true, net.n_classes).permute(0,3, 1, 2).float()
        with torch.no_grad():
            # predict the mask
            mask_pred = net(image)
            # convert to one-hot format
            if net.n_classes == 1:
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0,3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)

           

    net.train()

    # Fixes a potential division by zero error
    if num_val_batches == 0:
        return dice_score
    return dice_score / num_val_batches

In [4]:
checkpointfol = '.\\checkpoints\\Base-UNET-Focal1LossRotate\\'
dir_img = '.\\Dataset\\Validation\\Post\\Image512\\'
dir_mask = '.\\Dataset\\Validation\\Post\\Label512\\'
img_scale = 1
classes = 5
bilinear = False
device = 'cuda'

In [5]:
dataset = BasicDataset(dir_img, dir_mask, img_scale,mask_suffix = '.png')

In [6]:
val_loader = DataLoader(dataset, shuffle=True,batch_size=1, num_workers=1, pin_memory=True)

In [7]:
ch_list = []
result = []
for filename in os.listdir(checkpointfol):
    if filename.endswith('.pth'):
        ch_list.append(checkpointfol+filename)
for i in range(len(ch_list)):
    print('loading' + ch_list[i])
    net = UNet(n_channels=3, n_classes = classes, bilinear=bilinear)
    net.load_state_dict(torch.load(ch_list[i], map_location=device))
    net.to(device=device)
    result.append(evaluate(net,val_loader,device))
    print(result[i])

loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch1.pth


                                                                                                                       

tensor(0.4561, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch10.pth


                                                                                                                       

tensor(0.5987, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch11.pth


                                                                                                                       

tensor(0.5905, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch12.pth


                                                                                                                       

tensor(0.5983, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch13.pth


                                                                                                                       

tensor(0.6001, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch14.pth


                                                                                                                       

tensor(0.5991, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch15.pth


                                                                                                                       

tensor(0.5986, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch16.pth


                                                                                                                       

tensor(0.5992, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch17.pth


                                                                                                                       

tensor(0.6024, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch18.pth


                                                                                                                       

tensor(0.6067, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch19.pth


                                                                                                                       

tensor(0.5971, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch2.pth


                                                                                                                       

tensor(0.5922, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch20.pth


                                                                                                                       

tensor(0.5978, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch21.pth


                                                                                                                       

tensor(0.6014, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch22.pth


                                                                                                                       

tensor(0.5986, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch3.pth


                                                                                                                       

tensor(0.5831, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch4.pth


                                                                                                                       

tensor(0.5942, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch5.pth


                                                                                                                       

tensor(0.5932, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch6.pth


                                                                                                                       

tensor(0.5952, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch7.pth


                                                                                                                       

tensor(0.5960, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch8.pth


                                                                                                                       

tensor(0.5967, device='cuda:0')
loading.\checkpoints\Base-UNET-Focal1LossRotate\checkpoint_epoch9.pth


                                                                                                                       

tensor(0.5969, device='cuda:0')




In [8]:
result

[tensor(0.4561, device='cuda:0'),
 tensor(0.5987, device='cuda:0'),
 tensor(0.5905, device='cuda:0'),
 tensor(0.5983, device='cuda:0'),
 tensor(0.6001, device='cuda:0'),
 tensor(0.5991, device='cuda:0'),
 tensor(0.5986, device='cuda:0'),
 tensor(0.5992, device='cuda:0'),
 tensor(0.6024, device='cuda:0'),
 tensor(0.6067, device='cuda:0'),
 tensor(0.5971, device='cuda:0'),
 tensor(0.5922, device='cuda:0'),
 tensor(0.5978, device='cuda:0'),
 tensor(0.6014, device='cuda:0'),
 tensor(0.5986, device='cuda:0'),
 tensor(0.5831, device='cuda:0'),
 tensor(0.5942, device='cuda:0'),
 tensor(0.5932, device='cuda:0'),
 tensor(0.5952, device='cuda:0'),
 tensor(0.5960, device='cuda:0'),
 tensor(0.5967, device='cuda:0'),
 tensor(0.5969, device='cuda:0')]