In [None]:
import json
import pandas as pd
import copy
import glob
import cv2
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
sys.path.append(os.path.join('./','../pyunet'))
from lib.unet import UNet
from modules.train import Train
import torch

In [None]:
img_height     = 32
img_width      = 32
device         = 'cuda'
gpu_index      = 0
input_img_dir  = "./images/covid19ctscan/small/images/"
input_mask_dir = "./images/covid19ctscan/small/masks/"
model_file     = "test.pth"
epochs         = 100
learning_rate  = 0.001
in_channels    = 3
out_channels   = 4
is_normalized  = True
batch_size     = 2
loss_type      = 'CE'
is_residual    = True
double_skip    = True

params = {
    'img_height':     img_height,
    'img_width':      img_width,
    'device':         device,
    'gpu_index':      gpu_index,
    'input_img_dir':  input_img_dir,
    'input_mask_dir': input_mask_dir,
    'epochs':         epochs,
    'learning_rate':  learning_rate,
    'in_channels':    in_channels,
    'out_channels':   out_channels,
    'loss_type':      loss_type,
    'batch_size':     batch_size,
    'model_file':     model_file
}

cmd = Train(params=params)

cmd.execute()

model = cmd.model

In [None]:
from lib.utils import get_image, get_mask, get_predicted_img
import glob

images = sorted(glob.glob("{}/*".format(input_img_dir)))
masks  = sorted(glob.glob("{}/*".format(input_mask_dir)))

dim = (img_width, img_height)

num_images = len(images)
num_cols   = 3

col_names = [
    "Original",
    "Ground Truth",
    "Prediction"
]

fig, axes = plt.subplots(nrows=num_images, ncols=num_cols, figsize=(num_cols*4, num_images*4))

for ax, col in zip(axes[0], col_names):
    ax.set_title(col)
    
counter = 0

for i in range(num_images):
    image_file = images[i]
    mask_file  = masks[i]
    
    img  = get_image(image_file, dim)
    mask = get_mask(mask_file, dim)
    
    prediction = get_predicted_img(img, model)
    
    counter += 1
    
    plt.subplot(num_images, num_cols, counter)
    plt.imshow(img)
    
    counter += 1
    
    plt.subplot(num_images, num_cols, counter)
    plt.imshow(mask)
    
    counter += 1
    
    plt.subplot(num_images, num_cols, counter)
    plt.imshow(prediction)

plt.show()