In [90]:
import os
import glob
import cv2
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import albumentations as A
import onnxruntime

## Loading samples' path in a dataframe

In [91]:
PATH = './test-samples/'



mask_files = glob.glob(PATH + '*_mask*')
image_files = [file.replace('_mask', '') for file in mask_files]

def diagnosis(mask_path):
    return 1 if np.max(cv2.imread(mask_path)) > 0 else 0

samples_df = pd.DataFrame({"image_path": image_files,
                  "mask_path": mask_files,
                  "diagnosis": [diagnosis(x) for x in mask_files]})

samples_df

Unnamed: 0,image_path,mask_path,diagnosis
0,./test-samples/TCGA_CS_4942_19970222_11.tif,./test-samples/TCGA_CS_4942_19970222_11_mask.tif,1
1,./test-samples/TCGA_CS_4942_19970222_10.tif,./test-samples/TCGA_CS_4942_19970222_10_mask.tif,1
2,./test-samples/TCGA_CS_4941_19960909_15.tif,./test-samples/TCGA_CS_4941_19960909_15_mask.tif,1
3,./test-samples/TCGA_CS_4942_19970222_12.tif,./test-samples/TCGA_CS_4942_19970222_12_mask.tif,1


## Preprocessing Samples

In [138]:
def load_sample(idx):
    image = cv2.imread(samples_df.iloc[idx, 0])
    image = (np.array(image).astype(np.float32))/255.
    mask = cv2.imread(samples_df.iloc[idx, 1], 0)
#     mask = (np.array(mask).astype(np.float32))/255.
    
    test_transform = A.Compose([
                    A.Resize(width=128, height=128, p=1.0)
                    ])
    
    aug = test_transform(image=image, mask=mask)
    image = aug['image']
    mask = aug['mask']
            
    image = image.transpose((2,0,1))
    
    #image normalize
    mean_vec = np.array([0.485, 0.456, 0.406])
    std_vec = np.array([0.229, 0.224, 0.225])

    for i in range(image.shape[0]):
        image[i, :, :] = (image[i, :, :] - mean_vec[i]) / (std_vec[i])
    
    
    mask = np.expand_dims(mask, axis=-1).transpose((2,0,1))

    return image, mask

In [139]:
test_img, test_mask = load_sample(2)

In [140]:
test_img.shape

(3, 128, 128)

## Creating batch of single image

In [141]:
input_img = np.stack([test_img]*1)
input_img.shape

(1, 3, 128, 128)

## Loading the ONNX model

In [142]:
model_onnx = './checkpoints/brain-mri-unet.onnx'

session = onnxruntime.InferenceSession(model_onnx, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

## Predicting

In [143]:
result = session.run([output_name], {input_name: input_img})
# prediction=(np.argmax(np.array(result).squeeze(), axis=0))
# print(prediction)
result

[array([[[[5.12599945e-06, 8.55326653e-06, 2.90870667e-05, ...,
           1.80006027e-05, 9.59634781e-06, 7.45058060e-06],
          [3.54647636e-06, 4.05311584e-06, 2.53915787e-05, ...,
           1.39474869e-05, 3.96370888e-06, 6.04987144e-06],
          [1.71065331e-05, 3.76403332e-05, 6.13331795e-05, ...,
           5.33461571e-05, 1.85966492e-05, 1.66296959e-05],
          ...,
          [2.19941139e-05, 2.75373459e-05, 3.20672989e-05, ...,
           2.93850899e-05, 1.47521496e-05, 1.88946724e-05],
          [2.11596489e-05, 1.86264515e-05, 1.77323818e-05, ...,
           2.17258930e-05, 1.21593475e-05, 1.66594982e-05],
          [2.28285789e-05, 1.88648701e-05, 1.87158585e-05, ...,
           1.39474869e-05, 9.38773155e-06, 1.12652779e-05]]]],
       dtype=float32)]

In [144]:
arr = np.array(result).astype(np.float32)
arr = arr[0, 0, :, :, :]
arr

array([[[5.12599945e-06, 8.55326653e-06, 2.90870667e-05, ...,
         1.80006027e-05, 9.59634781e-06, 7.45058060e-06],
        [3.54647636e-06, 4.05311584e-06, 2.53915787e-05, ...,
         1.39474869e-05, 3.96370888e-06, 6.04987144e-06],
        [1.71065331e-05, 3.76403332e-05, 6.13331795e-05, ...,
         5.33461571e-05, 1.85966492e-05, 1.66296959e-05],
        ...,
        [2.19941139e-05, 2.75373459e-05, 3.20672989e-05, ...,
         2.93850899e-05, 1.47521496e-05, 1.88946724e-05],
        [2.11596489e-05, 1.86264515e-05, 1.77323818e-05, ...,
         2.17258930e-05, 1.21593475e-05, 1.66594982e-05],
        [2.28285789e-05, 1.88648701e-05, 1.87158585e-05, ...,
         1.39474869e-05, 9.38773155e-06, 1.12652779e-05]]], dtype=float32)

In [136]:
resultimg = Image.fromarray(arr.squeeze())
resultimg.show()

In [145]:
maskimg = test_mask[0, :, :]
maskimg = Image.fromarray(maskimg)
maskimg.show()