In [1]:
import os

import numpy as np
import pandas as pd

import pydicom
import cv2
import matplotlib.pyplot as plt

from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import binary_crossentropy
from keras.utils import Sequence
from keras import backend as keras
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler

from glob import glob
from tqdm import tqdm

Using TensorFlow backend.


In [14]:
INPUT_DIR = os.path.join("/mnt/data/Smart-Detect/", "Segmentation")

SEGMENTATION_DIR = os.path.join(INPUT_DIR, "segmentation")

SEGMENTATION_MODEL = os.path.join(INPUT_DIR, "unet_lung_seg.hdf5")

#Where segmented images will be saved
SEGMENTATION_RESULT = os.path.join(SEGMENTATION_DIR, "result")
SEGMENTATION_RESULT_TRAIN = os.path.join("/mnt/data/Smart-Detect/Target_Class_2D/regular_seg/", "covid")
SEGMENTATION_RESULT_TEST = os.path.join("/mnt/data/Smart-Detect/Segmentation/segtrain/", "No_Findings")

#Unused?
SEGMENTATION_TEST_DIR = os.path.join(SEGMENTATION_DIR, "test")
SEGMENTATION_TRAIN_DIR = os.path.join(SEGMENTATION_DIR, "train/image")

#The images to segment
RSNA_TRAIN_DIR = os.path.join("/mnt/data/Smart-Detect/Target_Class_2D/regular/", "covid")
RSNA_TEST_DIR = os.path.join("/mnt/data/Smart-Detect/ChestX-ray14/train/", "No_Findings")

#RSNA_LABELS_FILE = os.path.join(RSNA_DIR, "stage_1_train_labels.csv")
#RSNA_CLASS_INFO_FILE = os.path.join(RSNA_DIR, "stage_1_detailed_class_info.csv")



In [5]:
def dice_coef(y_true, y_pred):
    y_true_f = keras.flatten(y_true)
    y_pred_f = keras.flatten(y_pred)
    intersection = keras.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (keras.sum(y_true_f) + keras.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

segmentation_model = load_model(SEGMENTATION_MODEL, \
                                custom_objects={'dice_coef_loss': dice_coef_loss, \
                                                'dice_coef': dice_coef})

segmentation_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 512, 512, 1)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 32) 320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 512, 512, 32) 9248        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 256, 256, 32) 0           conv2d_2[0][0]                   
____________________________________________________________________________________________

In [12]:
def image_to_train(img):
    img = (img*(-1))+255
    npy = img /255
    
    npy = np.reshape(npy, npy.shape+(1,))
    npy = np.reshape(npy, (1,) + npy.shape)
    return npy

def train_to_image(npy):
    img = (npy[0,:,:,0]*255.).astype(np.uint8)
    kernel = np.ones((40,40),np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
    img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
    return img

In [16]:
def segment_image(pid, img, save_to):
    img = cv2.resize(img, (512, 512))
    segm_ret = segmentation_model.predict(image_to_train(img), \
                                          verbose=0)

    #print(segm_ret.shape)
    #mask_pp = cv2.morphologyEx(segm_ret, cv2.MORPH_OPEN, 10)
    img = cv2.bitwise_and(img, img, mask=train_to_image(segm_ret))
    
    cv2.imwrite(os.path.join(save_to, "%s.png" % pid), img)

for filename in tqdm(glob(os.path.join(RSNA_TRAIN_DIR, "*.png"))):
    pid, fileext = os.path.splitext(os.path.basename(filename))
    img = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
    segment_image(pid, img, SEGMENTATION_RESULT_TRAIN)

#for filename in tqdm(glob(os.path.join(RSNA_TEST_DIR, "*.jpg"))):
#    pid, fileext = os.path.splitext(os.path.basename(filename))
#    img = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
#    segment_image(pid, img, SEGMENTATION_RESULT_TEST)



100%|██████████| 6724/6724 [44:38<00:00,  2.51it/s] 
