In [None]:
!pip install pydicom
import numpy as np 
import pandas as pd
import pydicom,os,cv2
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-whitegrid')
import warnings
warnings.simplefilter("ignore")
%matplotlib inline

warnings.filterwarnings("ignore")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Downloading various python libraries to help clean, present, and augment data to be fed to the model. 

- Numpy was used to help process and clean the data
- Pandas was used to help create a dataframe that allowed for easy access to the metadata of the images from the training set 
- Pydicom was used to help obtain the metadata from each image, such as the patients age, sex, view position and so on. 
- Matplotlib was used to help present the data in graph formats: pie chart, bar graph, and line charts
- Torch was used to help create an AI model that can determine whether a set of lungs has pneumothorax or not based on x-rays 

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir("/content/drive/My Drive/Colab Notebooks")
!ls
data_folder = 'train_png'
train_rle_path = '/content/drive/My Drive/Colab Notebooks/pneumothorax/train-rle.csv'

print(os.path.exists('train_png'))


In [None]:
from glob import glob
train = sorted(glob('pneumothorax/dicom-images-train/*/*/*.dcm'))
test = sorted(glob('pneumothorax/dicom-images-test/*/*/*.dcm'))

print(f'Number of train dicom files in folder:{len(train)}')
print(f'Number of test dicom files in folder:{len(test)}')
train_rle_path = 'pneumothorax/train-rle.csv'
df = pd.read_csv('pneumothorax/train-rle.csv')
res = [*set(df)]
print(df.shape)

print(f'Total no of unique images in csv file: {df["ImageId"].nunique()}')
print(f'Images with duplicate Encoded pixels {df[df.duplicated(subset=["ImageId"])].shape[0]}')

print("-1 means no Pneumothorax")
df.head()

In [None]:

#def convert_images(filename, outdir):
 #   ds = pydicom.read_file(str(filename))
  #  img = ds.pixel_array
   # img = cv2.resize(img, (128, 128))
    #cv2.imwrite(outdir + filename.split('/')[-1][:-4] + '.png', img)
#train_path = 'pneumothorax/dicom-images-train/'
#test_path = 'pneumothorax/dicom-images-train/'
#train_out_path = 'train_png/'
#test_out_path = 'test_png/'
#if not os.path.exists(test_out_path):
 #   os.makedirs(test_out_path)
#import os
#import cv2
#import glob2
#import pydicom
#from joblib import Parallel, delayed
#from tqdm import tqdm_notebook as tqdm
#train_dcm_list = glob2.glob(os.path.join(train_path, '**/*.dcm'))
#test_dcm_list = glob2.glob(os.path.join(test_path, '**/*.dcm'))
#res1 = Parallel(n_jobs=8, backend='threading')(delayed(
 #   convert_images)(i, test_out_path) for i in tqdm(test_dcm_list, total=len(test_dcm_list)))

In [None]:
import pydicom
import matplotlib.pyplot as plt
#displaying the image
img = pydicom.read_file(train[0]).pixel_array
plt.imshow(img, cmap='bone')
plt.grid(False)

#displaying metadata
data = pydicom.dcmread(train[0])
print(data)


In [None]:
def show_info(dataset):
    print("Filename......:", file_path)
    print("Storage type.....:", dataset.SOPClassUID)
    print()

    pat_name = dataset.PatientName
    display_name = pat_name.family_name + ", " + pat_name.given_name
    print("Patient's name......:", display_name)
    print("Patient id..........:", dataset.PatientID)
    print("Patient's Age.......:", dataset.PatientAge)
    print("Patient's Sex.......:", dataset.PatientSex)
    print("Modality............:", dataset.Modality)
    print("Body Part Examined..:", dataset.BodyPartExamined)
    print("View Position.......:", dataset.ViewPosition)
    
    if 'PixelData' in dataset:
        rows = int(dataset.Rows)
        cols = int(dataset.Columns)
        print("Image size.......: {rows:d} x {cols:d}, {size:d} bytes".format(
            rows=rows, cols=cols, size=len(dataset.PixelData)))
        if 'PixelSpacing' in dataset:
            print("Pixel spacing....:", dataset.PixelSpacing)
def plot_pixel(dataset, figsize=(10,10)):
    plt.figure(figsize=figsize)
    plt.imshow(dataset.pixel_array, cmap=plt.cm.bone)
    plt.show()

In [None]:
file_path = train[0]
data = pydicom.dcmread(file_path)
show_info(data)
plot_pixel(data)

In [None]:
#dataframe to ease the access
patients = []
missing = 0

pd.reset_option('max_colwidth')

for t in train:
    data = pydicom.dcmread(t)
    patient = {}
    patient["UID"] = data.SOPInstanceUID
    try:
        encoded_pixels = df[df["ImageId"] == patient["UID"]].values[0][1]
        patient["EncodedPixels"] = encoded_pixels
    except:
        missing = missing + 1
    patient["Age"] = data.PatientAge
    patient["Sex"] = data.PatientSex
    patient["Modality"] = data.Modality
    patient["BodyPart"] = data.BodyPartExamined
    patient["ViewPosition"] = data.ViewPosition
    patients.append(patient)

print("missing labels: ", missing)
df_patients = pd.DataFrame(patients, columns=["UID", "EncodedPixels", "Age", "Sex", "Modality", "BodyPart", "ViewPosition"])
print("images with labels: ", df_patients.shape[0])
df_patients.head()

In [None]:
import matplotlib as mpl
import numpy as np

#gender
men = df_patients[df_patients["Sex"] == "M"].shape[0]
women = df_patients.shape[0] - men
print(men, women)


#illness
healthy = df_patients[df_patients["EncodedPixels"] == " -1"].shape[0]
ill = df_patients.shape[0] - healthy
print(healthy, ill)

#gender + illness
men_h = df_patients[(df_patients["Sex"] == "M") & (df_patients["EncodedPixels"] == " -1")].shape[0]
men_ill = men - men_h
women_h = df_patients[(df_patients["Sex"] == "F") & (df_patients["EncodedPixels"] == " -1")].shape[0]
women_ill = women - women_h
print(men_h, men_ill, women_h, women_ill)

perc = [str(round(men_ill/107.12, 1)) + "% \n ill", "healthy \n" + str(round(men_h/107.12, 1)) + "%", "healthy \n" + str(round(women_h/107.12, 1)) + "%",str(round(women_ill/107.12, 1)) + "% \n ill"]

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

fig.suptitle("4.1 Gender and Pneumothorax distributions", fontsize=24, y=1.1)

mpl.rcParams['font.size'] = 12.0

#circle for donut chart
circle0 = plt.Circle( (0,0), 0.6, color = 'white')
circle1 = plt.Circle( (0,0), 0.4, color = 'white')
circle2 = plt.Circle( (0,0), 0.6, color = 'white')

#men women
ax[0].pie([men, women], labels=["men", "women"], colors=["#42A5F5", "#E57373"], autopct='%1.1f%%', pctdistance=0.8, startangle=90)
ax[0].add_patch(circle0)
ax[0].axis('equal')

#gender healthy
mypie, _ = ax[2].pie([men, women], radius=1.3, labels=["men", "women"], colors=["#42A5F5", "#E57373"], startangle=90)
plt.setp( mypie, width=0.3, edgecolor='white')

mypie2, _ = ax[2].pie([ men_ill, men_h, women_h, women_ill], radius = 1.3 - 0.3, labels=perc, labeldistance=0.61,
                      colors = ["#FFB74D", "#9CCC65", "#9CCC65", "#FFB74D"], startangle=90)
plt.setp( mypie2, width=0.4, edgecolor='white')
plt.margins(0,0)

#healthy ill
ax[1].pie([healthy, ill], labels=["healthy", "ill"], colors=["#9CCC65", "#FFB74D"], autopct='%1.1f%%', pctdistance=0.8, startangle=135)
ax[1].add_patch(circle2)
ax[1].axis('equal')  

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
#group into bins the same aged men and women with histogram --> all of them and ill of them

#convert he Age column to int
df_patients["Age"] = pd.to_numeric(df_patients["Age"])

sorted_ages = np.sort(df_patients["Age"].values)
print(sorted_ages)

In [None]:
#calculating all and ill men and women histograms
bins = [i for i in range(100)]
plt.style.use('seaborn-whitegrid')

all_men = np.histogram(df_patients[df_patients["Sex"] == "M"]["Age"].values, bins=bins)[0]
all_women = np.histogram(df_patients[df_patients["Sex"] == "F"]["Age"].values, bins=bins)[0]

ill_men = np.histogram(df_patients[(df_patients["Sex"] == "M") & (df_patients["EncodedPixels"] != ' -1')]["Age"].values, bins=bins)[0]
ill_women = np.histogram(df_patients[(df_patients["Sex"] == "F") & (df_patients["EncodedPixels"] != ' -1')]["Age"].values, bins=bins)[0]

fig, axes = plt.subplots(ncols=2, sharey=True, figsize=(17, 16))

fig.suptitle("4.3 The presence of Pneumothorax at particular ages and genders", fontsize=22, y=0.96)

axes[0].margins(x=0.1, y=0.01)
m1 = axes[0].barh(bins[:-1], all_men, color='#90CAF9')
m2 = axes[0].barh(bins[:-1], ill_men, color='#0D47A1')
axes[0].set_title('Men', fontsize=18, pad=15)
axes[0].invert_xaxis()
axes[0].set(yticks=[i*5 for i in range(20)])
axes[0].tick_params(axis="y", labelsize=14)
axes[0].yaxis.tick_right()
axes[0].xaxis.tick_top()
axes[0].legend((m1[0], m2[0]), ('healthy', 'with Pneumothorax'), loc=2, prop={'size': 16})

locs = axes[0].get_xticks()

axes[1].margins(y=0.01)
w1 = axes[1].barh(bins[:-1], all_women, color='#EF9A9A')
w2 = axes[1].barh(bins[:-1], ill_women, color='#B71C1C')
axes[1].set_title('Women', fontsize=18, pad=15)
axes[1].xaxis.tick_top()
axes[1].set_xticks(locs)
axes[1].legend((w1[0], w2[0]), ('healthy', 'with Pneumothorax'), prop={'size': 17})

plt.show()

In [None]:
import matplotlib.pyplot as plt

Age =[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]

Pneumo_men = [0, 0, 0, 0, 0, 0, 1, 0, 0, 3, 4, 0, 0, 3, 4, 28, 55, 17, 15, 13, 21, 33, 15, 18, 30, 10, 20, 3, 27, 17, 39, 17, 16, 22, 24, 28, 2, 10, 8, 10, 23, 33, 12, 5, 13, 15, 28, 40, 15, 27, 14, 41, 29, 31, 30, 19, 24, 27, 43, 34, 41, 14, 24, 32, 52, 20, 28,  7,  7,  6, 22, 13, 11, 10,  1,  2,  2,  7, 12,  0,  8,  7,  0,  6,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0, 0, 0, 0, 0]

plt.plot(Age, Pneumo_men, color='blue', marker='o')
plt.title('4.4 Presence of Pneumothorax in Men at Increasing Ages',fontsize=25)
plt.xlabel('Age', fontsize=20)
plt.ylabel('Number of Men with Pneumothorax', fontsize=20)
plt.grid(True)
plt.rcParams['figure.figsize'] = [17,16]
plt.show()


In [None]:
import matplotlib.pyplot as plt

Women_Age =[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]

Pneumo_Women = [0,  0,  0,  0,  0,  0,  0,  2,  0,  0,  1,  8,  1, 16,  3,  5,  5,  5,  5,  7, 14,  6,  7, 15, 39,  6,  8,  0,  9,  4, 23, 12, 14,  9, 10, 27, 23,  1,  8, 28, 11, 12, 16, 30, 12, 27, 20, 25, 43, 25, 19, 41, 20, 21, 10, 18, 32, 29, 22, 10, 33, 14, 17, 28, 20, 32, 22, 11, 21, 11, 6, 7, 7, 29, 23,  2,  5,  2,  5,  0,  0, 2,  0,  4,  0,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0]

plt.plot(Age, Pneumo_Women, color='red', marker='o')
plt.title('4.5 Presence of Pneumothorax in Women at Increasing Ages',fontsize=25)
plt.xlabel('Age', fontsize=20)
plt.ylabel('Number of Women with Pneumothorax', fontsize=20)
plt.grid(True)
plt.rcParams['figure.figsize'] = [17,16]
plt.show()

In [None]:
def run_length_decode(rle, height=1024, width=1024, fill_value=1):
    component = np.zeros((height, width), np.float32)
    component = component.reshape(-1)
    rle = np.array([int(s) for s in rle.strip().split(' ')])
    rle = rle.reshape(-1, 2)
    start = 0
    for index, length in rle:
        start = start+index
        end = start+length
        component[start: end] = fill_value
        start = end
    component = component.reshape(width, height).T
    return component

def run_length_encode(component):
    component = component.T.flatten()
    start = np.where(component[1:] > component[:-1])[0]+1
    end = np.where(component[:-1] > component[1:])[0]+1
    length = end-start
    rle = []
    for i in range(len(length)):
        if i == 0:
            rle.extend([start[0], length[0]])
        else:
            rle.extend([start[i]-end[i-1], length[i]])
    rle = ' '.join([str(r) for r in rle])
    return rle

In [2]:
#Start of Model
import os, sys, math, re, gc, random
from time import time, strftime, gmtime
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from IPython.display import display, HTML

import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold

if not os.path.isdir('tpu_segmentation'):
    !git clone -q https://github.com/reyvaz/tpu_segmentation.git
    !pip install -qr tpu_segmentation/requirements.txt >/dev/null
    !wget -q https://raw.githubusercontent.com/reyvaz/pneumothorax_detection/master/pneumothorax_utils.py
from tpu_segmentation import *
from pneumothorax_utils import *

start_notebook = time()
print('Notebook started at: ', current_time_str())
print('Tensorflow version: ', tf.__version__)
tf.get_logger().setLevel('ERROR')

Notebook started at:  09:36 AM
Tensorflow version:  2.8.2


In [None]:
try: tpu
except:
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None
        print('TPU not found')
    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
    else:
        # Default distribution strategy. Works on CPU and single GPU.
        strategy = tf.distribute.get_strategy()

In [None]:
GCS_PATH = 'gs://kds-40bc87f88cf0cbbde64e128ce935b6696430f5ec557d0d961d95815c'
IMAGE_SIZE = [1024, 1024] # original size of the x-ray images
N_CLASSES = 1
N_CHANNELS = 1
N_REPLICAS = strategy.num_replicas_in_sync
classes = ['No Disease', 'Disease']

TFRecs_gcs_path = GCS_PATH + '/tfrecs/'
TFRECS_TRAIN_RLE = tf.io.gfile.glob(TFRecs_gcs_path + '*train-disease*.tfrec')
TFRECS_TRAIN_NORLE = tf.io.gfile.glob(TFRecs_gcs_path + '*train-no-disease*.tfrec')

N_TFRECS_MASK = len(TFRECS_TRAIN_RLE)
N_TFRECS_NOMASK = len(TFRECS_TRAIN_NORLE)

In [None]:
N_FOLDS = 5
skf = KFold(n_splits=N_FOLDS)
folds ={}
for fold,(idxT,idxV) in enumerate(skf.split(np.arange(N_TFRECS_MASK))):
    folds.update({fold+1: {'val': idxV, 'train': idxT}})
del fold

def get_fold_file_lists(fold_num, folds=folds, use_unmasked = False):
    fold = folds[fold_num]
    TRAINING_FILENAMES = [TFRECS_TRAIN_RLE[i] for i in fold['train']]
    VALIDATION_FILENAMES = [TFRECS_TRAIN_RLE[i] for i in fold['val']]

    if use_unmasked:
        TRAINING_FILENAMES += [TFRECS_TRAIN_NORLE[i] for i in fold['train']]
        VALIDATION_FILENAMES += [TFRECS_TRAIN_NORLE[i] for i in fold['val']]
    return TRAINING_FILENAMES, VALIDATION_FILENAMES

masked_examples = count_data_items(TFRECS_TRAIN_RLE)
unmasked_examples = count_data_items(TFRECS_TRAIN_NORLE)
class_ratio = unmasked_examples/masked_examples

print('Number of MASKED examples for training and validation:   ', masked_examples)
print('Number of NON MASKED examples for training and validation:', unmasked_examples)

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

def read_tfrecord(example, vars = ('image', 'rle', 'label')):
    features = {
        'img_id': tf.io.FixedLenFeature([], tf.string), 
        'image': tf.io.FixedLenFeature([], tf.string), 
        'rle': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        }
    features = {k: features[k] for k in vars}
    example = tf.io.parse_single_example(example, features)
    return [example[var] for var in features]
        
def load_dataset(filenames, ordered = False):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    if not ordered: 
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False 
        dataset = dataset.with_options(ignore_order)
    return dataset.map(read_tfrecord, num_parallel_calls=AUTO)

In [None]:
undersample_filter = lambda L, P: False if L == 0 and tf.random.uniform([]) < P else True
# randomly filters-out examples with label 0. L is the label, P is the rate to exclude

def decode_resize_inputs(inputs, target_size, image_size = IMAGE_SIZE,
                         n_channels = N_CHANNELS, n_classes = N_CLASSES):
    (image_data, rle), args = inputs[:2], inputs[2:]
    image = tf.image.decode_jpeg(image_data, channels=n_channels)
    image = tf.cast(image, tf.float32) / 255.0  
    mask = build_mask_array(rle, image_size)

    if target_size != image_size:
        image = tf.image.resize(image, target_size)
        mask = tf.image.resize(mask, target_size)

    image = tf.reshape(image, [*target_size, n_channels]) 
    mask = tf.reshape(mask, [*target_size, n_classes])
    return (image, mask, *args)

def data_augment(inputs, target_size, 
                 n_channels = N_CHANNELS, n_classes = N_CLASSES, 
                 p1=0.50, p2=0.33, p3=0.33, p4=0.75):
    
    (image, mask, label), args = inputs[:3], inputs[3:]

    if tf.random.uniform([]) < p1:
        image, mask = left_right_flip(image, mask)

    if tf.random.uniform([]) < p2:
        image, mask = random_rotate(image, target_size, n_channels, mask, 
                                    n_classes, 7.)
    elif tf.random.uniform([]) < p2:
        image, mask = random_shear(image, target_size, n_channels, mask)

    if tf.random.uniform([]) < p3: 
        image, mask = random_zoom_out_and_pan(image, target_size, mask, n_channels)
    elif tf.random.uniform([]) < p3*1.5: 
        image, mask = image_mask_zoom_in(image, mask, target_size, label, n_channels)
    
    image = tf.image.random_brightness(image, 0.1)
    image = tf.image.random_contrast(image, 0.7, 1.4)

    if tf.random.uniform([]) < p4: 
        image = coarse_dropout(image, target_size, n_channels, 
                               count_range=(20, 150), m_size = 0.01)
    return (image, mask, label, *args)

def final_reshape(inputs, target_size, target_var, make_rgb, augment,
                  n_channels = N_CHANNELS, n_classes = N_CLASSES):
    '''
    Converts image to 3 channels if specified by `make_rgb`
    Applies augmentations that require 3 channels if requested by `augment`
    Returns the image and one of {mask, label}, as specified by `target_var`
    '''
    (image, mask, label), args = inputs[:3], inputs[3:]

    if make_rgb:
        image = tf.image.grayscale_to_rgb(image)
        n_channels = 3
        if augment:
            image = tf.image.random_hue(image, 0.025)
            image = tf.image.random_saturation(image, 0.6, 1.4)

    image = tf.reshape(image, [*target_size, n_channels]) 

    if target_var == 'mask':
        mask = tf.reshape(mask, [*target_size, n_classes])
        target_var = mask
    elif target_var == 'label': target_var = tf.cast(label, tf.int32)
    else: raise Exception('target_var must be one of \'label\' or \'mask\'')
    return image, target_var

describe_ds = lambda x: print(re.sub('[<>]', '', str(x)))

def print_description(train_ds, val_ds, steps_per_epoch, use_unmasked, 
                      p_undersample, n_train, n_valid):
    describe_ds(train_ds)
    print('Steps per epoch: ', steps_per_epoch)
    approx = 'approx' if p_undersample and p_undersample !=1 and use_unmasked else ''
    print('Num train examples {} {}'.format(n_train, approx))
    print('Num valid examples {}'.format(n_valid))
    return None

In [None]:
def get_dataset(filenames, target_size, batch_size, target_var = 'mask',
                make_rgb = True, 
                augment = False,
                cache   = False,
                repeat  = False, 
                shuffle = False,
                ordered = False, 
                drop_remainder = False,
                p_undersample = False):
    dataset = load_dataset(filenames, ordered)
    if p_undersample:
        dataset = dataset.filter(lambda *data: undersample_filter(data[2], p_undersample))
    dataset = dataset.map(lambda *data: decode_resize_inputs(data, target_size), AUTO)
    if augment:
        dataset = dataset.map(lambda *data: data_augment(data, target_size), AUTO)
    dataset = dataset.map(lambda *data: final_reshape(data, target_size, 
                                      target_var, make_rgb, augment), AUTO)
    if cache: dataset = dataset.cache()  
    if repeat: dataset = dataset.repeat() 
    if shuffle: dataset = dataset.shuffle(shuffle, reshuffle_each_iteration=True) 
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_datasets(fold_num, target_size, imgs_per_replica, target_var = 'mask',
             use_unmasked = False, p_undersample = False, cache_val = True, print_descr = True):
    TRAINING_FILENAMES, VALIDATION_FILENAMES = get_fold_file_lists(
        fold_num, use_unmasked=use_unmasked)
    
    train_batch_size = imgs_per_replica * N_REPLICAS

    n_train = count_data_items(TRAINING_FILENAMES)
    n_valid = count_data_items(VALIDATION_FILENAMES)

    if p_undersample and use_unmasked:
        # approx number of examples given random undersampling, val is never undersampled
        n_train = int(n_train * (1-(p_undersample*class_ratio)/(1+class_ratio)))

    train_steps = np.ceil(n_train/train_batch_size).astype(int)
    buffer_size = int(n_train*0.33)

    random.shuffle(TRAINING_FILENAMES)
    train_dataset = get_dataset(TRAINING_FILENAMES, target_size, train_batch_size, 
                                target_var, 
                                augment = True, 
                                repeat = True, 
                                shuffle = buffer_size,
                                drop_remainder=True, 
                                p_undersample = p_undersample)
    
    val_batch_size = 11 * N_REPLICAS # constant validation dataset for dice metric comparison
    # 11 * N_REPLICAS minimizes wasted examples in all folds due to drop_remainder. 

    val_steps = np.ceil(n_valid/val_batch_size).astype(int)
    if target_var == 'mask': val_steps -= 1 # val steps in segmentation due to drop_remainder.

    drop_remainder_val = True if target_var == 'mask' else False

    valid_dataset = get_dataset(VALIDATION_FILENAMES, target_size, val_batch_size, target_var, 
                                cache = cache_val, ordered = True,
                                drop_remainder = drop_remainder_val)
    
    if print_descr:
        print_description(train_dataset, valid_dataset, train_steps, 
                          use_unmasked, p_undersample, n_train, n_valid)
    steps = {'train': train_steps, 'valid': val_steps}
    datasets = {'train': train_dataset, 'valid': valid_dataset}
    return datasets, steps

In [None]:
zip_dest = PROJECT_DIR + 'pneumothorax-fold-1-weights.zip'
if ADD_PRETRAINED: 
    if not os.path.isfile(zip_dest):
        !gdown https://drive.google.com/uc?id=1ptjR8KYSg64CZOvp4vVH3GyTROvjETyk -O {zip_dest}
    !unzip -qn {zip_dest} -d {weights_dir} 

In [None]:
saved_weights = [w.split('/')[-1] for w in tf.io.gfile.glob(weights_dir+'*.h5')]

!mkdir -p {weights_dir}discarded

weights_metrics = {}
weights_names = {}
for w in saved_weights:
    prefix, sz, _ , metric = w.split('_')
    prefix = '{}_{}'.format(prefix, sz)
    metric_float = float(metric.split('.')[0])*10e-6
    if not prefix in weights_metrics:
        weights_metrics[prefix] = metric_float
        weights_names[prefix] = w
    elif metric_float > weights_metrics[prefix]:
        !mv {weights_dir}{weights_names[prefix]} {weights_dir}discarded/
        weights_metrics[prefix] = metric_float
        weights_names[prefix] = w
    else: 
        !mv {weights_dir}{w} {weights_dir}discarded/

def check_and_save(history, model, fold_num, img_size, metric, metric_abbr = 'acc',
                   current_wname = 'weights.h5', weights_dir=weights_dir):
    current_metric = max(history.history[metric])
    size_str = str(img_size[0]) + 'x' + str(img_size[1])
    prefix = '{}-f{}_{}'.format(model.name.lower(), fold_num, size_str)
    metric_str = str(current_metric*10e4)[:5]
    weights_name = '{}_{}_{}.h5'.format(prefix, metric_abbr, metric_str)
    if not prefix in weights_metrics or current_metric >= weights_metrics[prefix]: 
        weights_names[prefix] = weights_name
        if prefix in weights_metrics:
            !rm -r {weights_dir}{prefix}*
        !cp {current_wname} {weights_dir}{weights_name}
        weights_metrics[prefix] = current_metric
        saved_weights.append(weights_name)
    return None

def create_df_row(w):
    # w: (str) is the filename of the saved weights
    base, size, _, score = w.split('_')
    key_id = '{}_{}'.format(base, size)
    score = float(score.split('.')[0])*10e-6
    size = eval(size.replace('x', ', '))
    base, efn_ver, model_type, fold_num = base.split('-')
    metric = 'accuracy' if 'bin' in model_type else 'avg image-wise dice'
    if 'bin' in model_type: model_type = 'binary' 
    elif 'pp' in model_type: model_type = 'unet++' 
    else: model_type = 'unet' 
    backbone, fold_num = base+efn_ver, int(fold_num[1])
    return [key_id, backbone, fold_num, size, score, metric, model_type, w]

def get_best_weights(df, n_per_fold = 5, sort_by_col = 'score', 
                     group_by_col = 'fold', mode = 'max'):
    m = df.groupby(group_by_col)[sort_by_col]
    m = m.nlargest(n_per_fold) if mode == 'max' else m.nsmallest(n_per_fold)
    try: idxs = [i[1] for i in m.index]
    except: idxs = [i for i in m.index]
    return df.loc[idxs].reset_index(drop=True)

In [None]:
EPOCHS = 30
use_unmasked = True
p_undersample = 0.50 # rate to exclude x-rays labeled 0

metric_monitor = performance_monitor('val_accuracy', 'max')
loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=0.25)

lr_params =  [3e-4,  3e-4,  1e-6, 2, 4, 8e-1]
lr_sched = lr_schedule_builder(lr_params)
nadam = tf.keras.optimizers.Nadam() 

In [None]:
folds_to_train = [1] 
dims_to_train = [512] 
efficientnet_versions = [1]

In [None]:
if TRAINING:
    for training_fold in folds_to_train:
        for dim in dims_to_train: 
            target_size = (dim, dim)
            print(hline)
            cache_val, imgs_per_replica = (True, 8) if dim <= 512 else (False, 4)
            datasets, steps = get_datasets(training_fold, target_size, imgs_per_replica , 
                                'label', use_unmasked, p_undersample, cache_val)
            INPUT_SHAPE  = (*target_size, 3)
            for efn_ver in efficientnet_versions:
                base_model = 'EfficientNet-B{}'.format(efn_ver)
                print(hline +'\nTraining {} on FOLD {} with image size {}\n'.format(
                        base_model, training_fold, target_size) + hline)
                
                with strategy.scope():                       
                    model = build_classifier(base_model, N_CLASSES, INPUT_SHAPE, 
                                                name_suffix='-bin')
                    model.compile(optimizer=nadam, loss=loss, metrics=['accuracy', 'AUC']) 

                checkpoint = config_checkpoint(monitor ='val_accuracy', mode = 'max')
                train_begin = time()

                history = model.fit(datasets['train'], steps_per_epoch=steps['train'],
                        epochs = EPOCHS,
                        verbose = 0, 
                        callbacks=[lr_sched, metric_monitor, checkpoint],
                        validation_data = datasets['valid'])
                
                check_and_save(history, model, training_fold, target_size, 'val_accuracy')
                print('Time to train {} epochs: {} (mm:ss)\n'.format(
                    EPOCHS, time_passed(train_begin)))
                
                del model
                K.clear_session()

    del training_fold, dim, target_size, datasets, INPUT_SHAPE, imgs_per_replica, history

In [None]:
test_fn = glob.glob('/content/drive/MyDrive/Colab Notebooks/test_png/*/*.png')
x_test = [cv2.resize(np.array(Image.open(fn)),(img_size,img_size)) for fn in test_fn]
x_test = np.array(x_test)
x_test = np.array([np.repeat(im[...,None],3,2) for im in x_test])
print(x_test.shape)
preds_test = model.predict(x_test,batch_size=batch_size)