Skip to content

Commit

Permalink
added tutorials to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kmzzhang committed Aug 2, 2019
1 parent 32c17a3 commit f749672
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 78 deletions.
4 changes: 4 additions & 0 deletions deepCR/test/model_zoo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Model zoo: available models
==============

Coming soon...
9 changes: 5 additions & 4 deletions deepCR/test/test_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import numpy as np
import pytest

Expand All @@ -9,9 +7,12 @@
def test_train():
inputs = np.zeros((6, 64, 64))
sky = np.ones(6)
trainer = train(image=inputs, mask=inputs, sky=sky, aug_sky=[-0.9, 10], epoch=2, verbose=False)
trainer = train(image=inputs, mask=inputs, sky=sky, aug_sky=[-0.9, 10], verbose=False, epoch=2, save_after=10)
trainer.train()
trainer.save()
filename = trainer.save()
trainer.load(filename)
trainer.train_continue(1)
assert trainer.epoch_mask == 3


if __name__ == '__main__':
Expand Down
68 changes: 37 additions & 31 deletions deepCR/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class train():

def __init__(self, image, mask, ignore=None, sky=None, aug_sky=[0, 0], name='model', hidden=32, gpu=False, epoch=50,
batch_size=16, lr=0.005, auto_lr_decay=True, lr_decay_patience=4, lr_decay_factor=0.1, save_after=0,
batch_size=16, lr=0.005, auto_lr_decay=True, lr_decay_patience=4, lr_decay_factor=0.1, save_after=1e5,
plot_every=10, verbose=True, use_tqdm=False, use_tqdm_notebook=False, directory='./'):
""" This is the class for training deepCR-mask.
:param image: np.ndarray (N*W*W) training data: image array with CR.
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(self, image, mask, ignore=None, sky=None, aug_sky=[0, 0], name='mod
self.every = plot_every
self.directory = directory
self.verbose = verbose
self.mode0_complete = False

if use_tqdm_notebook:
self.tqdm = tqdm_notebook
Expand Down Expand Up @@ -142,25 +143,29 @@ def train(self):
"""
if self.verbose:
print('Begin first {} epochs of training'.format(int(self.n_epochs * 0.4 + 0.5)))
print('Use batch statistics for batch norm; keep running mean')
print('Use batch activate statistics for batch normalization; keep running mean to be used after '
'these epochs')
print('')
self.train_initial(int(self.n_epochs * 0.4 + 0.5))

filename = self.save()
self.load(filename)
self.set_to_eval()
if self.verbose:
print('Continue onto next {} epochs of training'.format(self.n_epochs - int(self.n_epochs * 0.4 + 0.5)))
print('Batch normalization running statistics frozen and used')
print('')
self.train_continue(self.n_epochs - int(self.n_epochs * 0.4 + 0.5))

def train_initial(self, epochs):
self.network.train()
for epoch in self.tqdm(range(int(self.n_epochs * 0.4 + 0.5)), disable=self.disable_tqdm):
for epoch in self.tqdm(range(epochs), disable=self.disable_tqdm):
for t, dat in enumerate(self.TrainLoader):
self.optimize_network(dat)
self.epoch_mask += 1

if self.epoch_mask % self.every==0:
plt.figure(figsize=(10, 30))
plt.subplot(131)
plt.imshow(np.log(self.img0[0, 0].detach().cpu().numpy()), cmap='gray')
plt.title('epoch=%d'%self.epoch_mask)
plt.subplot(132)
plt.imshow(self.pdt_mask[0, 0].detach().cpu().numpy() > 0.5, cmap='gray')
plt.title('prediction > 0.5')
plt.subplot(133)
plt.imshow(self.mask[0, 0].detach().cpu().numpy(), cmap='gray')
plt.title('ground truth')
plt.show()
if self.epoch_mask % self.every == 0:
self.plot_example()

if self.verbose:
print('----------- epoch = %d -----------' % (self.epoch_mask))
Expand All @@ -177,27 +182,14 @@ def train(self):
if self.verbose:
print('')

filename = self.save()
self.load(filename)
self.set_to_eval()
if self.verbose:
print('Continue onto next {} epochs of training'.format(self.n_epochs - int(self.n_epochs * 0.4 + 0.5)))
print('Batch norm running statistics frozen and used')
print('')
for epoch in self.tqdm(range(self.n_epochs - int(self.n_epochs * 0.4 + 0.5)), disable=self.disable_tqdm):
def train_continue(self, epochs):
for epoch in self.tqdm(range(epochs), disable=self.disable_tqdm):
for t, dat in enumerate(self.TrainLoader):
self.optimize_network(dat)
self.epoch_mask += 1

if self.epoch_mask % self.every==0:
plt.figure(figsize=(10, 30))
plt.subplot(131)
plt.imshow(np.log(self.img0[0, 0].detach().cpu().numpy()), cmap='gray')
plt.subplot(132)
plt.imshow(self.pdt_mask[0, 0].detach().cpu().numpy()>0.5, cmap='gray')
plt.subplot(133)
plt.imshow(self.mask[0, 0].detach().cpu().numpy(), cmap='gray')
plt.show()
self.plot_example()

if self.verbose:
print('----------- epoch = %d -----------' % self.epoch_mask)
Expand All @@ -214,8 +206,22 @@ def train(self):
if self.verbose:
print('')

def plot_example(self):
plt.figure(figsize=(10, 30))
plt.subplot(131)
plt.imshow(np.log(self.img0[0, 0].detach().cpu().numpy()), cmap='gray')
plt.title('epoch=%d' % self.epoch_mask)
plt.subplot(132)
plt.imshow(self.pdt_mask[0, 0].detach().cpu().numpy() > 0.5, cmap='gray')
plt.title('prediction > 0.5')
plt.subplot(133)
plt.imshow(self.mask[0, 0].detach().cpu().numpy(), cmap='gray')
plt.title('ground truth')
plt.show()

def set_to_eval(self):
self.network.eval()

def optimize_network(self, dat):
self.set_input(*dat)
self.pdt_mask = self.network(self.img0)
Expand Down
7 changes: 6 additions & 1 deletion docs/deepCR.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
API documentation
API
==============

.. automodule:: deepCR
:members:
:undoc-members:
:show-inheritance:

.. automodule:: train
:members:
:undoc-members:
:show-inheritance:
54 changes: 12 additions & 42 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ Welcome to the documentation for `deepCR`. You will use `deepCR` to apply a lear

.. image:: https://raw.githubusercontent.com/profjsb/deepCR/master/imgs/postage-sm.jpg

This is the documentation for the installable package which implements the methods described in the paper: Zhang & Bloom (2019), submitted. Code to benchmark the model and to generate figures and tables in the paper can be found in the deepCR-paper Github repo: https://github.com/kmzzhang/deepCR-paper


Installation
^^^^^^^^^^^^
Expand All @@ -40,43 +38,25 @@ Or you can install from source:
cd deepCR/
pip install
Quick Start
Currently available models
^^^^^^^^^^^

With Python >=3.5:

.. code-block:: python
from deepCR import deepCR
from astropy.io import fits
mask:

image = fits.getdata("*********_flc.fits")
mdl = deepCR(mask="ACS-WFC-F606W-2-32",
inpaint="ACS-WFC-F606W-2-32",
device="GPU")
mask, cleaned_image = mdl.clean(image, threshold = 0.5)
ACS-WFC-F606W-2-4

Note:
Input image must be in units of electrons
ACS-WFC-F606W-2-32(*)

To reduce memory consumption (recommended for image larger than 1k x 1k):
inpaint:

.. code-block:: python
ACS-WFC-F606W-2-32(*)

mask, cleaned_image = mdl.clean(image, threshold = 0.5, seg = 256)
which segments the input image into patches of 256*256, seperately perform CR rejection on the patches, before stitching back to original image size.

Currently available models
^^^^^^^^^^^^^^^^^^^^^^^^^^
ACS-WFC-F606W-3-32

mask: ACS-WFC-F606W-2-4
ACS-WFC-F606W-2-32(*)
Recommended models are marked in (*). Larger number indicate larger capacity.

inpaint: ACS-WFC-F606W-2-32
ACS-WFC-F606W-3-32(*)
Note that trained models may have input unit or preprocessing requirements. For the ACS-WFC-F606W models, input images must come from *_flc.fits* files which are in units of electrons.

The two numbers following instrument configuration specifies model size, with larger number indicating better performing model at the expense of runtime. Recommanded models are marked in (*). For benchmarking of these models, please refer to the original paper.

Limitations and Caveats
^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -88,24 +68,14 @@ Contributing

We are very interested in getting bug fixes, new functionality, and new trained models from the community (especially for ground-based imaging and spectroscopy). Please fork this repo and issue a PR with your changes. It will be especially helpful if you add some tests for your changes.


How to Use This Guide
---------------------


If you run into any issues, please don't hesitate to `open an issue on GitHub
<https://github.com/profjsb/deepCR/issues>`_.

.. toctree::
:maxdepth: 4
:caption: Contents:

tutorial_use
tutorial_train
model_zoo
deepCR


Indices and tables
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
69 changes: 69 additions & 0 deletions docs/tutorial_train.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
Quickstart: Training new deepCR models
==============

Training new models
^^^^^^^^^^^^^^^^^^^

deepCR provides easy-to-use training functionality. Assume you have constructed your dataset which includes the
following ``numpy`` arrays:

image: np.ndarray (N,W,W). Array containing N input images chucks of W*W

mask: np.ndarray (N,W,W). Array containing N ground truth CR mask chucks of W*W.

ignore: (optional) np.ndarray (N,W,W). Array containing flags where we do not want to train or evaluate the model on. This
typically includes bad pixels and saturations, or any other artifact falsely included in ``mask``

sky: (optional) np.ndarray (N,) Array containing sky background level for each image chunks.

.. code-block:: python
from deepCR import train
trainer = train(image, mask, ignore=ignore, sky=sky, aug_sky=[-0.9, 3], name='mymodel', gpu=True, epoch=50,
save_after=20, plot_every=10, use_tqdm=False)
trainer.train()
filename = trainer.save() # not necessary if save_after is specified
The aug_sky argument enables data augmentation in sky background; random sky background in the range
[aug_sky[0] * sky, aug_sky[1] * sky] is used for each input image. Sky array must be provided to use this functionality.
This serves as a regularizer to allow the trained model to adapt to a wider range of sky background or equivalently
exposure times. Remedy for the fact that exposure time in the training set is discrete and limited.

The save_after argument lets the trainer to save models on every epoch after save_after which has the currently lowest
validation loss. If this is not specified, you have to use trainer.save() to manually save the model at the last epoch.

After training, you can examine that validation loss has reached its minimum by
.. code-block:: python
trainer.plot_loss()
If validation loss is still reducing, you can continue training by
.. code-block:: python
trainer.train_continue(20)
Do not use trainer.train(). Specify number of additional epochs.

Loading your new model
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from deepCR import deepCR
mdl = deepCR(mask='save_directory/my_model_epoch50.pth', hidden=32)
It's necessary to specify the number of hidden channels in the first layer if it's not default (32).

Testing your model
^^^^^^^^^^^^^^^^^^
You should test your model on a separate test set, which ideally should come from different fields than the training
set and represent a wide range of cases, e.g., exposure times. You may test your model separately on different
situations.

.. code-block:: python
from deepCR import roc
import matplotlib.pyplot as plt
tpr, fpr = evaluate.roc(mdl, image=image, mask=mask, ignore=ignore)
plt.plot(fpr, tpr)
plt.show()
52 changes: 52 additions & 0 deletions docs/tutorial_use.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
Quickstart: Using deepCR
==============

Quick download of a HST ACS/WFC image

.. code-block:: bash
wget -O jdba2sooq_flc.fits https://mast.stsci.edu/api/v0.1/Download/file?uri=mast:HST/product/jdba2sooq_flc.fits
For smaller sized images

.. code-block:: python
from deepCR import deepCR
from astropy.io import fits
image = fits.getdata("jdba2sooq_flc.fits")[:512,:512]
# create an instance of deepCR with specified model configuration
mdl = deepCR(mask="ACS-WFC-F606W-2-32",
inpaint="ACS-WFC-F606W-2-32",
device="CPU")
# apply to input image
mask, cleaned_image = mdl.clean(image, threshold = 0.5)
# best threshold is highest value that generate mask covering full extent of CR
# choose threshold by visualizing outputs.
# note that deepCR-inpaint would overestimate if mask does not fully cover CR.
# if you only need CR mask you may skip image inpainting for shorter runtime
mask = mdl.clean(image, threshold = 0.5, inpaint=False)
# if you want probabilistic cosmic ray mask instead of binary mask
prob_mask = mdl.clean(image, binary=False)
For WFC full size images (4k * 2k), you should specify **segment = True** to tell deepCR to segment the input image into 256*256 patches, and process one patch at a time.
Otherwise this would take up > 10gb memory. We recommended you use segment = True for images larger than 1k * 1k on CPU. GPU memory limits may be more strict.

.. code-block:: python
image = fits.getdata("jdba2sooq_flc.fits")
mask, cleaned_image = mdl.clean(image, threshold = 0.5, segment = True)
(CPU only) In place of **segment = True**, you can also specify **parallel = True** and invoke the multi-threaded version of segment mode (**segment = True**). This will speed things up a lot. You don't need to specify **segment = True** again.

.. code-block:: python
image = fits.getdata("jdba2sooq_flc.fits")
mask, cleaned_image = mdl.clean(image, threshold = 0.5, parallel = True, n_jobs=-1)
**n_jobs=-1** makes use of all your CPU cores.

Note that this won't speed things up if you're using GPU!

0 comments on commit f749672

Please sign in to comment.