# WFC3 IR Blob Classification: Modeling and Evaluation

### The purpose of this notebook is to walk the user through the modeling and evaluation pipeline for the WFC3 IR Blob Classifier. In this notebook, the user will:

###    3. Trains a CNN to classify if blobs are in a subframe.
#### We train a convolutional neural network with the following initial hyperparameters:
#### - 2 convolutional layers (1 filter to 8 filters to 16 filters)
#### - 2 fully connected layers (16 * 64 * 64 neurons to 128 neurons to 2 neurons)
#### - 5x5 kernel
#### - 2x2 max pooling at the end of each convolutional layer
#### - 2 padding on each feature map. This ensures the feature maps don't shrink after a convolution, but before a max pool
#### - 15% and 30 % dropout regularization
#### - Cross Entropy Loss
#### - Adam optimizer
#### - Batch size of 100
#### - 5 epochs

### 4. Evaluates the model's performance.
#### We use loss, accuracy, and confusion matrices as metrics for evaluating the model's performance. In addition, we check incorrect images and plot saliency maps to further investigate how and why the model makes certain predictions.

## Introduction (Pending...)

## Imports
### This notebook also uses functions defined in the complimentary python script, wfc3_ir_blob_class_utils.py. Please read the documentation of the script for further knowledge about it's functionality under Modeling

In [None]:
import os
from glob import glob

import numpy as np
from matplotlib import pyplot as plt
from ginga.util.zscale import zscale
from astropy.io import fits

from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn import metrics
import seaborn as sns

In [None]:
%run wfc3_ir_blob_class_utils.py

## Load data from data processing notebook

In [None]:
SIZE = 256
PATH = 'saved_generated_datasets/'

In [None]:
# Load/combine non blob and blob sets into one training set
train_nb = np.load('{}training_non_blob.npz'.format(PATH))
train_b = np.load('{}training_blob.npz'.format(PATH))

train_image_set = np.concatenate((train_nb['image_set'], train_b['image_set']))
train_labels = np.concatenate((train_nb['labels'], train_b['labels']))

print ('Training Image Set and Label Sizes:', train_image_set.shape, train_labels.shape)

In [None]:
# Load/combine non blob and blob sets into one validation set
val_nb = np.load('{}validation_non_blob.npz'.format(PATH))
val_b = np.load('{}validation_blob.npz'.format(PATH))

val_image_set = np.concatenate((val_nb['image_set'], val_b['image_set']))
val_labels = np.concatenate((val_nb['labels'], val_b['labels']))

print ('Validation Image Set and Label Sizes:', val_image_set.shape, val_labels.shape)

In [None]:
# Load/combine non blob and blob sets into one test set
test_nb = np.load('{}test_non_blob.npz'.format(PATH))
test_b = np.load('{}test_blob.npz'.format(PATH))

test_image_set = np.concatenate((test_nb['image_set'], test_b['image_set']))
test_labels = np.concatenate((test_nb['labels'], test_b['labels']))

print ('Test Image Set and Label Sizes:', test_image_set.shape, test_labels.shape)

## "Check" the contents of the data sets are what they should be
### As of now, each data set should have 100 non blob images and 100 blob images, in that order.
### Check: 
#### - If index < 100, label is 0 and the image shown has no blobs
#### - If index >= 100, label is 1 and the image shown has blobs (they can be faint sometimes...)

In [None]:
random_index = np.random.randint(0, 200, size=3)

fig, axs = plt.subplots(1, 3, figsize=[15,5])
axs[0].imshow(train_image_set[random_index[0]], cmap='Greys')
axs[0].set_title('Train - Index: {}, Label: {}'.format(random_index[0], train_labels[random_index[0]]))
axs[1].imshow(val_image_set[random_index[1]], cmap='Greys')
axs[1].set_title('Val - Index: {}, Label: {}'.format(random_index[1], val_labels[random_index[1]]))
axs[2].imshow(test_image_set[random_index[2]], cmap='Greys')
axs[2].set_title('Test - Index: {}, Label: {}'.format(random_index[2], test_labels[random_index[2]]))

## Generate a smaller test set
### Randomly select subframes from a larger test set to make a smaller test set. In this example, we choose our test set to be 100 random subframes from our 200 subframe generated test set.

In [None]:
test_image_set_small, test_labels_small = generate_test_data(test_image_set, test_labels, num=100)

In [None]:
test_image_set_small.shape, test_labels_small.shape

## Test model functionality
### Before running any models, verify an image can be passed through the model and produce an output. If not, there's a bug somewhere in the model.

In [None]:
model = Classifier()
image = train_image_set[0].reshape(1,1,SIZE,SIZE)
model_output = model(torch.Tensor(image)).detach().numpy()

In [None]:
image.shape, model_output.shape

In [None]:
plt.imshow(image[0, 0], cmap='Greys')
model_output

## Count trainable parameters
### As model's become more complex, the number of parameters increase and training will take more time. The default model has approximately 8M parameters.

In [None]:
count_parameters(model)

## Establish baseline
### Baseline Accuracy = 50%; randomly choosing labels
### If the model cannot outperform the baseline, then the model isn't learning. Try debugging data and model if the model's accuracy is at baseline.

## Set hyperparameters
### A personal rule of thumb is for the batch size to be 1/100th the size of the training set. This allows the model to update itself 100 times each epoch.

In [None]:
dataloader_params = {
    'batch_size': 1,
    'shuffle': True,
    'num_workers': 0
    }
num_epochs=5

## Train and validate model
### Since there isn't a substantial amount of data, the model may not learn to a high degree (i.e. accuracy may be around baseline). However, the loss should be decreasing, meaning it is learning a little and is properly functioning!
### In addition, the loss function, confusion matrix, saliency map, and final testing cells below are just illustrations for running code.
### Training should take less than two minutes with the generated sata set and default hyperparameters.

In [None]:
build_model_return = build_model(train_image_set, 
                                 train_labels, 
                                 val_image_set, 
                                 val_labels, 
                                 dataloader_params, 
                                 num_epochs)

In [None]:
model, lst_train_loss, lst_val_loss, lst_accuracy = build_model_return

## Plot loss functions

In [None]:
plot_metrics(num_epochs, lst_train_loss, lst_val_loss, lst_accuracy)

## Plot confusion matrix on test data

In [None]:
test_outputs_small, test_predictions_small, cm = confusion_matrix(model, 
                                                                  test_image_set_small, 
                                                                  test_labels_small)

## Analyze saliency maps
### We can determine what features the model decides is most important when classifying subframes

In [None]:
index = np.random.randint(test_labels_small.shape[0])
sal_map = saliency_map(model, test_image_set_small[index], test_labels_small[index], index)

## Analyze incorrect images
### By analyzing trends in incorrect images, we can determine what our model is struggling with and make appropriate adjustments to our data and model

In [None]:
incorrect = check_incorrect_image(test_image_set_small, 
                                  test_labels_small, 
                                  test_outputs_small, 
                                  test_predictions_small)

## Final test: check predictions on full subframes

In [None]:
final_test = np.load('saved_generated_datasets/final_test.npz')

In [None]:
final_test_blob = final_test['blob']
final_test_non_blob = final_test['non_blob']
final_test_median = final_test['median']

In [None]:
indices = np.random.randint(0, final_test_blob.shape[0], 3)
sal_map_blob = saliency_map(model, final_test_blob[indices[0]], 1, indices[0])
sal_map_non_blob = saliency_map(model, final_test_non_blob[indices[1]], 0, indices[1])
sal_map_median = saliency_map(model, final_test_median[indices[2]], 1, indices[2])

## Save Model

In [None]:
torch.save(model.state_dict(), 'example_model.torch')

## Complete!

## Appendix: Load Model
### Now with a saved model,  it can be loaded onto other notebooks that have the script running.

In [None]:
%run wfc3_ir_blob_class_utils.py

In [None]:
model = Classifier()
model.load_state_dict(torch.load('example_model.torch'))
model.eval()