# WFC3 Guide Star Failure Classification Using Convolutional Neural Networks (CNNs)
---
The purpose of this notebook is to demosnstrate how to use a DeepWFC3 machine learning model to identify if a WFC3 image is affected by a guide star failure (GS fail). The models presented here are fully described in the [WFC3 ISR 2024-03](https://www.stsci.edu/files/live/sites/www/files/home/hst/instrumentation/wfc3/documentation/instrument-science-reports-isrs/_documents/2024/WFC3-ISR-2024-03.pdf).

## Imports

If you are running this notebook in Jupyter, this notebook assumes you created the virtual environment defined in environment.yml. If not, close this notebook and run the following lines in a terminal window:

```
conda env create -f environment.yml
conda activate deepwfc3_env
```

We import the following libraries:

- `os` for handling paths
- `glob` for querying directories
- `numpy` for handling arrays
- `pandas` for handling dataframes
- `matplotlib` for plotting
- `astropy` for handling astronomical data
- `astroquery` for downloading astronomical data
- `ginga` for image scaling
- `torch` as our machine learning framework

We also import `data_process` for data reduction tasks and `model` to load the models.

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from astropy.io import fits
from astroquery.mast import Observations
from ginga.util.zscale import zscale

# Machine learning module
import torch
from torch import nn

# Data processing and augmentation module
import data_process

# Model module
from model import Model

## Download Example Data

We start by downloading some sample images from the MAST database, using `astroquery`, that are examples of nominal and GS fail images.

The rootnames of the guide star failures are:
- ied203fqq
- ie9m0xv1q
- ieou18fkq

The nominal images are:
- iec39axmq
- ientf1gjq
- ie3b36n5q
- ie3b40ljq

First, we get a table of the fits files with the image data that we are interested in to make sure that we are downloading the correct files.

In [None]:
# IDs and rootnames of the example images we'll download
example_IDs = ['ied203*', 'ie9m0x020', 'ieou18020','iec39a010', 'ientf1010', 'ie3b36010', 'ie3b40010']
rootnames = ['ied203fqq', 'ie9m0xv1q', 'ieou18fkq', 'iec39axmq', 'ientf1gjq', 'ie3b36n5q', 'ie3b40ljq']

example_query = Observations.query_criteria(obs_id=example_IDs)
example_prods = Observations.get_product_list(example_query)
example_table = Observations.filter_products(example_prods, obs_id=rootnames, extension=['_flt.fits'])
example_table

Once we have confirmed that the files in the table are the correct images, and that the correct number of them are in the table (7), we download them from MAST.

In [None]:
# Download the images we want to use
downloads = Observations.download_products(example_table, mrp_only=False, cache=False)

downloads

## Process and augment the images

Next, we process and augment the downlaoded example images using `data_process`. The function `log_data_process` performs the following procedure:
* Opens the SCI array of the `flt.fits` file as an array.
* Sets all values in the image array that are less than one to be equal to 1.
* Scales the image data logarithmically.
* Resizes the image array to the dimensions (256, 256).
* Uses min/max scaling to scale the pixel values to be between 0 and 1.

The function `augment` creates an augmented copy of the processed image by:
* Vertically flipping the image with a 50% probability.
* Horizontally flipping the image with a 50% probability.
* Rotating the image to a random degree of (0,360).
* Cropping the image in the center to be (180,180).

In [None]:
# Initialize lists of images and rootnames
original_images = []
processed_images = []
augmented_images = []
root = []

# Process the images
for i in range(0, len(downloads['Local Path'])):
    # Get the processed and original version of the image
    proc_image, orig_image = data_process.log_image_process(downloads['Local Path'][i], True)
    processed_images.append(proc_image)
    original_images.append(orig_image)
    
    # Get an augmented version of the image
    aug_image = data_process.augment(proc_image)
    augmented_images.append(aug_image)
    
    # Get ordered list of rootnames for later
    base = os.path.basename(downloads['Local Path'][i])
    root.append(base.split('_')[0])

# Convert lists to arrays
proc_examples = np.array(processed_images)
aug_examples = np.array(augmented_images)

Now that we have the processed and augmented images, we look at the difference between the original, processed, and augmented images:

In [None]:
# Change the number to look at different images in the set
img = 1
vmin, vmax = zscale(original_images[img])

fig, axs = plt.subplots(1,3, figsize=[20,30])
# Plot the unprocessed image
axs[0].set_title('Original image (Zscaled)')
A = axs[0].imshow(original_images[img], vmin=vmin, vmax=vmax, cmap='gray', origin='lower')
fig.colorbar(A, ax=axs[0], shrink=0.15)

# Plot the processed image
axs[1].set_title('Processed image')
B = axs[1].imshow(processed_images[img], cmap='gray', origin='lower')
fig.colorbar(B, ax=axs[1], shrink=0.15)

# Plot the augmented image
axs[2].set_title('Augmented image')
C = axs[2].imshow(augmented_images[img], cmap='gray', origin='lower')
fig.colorbar(C, ax=axs[2], shrink=0.15)
plt.show()

## Load the models

Next, we load the model parameters and achitectures, and set them to evaluation mode.

In [None]:
# Get the model parameters
param_files = sorted(glob.glob('model_params/*.pt'))

param_files

In [None]:
# Initialize the models:
model1 = Model(sub_array_size=256)
model1.load_state_dict(torch.load(param_files[0]))
model1.eval();

model2 = Model(sub_array_size=180)
model2.load_state_dict(torch.load(param_files[1]))
model2.eval();

model3 = Model(sub_array_size=256)
model3.load_state_dict(torch.load(param_files[2]))
model3.eval();

## Predictions image classifications

Now, we classify the images with the models. We also define a function `softmax` to convert output neuron activations to probabilities.

In [None]:
# Define softmax for getting prediction probs
softmax = nn.Softmax(dim=1)

In [None]:
# Model 1 predictions
model1_out = model1(torch.Tensor(proc_examples.reshape(proc_examples.shape[0], 1, proc_examples.shape[1], proc_examples.shape[1])))
model1_preds = np.argmax(model1_out.detach().numpy(),axis=1)
model1_pred_probs = np.max((softmax(model1_out)).detach().numpy(),axis=1)

# Model 2 predictions
model2_out = model2(torch.Tensor(aug_examples.reshape(aug_examples.shape[0], 1, aug_examples.shape[1], aug_examples.shape[1])))
model2_preds = np.argmax(model2_out.detach().numpy(),axis=1)
model2_pred_probs = np.max((softmax(model2_out)).detach().numpy(),axis=1)

# Model 3 predictions
model3_out = model3(torch.Tensor(proc_examples.reshape(proc_examples.shape[0], 1, proc_examples.shape[1], proc_examples.shape[1])))
model3_preds = np.argmax(model3_out.detach().numpy(),axis=1)
model3_pred_probs = np.max((softmax(model3_out)).detach().numpy(),axis=1)

## Evaluate predictions

We compare our model predictions and probabilities with the actual classifications.

In [None]:
# Dataframe columns
col_names = [
    'Example Rootname', 'Correct Prediction', 
    'Model 1 Prediction', 'Model 1 Probability',
    'Model 2 Prediction', 'Model 2 Probability',
    'Model 3 Prediction', 'Model 3 Probability'
]

# Corrent predictions
correct_preds = [1, 0, 1, 0, 0, 0, 1]

# Make dictionary
all_data = {'Example Rootname': root,
            'Correct Prediction':correct_preds,
            'Model 1 Prediction':model1_preds,
            'Model 1 Prediction Probability':model1_pred_probs,
            'Model 2 Prediction':model2_preds,
            'Model 2 Prediction Probability':model2_pred_probs,
            'Model 3 Prediction':model3_preds,
            'Model 3 Prediction Probability':model3_pred_probs
}

# Create dataframe from dictionary
pred_table = pd.DataFrame(all_data)
pred_table

All the models predicted the images correctly with high probabilities for their predictions.

## Conclusions

Thank you for walking through this notebook. Now you should be more familiar with using our models to predict if WFC3 images are affected by guide star failures.

## About this Notebook

Authors: Megan Jones, Fred Dauphin, DeepHST

Created on: 2024-04-29

Updated on: 2024-11-23

## Citations

If you use `numpy`, `matplotlib`, `pandas`, `astropy`, `astroquery`, or `torch` for published research, please cite the authors. Follow these links for more information about citing `numpy`, `matplotlib`, `pandas`, `astropy`, `astroquery`, and `torch`:

- Citing [`numpy`](https://numpy.org/citing-numpy/)
- Citing [`matplotlib`](https://matplotlib.org/stable/project/citing.html)
- Citing [`pandas`](https://pandas.pydata.org/about/citing.html)
- Citing [`astropy`](https://www.astropy.org/acknowledging.html)
- Citing [`astroquery`](https://github.com/astropy/astroquery/blob/main/astroquery/CITATION)
- Citing [`torch`](https://arxiv.org/abs/1912.01703)