In [1]:
%matplotlib inline
import io

import torch
import numpy as np
from ipywidgets import interact
from ipywidgets.widgets import Dropdown, FileUpload, Output
from matplotlib import pyplot as plt
from segmentation_models_pytorch import Unet, Linknet, FPN, PSPNet
from torchvision.transforms import functional as tf
from PIL import Image
plt.ion()

<matplotlib.pyplot._IonContext at 0x7f78d84f3750>

In [2]:
def preprocess(image):
    if image.shape[0] in [1, 3]:
        image = image[0]
    elif image.shape[-1] in [1, 3]:
        image = image[..., -1]
    image = tf.to_pil_image(image)
    image = tf.resize(image, [512, 512])
    image = tf.to_tensor(image)
    image = tf.normalize(image, image.mean(), image.std())
    return image.unsqueeze(0)

def predict(image, experiment_name, architecture_name, encoder, encoder_weights):
    if architecture_name == 'Unet':
        architecture = Unet
    if architecture_name == 'Linknet':
        architecture = Linknet
    if architecture_name == 'FPN':
        architecture = FPN
    if architecture_name == 'PSPNet':
        architecture = PSPNet
    model = architecture(encoder, encoder_weights=encoder_weights, activation='sigmoid', in_channels=1).to('cpu')
    checkpoint = f'https://github.com/pbizopoulos/comprehensive-comparison-of-deep-learning-models-for-lung-and-covid-19-lesion-segmentation-in-ct/releases/latest/download/{experiment_name}-{architecture_name}-{encoder}-{encoder_weights}.pt'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, map_location='cpu'))
    model.eval()
    preprocessed_image = preprocess(image)
    prediction = model(preprocessed_image)
    prediction = prediction[0, 0].detach().numpy()
    prediction = prediction > 0.5
    return prediction, preprocessed_image

In [3]:
experiment_name_list = ['lung-segmentation', 'lesion-segmentation-a']
architecture_name_list = ['Unet', 'Linknet', 'FPN', 'PSPNet']
encoder_list = ['vgg11', 'vgg13', 'vgg19', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'resnext50_32x4d', 'dpn68', 'dpn98', 'mobilenet_v2', 'xception', 'inceptionv4', 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6']
encoder_weights_list = [None, 'imagenet']

output = Output()

def predict_and_plot(image, experiment_name, architecture_name, encoder, encoder_weights):
    if image:
        for name, file_info in image.items():
            image = Image.open(io.BytesIO(file_info['content']))
    else:
        image = Image.open('./covid-19-pneumonia-4.jpg')
    image = np.asarray(image)
    prediction, preprocessed_image = predict(image, experiment_name, architecture_name, encoder, encoder_weights)
    with output:
        output.clear_output(wait=True)
        plt.subplot(121)
        plt.imshow(preprocessed_image[0, 0], cmap='gray')
        plt.subplot(122)
        plt.imshow(preprocessed_image[0, 0], cmap='gray')
        plt.imshow(prediction, cmap='Reds', alpha=0.5)
        plt.show()
p = interact(predict_and_plot, image=FileUpload(accept='image/*', multiple=False), experiment_name=Dropdown(options=experiment_name_list), architecture_name=Dropdown(options=architecture_name_list), encoder=Dropdown(options=encoder_list), encoder_weights=Dropdown(options=encoder_weights_list))
display(output)

interactive(children=(FileUpload(value={}, accept='image/*', description='Upload'), Dropdown(description='expe…

Output()