<a href="https://colab.research.google.com/github/tcivie/Bone_Marrow_Cells_Classification/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Define constants for the project

In [2]:
import re
from subprocess import Popen, PIPE

import tensorflow as tf
from tensorflow import keras
from keras import layers
import matplotlib.pyplot as plt

import os

In [3]:
CATEGORIES = ["ABE", "ART", "BAS", "BLA", "EBO", "EOS", "FGC", "HAC", "KSC", "LYI", "LYT", "MMZ", "MON", "MYB", "NGB",
              "NGS", "NIF", "OTH", "PEB", "PLM", "PMO"]
DATA_PATH = os.path.join(os.getcwd(), 'BM_cytomorphology_data')
IMAGE_SIZE = (250, 250)
BATCH_SIZE = 32  # How many images to train simultaneously

Define check image script

In [4]:
def checkImage(path):
    proc = Popen(['mogrify', path], stdout=PIPE, stderr=PIPE)
    out, err = proc.communicate()
    exitcode = proc.returncode
    return exitcode, out, err

Create Image filtering script (To filter out any broken images)

In [5]:
def filter_corrupted_images():
    """
    Filter corrupted images
    :return: number of images filtered
    """
    corrupted_images = 0
    total_images = 0
    for folder_name in CATEGORIES:
        for root, subdirs, files in os.walk(os.path.join(DATA_PATH, folder_name)):
            # for file in files:
            code, output, error = checkImage(root + '/*.jpg')
            if str(code) != "0" or str(error, "utf-8") != "":
                for file in re.findall('[A-Z]{3}_.*.jpg', str(error, "utf-8")):
                    print('Removed:' + os.path.join(root, file))
                    corrupted_images += 1
                    # Delete corrupted image
                    os.remove(os.path.join(root, file))
    return corrupted_images, total_images

Define dataset generation function

In [6]:
def generate_dataset():
    """
    Generate dataset
    :return: training dataset and validation dataset
    """
    train_ds, val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        DATA_PATH,
        validation_split=0.2,  # How much data in recent to save for validation (20% in our case)
        subset="both",  # If return only training data or validation (Or both)
        seed=2905,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        label_mode='categorical'
    )
    return train_ds, val_ds

Define data augmentation model


In [7]:
def gen_augmentation_model():
    """
    :return: Defined sequential model that randomly flips the images in horizontally and vertically and rotates the
    image with the factor of 0.1
    """
    return keras.Sequential(
        [
            layers.RandomFlip("horizontal_and_vertical")
        ]
    )

Define data preprocessing function

In [8]:
def data_preprocessing(train_ds):
    """
    Augments the training data
    :param train_ds: The data to augment
    :return: Augmented data
    """
    data_augmentation = gen_augmentation_model()
    augmented_train_ds = train_ds.map(
        lambda img, label: (data_augmentation(img, training=True), label),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    return augmented_train_ds

Define data creation function

In [9]:
def make_model(input_shape):
    inputs = keras.Input(shape=input_shape)
    # Entry block
    x = layers.Rescaling(1.0 / 255)(inputs)  # Rescale the values for the images to go from 0 to 1
    x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(x)
    x = layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(len(CATEGORIES), activation="softmax")(x)
    # outputs = layers.Activation('softmax')(x)
    return keras.Model(inputs, outputs)

Define training function

In [10]:
def train_model(model, train_ds, val_ds, epochs=25):
    callbacks = [
        keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras")
    ]
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),  # (Adaptive Moment Estimation) optimization algorithm
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )
    model.fit(
        train_ds,
        epochs=epochs,
        callbacks=callbacks,
        validation_data=val_ds
    )

Main

In [11]:
corrupted_images, total_images_checked = filter_corrupted_images()
# print("Total corrupted deleted: " + str(corrupted_images) + " Total images scanned: " + str(total_images_checked))
train_ds, val_ds = generate_dataset()
# plot_images(train_ds, gen_augmentation_model())

# Apply data_augmentation to teh training images
train_ds = data_preprocessing(train_ds)

# Prefetching samples in GPU memory helps maximize GPU utilization
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

model = make_model(input_shape=IMAGE_SIZE + (3,))
keras.utils.plot_model(model, show_shapes=True)

train_model(model, train_ds, val_ds)

ValueError: ignored