In [None]:
!conda install -c conda-forge gdcm -y

In [None]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns

import pydicom
import scipy.ndimage
import gdcm

from skimage import measure 
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from skimage.morphology import disk, opening, closing
from tqdm import tqdm

from IPython.display import HTML
from PIL import Image

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

from os import listdir,mkdir

In [None]:
listdir("../input/")


In [None]:
#basepath = "../input/osic-pulmonary-fibrosis-progression/"
# or if you are taking part in RSNA pulmonary embolism detection:
basepath = "../input/rsna-str-pulmonary-embolism-detection/"
listdir(basepath)

In [None]:
train = pd.read_csv(basepath + "train.csv")
test = pd.read_csv(basepath + "test.csv")


In [None]:
train.shape

In [None]:
train.head()

In [None]:
if basepath == "../input/osic-pulmonary-fibrosis-progression/":
    train["dcm_path"] = basepath + "train/" + train.Patient + "/"
else:
    train["dcm_path"] = basepath + "train/" + train.StudyInstanceUID + "/" + train.SeriesInstanceUID  

> #SORTING DONE TO LOAD DATA 

In [None]:
def load_scans(dcm_path):
    if basepath == "../input/osic-pulmonary-fibrosis-progression/":
        # in this competition we have missing values in ImagePosition, this is why we are sorting by filename number
        files = listdir(dcm_path)
        file_nums = [np.int(file.split(".")[0]) for file in files]
        sorted_file_nums = np.sort(file_nums)[::-1]
        slices = [pydicom.dcmread(dcm_path + "/" + str(file_num) + ".dcm" ) for file_num in sorted_file_nums]
    else:
        # otherwise we sort by ImagePositionPatient (z-coordinate) or by SliceLocation
        slices = [pydicom.dcmread(dcm_path + "/" + file) for file in listdir(dcm_path)]
        slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    return slices

In [None]:
example = train.dcm_path.values[0]
scans = load_scans(example)

CHECK FIRST IMAGE

In [None]:
scans[0]

We need to transform to Hounsfield units as the spectral composition of the x-rays depends on the measurement settings like acquisition parameters and tube voltage. By normalizing to values of water and air (water has HU 0 and air -1000) the images of different measurements are becoming comparable

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))
for n in range(10):
    image = scans[n].pixel_array.flatten()
    rescaled_image = image * scans[n].RescaleSlope + scans[n].RescaleIntercept
    sns.distplot(image.flatten(), ax=ax[0]);
    sns.distplot(rescaled_image.flatten(), ax=ax[1])
ax[0].set_title("Raw pixel array distributions for 10 examples")
ax[1].set_title("HU unit distributions for 10 examples");

For some examples we can see that there are raw values at -2000. They correspond to images with a circular boundary within the image. The "outside" of this circle value is often set to -2000

In [None]:
def set_outside_scanner_to_air(raw_pixelarrays):
    # in OSIC we find outside-scanner-regions with raw-values of -2000. 
    # Let's threshold between air (0) and this default (-2000) using -1000
    raw_pixelarrays[raw_pixelarrays <= -1000] = 0
    return raw_pixelarrays

In [None]:
def transform_to_hu(slices):
    images = np.stack([file.pixel_array for file in slices])
    images = images.astype(np.int16)

    images = set_outside_scanner_to_air(images)
    
    # convert to HU
    for n in range(len(slices)):
        
        intercept = slices[n].RescaleIntercept
        slope = slices[n].RescaleSlope
        
        if slope != 1:
            images[n] = slope * images[n].astype(np.float64)
            images[n] = images[n].astype(np.int16)
            
        images[n] += np.int16(intercept)
    
    return np.array(images, dtype=np.int16)

Now all raw values per slice are scaled to H-units.

In [None]:
hu_scans = transform_to_hu(scans)

In [None]:
fig, ax = plt.subplots(1,4,figsize=(20,3))
ax[0].set_title("Original CT-scan")
ax[0].imshow(scans[0].pixel_array, cmap="bone")
ax[1].set_title("Pixelarray distribution");
sns.distplot(scans[0].pixel_array.flatten(), ax=ax[1]);

ax[2].set_title("CT-scan in HU")
ax[2].imshow(hu_scans[0], cmap="bone")
ax[3].set_title("HU values distribution");
sns.distplot(hu_scans[0].flatten(), ax=ax[3]);

for m in [0,2]:
    ax[m].grid(False)

In [None]:
N = 1000

In [None]:
def get_window_value(feature):
    if type(feature) == pydicom.multival.MultiValue:
        return np.int(feature[0])
    else:
        return np.int(feature)

pixelspacing_r = []
pixelspacing_c = []
slice_thicknesses = []
patient_id = []
patient_pth = []
row_values = []
column_values = []
window_widths = []
window_levels = []

if basepath == "../input/osic-pulmonary-fibrosis-progression/":
    patients = train.Patient.unique()[0:N]
else:
    patients = train.SeriesInstanceUID.unique()[0:N]

for patient in patients:
    patient_id.append(patient)
    if basepath == "../input/osic-pulmonary-fibrosis-progression/":
        path = train[train.Patient == patient].dcm_path.values[0]
    else:
        path = train[train.SeriesInstanceUID == patient].dcm_path.values[0]
    example_dcm = listdir(path)[0]
    patient_pth.append(path)
    dataset = pydicom.dcmread(path + "/" + example_dcm)
    
    window_widths.append(get_window_value(dataset.WindowWidth))
    window_levels.append(get_window_value(dataset.WindowCenter))
    
    spacing = dataset.PixelSpacing
    slice_thicknesses.append(dataset.SliceThickness)
    
    row_values.append(dataset.Rows)
    column_values.append(dataset.Columns)
    pixelspacing_r.append(spacing[0])
    pixelspacing_c.append(spacing[1])
    
scan_properties = pd.DataFrame(data=patient_id, columns=["patient"])
scan_properties.loc[:, "rows"] = row_values
scan_properties.loc[:, "columns"] = column_values
scan_properties.loc[:, "area"] = scan_properties["rows"] * scan_properties["columns"]
scan_properties.loc[:, "pixelspacing_r"] = pixelspacing_r
scan_properties.loc[:, "pixelspacing_c"] = pixelspacing_c
scan_properties.loc[:, "pixelspacing_area"] = scan_properties.pixelspacing_r * scan_properties.pixelspacing_c
scan_properties.loc[:, "slice_thickness"] = slice_thicknesses
scan_properties.loc[:, "patient_pth"] = patient_pth
scan_properties.loc[:, "window_width"] = window_widths
scan_properties.loc[:, "window_level"] = window_levels
scan_properties.head()

pixelspacing attribute-e. It tells us how much physical distance is covered by one pixel.

between patients the pixelspacing can differ due to personal or institutional preferences of doctors and the clinic and it also depends on the scanner type. Consequently if you compare two images in the size of the lungs it does not automatically mean that the bigger one is really larger in the physical size of the organ!

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))
sns.distplot(pixelspacing_r, ax=ax[0], color="Limegreen", kde=False)
ax[0].set_title("Pixel spacing distribution \n in row direction ")
ax[0].set_ylabel("Counts in train")
ax[0].set_xlabel("mm")
sns.distplot(pixelspacing_c, ax=ax[1], color="Mediumseagreen", kde=False)
ax[1].set_title("Pixel spacing distribution \n in column direction");
ax[1].set_ylabel("Counts in train");
ax[1].set_xlabel("mm");

PIXEL SPACING IS VARYING A LOT ! NEED TO IMPROVE

In [None]:
counts = scan_properties.groupby(["rows", "columns"]).size()
counts = counts.unstack()
counts.fillna(0, inplace=True)


fig, ax = plt.subplots(1,2,figsize=(20,5))
sns.distplot(slice_thicknesses, color="orangered", kde=False, ax=ax[0])
ax[0].set_title("Slice thicknesses of all patients");
ax[0].set_xlabel("Slice thickness in mm")
ax[0].set_ylabel("Counts in train");

for n in counts.index.values:
    for m in counts.columns.values:
        ax[1].scatter(n, m, s=counts.loc[n,m], c="midnightblue")
ax[1].set_xlabel("rows")
ax[1].set_ylabel("columns")
ax[1].set_title("Pixel area of ct-scan per patient");

The slice thickness tells us how much distance is covered in Z-direction by one slic

pixel_array of raw values covers a specific area given by row and column values

THIN SLICES IN CT SCAN MEANS MORE DETAIL

In [None]:
scan_properties["r_distance"] = scan_properties.pixelspacing_r * scan_properties.rows
scan_properties["c_distance"] = scan_properties.pixelspacing_c * scan_properties["columns"]
scan_properties["area_cm2"] = 0.1* scan_properties["r_distance"] * 0.1*scan_properties["c_distance"]
scan_properties["slice_volume_cm3"] = 0.1*scan_properties.slice_thickness * scan_properties.area_cm2

We have some images with extreme large sliche areas and volumes

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))
sns.distplot(scan_properties.area_cm2, ax=ax[0], color="purple")
sns.distplot(scan_properties.slice_volume_cm3, ax=ax[1], color="magenta")
ax[0].set_title("CT-slice area in $cm^{2}$")
ax[1].set_title("CT-slice volume in $cm^{3}$")
ax[0].set_xlabel("$cm^{2}$")
ax[1].set_xlabel("$cm^{3}$");

In [None]:
max_path = scan_properties[
    scan_properties.area_cm2 == scan_properties.area_cm2.max()].patient_pth.values[0]
min_path = scan_properties[
    scan_properties.area_cm2 == scan_properties.area_cm2.min()].patient_pth.values[0]

min_scans = load_scans(min_path)
min_hu_scans = transform_to_hu(min_scans)

max_scans = load_scans(max_path)
max_hu_scans = transform_to_hu(max_scans)

background_water_hu_scans = max_hu_scans.copy()

In [None]:
def set_manual_window(hu_image, custom_center, custom_width):
    w_image = hu_image.copy()
    min_value = custom_center - (custom_width/2)
    max_value = custom_center + (custom_width/2)
    w_image[w_image < min_value] = min_value
    w_image[w_image > max_value] = max_value
    return w_image

Smallest and larges CT-slice area

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,10))
ax[0].imshow(set_manual_window(min_hu_scans[np.int(len(min_hu_scans)/2)], -700, 255), cmap="YlGnBu")
ax[1].imshow(set_manual_window(max_hu_scans[np.int(len(max_hu_scans)/2)], -700, 255), cmap="YlGnBu");
ax[0].set_title("CT-scan with small slice area")
ax[1].set_title("CT-scan with large slice area");
for n in range(2):
    ax[n].axis("off")

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))
sns.distplot(max_hu_scans[np.int(len(max_hu_scans)/2)].flatten(), kde=False, ax=ax[1])
ax[1].set_title("Large area image")
sns.distplot(min_hu_scans[np.int(len(min_hu_scans)/2)].flatten(), kde=False, ax=ax[0])
ax[0].set_title("Small area image")
ax[0].set_xlabel("HU values")
ax[1].set_xlabel("HU values");

we can see that the large one has a lot of useless region covered. We could crop it.

In [None]:
max_path = scan_properties[
    scan_properties.slice_volume_cm3 == scan_properties.slice_volume_cm3.max()].patient_pth.values[0]
min_path = scan_properties[
    scan_properties.slice_volume_cm3 == scan_properties.slice_volume_cm3.min()].patient_pth.values[0]

min_scans = load_scans(min_path)
min_hu_scans = transform_to_hu(min_scans)

max_scans = load_scans(max_path)
max_hu_scans = transform_to_hu(max_scans)

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,10))
ax[0].imshow(set_manual_window(min_hu_scans[np.int(len(min_hu_scans)/2)], -700, 255), cmap="YlGnBu")
ax[1].imshow(set_manual_window(max_hu_scans[np.int(len(max_hu_scans)/2)], -700, 255), cmap="YlGnBu");
ax[0].set_title("CT-scan with small slice volume")
ax[1].set_title("CT-scan with large slice volume");
for n in range(2):
    ax[n].axis("off")

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))
sns.distplot(max_hu_scans[np.int(len(max_hu_scans)/2)].flatten(), kde=False, ax=ax[1])
ax[1].set_title("Large slice volume")
sns.distplot(min_hu_scans[np.int(len(min_hu_scans)/2)].flatten(), kde=False, ax=ax[0])
ax[0].set_title("Small slice volume")
ax[0].set_xlabel("HU values")
ax[1].set_xlabel("HU values");

3D-reconstruction of CT-scans

In [None]:
def plot_3d(image, threshold=700, color="navy"):
    
    # Position the scan upright, 
    # so the head of the patient would be at the top facing the camera
    p = image.transpose(2,1,0)
    
    verts, faces,_,_ = measure.marching_cubes_lewiner(p, threshold)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], alpha=0.2)
    mesh.set_facecolor(color)
    ax.add_collection3d(mesh)

    ax.set_xlim(0, p.shape[0])
    ax.set_ylim(0, p.shape[1])
    ax.set_zlim(0, p.shape[2])

    plt.show()

In [None]:
plot_3d(max_hu_scans)

In [None]:
old_distribution = max_hu_scans.flatten()

In [None]:
example = train.dcm_path.values[0]
scans = load_scans(example)
hu_scans = transform_to_hu(scans)

In [None]:
plot_3d(hu_scans)

In [None]:
plt.figure(figsize=(20,5))
sns.distplot(old_distribution, label="weak 3d plot", kde=False)
sns.distplot(hu_scans.flatten(), label="strong 3d plot", kde=False)
plt.title("HU value distribution")
plt.legend();

In [None]:
print(len(max_hu_scans), len(hu_scans))

Resampling the voxel size

We can resize, crop, blur and shift intensities

In [None]:
def resample(image, scan, new_spacing=[1,1,1]):
    # Determine current pixel spacing
    spacing = np.array([scan[0].SliceThickness] + list(scan[0].PixelSpacing), dtype=np.float32)
    
    resize_factor = spacing / new_spacing
    new_shape = np.round(image.shape * resize_factor)
    
    # recompute the resize factor and spacing such that we match the rounded new shape above
    rounded_resize_factor = new_shape / image.shape
    rounded_new_spacing = spacing / rounded_resize_factor
    
    # zoom with resize factor
    image = scipy.ndimage.interpolation.zoom(image, rounded_resize_factor, mode='nearest')
    
    return image, rounded_new_spacing

In [None]:
img_resampled, spacing = resample(max_hu_scans, scans, [1,1,1])
print("Shape before resampling\t", max_hu_scans.shape)
print("Shape after resampling\t", img_resampled.shape)

In [None]:
plot_3d(img_resampled)

In [None]:
image_sizes = scan_properties.groupby(["rows", "columns"]).size().sort_values(ascending=False)
image_sizes

In [None]:
plt.figure(figsize=(8,8))
for n in counts.index.values:
    for m in counts.columns.values:
        plt.scatter(n, m, s=counts.loc[n,m], c="dodgerblue", alpha=0.7)
plt.xlabel("rows")
plt.ylabel("columns")
plt.title("Pixel area of ct-scan per patient");
plt.plot(np.arange(0,1400), '-.', c="purple", label="squared")
plt.plot(888 * np.ones(1400), '-.', c="crimson", label="888 rows");
plt.legend();

In [None]:
class ImageObserver:
    
    def __init__(self, scan_properties, batch_size):
        self.scan_properties = scan_properties
        self.batch_size = batch_size
    
    def select_group(self, group=(512,512)):
        self.group = group
        self.name = "rows {}, columns {}".format(group[0], group[1])
        self.batch_shape = (self.batch_size, group[0], group[1])
        self.selection = self.scan_properties[
            (self.scan_properties["rows"]==group[0]) & (self.scan_properties["columns"]==group[1])
        ].copy()
        self.patient_pths = self.selection.patient_pth.unique()
    
    
    def get_loader(self):
        
        idx=0
        images = np.zeros(self.batch_shape)
        
        for path in self.patient_pths:
            
            scans = load_scans(path)
            hu_scans = transform_to_hu(scans)
            images[idx,:,:] = hu_scans[0]
            
            idx += 1
            if idx == self.batch_shape[0]:
                yield images
                images = np.zeros(self.batch_shape)
                idx = 0
        if idx > 0:
            yield images

In [None]:
my_choice = image_sizes.index.values[0]
print(my_choice)
to_display = 4

In [None]:
observer = ImageObserver(scan_properties, to_display)
observer.select_group(my_choice)
observer_iterator = observer.get_loader()

In [None]:
images = next(observer_iterator)

In [None]:
fig, ax = plt.subplots(1,to_display,figsize=(20,5))


for m in range(to_display):
    image = images[m]
    ax[m].imshow(set_manual_window(image, -500, 1000), cmap="YlGnBu")
    ax[m].set_title(observer.name)

In [None]:
scan_properties.shape

In [None]:
scan_properties.head(1)

In [None]:
def resize_scan(scan, new_shape):
    # read slice as 32 bit signed integers
    img = Image.fromarray(scan, mode="I")
    # do the resizing
    img = img.resize(new_shape, resample=Image.LANCZOS)
    # convert back to 16 bit integers
    resized_scan = np.array(img, dtype=np.int16)
    return resized_scan

In [None]:
def crop_scan(scan):
    img = Image.fromarray(scan, mode="I")
    
    left = (scan.shape[0]-512)/2
    right = (scan.shape[0]+512)/2
    top = (scan.shape[1]-512)/2
    bottom = (scan.shape[1]+512)/2

    img = img.crop((left, top, right, bottom))
    # convert back to 16 bit integers
    cropped_scan = np.array(img, dtype=np.int16)
    return cropped_scan

In [None]:
def crop_and_resize(scan, new_shape):
    img = Image.fromarray(scan, mode="I")
    
    left = (scan.shape[0]-512)/2
    right = (scan.shape[0]+512)/2
    top = (scan.shape[1]-512)/2
    bottom = (scan.shape[1]+512)/2
    
    img = img.crop((left, top, right, bottom))
    img = img.resize(new_shape, resample=Image.LANCZOS)
    
    cropped_resized_scan = np.array(img, dtype=np.int16)
    return cropped_resized_scan

In [None]:
def preprocess_to_hu_scans(scan_properties, my_shape, output_dir):
    
    for i, patient in enumerate(tqdm(scan_properties.patient.values)):
        pth = scan_properties.loc[scan_properties.patient==patient].patient_pth.values[0]
        scans = load_scans(pth)
        hu_scans = transform_to_hu(scans) 
        prepared_scans = np.zeros((hu_scans.shape[0], my_shape[0], my_shape[1]), dtype=np.int16)
        
        # if squared:
        if hu_scans.shape[1] == hu_scans.shape[2]:
            
            # if size is as desired
            if hu_scans.shape[1] == my_shape[0]:
                continue
            # else resize:
            else:
               # as we have not converted to jpeg to keep all information, we need to do a workaround
                hu_scans = hu_scans.astype(np.int32)
                for s in range(hu_scans.shape[0]): 
                    prepared_scans[s] = resize_scan(hu_scans[s,:,:], my_shape)

        # if non-squared - do a center crop to 512, 512 and then resize to desired shape
        else:
            hu_scans = hu_scans.astype(np.int32)
            for s in range(hu_scans.shape[0]):
                # if desired shape is 512x512:
                if my_shape[0]==512:
                    prepared_scans[s] = crop_scan(hu_scans[s,:,:])
                else:
                    prepared_scans[s] = crop_and_resize(hu_scans[s,:,:], my_shape)
                
        # save the prepared scans of patient:
        np.save(output_dir + "/" + patient + '_hu_scans', prepared_scans)

In [None]:
generate_512_512 = False

if generate_512_512:
    output_dir = "scans_512x512"
    mkdir(output_dir)
    my_shape = (512, 512)
    preprocess_to_hu_scans(scan_properties, my_shape, output_dir)

In [None]:
generate_224_224 = False

if generate_224_224:
    output_dir = "scans_224x224"
    mkdir(output_dir)
    my_shape = (224, 224)
    preprocess_to_hu_scans(scan_properties, my_shape, output_dir)

In [None]:
generate_128_128 = False

if generate_128_128:
    output_dir = "scans_128x128"
    mkdir(output_dir)
    my_shape = (128, 128)
    preprocess_to_hu_scans(scan_properties, my_shape, output_dir)

In [None]:
generate_64_64 = False

if generate_64_64:
    output_dir = "scans_64x64"
    mkdir(output_dir)
    my_shape = (64, 64)
    preprocess_to_hu_scans(scan_properties, my_shape, output_dir)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import glob
import pandas as pd 
from tqdm import tqdm
import sys
import glob
import cv2

import pydicom
from sklearn.utils import shuffle

import albumentations as A

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.optim.lr_scheduler import  ReduceLROnPlateau

In [None]:
sys.path.append('../input/efficientnet-pytorch/EfficientNet-PyTorch-master')
sys.path.append('../input/pretrainedmodels/pretrainedmodels-0.7.4/')
sys.path.append('../input/segmentation-models-pytorch/')
import segmentation_models_pytorch as smp

In [None]:
root_path = '../input/osiclungmask100'
img_path = sorted(glob.glob(root_path+'/*/img.png'))
mask_path = sorted(glob.glob(root_path+'/*/post_label.png'))

imgpaths,maskpaths = shuffle(img_path,mask_path, random_state=0)

train_images_path = imgpaths[:int(len(imgpaths)*0.8)]
train_masks_path = maskpaths[:int(len(imgpaths)*0.8)]
val_images_path = imgpaths[int(len(imgpaths)*0.8):]
val_masks_path = maskpaths[int(len(maskpaths)*0.8):]

transform = A.Compose([
    A.Rotate(p=0.2,limit=30),
    A.HorizontalFlip(p=0.2),
    A.OneOf([
        A.GridDistortion(p=0.1,distort_limit=0.2),
        A.ElasticTransform(sigma=10, alpha=1,  p=0.1)
    ]),
])


In [None]:
batch = 8
lr = 0.0003
wd = 5e-4
epochs = 80
output_path = './'
device =  torch.device('cuda:0')
experiment_name = 'lung_Unet_densenet121'

In [None]:
class Data_Generate(Dataset):
    def __init__(self,img_paths,seg_paths=None,transform=None):
        self.img_paths = img_paths
        self.seg_paths = seg_paths
        self.transform = transform
        
    def __getitem__(self,index):
        if self.seg_paths is not None:
            img_path = self.img_paths[index]
            mask_path = self.seg_paths[index]
            
            mask = cv2.imread(mask_path,0)/255
            img = cv2.imread(img_path,0)/255

            if self.transform != None:
                aug = transform(image=img,mask=mask)
                img = aug['image']
                mask = aug['mask']
                
            img = img[None,:,:]
            img = img.astype(np.float32)
            mask = mask[None,:,:]
            mask = mask.astype(np.float32)
            
            return img,mask
        
        else:
            img = cv2.imread(self.img_paths[index],0)/255
            img = img[None,:,:]
            img = img.astype(np.float32)
            return img
        
    def __len__(self):
        return len(self.img_paths)

In [None]:
train_db = Data_Generate(train_images_path,train_masks_path,transform=transform)
train_loader = DataLoader(train_db, batch_size=batch, shuffle=True, num_workers=4)
val_db = Data_Generate(val_images_path,val_masks_path,transform=None)
val_loader = DataLoader(val_db, batch_size=batch, shuffle=False, num_workers=4)

In [None]:
f,ax = plt.subplots(4,4,figsize=(16,16))
for i in range(16):
    img = train_db[i][0]
    ax[i//4,i%4].imshow(img[0])

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
model = smp.Unet('densenet121', classes=1, in_channels=1,activation='sigmoid',encoder_weights='imagenet').to(device)
    
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-8, verbose=True)

criterion = smp.utils.losses.DiceLoss(eps=1.)
iou = smp.utils.metrics.IoU()
early_stopping = EarlyStopping(patience=6, verbose=True,path=os.path.join(output_path, f'best_{experiment_name}.pth'))


In [None]:
num_train_loader = len(train_loader)
num_val_loader = len(val_loader)
for epoch in range(epochs):
    train_losses,train_score,val_losses,val_score = 0,0,0,0
    model.train()

    for idx, sample in enumerate(train_loader):
        image, label = sample
        image, label = image.to(device), label.to(device)
        out = model(image)
        loss = criterion(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses += loss/num_train_loader
        train_score += iou(out,label)/num_train_loader
    
    model.eval()
    for idx, sample in enumerate(val_loader):
        image, label = sample
        image, label = image.to(device), label.to(device)
        with torch.no_grad():
            out = model(image)
        loss = criterion(out, label)
        val_losses += loss/num_val_loader
        val_score += iou(out,label)/num_val_loader
    print('epoch {}/{}\t LR:{}\t train_loss:{}\t train_score:{}\t val_loss:{}\t val_score:{}' \
          .format(epoch+1, epochs, optimizer.param_groups[0]['lr'], train_losses, train_score, val_losses, val_score))
    scheduler.step(val_losses)
    
    early_stopping(val_losses, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

In [None]:
class Test_Generate(Dataset):
    def __init__(self,img_paths):
        self.img_paths = img_paths
        
    def __getitem__(self,index):
        dicom = pydicom.dcmread(self.img_paths[index])
        slice_img = dicom.pixel_array
        slice_img = (slice_img-slice_img.min())/(slice_img.max()-slice_img.min())
        slice_img = (slice_img*255).astype(np.uint8)
        if slice_img.shape[0] != 512:
            slice_img = cv2.resize(slice_img,(512,512))
            
        slice_img = slice_img[None,:,:]
        slice_img = (slice_img/255).astype(np.float32)
        return slice_img
        
    def __len__(self):
        return len(self.img_paths)

In [None]:
dicom_root_path = '../input/osic-pulmonary-fibrosis-progression/train/*/*'
dicom_paths = glob.glob(dicom_root_path)
dicom_paths = random.sample(dicom_paths,16)

test_db = Test_Generate(dicom_paths)
test_loader = DataLoader(test_db, batch_size=batch, shuffle=False, num_workers=0)

model.load_state_dict(torch.load('./best_lung_Unet_densenet121.pth'))
model.eval()

outs = []
for idx, sample in enumerate(test_loader):
    image = sample
    image = image.to(device)
    with torch.no_grad():
        out = model(image)
    out = out.cpu().data.numpy()
    out = np.where(out>0.5,1,0)
    out = np.squeeze(out)
    outs.append(out)
    
outs = np.concatenate(outs)

In [None]:
f,ax = plt.subplots(4,4,figsize=(16,16))
axes = ax.flatten()
for idx in range(len(outs)//2):
    axes[idx*2].imshow(test_db[idx][0])
    axes[idx*2+1].imshow(outs[idx])