In [2]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.data_loading import BasicDataset
from utils.dice_score import multiclass_dice_coeff, dice_coeff
from torch.utils.data import DataLoader, random_split
from unet import UNet
import os
import segmentation_models_pytorch as smp

In [20]:
def evaluate(net, dataloader, device):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']
        # move images and labels to correct device and type
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)

        mask_true = F.one_hot(mask_true, 5).permute(0,3, 1, 2).float()
        with torch.no_grad():
            # predict the mask
            mask_pred = net(image)
            # convert to one-hot format
            if 5 == 1:
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), 5).permute(0,3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)

           

    net.train()

    # Fixes a potential division by zero error
    if num_val_batches == 0:
        return dice_score
    return dice_score / num_val_batches

In [21]:
def evaluateloss(net, dataloader, device):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']
        # move images and labels to correct device and type
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)
        mask_true = F.one_hot(mask_true, 5).permute(0,3, 1, 2).float()
        with torch.no_grad():
            # predict the mask
            mask_pred = net(image)
            # convert to one-hot format
            if net.n_classes == 1:
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), 5).permute(0,3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)

           

    net.train()

    # Fixes a potential division by zero error
    if num_val_batches == 0:
        return dice_score
    return dice_score / num_val_batches

In [22]:
checkpointfol = '.\\checkpoints\\Resnet50higherlr\\'
dir_img = '.\\Dataset\\Validation\\Post\\Image512\\'
dir_mask = '.\\Dataset\\Validation\\Post\\Label512\\'
img_scale = 1
classes = 5
bilinear = False
device = 'cuda'

In [23]:
dataset = BasicDataset(dir_img, dir_mask, img_scale,mask_suffix = '.png')

In [24]:
val_loader = DataLoader(dataset, shuffle=True,batch_size=1, num_workers=1, pin_memory=True)

In [25]:
ch_list = []
result = []
for filename in os.listdir(checkpointfol):
    if filename.endswith('.pth'):
        ch_list.append(checkpointfol+filename)
for i in range(len(ch_list)):
    print('loading' + ch_list[i])
    net = smp.Unet(
        encoder_name='resnet50',
        encoder_depth=5,
        encoder_weights='imagenet',
        decoder_use_batchnorm=False,
        decoder_channels=(1024,512,256, 128, 64),
        decoder_attention_type=None,
        in_channels=3,
        classes=5,
        activation=None,
        aux_params=None
    )
    net.load_state_dict(torch.load(ch_list[i], map_location=device))
    net.to(device=device)
    result.append(evaluate(net,val_loader,device))
    print(result[i])

loading.\checkpoints\Resnet50higherlr\checkpoint_epoch1.pth


                                                                                                                       

tensor(0.1856, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch10.pth


                                                                                                                       

tensor(0.5278, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch11.pth


                                                                                                                       

tensor(0.5700, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch12.pth


                                                                                                                       

tensor(0.5754, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch13.pth


                                                                                                                       

tensor(0.5901, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch14.pth


                                                                                                                       

tensor(0.5913, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch15.pth


                                                                                                                       

tensor(0.5989, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch16.pth


                                                                                                                       

tensor(0.5963, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch17.pth


                                                                                                                       

tensor(0.5931, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch18.pth


                                                                                                                       

tensor(0.5969, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch19.pth


                                                                                                                       

tensor(0.5885, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch2.pth


                                                                                                                       

tensor(0.1836, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch20.pth


                                                                                                                       

tensor(0.5976, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch3.pth


                                                                                                                       

tensor(0.1773, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch4.pth


                                                                                                                       

tensor(0.1974, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch5.pth


                                                                                                                       

tensor(0.2291, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch6.pth


                                                                                                                       

tensor(0.2782, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch7.pth


                                                                                                                       

tensor(0.3189, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch8.pth


                                                                                                                       

tensor(0.4159, device='cuda:0')
loading.\checkpoints\Resnet50higherlr\checkpoint_epoch9.pth


                                                                                                                       

tensor(0.4922, device='cuda:0')




In [None]:
result