In [1]:
!pip install gdown

In [2]:
!gdown --id 1C6mFAW5gwZtVgM8IQOj3S4dqJlIJN2Xu

In [4]:
!unzip covid19_classification_dataset.zip

In [3]:
import os
import numpy as np
from glob import glob
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow.keras.applications.resnet_v2 import preprocess_input, ResNet50V2

In [8]:
BATCH_SIZE = 128
IMG_HEGIHT  = 256
IMG_WIDHT   = 256
IMG_CHANNEL = 3
BUFFER_SIZE = BATCH_SIZE*5
MODEL_PATH = "./covid19_classification_dataset"
TRAIN_PATH = "./covid19_classification_dataset/train"
VAL_PATH = "./covid19_classification_dataset/val"
CLASS_DICT = {
    'COVID':1,
    'NORMAL':0
}
rot_layer = tf.keras.layers.experimental.preprocessing.RandomRotation(0.1)

In [9]:
def load(image_path, label):
    
    image   = tf.io.read_file(image_path)
    image   = tf.image.decode_png(image, channels=IMG_CHANNEL)
    
    image  = tf.image.resize(image, (IMG_HEGIHT, IMG_WIDHT))
    
    # Convert both images to float32 tensors
    image  = tf.cast(image, tf.float32)
    
    return image, label

In [10]:
test_image, label = load(os.path.join(TRAIN_PATH, "COVID/COVID-1021.png"), CLASS_DICT['COVID'])
print(test_image.shape)


# casting to int for matplotlib to show the image
plt.figure()
plt.title(label)
plt.imshow(test_image/255.0)

In [11]:
@tf.function()
def random_flip(image):
#     if tf.random.uniform(()) > 0.5:
#         # Random mirroring
#         image  = tf.image.flip_left_right(image)

#     image = tf.keras.preprocessing.image.random_rotation(image, rg=(-20, 20))
#     image = rot_layer(image)
    return image


def processing_image(image):    
    image  =  preprocess_input(image) 
    return image


def load_image_train(image, label):
    image, label = load(image, label)
    image = random_flip(image)
    image = processing_image(image)

    return image, label


def load_image_val(image, label):
    image, label = load(image, label)
    image = processing_image(image)

    return image, label

In [12]:
def get_img_path_and_labels(path, num_pos, num_neg=-1):
    pos_paths = np.array(glob(path + f'/{list(CLASS_DICT.keys())[0]}/*.png'))
    neg_paths = np.array(glob(path + f'/{list(CLASS_DICT.keys())[1]}/*.png'))
    
    assert (num_pos >= -1), f'Incorrect num_pos {num_pos}'
    assert (num_neg >= -1), f'Incorrect num_neg {num_neg}'
    if num_pos != -1:
        pos_paths = pos_paths[:num_pos]
    if num_neg != -1:
        neg_paths = neg_paths[:num_neg]
    image_paths = np.append(pos_paths, neg_paths)
    labels = np.array([CLASS_DICT[os.path.basename(os.path.dirname(im_p))] for im_p in image_paths])    
        
    print("TOTAL SAMPLES: ", len(image_paths))
    print('POSSITVE CLASS: ', sum(np.array(labels)==1))
    print('NEGATIVE CLASS: ', sum(np.array(labels)==0))
    idx = np.arange(len(labels))
    np.random.shuffle(idx)
    return image_paths[idx], labels[idx]

In [38]:
train_imgs_path = get_img_path_and_labels(TRAIN_PATH, num_pos=-1)
train_dataset = tf.data.Dataset.from_tensor_slices((train_imgs_path[0], train_imgs_path[1]))
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [39]:
val_imgs_path = get_img_path_and_labels(VAL_PATH, num_pos=-1)
val_dataset = tf.data.Dataset.from_tensor_slices((val_imgs_path[0],val_imgs_path[1]))
val_dataset = val_dataset.map(load_image_val)
val_dataset = val_dataset.batch(BATCH_SIZE)

In [57]:
# list out keys and values separately
key_list = list(CLASS_DICT.keys())
val_list = list(CLASS_DICT.values())

for ims, las in train_dataset.take(2):
    plt.figure(figsize=(15,20))
    for i in range(9):
        plt.subplot(3, 3, i+1)
        position = val_list.index(int(las[i]))
        plt.title(key_list[position])
        plt.imshow(ims[i])
        plt.axis('off')
    plt.savefig('./input.jpg')
    plt.show() 

In [72]:
import seaborn as sns
'''check the count of the various disease types'''
sns.countplot(train_imgs_path[1])
plt.title('Covid19 X-ray Datasets')
plt.grid()
plt.savefig('./imbalance.jpg')
plt.show()

In [20]:
def build_model(input_shape, final_act=None, weights=None):
    # Create the base model from the pre-trained model MobileNet V2
    base_model = ResNet50V2(input_shape=input_shape,
                                         include_top=False,
                                          weights=weights)
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    if final_act is not None:
        prediction_layer = tf.keras.layers.Dense(1, activation=final_act)
    else:
        prediction_layer = tf.keras.layers.Dense(1)
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs)
    x = global_average_layer(x)
    outputs = prediction_layer(x)
    model = tf.keras.Model(inputs, outputs)
    
    return model

In [21]:
pret_model = build_model(input_shape= (IMG_HEGIHT, IMG_WIDHT, IMG_CHANNEL), weights='imagenet', final_act='sigmoid')

In [22]:
METRICS = [
      tf.keras.metrics.TruePositives(name='tp'),
      tf.keras.metrics.FalsePositives(name='fp'),
      tf.keras.metrics.TrueNegatives(name='tn'),
      tf.keras.metrics.FalseNegatives(name='fn'), 
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

In [23]:
base_learning_rate = 0.0001
loss_func = tfa.losses.SigmoidFocalCrossEntropy()

pret_model.compile(loss=loss_func,
              optimizer = tf.keras.optimizers.Adam(base_learning_rate),
              metrics=METRICS)

In [24]:
pret_model_history = pret_model.fit(train_dataset,
                         epochs=10,
                         validation_data=val_dataset)

In [25]:
pret_model.save(os.path.join(MODEL_PATH, 'pret_model.h5'))

In [55]:
fig, axs = plt.subplots(nrows = 1, ncols =2 , figsize = (12,6))
axs[0].plot(range(len(pret_model_history.history['accuracy'])), pret_model_history.history['accuracy'], linewidth = 5, label = "training")
axs[0].plot(range(len(pret_model_history.history['accuracy'])), pret_model_history.history['val_accuracy'], linewidth = 5, label = "validation")
axs[1].plot(range(len(pret_model_history.history['loss'])), pret_model_history.history['loss'], linewidth = 5, label = "training")
axs[1].plot(range(len(pret_model_history.history['loss'])), pret_model_history.history['val_loss'], linewidth = 5, label = "validation")
axs[0].set_xlabel("epochs")
axs[0].set_xlabel("epochs")
axs[0].set_ylabel("accuracy")
axs[1].set_ylabel("loss")
plt.legend()    
plt.savefig('./accuracy_loss.jpg')

In [56]:
fig, axs = plt.subplots(nrows = 1, ncols =2 , figsize = (12,6))
axs[0].plot(range(len(pret_model_history.history['precision'])), pret_model_history.history['precision'], linewidth = 5, label = "training")
axs[0].plot(range(len(pret_model_history.history['precision'])), pret_model_history.history['val_precision'], linewidth = 5, label = "validation")
axs[1].plot(range(len(pret_model_history.history['recall'])), pret_model_history.history['recall'], linewidth = 5, label = "training")
axs[1].plot(range(len(pret_model_history.history['recall'])), pret_model_history.history['val_recall'], linewidth = 5, label = "validation")
axs[0].set_xlabel("epochs")
axs[1].set_xlabel("epochs")
axs[0].set_ylabel("precision")
axs[1].set_ylabel("recall")
plt.legend()
plt.savefig('./precision_recall.jpg')

In [29]:
def get_correct_results(model, dataset):
    total_correct_pos = 0
    total_correct_neg = 0
    total_correct = 0

    for x,y in dataset:    
        y = tf.dtypes.cast(y, tf.int32)
        output = model.predict(x)

        prediction = output > 0.5   
        prediction = tf.dtypes.cast(prediction, tf.int32)

        correct_pre = prediction[:,0] == y
        correct_pre = tf.dtypes.cast(correct_pre, tf.int32)

        # accumulate
        total_correct = total_correct + tf.math.reduce_sum(correct_pre)
        total_correct_pos = total_correct_pos + tf.math.reduce_sum(correct_pre[y==1])
        total_correct_neg = total_correct_neg + tf.math.reduce_sum(correct_pre[y==0])

    print('total_correct: ', total_correct)
    print('total_correct_pos: ', total_correct_pos)
    print('total_correct_neg: ', total_correct_neg)

In [30]:
get_correct_results(pret_model, val_dataset)