# WFC3 Figure 8 Ghost Classification using GoogLeNet
---

The purpose of the notebook is to demonstrate how to use a DeepWFC3 model to predict if a WFC3 image contains a Figure 8 Ghost. 

## Imports <a id="imports"></a>

If you are running this notebook in Jupyter, this notebook assumes you created the virtual environment defined in `environment.yml`. If not, close this notebook and run the following lines in a terminal window:

`conda env create -f environment.yml`

`conda activate deepwfc3_env`

We import the following libraries:
- *numpy* for handling arrays
- *matplotlib* for plotting
- *torch* as our machine learning framework

We also import functions from `utils.py` to process images to ImageNet format, load the model, and plot saliency maps

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

import torchvision
from torchvision import transforms

from utils import process_image, load_wfc3_uvis_figure8_model, saliency_map

## Load Data

`examples.npz` is a compressed numpy file containing two WFC3 images post processing pipeline (see (insert py file name here)). The first image is a null image of the galaxy N5643 (idgg69pmq) and the second is the globular cluster NGC-6752 (ibhf01sjq), which contains a figure 8 ghost.

We load the images using `np.load()`.

In [None]:
example = np.load('examples.npz')['examples']

## Scale Data to ImageNet 

Since our model was pretrained using [ImageNet](https://www.image-net.org/), we need to scale our examples to match the statistics of the dataset. We do this by:
- min-max scaling the images to have a minimum/maximum pixel value of 0/1
- making three copies of our examples to use as "RGB channels"
- center cropping to a 224x224
- normalizing the channels to $N(\mu=(0.485, 0.456, 0.406), \sigma=(0.229, 0.224, 0.225))$

See some [documentation](https://pytorch.org/hub/pytorch_vision_googlenet/) for more information.

In [None]:
example_0_process = process_image(example[0])
example_1_process = process_image(example[1])

## Load Model

Our model uses the [GoogLeNet](https://arxiv.org/pdf/1409.4842.pdf) architecture. The convolutional layers are frozen, and the fully conneted layers were retrained using the dataset described in (insert ISR and hyperlink here). We modify the fully connected layers to append two 1024 neuron layers with dropout rates of 0.5 and ReLU activation. Finally, we append one additional 2 neuron (number of classifications) layer with a dropout rate of 0.2 and linear activation.

The model is saved as `wfc3_uvis_figure8_model.torch` and can be loaded using `load_wfc3_uvis_figure8_model()`.

In [None]:
model = load_wfc3_uvis_figure8_model('wfc3_uvis_figure8_model.torch')

## Predict Examples

To predict the example classifications, we use them as arguments for `model()`, which returns the last two output neurons. The index of the greatest neuron output is the prediction.

In [None]:
pred_0 = model(example_0_process)
pred_1 = model(example_1_process)

In [None]:
print ('Example 0 Output Neurons: {}'.format(pred_0))
print ('Example 1 Output Neurons: {}'.format(pred_1))

## View Saliency Maps

We can view the [saliency maps](https://arxiv.org/pdf/1312.6034.pdf) our model produces for the examples by using `saliency_map()`, which prints the prediction probabilities, and plots the original image and the saliency map.

In [None]:
sm = saliency_map(model, example_0_process)
sm = saliency_map(model, example_1_process)

## Conclusions <a id="con"></a>

Thank you for walking through this notebook. Now you should be more familiar with using our model to predict if figure 8 ghosts are on WFC3 images.

## About this Notebook <a id="about"></a>

**Author:** Fred Dauphin, DeepWFC3

**Updated on:** 2022-01-14

## Citations <a id="cite"></a>

If you use `numpy`, `matplotlib`, or `torch` for published research, please cite the authors. Follow these links for more information about citing `numpy`, `matplotlib`, and `torch`:

* [Citing `numpy`](https://numpy.org/doc/stable/license.html)
* [Citing `matplotlib`](https://matplotlib.org/stable/users/project/license.html#:~:text=Matplotlib%20only%20uses%20BSD%20compatible,are%20acceptable%20in%20matplotlib%20toolkits.)
* [Citing `torch`](https://github.com/pytorch/pytorch/blob/master/LICENSE)