-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
185 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Model zoo: available models | ||
============== | ||
|
||
Coming soon... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
API documentation | ||
API | ||
============== | ||
|
||
.. automodule:: deepCR | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
.. automodule:: train | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
Quickstart: Training new deepCR models | ||
============== | ||
|
||
Training new models | ||
^^^^^^^^^^^^^^^^^^^ | ||
|
||
deepCR provides easy-to-use training functionality. Assume you have constructed your dataset which includes the | ||
following ``numpy`` arrays: | ||
|
||
image: np.ndarray (N,W,W). Array containing N input images chucks of W*W | ||
|
||
mask: np.ndarray (N,W,W). Array containing N ground truth CR mask chucks of W*W. | ||
|
||
ignore: (optional) np.ndarray (N,W,W). Array containing flags where we do not want to train or evaluate the model on. This | ||
typically includes bad pixels and saturations, or any other artifact falsely included in ``mask`` | ||
|
||
sky: (optional) np.ndarray (N,) Array containing sky background level for each image chunks. | ||
|
||
.. code-block:: python | ||
from deepCR import train | ||
trainer = train(image, mask, ignore=ignore, sky=sky, aug_sky=[-0.9, 3], name='mymodel', gpu=True, epoch=50, | ||
save_after=20, plot_every=10, use_tqdm=False) | ||
trainer.train() | ||
filename = trainer.save() # not necessary if save_after is specified | ||
The aug_sky argument enables data augmentation in sky background; random sky background in the range | ||
[aug_sky[0] * sky, aug_sky[1] * sky] is used for each input image. Sky array must be provided to use this functionality. | ||
This serves as a regularizer to allow the trained model to adapt to a wider range of sky background or equivalently | ||
exposure times. Remedy for the fact that exposure time in the training set is discrete and limited. | ||
|
||
The save_after argument lets the trainer to save models on every epoch after save_after which has the currently lowest | ||
validation loss. If this is not specified, you have to use trainer.save() to manually save the model at the last epoch. | ||
|
||
After training, you can examine that validation loss has reached its minimum by | ||
.. code-block:: python | ||
trainer.plot_loss() | ||
If validation loss is still reducing, you can continue training by | ||
.. code-block:: python | ||
trainer.train_continue(20) | ||
Do not use trainer.train(). Specify number of additional epochs. | ||
|
||
Loading your new model | ||
^^^^^^^^^^^^^^^^^^^^^^ | ||
.. code-block:: python | ||
from deepCR import deepCR | ||
mdl = deepCR(mask='save_directory/my_model_epoch50.pth', hidden=32) | ||
It's necessary to specify the number of hidden channels in the first layer if it's not default (32). | ||
|
||
Testing your model | ||
^^^^^^^^^^^^^^^^^^ | ||
You should test your model on a separate test set, which ideally should come from different fields than the training | ||
set and represent a wide range of cases, e.g., exposure times. You may test your model separately on different | ||
situations. | ||
|
||
.. code-block:: python | ||
from deepCR import roc | ||
import matplotlib.pyplot as plt | ||
tpr, fpr = evaluate.roc(mdl, image=image, mask=mask, ignore=ignore) | ||
plt.plot(fpr, tpr) | ||
plt.show() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
Quickstart: Using deepCR | ||
============== | ||
|
||
Quick download of a HST ACS/WFC image | ||
|
||
.. code-block:: bash | ||
wget -O jdba2sooq_flc.fits https://mast.stsci.edu/api/v0.1/Download/file?uri=mast:HST/product/jdba2sooq_flc.fits | ||
For smaller sized images | ||
|
||
.. code-block:: python | ||
from deepCR import deepCR | ||
from astropy.io import fits | ||
image = fits.getdata("jdba2sooq_flc.fits")[:512,:512] | ||
# create an instance of deepCR with specified model configuration | ||
mdl = deepCR(mask="ACS-WFC-F606W-2-32", | ||
inpaint="ACS-WFC-F606W-2-32", | ||
device="CPU") | ||
# apply to input image | ||
mask, cleaned_image = mdl.clean(image, threshold = 0.5) | ||
# best threshold is highest value that generate mask covering full extent of CR | ||
# choose threshold by visualizing outputs. | ||
# note that deepCR-inpaint would overestimate if mask does not fully cover CR. | ||
# if you only need CR mask you may skip image inpainting for shorter runtime | ||
mask = mdl.clean(image, threshold = 0.5, inpaint=False) | ||
# if you want probabilistic cosmic ray mask instead of binary mask | ||
prob_mask = mdl.clean(image, binary=False) | ||
For WFC full size images (4k * 2k), you should specify **segment = True** to tell deepCR to segment the input image into 256*256 patches, and process one patch at a time. | ||
Otherwise this would take up > 10gb memory. We recommended you use segment = True for images larger than 1k * 1k on CPU. GPU memory limits may be more strict. | ||
|
||
.. code-block:: python | ||
image = fits.getdata("jdba2sooq_flc.fits") | ||
mask, cleaned_image = mdl.clean(image, threshold = 0.5, segment = True) | ||
(CPU only) In place of **segment = True**, you can also specify **parallel = True** and invoke the multi-threaded version of segment mode (**segment = True**). This will speed things up a lot. You don't need to specify **segment = True** again. | ||
|
||
.. code-block:: python | ||
image = fits.getdata("jdba2sooq_flc.fits") | ||
mask, cleaned_image = mdl.clean(image, threshold = 0.5, parallel = True, n_jobs=-1) | ||
**n_jobs=-1** makes use of all your CPU cores. | ||
|
||
Note that this won't speed things up if you're using GPU! |