<img align="left" src = https://project.lsst.org/sites/default/files/Rubin-O-Logo_0.png width=250 style="padding: 10px"> 
<br><b>AI0: Introduction to AI-based Image Classification with Tensorflow</b> <br>
Contact author: Brian Nord <br>
Last verified to run: YYYY-MM-DD <br>
LSST Science Pipelines version: ?? <br>
Container size: medium <br>
Targeted learning level: beginner <br>

**Description:** An introduction to the classification of images with AI-based classification algorithms.

**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task.

**LSST Data Products:** None; MNIST data

**Packages:** numpy, matplotlib, sklearn, tensorflow

**Credits and Acknowledgments:** None

**Get Support:**
Find DP0-related documentation and resources at <a href="https://dp0.lsst.io">dp0.lsst.io</a>. Questions are welcome as new topics in the <a href="https://community.lsst.org/c/support/dp0">Support - Data Preview 0 Category</a> of the Rubin Community Forum. Rubin staff will respond to all questions posted there.

## 1. Introduction

This Jupyter Notebook introduces artificial intelligence (AI)-based image classification. It demonstrates how to perform a few key steps:
1. examine and prepare data for classification;
2. train an AI algorithm;
3. plot diagnostics of the training performance;
4. initially assess those diagnostics. 

AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. 

This notebook uses `tensorflow`, one of the two most commonly used `python` libraries for deep learning. `Tensorflow` is often easier to use because of how it handles data sets and the logic used for model building. However, it is typically also difficult to develop network models creatively. We use `tensorflow` first in this series of tutorials so that users who are new to deep learning can focus on learning AI. In later tutorials, we will use `pytorch` because it is more flexible and more commonly used in science applications. 

This notebook uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database).  In a future notebook, we will we'll use stars and galaxies drawn from DP0 data.

The use of data in this notebook requires a medium-sized ram allocation (8Gi).

The end of this notebook contains a Glossary of Terms and a comment regarding usage of terms in AI contexts.

In [None]:
%reload_ext pycodestyle_magic
%flake8_on
import logging
logging.getLogger("flake8").setLevel(logging.FATAL)

### 1.1. Import Packages

[`numpy`](https://numpy.org/) is a widely used Python library for computations and mathematical operations on multi-dimensional arrays.

[`matplotlib`](https://matplotlib.org/) is a widely used Python plot library. 

[`tensorflow`](https://www.tensorflow.org) is a widely used library from Google for fast tensor operations --- often used for building neural network models. 

[`sklearn`](https://scikit-learn.org/stable/) is a library for machine learning.

In [None]:
import numpy as np
import os
import datetime

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm

import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.metrics import roc_curve, roc_auc_score, auc, RocCurveDisplay
from sklearn.preprocessing import LabelBinarizer

### 1.2 Define Functions

In [None]:
def normalizeInputs(x_temp, input_minimum, input_maximum):
    """Normalize a datum that is an input to the neural network

    Parameters
    ----------
    x_temp: `numpy.array`
       image data
    input_minimum: `float`
       minimum value for normalization
    input_maximum: `float`
       maximum value for normalization

    Returns
    -------
    x_temp_norm: `numpy.array`
       normalized image data
    """
    x_temp_norm = (x_temp - input_minimum)/input_maximum
    return x_temp_norm

In [None]:
def createFileUidTimestamp():
    """Create a timestamp for a filename.

    Parameters
    ----------
    None

    Returns
    -------
    file_uid_timestamp : `string`
       String from date and time.
    """
    file_uid_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    return file_uid_timestamp


In [None]:
def createFileName(file_prefix="", file_location="Data/Sandbox/",
                   file_suffix="", useuid=True, verbose=True):
    """Create a file name.

    Parameters
    ----------
    file_prefix: `string`
       prefix of file name
    file_location: `string`
       path to file
    file_suffix: `string`
       suffix/extension of file name
    useuid: 'bool'
       choose to use a unique id
    verbose: 'bool'
       choose to print the file name

    Returns
    -------
    file_final: `string`
        filename used for saving
    """
    if useuid:
        file_uid = createFileUidTimestamp()
    else:
        file_uid = ""

    file_final = file_location + file_prefix + "_" + file_uid + file_suffix

    if verbose:
        print(file_final)

    return file_final


In [None]:
def plotArrayImageExamples(x_tra, y_tra, num=10,
                           save_file=False,
                           file_prefix="prediction_histogram",
                           file_location="./",
                           file_suffix=".png"):
    """Plot an array of examples of images and labels

    Parameters
    ----------
    x_tra: `numpy.ndarray`
       training data images
    y_tra: `numpy.ndarray`
       training data labels
    num: `int`, optional
       number examples to plot
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    num_row = 2
    num_col = 5
    images = x_tra[:num]
    labels = y_tra[:num]

    fig, axes = plt.subplots(num_row, num_col,
                             figsize=(1.5*num_col, 2*num_row))
    for i in range(num):
        ax = axes[i//num_col, i%num_col]
        ax.imshow(images[i], cmap='gray')
        ax.set_title('Label: {}'.format(labels[i]))

    plt.tight_layout()

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)
        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotROCMulticlassOnevsrest(y_tra, y_tes, y_pred_tes, label_target_list,
                               color_list,
                               save_file=False,
                               file_prefix="prediction_histogram",
                               file_location="./",
                               file_suffix=".png"):
    """Plot Receiver Operator Curve for one-vs-rest scenario

    Parameters
    ----------
    y_tra: `numpy.ndarray`
       training data images
    y_tes: `numpy.ndarray`
       test data images
    y_pred_tes: `numpy.ndarray`
       test data predicted labels
    label_target_list: 'list'
    color_list: 'list'

    Returns
    -------
    file_final: `string`
    """
    fig, ax = plt.subplots(figsize=(6, 6))

    for label_target, color in zip(label_target_list, color_list):

        label_binarizer = LabelBinarizer().fit(y_tra)
        y_onehot_tes = label_binarizer.transform(y_tes)

        class_id = np.flatnonzero(label_binarizer.classes_ == label_target)[0]

        display = RocCurveDisplay.from_predictions(
            y_onehot_tes[:, class_id],
            y_pred_tes[:, class_id],
            name=f"{label_target} vs the rest",
            color=color,
            ax=ax,
            plot_chance_level=(class_id == 0)
        )

    _ = display.ax_.set(
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
        title="ROC: One-vs-Rest",
    )

    if save_file:
        createFileName(file_prefix=file_prefix,
                       file_location=file_location,
                       file_suffix=file_suffix,
                       useuid=True)

In [None]:
def plotROCMulticlassOnevsone(y_tra, y_tes, y_pred_tes, label_target_list,
                              color_list, save_file=False,
                              file_prefix="prediction_histogram",
                              file_location="./",
                              file_suffix=".png"):
    """Plot Receiver Operator Curve for one-vs-one scenario

    Parameters
    ----------
    y_tra: `numpy.ndarray`
       training data true labels
    y_tes: `numpy.ndarray`
       test data true labels
    y_pred_tes: `numpy.ndarray`
       test data predicted labels
    label_target_list: 'list'
    color_list: 'list'
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    fig, ax = plt.subplots(figsize=(6, 6))

    for label_target, color in zip(label_target_list, color_list):

        label_binarizer = LabelBinarizer().fit(y_tra)
        y_onehot_tes = label_binarizer.transform(y_tes)

        class_id = np.flatnonzero(label_binarizer.classes_ == label_target)[0]

        display = RocCurveDisplay.from_predictions(
            y_onehot_tes[:, class_id],
            y_pred_tes[:, class_id],
            name=f"{label_target} vs the rest",
            color=color,
            ax=ax,
            plot_chance_level=(class_id == 0)
        )

    _ = display.ax_.set(
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
        title="ROC: One-vs-Rest",
    )

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

In [None]:
def plotArrayHistogramExamples(x_tra, y_tra, num=10,
                               save_file=False,
                               file_prefix="prediction_histogram",
                               file_location="./",
                               file_suffix=".png"):
    """Plot histograms of image pixel values

    Parameters
    ----------
    x_tra: `numpy.ndarray`
       training image data
    y_tra: `numpy.ndarray`
       training label data
    num: `int`, optional
       number of examples to show
    file_prefix: 'string', optional
       prefix of file name
    file_location: 'string', optional
       path to file
    file_suffix: 'string', optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    n_bins = 10
    num = 10
    num_row = 2
    num_col = 5
    images = x_tra[:num]
    labels = y_tra[:num]

    fig, axes = plt.subplots(num_row, num_col,
                             figsize=(1.5*num_col, 2*num_row))

    for i in range(num):
        ax = axes[i//num_col, i%num_col]
        ax.hist(images[i], bins=n_bins)
        ax.set_title('Label: {}'.format(labels[i]))

    plt.tight_layout()

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotPredictionHistogram(y_prediction_a, y_prediction_b=None,
                            y_prediction_c=None, n_classes=None,
                            n_objects_a=None, n_colors=None,
                            title_a=None, title_b=None,
                            title_c=None, label_a=None,
                            label_b=None, label_c=None,
                            alpha=0.5, figsize=(12, 5),
                            save_file=False,
                            file_prefix="prediction_histogram",
                            file_location="./",
                            file_suffix=".png"):
    """Plot histogram of predicted labels

    Parameters
    ----------
    y_prediction_a: `numpy.ndarray`
    y_prediction_b: `numpy.ndarray`, optional
    y_prediction_c: `numpy.ndarray`, optional
    n_classes: `int`, optional
    n_objects_a: `int`, optional
    n_colors: `int`, optional
    title_a: `string`, optional
    title_b: `string`, optional
    title_c: `string`, optional
    label_a: `string`, optional
    label_b: `string`, optional
    label_c: `string`, optional
    alpha: `float`, optional
       transparency
    figsize: `tuple`, optional
       figure size
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    ndim = y_prediction_a.ndim

    if ndim == 2:
        fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)
        fig.subplots_adjust(wspace=0.35)
    elif ndim == 1:
        fig, ax = plt.subplots(figsize=figsize)

    shape_a = np.shape(y_prediction_a)

    if n_objects_a is None:
        n_objects_a = shape_a[0]

    if ndim == 2:
        if n_classes == None:
            n_classes = shape_a[1]
        if n_colors is None:
            n_colors = n_classes
    elif ndim == 1:
        if n_colors is None:
            n_colors = 1

    if ndim == 2:
        colors = cm.Purples(np.linspace(0, 1, n_colors))
        xlabel = "Probability for Each Class"

        axa.set_ylim(0, n_objects_a)
        axa.set_xlabel(xlabel)
        axa.set_title(title_a)

        for i in np.arange(n_classes):
            axa.hist(y_prediction_a[:, i], alpha=alpha,
                     color=colors[i], label="'" + str(i) + "'")

        if y_prediction_b is not None:
            shape_b = np.shape(y_prediction_b)
            axb.set_ylim(0, shape_b[0])
            axb.set_xlabel(xlabel)
            axb.set_title(title_b)

            for i in np.arange(n_classes):
                axb.hist(y_prediction_b[:, i], alpha=alpha,
                         color=colors[i], label="'" + str(i) + "'")

        if y_prediction_c is not None:
            shape_c = np.shape(y_prediction_c)
            axc.set_ylim(0, shape_c[0])
            axc.set_xlabel(xlabel)
            axc.set_title(title_c)

            for i in np.arange(n_classes):
                axc.hist(y_prediction_c[:, i], alpha=alpha,
                         color=colors[i], label="'" + str(i) + "'")

    elif ndim == 1:
        ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='purple',
                             label=label_a)
        y_max_list = [max(ya)]

        if y_prediction_b is not None:
            yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='blue',
                                 label=label_b)
            y_max_list.append(max(yb))

        if y_prediction_c is not None:
            yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='green',
                                 label=label_c)
            y_max_list.append(max(yc))

        plt.ylim(0, np.max(y_max_list)*1.1)
        plt.xlabel("Top Choice-Class")

    plt.legend(loc='upper right')

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)
        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotLossHistory(history, figsize=(8, 5),
                    save_file=False,
                    file_prefix="prediction_histogram",
                    file_location="./",
                    file_suffix=".png"):
    """Plot loss history of the model as function of epoch

    Parameters
    ----------
    history: `keras.src.callbacks.history.History`
       keras callback history object containing the losses at each epoch
    figsize: `tuple`, optional
       figure size
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)

    loss_tra = np.array(history.history['loss'])
    loss_val = np.array(history.history['val_loss'])
    loss_dif = loss_val - loss_tra

    ax1.plot(loss_tra, label='Training')
    ax1.plot(loss_val, label='Validation')
    ax1.legend()

    ax2.plot(loss_dif, color='red', label='residual')
    ax2.axhline(y=0, color='grey', linestyle='dashed', label='zero bias')
    ax2.sharex(ax1)
    ax2.legend()

    ax1.set_title('Loss History')
    ax1.set_ylabel('Loss')
    ax2.set_ylabel('Loss Residual')
    ax2.set_xlabel('Epoch')
    plt.tight_layout()

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotConfusionMatrix(cm_tra, cm_val, cm_tes, save_file=False,
                        file_prefix="prediction_histogram",
                        file_location="./",
                        file_suffix=".png"):
    """Plot the confusion matrix of predictions.

    Parameters
    ----------
    confusion_matrix_tra: `numpy.ndarray`
       confusion matrix for the training data
    confusion_matrix_val: `numpy.ndarray`
       confusion matrix for the validation data
    confusion_matrix_tes: `numpy.ndarray`
       confusion matrix for the test data
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """

    cm_display_tra = ConfusionMatrixDisplay(confusion_matrix=cm_tra)
    cm_display_val = ConfusionMatrixDisplay(confusion_matrix=cm_val)
    cm_display_tes = ConfusionMatrixDisplay(confusion_matrix=cm_tes)

    fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=(22, 5))

    cm_display_tra.plot(ax=axa)
    cm_display_val.plot(ax=axb)
    cm_display_tes.plot(ax=axc)

    axa.set_title("Training")
    axb.set_title("Validation")
    axc.set_title("Testing")

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotArrayImageConfusion(x_tra, y_tra, y_pred_tra_topchoice,
                            title_main=None, num=10,
                            save_file=False,
                            file_prefix="prediction_histogram",
                            file_location="./",
                            file_suffix=".png"):
    """Plot images of examples objects that are misclassified.

    Parameters
    ----------
    x_tra: `numpy.ndarray`
       training image data
    y_tra: `numpy.ndarray`
       training label data
    y_pred_tra_topchoice: `numpy.ndarray`
       top choice of the predicted labels
    title_main: `string`, optional
       title for the plot
    num: `int`, optional
       number of examples
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    num_row = 2
    num_col = 5
    images = x_tra[:num]
    labels_true = y_tra[:num]
    labels_pred = y_pred_tra_topchoice[:num]

    fig, axes = plt.subplots(num_row, num_col,
                             figsize=(1.5*num_col, 2*num_row))

    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('cornflowerblue')

    for i in range(num):
        ax = axes[i//num_col, i%num_col]
        ax.imshow(images[i], cmap='gray')
        ax.set_title(r'True: {}'.format(labels_true[i]) + '\n'
                     + 'Pred: {}'.format(labels_pred[i]))

    fig.suptitle(title_main)
    plt.tight_layout()

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

In [None]:
def plotArrayHistogramConfusion(x_tra, y_tra, y_pred_tra_topchoice,
                                title_main=None, num=10,
                                save_file=False,
                                file_prefix="prediction_histogram",
                                file_location="./",
                                file_suffix=".png"):
    """Plot histograms of pixel values for images that are misclassified.

    Parameters
    ----------
    x_tra: `numpy.ndarray`
       training image data
    y_tra: `numpy.ndarray`
       training label data
    y_pred_tra_topchoice: `numpy.ndarray`
       top choice of the predicted labels
    title_main: `string`, optional
       title of plot
    num: `int`, optional
       number of examples
    file_prefix: `string`, optional
       prefix of file name
    file_location: `string`, optional
       path to file
    file_suffix: `string`, optional
       suffix/extension of file name

    Returns
    -------
    None
    """
    n_bins = 10
    num_row = 2
    num_col = 5
    images = x_tra[:num]
    labels_true = y_tra[:num]
    labels_pred = y_pred_tra_topchoice[:num]

    fig, axes = plt.subplots(num_row, num_col,
                             figsize=(1.5*num_col, 2*num_row))

    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('cornflowerblue')

    for i in range(num):
        ax = axes[i//num_col, i%num_col]
        ax.hist(images[i], bins=n_bins)
        ax.set_title(r'True: {}'.format(labels_true[i]) + '\n'
                     + 'Pred: {}'.format(labels_pred[i]))

    fig.suptitle(title_main)
    plt.tight_layout()

    if save_file:
        file_final = createFileName(file_prefix=file_prefix,
                                    file_location=file_location,
                                    file_suffix=file_suffix,
                                    useuid=True)

        plt.savefig(file_final, bbox_inches='tight')

    plt.show()

### 1.3 Define Paths for Data and Plots

Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. We set paths for the model weight parameters and diagnostic figures. We also set the variable `run_label` for each training run. We also save these paths in a dictionary to facilitate passing information to plotting functions.

In [None]:
run_label = "Run000"

path_dict = {'run_label': run_label,
             'dir_data_model': "Data/Models/",
             'dir_data_figures': "Data/Figures/",
             'file_model_prefix': "Model",
             'file_figure_prefix': "Figure",
             'file_figure_suffix': ".png",
             'file_model_suffix': ".keras"
             }

if not os.path.exists(path_dict['dir_data_model']):
    os.makedirs(path_dict['dir_data_model'])

if not os.path.exists(path_dict['dir_data_figures']):
    os.makedirs(path_dict['dir_data_figures'])

## 2. Load and Prepare data: MNIST Handwritten Digits

### 2.1. Download Dataset

The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. `tensorflow` has a simple function easily downloading the MNIST data to your local server for free. It automatically downloads the data into.

The **input** data are held in `x_`, while the **output** (aka, label) data are held in `y_`.

In [None]:
mnist = tf.keras.datasets.mnist
train_temp, test_temp = mnist.load_data()

### 2.2. Split Data into Train/Validation/Test

It is essential to split for a proper 'blind' analysis and optimization of an AI model.

There are three primary data sets used in model optimization:

* **Training** (`_tra`) data is used directly by the algorithm to update the parameters of the AI model -- e.g., the weights of the computational neurons on the  edges in neural networks.
* **Validation**  (`_val`)  data is used indirectly to update the hyperparameters of the AI model -- e.g., the batchsize, the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.
* **Test(ing)** (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. 

The `TF` class automatically downloads data into training and test data sets. Therefore, we use the `sklearn` `train_test_split()` function to further split the training set into training and validation data sets. We then 


In [None]:
fraction_validation = 0.25

In [None]:
# set the test data sets from the temp data at read-in
x_tes, y_tes = test_temp[0], test_temp[1]

In [None]:
# set the training and validata data sets from the temp data at read-in
# use the sklearn train_test_split function
x_tra, x_val, y_tra, y_val = train_test_split(train_temp[0], train_temp[1],
                                              test_size=fraction_validation,
                                              random_state=1)

### 2.3. Normalize data

First, we make sure that the input data are floats. This allows us to perform computations on the real number line for the inputs.

Second, we normalize the data according to the maximum value in all the data sets. The inputs will all exist on a smaller range. This improves the stability of the training.

In [None]:
# set to floats
x_tra = x_tra.astype('float32')
x_val = x_val.astype('float32')
x_tes = x_tes.astype('float32')

# calculate min and max across all input images
input_minimum = np.min([np.min(x_tra), np.min(x_val), np.min(x_tes)])
input_maximum = np.max([np.max(x_tra), np.max(x_val), np.max(x_tes)])

print("Before")
print("min/max", np.min(x_tra), np.max(x_tra))

x_tra = normalizeInputs(x_tra, input_minimum, input_maximum)
x_val = normalizeInputs(x_val, input_minimum, input_maximum)
x_tes = normalizeInputs(x_tes, input_minimum, input_maximum)

print("After")
print("min/max", np.min(x_tra), np.max(x_tra))

# get shapes
image_shape = x_tra[0, :, :].shape

### 2.4. Examine Raw Data

Review data shapes. 

The zeroth elements of the `x` and `y` shapes should match. The first and second elements of `x` should be equal: these are the dimensions of the images. The image size, in part determines the depth of the neural network that can be created.

Print the data shapes to make sure you understand how many objects there are and what the number of pixels is for each image.

In [None]:
print('check data shapes')
print('x_train:', x_tra.shape)
print('y_train:', y_tra.shape)
print('x_valid:', x_val.shape)
print('y_valid:', y_val.shape)
print('x_test:', x_tes.shape)
print('y_test:', y_tes.shape)

Plot examples to gain visual familiarity. Do these all look like hand-written digits?

In [None]:
file_prefix = path_dict['file_figure_prefix'] + "_"\
                                              + "Example_Image_Array"\
                                              + "_" + path_dict['run_label']
plotArrayImageExamples(x_tra, y_tra,
                       file_prefix=file_prefix,
                       file_location=path_dict['dir_data_figures'],
                       file_suffix=path_dict['file_figure_suffix'])

Plot pixel distributions to further understand data. Is it normalized? Do the disributions of the pixel values make sense according to what you see in the related images above? 


In [None]:
file_prefix = path_dict['file_figure_prefix'] + "_"\
                                              + "Example_Histogram_Array"\
                                              + "_" + path_dict['run_label']

plotArrayHistogramExamples(x_tra, y_tra,
                           num=10,
                           file_prefix=file_prefix,
                           file_location=path_dict['dir_data_figures'],
                           file_suffix=path_dict['file_figure_suffix'])

## 3. Train Model: Dense Neural Network

### 3.1. Define Model Training Parameters

Define optimizer
Define loss
Define accuracy
Define batch_size
Define epochs
Define metrics

In [None]:
epochs = 10
batch_size = 32
verbose = True
optimizer = "sgd"
loss = tf.keras.losses.SparseCategoricalCrossentropy()
metrics = ['accuracy']
dropout_rate = 0.3
learning_rate = 0.01
momentum = 0.9
seed = 1000

Set the random seed for neural network weight initialization

In [None]:
tf.keras.utils.set_random_seed(seed)

### 3.2. Define Model

Define Sequential Model
Define layers
Define flat layer
Define dense layers
Define activation function; define types activation functions -- sigmoid and relu
Define weights and biases

In [None]:
model_layers = [tf.keras.layers.Input(shape=image_shape),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(256, activation='sigmoid'),
                tf.keras.layers.Dense(64, activation='sigmoid'),
                tf.keras.layers.Dropout(dropout_rate),
                tf.keras.layers.Dense(10, activation='softmax')]

model = tf.keras.models.Sequential(model_layers)

View a summary of the network architecture. Examine the shapes of the layers and the numbers of parameters. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model and a high computational cost.

In [None]:
model.summary()

### 3.3. Compile and Train Model

Compile the model with the model settings created earlier.

In [None]:
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

Train (fit) the model. The output `history` contains the loss value of the training data and the validation data for each epoch.

In [None]:
history = model.fit(x_tra, y_tra,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_data=(x_val, y_val),
                    verbose=True)

Save the model as a `.keras` zip archive so that it can be used later -- e.g., for comparison to other models.

In [None]:
file_prefix = path_dict['file_model_prefix'] + "_" + path_dict['run_label']
file_name_final = createFileName(file_prefix=file_prefix,
                                 file_location=path_dict['dir_data_model'],
                                 file_suffix=path_dict['file_model_suffix'],
                                 useuid=True,
                                 verbose=True)

model.save(file_name_final)

## 4. Diagnosing the Results of the Classification Model

### 4.1. Key Terms  for Diagnostic Metrics

We use the following diagnostics to assess the status of the network optimization and efficacy. 
https://scikit-learn.org/stable/modules/model_evaluation.html


* **Metrics**
    * Loss:
    * Accuracy: Use as a rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid for imbalanced datasets. Consider using another metric.
    * tpr (Recall): Use when false negatives are more expensive than false positives.
    * for: Use when false positives are more expensive than false negatives.
    * precision: Use when it's very important for positive predictions to be accurate.
    * 
* **Generalization Error**: The Generalization Error (GE) is the difference in loss when the model is applied to training data versus when it is applied to validation data and test data.
* **Confusion Matrix**:
* **Receiver Operator Characteristic (ROC) Curve**:


### 4.2. Classification Predictions

Predict classification probabilities on the training, validation and test sets.

In [None]:
y_pred_tra = model.predict(x_tra, verbose=True)
y_pred_val = model.predict(x_val, verbose=True)
y_pred_tes = model.predict(x_tes, verbose=True)

Identify what the top-choice class is for each object in the training, validation and test sets.

In [None]:
y_pred_tra_topchoice = y_pred_tra.argmax(axis=1)
y_pred_val_topchoice = y_pred_val.argmax(axis=1)
y_pred_tes_topchoice = y_pred_tes.argmax(axis=1)

In [None]:
print("10 probabilities for each object:", np.shape(y_pred_tra))
print("Top choice for each object:", np.shape(y_pred_tra_topchoice))

Histograms of prediction distributions by class

In [None]:
file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "Histograms_top_choice"\
    + "_" + path_dict['run_label']
plotPredictionHistogram(y_pred_tra_topchoice,
                        y_prediction_b=y_pred_val_topchoice,
                        y_prediction_c=y_pred_tes_topchoice,
                        label_a="Training Set",
                        label_b="Validation Set",
                        label_c="Testing Set",
                        figsize=(12, 5),
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "Histograms_class_probabilities"\
    + "_" + path_dict['run_label']
plotPredictionHistogram(y_pred_tra,
                        y_prediction_b=y_pred_val,
                        y_prediction_c=y_pred_tes,
                        title_a='Training Set',
                        title_b='Validation Set',
                        title_c='Testing Set',
                        figsize=(15, 4),
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

Observations about these histograms ...
1. very similar shapes across the data sets: that's good

### 4.3. Generalization Error

The primary task in optimizing a network is to minimize the Generalization Error. 

### 4.3.1. Loss History: History of Loss and Accuracy during Training

Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis.

In [None]:
file_prefix = path_dict['file_figure_prefix'] + "_"\
                                              + "LossHistory"\
                                              + "_"\
                                              + path_dict['run_label']
plotLossHistory(history,
                file_prefix=file_prefix,
                file_location=path_dict['dir_data_figures'],
                file_suffix=path_dict['file_figure_suffix'])

### 4.3.2. Confusion Matrix: Bias in Trained Model?

Compute confusion matrices

In [None]:
cm_tra = confusion_matrix(y_pred_tra_topchoice, y_tra)
cm_val = confusion_matrix(y_pred_val_topchoice, y_val)
cm_tes = confusion_matrix(y_pred_tes_topchoice, y_tes)

plot confusion matrices for training, validation, and test samples (left, right, middle)

In [None]:
plotConfusionMatrix(cm_tra, cm_val, cm_tes)

### 4.3.4. Investigating Errant Classifications: Look at the examples


Choose a digit/class (human option/choice) for examination.

In [None]:
class_value = 2

Find all objects that have that class value. 
Obtain indices for the true positives (tp's), false positives (fp's), true negatives (tn's), and false negatives (fn's).

In [None]:
ind_class_tp_tra = np.where((y_tra == class_value)
                            & (y_pred_tra_topchoice == class_value))[0]

ind_class_fp_tra = np.where((y_tra != class_value)
                            & (y_pred_tra_topchoice == class_value))[0]

ind_class_tn_tra = np.where((y_tra != class_value)
                            & (y_pred_tra_topchoice != class_value))[0]

ind_class_fn_tra = np.where((y_tra == class_value)
                            & (y_pred_tra_topchoice != class_value))[0]

plot examples of false positives

In [None]:
file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "ExampleImages_TruePostives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayImageConfusion(x_tra[ind_class_tp_tra],
                        y_tra[ind_class_tp_tra],
                        y_pred_tra_topchoice[ind_class_tp_tra],
                        title_main="True Positives",
                        num=10,
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "ExampleImages_FalsePostives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayImageConfusion(x_tra[ind_class_fp_tra],
                        y_tra[ind_class_fp_tra],
                        y_pred_tra_topchoice[ind_class_fp_tra],
                        title_main="False Positives",
                        num=10,
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "ExampleImages_TrueNegatives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayImageConfusion(x_tra[ind_class_tn_tra],
                        y_tra[ind_class_tn_tra],
                        y_pred_tra_topchoice[ind_class_tn_tra],
                        title_main="True Negatives",
                        num=10,
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix'] + "_"\
    + "ExampleImages_FalseNegatives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayImageConfusion(x_tra[ind_class_fn_tra],
                        y_tra[ind_class_fn_tra],
                        y_pred_tra_topchoice[ind_class_fn_tra],
                        title_main="False Negatives",
                        num=10,
                        file_prefix=file_prefix,
                        file_location=path_dict['dir_data_figures'],
                        file_suffix=path_dict['file_figure_suffix'])

Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives.

In [None]:
file_prefix = path_dict['file_figure_prefix']\
    + "_" + "ExampleImages_TruePostives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayHistogramConfusion(x_tra[ind_class_tp_tra],
                            y_tra[ind_class_tp_tra],
                            y_pred_tra_topchoice[ind_class_tp_tra],
                            title_main="True Positives",
                            num=10,
                            file_prefix=file_prefix,
                            file_location=path_dict['dir_data_figures'],
                            file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix']\
    + "_" + "ExampleImages_FalsePostives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayHistogramConfusion(x_tra[ind_class_fp_tra],
                            y_tra[ind_class_fp_tra],
                            y_pred_tra_topchoice[ind_class_fp_tra],
                            title_main="False Positives",
                            num=10,
                            file_prefix=file_prefix,
                            file_location=path_dict['dir_data_figures'],
                            file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix']\
    + "_" + "ExampleImages_TrueNegatives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayHistogramConfusion(x_tra[ind_class_tn_tra],
                            y_tra[ind_class_tn_tra],
                            y_pred_tra_topchoice[ind_class_tn_tra],
                            title_main="True Negatives",
                            num=10,
                            file_prefix=file_prefix,
                            file_location=path_dict['dir_data_figures'],
                            file_suffix=path_dict['file_figure_suffix'])

file_prefix = path_dict['file_figure_prefix']\
    + "_" + "ExampleImages_FalseNegatives_on_class_"\
    + str(class_value) + "_" + path_dict['run_label']
plotArrayHistogramConfusion(x_tra[ind_class_fn_tra],
                            y_tra[ind_class_fn_tra],
                            y_pred_tra_topchoice[ind_class_fn_tra],
                            title_main="False Negatives",
                            num=10,
                            file_prefix=file_prefix,
                            file_location=path_dict['dir_data_figures'],
                            file_suffix=path_dict['file_figure_suffix'])

## 5. Exercises for the Learner

Each time you train a new model, re-run all the diagnostic plots.

1. How do the loss and accuracy histories change when batch size is small or large? Why?
2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?
3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?
3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?
5. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?
6. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?
7. Use the `time` module to estimate the time for the model fitting. Record that time. Add a convolutional layer to the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?

## 6. Glossary of neural network terms

1. network weights
2. deep learning
3. machine learning
4. learning
5. activation function
6. pool(ing)
7. convolution
8. layer
9. loss function
10. confusion matrix
11. epoch
12. batch size
13. learning rate
14. momentum
15. stochastic gradient descent
16. optimizer
17. receiver operator characteristic (ROC)
18. area under the curve (AUC)
19. training
20. validation
21. testing
22. class
23. hyperparameter (vs. parameter)

## 7. AI is math, not magic.

AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the reverse Boltzmann machine). 

Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.

1. learn $\rightarrow$ fit
2. hallucinate/lie $\rightarrow$ predict incorrectly
3. understand $\rightarrow$ model has converged
4. cheat $\rightarrow$ more efficiently guesses the best weight parameters of the model
5. believe $\rightarrow$ predict/infer based on statistical priors

When we over-anthropomorphize this mathematical tool, we obfuscate how it actually works, and that makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are large-parameter models that are being fit to data. The only learning that's happening is what we do with these models.