In [1]:
import matplotlib.pyplot as plt
from ctypes import *
from typing import List
from skimage import io
import numpy as np
import xir
import vart
import vitis_ai_library
import os
import math

import time
import sys

from functools import partial

from unet.dataset import ImageToImage2D, JointTransform2D
from unet.metrics import jaccard_index, f1_score, LogNLLLoss
from unet.utils import MetricList, Logger

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F

In [2]:
dataset_dir = "./images_folder"
results_path = "./build_test/comp_model/"
model_path = "./build_test/comp_model/UNet2D_compiled.xmodel"

In [9]:
def val_epoch(model, dataset, n_batch=1, metric_list=MetricList({}), loss_function=LogNLLLoss()):
    """
    Validation of given dataset.

    Args:
        dataset: an instance of unet.dataset.ImageToImage2D
        n_batch: size of batch during training
            metric_list: unet.utils.MetricList object, which contains metrics
            to be recorded during validation

    Returns:
        logs: dictionary object containing the validation loss and
            the metrics given by the metric_list object
    """

    metric_list.reset()
    running_val_loss = 0.0
        
    g = xir.Graph.deserialize(model)
    runner = vitis_ai_library.GraphRunner.create_graph_runner(g)
    # get a list of runner inputs
    inputTensors = runner.get_input_tensors()
    output_tensor_buffers = runner.get_outputs()    
     
    for batch_idx, (X_batch, y_batch, *rest) in enumerate(DataLoader(dataset, batch_size=1)):
            
        input_image = np.asarray(X_batch, dtype=np.float32) / 255
            
        inputData = []
        for inputTensor in inputTensors:
            inputData.append(input_image.reshape(inputTensor.dims))
        job_id = runner.execute_async(inputData, output_tensor_buffers)
        runner.wait(job_id)
            
        output = np.array(output_tensor_buffers[0], np.uint8)
        mask = np.transpose(output, (0, 3, 1, 2)).astype(np.float32)
        y_out = torch.from_numpy(mask)
        training_loss = loss_function(y_out, y_batch)
        running_val_loss += training_loss.item()
        metric_list(y_out, y_batch)

    del X_batch, y_batch

    logs = {'val_loss': running_val_loss/(batch_idx + 1),
                **metric_list.get_results(normalize=batch_idx+1)}

    return logs

In [10]:
tf_val = JointTransform2D(crop=(256, 256), p_flip=0.5, color_jitter_params=None, long_mask=True)
predict_dataset = ImageToImage2D(dataset_dir, tf_val)

logger = Logger(verbose=True)
metric_list = MetricList({'jaccard': partial(jaccard_index),
                                  'f1': partial(f1_score)})

In [12]:
logs = val_epoch(model=model_path, dataset=predict_dataset, metric_list=metric_list)
print(logs)

{'val_loss': 1.7047607799074542, 'jaccard': 0.019970975973323656, 'f1': 0.4682012813304787}
