# Training DLN model on both datasets


15-04-2025


In [15]:
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path

In [16]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.rc("font", size=14)
plt.rc("axes", labelsize=14, titlesize=14)
plt.rc("legend", fontsize=14)
plt.rc("xtick", labelsize=10)
plt.rc("ytick", labelsize=10)

In [17]:
import tensorflow as tf
import keras

print(tf.__version__)
keras.backend.clear_session()
print(tf.config.list_physical_devices("GPU"))

tf.random.set_seed(42)

2.18.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [18]:
from keras.models import Sequential

from keras.layers import (
    RandomRotation,
    RandomTranslation,
    RandomFlip,
    RandomContrast,
    Dense,
    ReLU,
    BatchNormalization,
    GlobalAveragePooling2D,
)
from keras.regularizers import L2

from keras.losses import CategoricalCrossentropy
from keras.applications.inception_v3 import (
    preprocess_input as inceptionV3_preprocess_input,
)
from keras.applications.resnet_v2 import preprocess_input as resnet50v2_preprocess_input
from keras.applications.mobilenet_v2 import (
    preprocess_input as mobilenetv2_preprocess_input,
)
from keras.applications.efficientnet import (
    preprocess_input as efficientnetb0_preprocess_input,
)
from sklearn.metrics import (
    balanced_accuracy_score,
    precision_score,
    f1_score,
    recall_score,
    ConfusionMatrixDisplay,
)

Defining custom functions

In [19]:
def get_confusion_matrix(model_name, y_test, y_pred, class_names):

    _, ax = plt.subplots(figsize=(8, 8))
    cm = ConfusionMatrixDisplay.from_predictions(
        y_test,
        y_pred,
        ax=ax,
        xticks_rotation="vertical",
        colorbar=False,
        normalize="true",
        display_labels=class_names,
    )

    plt.rc("font", size=12)
    ax.set_title(f"Confusion Matrix {model_name}")
    plt.savefig(f"confusion_matrix_{model_name}.png")


def plot_history(model_name, history, metrics):
    sns.lineplot(data=history[metrics[0]], label=metrics[0])
    sns.lineplot(data=history[metrics[1]], label=metrics[1])
    plt.xlabel("epochs")
    plt.ylabel("metric")
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.savefig(
        f"/home/t.afanasyeva/deep_learning_anaemias/output/{model_name}_{metrics}_history.png",
        bbox_inches="tight",
    )
    plt.show()
    plt.close()

## Load data

loading Imagestream

In [20]:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 32
IMG_SIZE = 224


def expand_ds(train_ds):
    """
    Expands a training dataset by applying a series of data augmentation transformations.

    Args:
        train_ds: A TensorFlow dataset containing training data.

    Returns:
        A TensorFlow dataset with augmented data, interleaved with the original dataset.
    """
    data_augmentation_list = [
        Sequential([RandomRotation(factor=0.15)]),
        Sequential([RandomTranslation(height_factor=0.1, width_factor=0.1)]),
        Sequential([RandomFlip()]),
        Sequential([RandomContrast(factor=0.1)]),
    ]

    ds_list = [
        train_ds.map(
            lambda x, y: (data_augmentation(x, training=True), y),
            num_parallel_calls=AUTOTUNE,
        )
        for data_augmentation in data_augmentation_list
    ]
    ds_list.append(train_ds)
    ds = tf.data.Dataset.from_tensor_slices(ds_list)
    train_ds = ds.interleave(
        lambda x: x,
        cycle_length=1,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    return train_ds

In [21]:
def add_source_channel(image_tensor, source_id):
    """
    Adds a channel to the image that encodes the source dataset

    Args:
        image_tensor: A tensor of shape [H, W, C]
        source_id: The dataset identifier value to fill the new channel with

    Returns:
        A tensor with an additional channel containing the source_id value
    """

    # Assuming [H, W, C] format (TensorFlow standard)
    H, W, C = image_tensor.shape

    source_channel = tf.ones((H, W, 1), dtype=image_tensor.dtype) * source_id
    augmented_tensor = tf.concat([image_tensor, source_channel], axis=2)

    return augmented_tensor


# Example in a dataset pipeline
def add_source_to_dataset(image, label, source_id):
    image_with_source = add_source_channel(image, source_id)
    return image_with_source, label

In [22]:
path_in = Path.cwd().parent / "resources/imagestream"

train_ds_im, test_ds_im = keras.utils.image_dataset_from_directory(
    path_in,
    labels="inferred",
    label_mode="categorical",
    class_names=[
        "discocyte",
        "holly_leaf",
        "granular",
        "sickle",
        "echinocyte",
    ],
    color_mode="grayscale",
    batch_size=None,
    image_size=(IMG_SIZE, IMG_SIZE),
    shuffle=True,
    seed=93,
    validation_split=0.2,
    subset="both",
    data_format="channels_last",
    verbose=True,
)
class_names = test_ds_im.class_names

train_ds_im, test_ds_im = [
    ds.map(lambda x, y: (tf.image.grayscale_to_rgb(x), y), num_parallel_calls=AUTOTUNE)
    for ds in (train_ds_im, test_ds_im)
]
train_ds_im = expand_ds(train_ds_im)
train_ds_im = train_ds_im.map(
    lambda x, y: (add_source_to_dataset(x, y, 1)),
    num_parallel_calls=AUTOTUNE,
)

Found 18237 files belonging to 5 classes.
Using 14590 files for training.
Using 3647 files for validation.


In [23]:
path_in_cp = Path.cwd().parent / "resources/cytpix/augmented"

train_ds_cp, test_ds_cp = keras.utils.image_dataset_from_directory(
    path_in_cp,
    labels="inferred",
    label_mode="categorical",
    class_names=[
        "discocyte",
        "holly_leaf",
        "granular",
        "sickle",
        "echinocyte",
    ],
    color_mode="grayscale",
    batch_size=None,
    image_size=(IMG_SIZE, IMG_SIZE),
    shuffle=True,
    seed=93,
    validation_split=0.2,
    subset="both",
    data_format="channels_last",
    verbose=True,
)

train_ds_cp, test_ds_cp = [
    ds.map(lambda x, y: (tf.image.grayscale_to_rgb(x), y), num_parallel_calls=AUTOTUNE)
    for ds in (train_ds_cp, test_ds_cp)
]
train_ds_cp = expand_ds(train_ds_cp)
train_ds_cp = train_ds_cp.map(
    lambda x, y: (add_source_to_dataset(x, y, 1)),
    num_parallel_calls=AUTOTUNE,
)
test_ds_cp = test_ds_cp.map(
    lambda x, y: (add_source_to_dataset(x, y, 1)),
    num_parallel_calls=AUTOTUNE,
)

Found 15000 files belonging to 5 classes.
Using 12000 files for training.
Using 3000 files for validation.


In [24]:
train_ds = tf.data.Dataset.sample_from_datasets(
    [train_ds_im, train_ds_cp], weights=[0.5, 0.5], seed=42
).shuffle(buffer_size=25000)

test_ds = tf.data.Dataset.sample_from_datasets(
    [test_ds_im, test_ds_cp], weights=[0.5, 0.5]
).shuffle(buffer_size=25000)

## Set up models to compare

In [11]:
earlystopper = keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=3, verbose=3, mode="min", restore_best_weights=True
)


def learning_rate_schedule(epoch, lr):
    if epoch < 5:
        return (lr * tf.math.exp(0.5)).numpy()
    if epoch < 15:
        return lr
    else:
        return (lr * tf.math.exp(-0.1)).numpy()


lr_scheduler = keras.callbacks.LearningRateScheduler(learning_rate_schedule)
optimizer = keras.optimizers.Adam(learning_rate=0.01)

In [12]:
EPOCHS = 200

In [13]:
preprocess_input_dict = {
    "ResNet50V2": resnet50v2_preprocess_input,
    "MobileNetV2": mobilenetv2_preprocess_input,
    "EfficientNetB0": efficientnetb0_preprocess_input,
    "InceptionV3": inceptionV3_preprocess_input,
}
models_dict = {
    "ResNet50V2": keras.applications.ResNet50V2,
    "MobileNetV2": keras.applications.MobileNetV2,
    "EfficientNetB0": keras.applications.EfficientNetB0,
    "InceptionV3": keras.applications.InceptionV3,
}

results = {}
history_dict = {}

path_out = Path.cwd().parent / f"output/{datetime.now().strftime('%y%m%d')}_output"
path_out.mkdir(parents=True, exist_ok=True)

In [None]:
for model_name, model_class in models_dict.items():
    print(f"Training {model_name}...")

    base_model = model_class(
        include_top=False,
        weights=None,
        input_tensor=None,
        input_shape=(IMG_SIZE, IMG_SIZE, 4),
        pooling="None",
        classes=5,
        classifier_activation="softmax",
    )
    # print(base_model.summary())
    base_model.trainable = True

    model = Sequential()
    model.add(base_model)

    model.add(Dense(base_model.output_shape[-1], kernel_regularizer=L2(0.01)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dense((base_model.output_shape[-1] // 2), kernel_regularizer=L2(0.01)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dense(124, kernel_regularizer=L2(0.01)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(GlobalAveragePooling2D())
    model.add(Dense(5, activation="softmax"))

    preprocess_input = preprocess_input_dict[model_name]

    train_ds_processed = (
        train_ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(AUTOTUNE)
    )
    test_ds_processed = (
        test_ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(AUTOTUNE)
    )

    print(f"Preprocessed {model_name} data")

    with tf.device("GPU:0"):
        model.compile(
            optimizer=optimizer,
            loss=CategoricalCrossentropy(from_logits=False),
            metrics=["accuracy"],
        )
        print(f"Compiled {model_name} model")

        history = model.fit(
            train_ds_processed,
            validation_data=test_ds_processed,
            callbacks=[earlystopper, lr_scheduler],
            epochs=EPOCHS,
            validation_freq=1,
        )
        model.save(path_out / f"{model_name}.keras")

    y_test = tf.concat([y for _, y in test_ds_processed], axis=0)
    y_test = np.argmax(y_test, axis=1)
    y_pred = model.predict(test_ds_processed)
    y_pred = y_pred.argmax(axis=1)

    accuracy = balanced_accuracy_score(y_test, y_pred)
    f1_score_model = f1_score(y_test, y_pred, average="weighted")
    precision = precision_score(y_test, y_pred, average="weighted")
    recall = recall_score(y_test, y_pred, average="weighted")

    scores = {
        "test_balanced_accuracy": accuracy,
        "test_f1_weighted": f1_score_model,
        "test_precision_weighted": precision,
        "test_recall_weighted": recall,
    }

    results[model_name] = {"scores": scores}
    history_dict[model_name] = {"history": history.history}
    get_confusion_matrix(model_name, y_test, y_pred, class_names)

In [None]:
# Convert results to DataFrame
results_df = pd.DataFrame({k: v["scores"] for k, v in results.items()}).T
results_df.to_csv(path_out/"models_results.csv", index=True))

In [None]:
for model_name, _ in models_dict.items():
    history = history_dict[model_name]["history"]
    history["val_loss"] = [val for val in history["val_loss"] for _ in range(2)]
    history["val_accuracy"] = [val for val in history["val_accuracy"] for _ in range(2)]
    plot_history(model_name, history, ["loss", "val_loss"])
    plot_history(model_name, history, ["accuracy", "val_accuracy"])

In [None]:
# for image, label in train_ds_im.take(1):
#     print("Image shape:", image.shape)
#     print("Label shape:", label.shape)

# image, label = next(iter(train_ds_im))
# plt.imshow(image.numpy().astype("uint8"))

# def add_source_channel(image_tensor, source_id):
#     """
#     Adds a channel to the image that encodes the source dataset

#     Args:
#         image_tensor: A tensor of shape [H, W, C]
#         source_id: The dataset identifier value to fill the new channel with

#     Returns:
#         A tensor withe bottom row black or white  based on source_id
#         """
#     # Assuming [H, W, C] format (TensorFlow standard)
#     H, W, C = image_tensor.shape

#     bottom_row_color_by_source_id = tf.zeros((1, W, C), dtype=image_tensor.dtype) if source_id == 1 else tf.ones((1, W, C), dtype=image_tensor.dtype)
#     augmented_tensor = tf.concat([image_tensor[:-1, :, :], bottom_row_color_by_source_id], axis=0)

#     return augmented_tensor

# aug_image = add_source_channel(image, 0)
# plt.imshow(aug_image.numpy().astype("uint8"))