In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.abspath(os.path.dirname("__file__")), '..'))
from utils import data_handling, model_development
import matplotlib.pyplot as plt
from ipywidgets import interact
from sklearn.model_selection import train_test_split

In [None]:
##
# Configure the model training policy:
#
model_development.configure_training_policy()

In [None]:
##
# Load and preprocess data:
#
cell_images, cell_labels, cell_types = data_handling.load_data_from_file('../data/labels.csv', True)
print([cell_images.shape, len(cell_labels), len(cell_types)])

@interact(n = (0, cell_images.shape[0] - 1))
def display_data(n = 0):
    fig, axs = plt.subplots()
    axs.imshow(cell_images[n], cmap = 'gray')
    axs.set_title(f'Class: {cell_labels[n]} - Type: {cell_types[n]}')
    return None

In [None]:
##
# Prepare data for model training:
#
X_train_1, Y_train_1 = \
    train_test_split(cell_images, cell_labels, test_size = 0.15, random_state = 1)
X_train_2, Y_train_2 = \
    train_test_split(cell_images, cell_labels, test_size = 0.15, random_state = 2)
X_train_3, Y_train_3 = \
    train_test_split(cell_images, cell_labels, test_size = 0.15, random_state = 3)
X_train_4, Y_train_4 = \
    train_test_split(cell_images, cell_labels, test_size = 0.15, random_state = 4)
X_train_5, Y_train_5 = \
    train_test_split(cell_images, cell_labels, test_size = 0.15, random_state = 5)

In [None]:
##
# Train classification model (only the first phase):
#
optimizer = model_development.create_optimizer('nadam')
inception_resnetv2 = \
    model_development.inception_resnetv2(input_shape = (300, 300, 3), weights = 'imagenet', freeze_convolutional_base = True, display_model_information = False)

history, training_time = \
    model_development.train_classification_model(training_phase = 1, model = inception_resnetv2, 
    optimizer = optimizer, training_metrics = ['accuracy', 'Precision', 'Recall'],
    model_name = 'InceptionResNetv2', version = '1.00.00rt', 
    X = X_train_1, Y = Y_train_1, metric_to_monitor = 'val_accuracy', 
    no_of_epochs = 100, batch_size = 16, validation_split_ratio = 0.15)