Skip to content

Commit

Permalink
new stitching scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
kmzzhang committed Jul 23, 2019
1 parent 7a971dd commit 1d9b24f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 103 deletions.
33 changes: 23 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

## deepCR: Deep Learning Based Cosmic Ray Removal for Astronomical Images

Identify and remove cosmic rays from astronomical images using trained convolutional neural networks. Fast on both CPU and GPU.
This is the installable package which implements the methods described in the paper: Zhang & Bloom (2019), submitted.
Identify and remove cosmic rays from astronomical images using trained convolutional neural networks.
Fast on both CPU and GPU.

This is the installable package which implements the methods described in the paper: Zhang & Bloom (2019), submitted to AAS Journals.

Code to reproduce benchmarking results in the paper is at: https://github.com/kmzzhang/deepCR-paper

If you use this package, please cite Zhang & Bloom (2019): www.arxiv.org/XXX
This repo is under active development.

<img src="https://raw.githubusercontent.com/profjsb/deepCR/master/imgs/postage-sm.jpg" wdith="90%">

Expand All @@ -33,49 +36,59 @@ With Python >=3.5:
from deepCR import deepCR
from astropy.io import fits
image = fits.getdata("example_flc.fits")

# create an instance of deepCR with specified model configuration
mdl = deepCR(mask="ACS-WFC-F606W-2-32",
inpaint="ACS-WFC-F606W-3-32",
device="CPU")
# apply to input image
# you should examine cleaned_image to verify that proper threshold is used
mask, cleaned_image = mdl.clean(image, threshold = 0.5)
# examine those outputs to choose an adequate threshold

# for quicker mask prediction, skip image inpainting
# if you only need CR mask you may skip image inpainting for shorter runtime
mask = mdl.clean(image, threshold = 0.5, inpaint=False)

# for probabilistic cosmic ray mask (instead of binary mask)
# if you want probabilistic cosmic ray mask instead of binary mask
prob_mask = mdl.clean(image, binary=False)
```

To reduce memory consumption (recommended for image larger than 1k x 1k), deepCR could segment input image to smaller patches of seg*seg, perform CR rejection one by one, and stitch the outputs back to original size. Runtime does not increase by much.
To reduce memory consumption (recommended for image larger than 1k x 1k), you can tell deepCR to segment input image to smaller patches of seg * seg.
```python
mask, cleaned_image = mdl.clean(image, threshold = 0.5, seg = 256)
mask = mdl.clean(image, threshold = 0.5, seg = 256)
```
We recommend using patch size (seg) no smaller than 64.
deepCR perform CR rejection on small patches one at a time. It then stitches the predictions back to the original image size.

We recommend using patch size (seg) no smaller than 64. seg is by default 256.

### Currently available models

mask:

ACS-WFC-F606W-2-4

ACS-WFC-F606W-2-32(*)

inpaint:

ACS-WFC-F606W-2-32

ACS-WFC-F606W-3-32(*)

Recommended models are marked in (*).

Input images should come from _flc.fits files which are in units of electrons.
The two numbers, e.g., 2-32, specifies model hyperparameter. Larger number indicate larger capacity and better performance.
Larger number indicate larger capacity and better performance.

### API Documentation

Documentation is under development at: https://deepcr.readthedocs.io/en/latest/deepCR.html

### Limitations and Caveats

In the current release, the included models have been built and tested only on Hubble Space Telescope (HST) ACS/WFC images in the F606W filter. Applying them to other HST detectors is discouraged. Users should exert caution when applying the models to other filters of ACS/WFC.
In the current release, the included models have been trained and tested only on Hubble Space Telescope (HST) ACS/WFC images in the F606W filter. They may work well on nearby ACS/WFC filters, though users should exert caution.

The ACS/WFC models are not expected to work optimally on other HST detectors, though we'd be interested to know if you find additional use cases for them.

### Contributing

Expand Down
144 changes: 73 additions & 71 deletions deepCR/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,31 @@

class deepCR():

def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='medmask', device='CPU',
def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-3-32', device='CPU',
model_dir=default_model_path):

"""model class instantiation for deepCR. Here is the
declaration of the learned mask and inpainting models
to be used on images
Parameters
----------
mask : str
Name of deepCR-mask model to use. This is one of the keys in
`mask_dict`
inpaint : str
Name of the inpainting model to use. This is one of the keys in
`inpaint_dict`. It can also be `medmask` which will then
use a simple 5x5 median mask for inpainting
device : str
One of 'CPU' or 'GPU'
model_dir : str
The location of the model directory with the mask/ and inpaint/
subdirectories. This defaults to where the pre-shipped
models live (in `learned_models/`)
Returns
-------
None
"""
if(device == 'GPU'):
Instantiation of deepCR with specified model configurations
Parameters
----------
mask : str
Name of deepCR-mask model to use.
inpaint : str
Name of the inpainting model to use. It can also be `medmask` which will then
use a simple 5x5 median mask sampling for inpainting
device : str
One of 'CPU' or 'GPU'
model_dir : str
The location of the model directory with the mask/ and inpaint/
subdirectories. This defaults to where the pre-shipped
models live (in `learned_models/`)
Returns
-------
None
"""
ifdevice == 'GPU':
self.dtype = torch.cuda.FloatTensor
self.dint = torch.cuda.ByteTensor
wrapper = nn.DataParallel
Expand Down Expand Up @@ -85,29 +82,34 @@ def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='medmask', device='CPU',
for p in self.inpaintNet.parameters():
p.required_grad = False

def clean(self, img0, threshold=0.5, inpaint=True, binary=True,
seg=0, parallel=False, n_jobs=-1):
def clean(self, img0, threshold=0.5, inpaint=True, binary=True, segment=False,
patch=256, parallel=False, n_jobs=-1):
"""
Identify cosmic rays in an input image, and (optionally) inpaint with the predicted cosmic ray mask
:param img0: (np.ndarray) 2D input image conforming to model requirements. For HST ACS/WFC, must be from _flc.fits and in units of electrons in native resolution.
:param threshold: (float) applied to probabilistic mask to generate binary mask
:param inpaint: (bool) return clean, inpainted image only if True
:param binary: return binary CR mask if True. probabilistic mask if False
:param seg: for large input images, blocksize to apply models on
:param parallel: run in parallel if True and seg > 0
:param n_jobs: number of jobs to run in parallel, passed to `joblib`
:return: mask or binary mask; or None if internal call
:param segment: (bool) if True, segment input image into chunks of patch * patch before performing CR rejection. Used for memory control.
:param patch: (int) Use 256 unless otherwise required. if segment==True, segment image into chunks of patch * patch.
:param parallel: (bool) run in parallel if True and segment==True
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""
if seg==0:

# data pre-processing
img0 = img0.astype(np.float32) / 100

if not segment:
return self.clean_(img0, threshold=threshold,
inpaint=inpaint, binary=binary)
else:
if not parallel:
return self.clean_large(img0, threshold=threshold,
inpaint=inpaint, binary=binary, seg=seg)
inpaint=inpaint, binary=binary, patch=patch)
else:
return self.clean_large_parallel(img0, threshold=threshold,
inpaint=inpaint, binary=binary, seg=seg,
inpaint=inpaint, binary=binary, patch=patch,
n_jobs=n_jobs)

def clean_(self, img0, threshold=0.5, inpaint=True, binary=True):
Expand All @@ -120,10 +122,8 @@ def clean_(self, img0, threshold=0.5, inpaint=True, binary=True):
:param threshold: for creating binary mask from probabilistic mask
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:return: mask or binary mask; or None if internal call
:return: CR mask and (optionally) clean inpainted image
"""
# data proprocessing
img0 = img0.astype(np.float32) / 100

shape = img0.shape
pad_x = 4 - shape[0] % 4
Expand Down Expand Up @@ -156,8 +156,6 @@ def clean_(self, img0, threshold=0.5, inpaint=True, binary=True):
img0 = img0.detach().cpu().view(shape[0], shape[1]).numpy()
img1 = medmask(img0, binary_mask)
inpainted = img1 * binary_mask + img0 * (1 - binary_mask)


if binary:
return binary_mask[pad_x:, pad_y:], inpainted[pad_x:, pad_y:] * 100
else:
Expand All @@ -173,21 +171,29 @@ def clean_(self, img0, threshold=0.5, inpaint=True, binary=True):
return mask[pad_x:, pad_y:]

def clean_large_parallel(self, img0, threshold=0.5, inpaint=True, binary=True,
seg=256, n_jobs=-1):

patch=256, n_jobs=-1):
"""
given input image
return cosmic ray mask and (optionally) clean image
mask could be binary or probabilistic
:param img0: (np.ndarray) 2D input image
:param threshold: for creating binary mask from probabilistic mask
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:param patch: (int) Use 256 unless otherwise required. patch size to run deepCR on.
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""
folder = './joblib_memmap'
try:
mkdir(folder)
except FileExistsError:
pass


im_shape = img0.shape
img0_dtype = img0.dtype
hh = int(math.ceil(im_shape[0]/seg))
ww = int(math.ceil(im_shape[1]/seg))

img0 = np.pad(img0, 3, 'constant')
hh = int(math.ceil(im_shape[0]/patch))
ww = int(math.ceil(im_shape[1]/patch))

img0_filename_memmap = path.join(folder, 'img0_memmap')
dump(img0, img0_filename_memmap)
Expand All @@ -205,24 +211,23 @@ def clean_large_parallel(self, img0, threshold=0.5, inpaint=True, binary=True,
shape=im_shape, mode='w+')

@wrap_non_picklable_objects
def fill_values(i, j, img0, img1, mask, seg, inpaint, threshold, binary):
img = img0[i * seg:(i + 1) * seg + 6, j * seg:(j + 1) * seg + 6]
def fill_values(i, j, img0, img1, mask, patch, inpaint, threshold, binary):
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
if inpaint:
mask_, clean_ = self.clean_(img, threshold=threshold, inpaint=True, binary=binary)
mask[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = mask_[3:-3, 3:-3]
img1[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = clean_[3:-3, 3:-3]
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
img1[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = clean_
else:
mask_ = self.clean_(img, threshold=threshold, inpaint=False, binary=binary)
mask[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = mask_[3:-3, 3:-3]
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_

results = Parallel(n_jobs=n_jobs, verbose=0)\
(delayed(fill_values)(i, j, img0, img1, mask, seg, inpaint, threshold, binary)
for i in range(hh) for j in range(ww))
(delayed(fill_values)(i, j, img0, img1, mask, patch, inpaint, threshold, binary)
for i in range(hh) for j in range(ww))

mask = np.array(mask)
if inpaint:
img1 = np.array(img1)

try:
shutil.rmtree(folder)
except:
Expand All @@ -233,10 +238,8 @@ def fill_values(i, j, img0, img1, mask, seg, inpaint, threshold, binary):
else:
return mask


def clean_large(self, img0, threshold=0.5, inpaint=True, binary=True,
seg=256):

patch=256):
"""
given input image
return cosmic ray mask and (optionally) clean image
Expand All @@ -248,37 +251,36 @@ def clean_large(self, img0, threshold=0.5, inpaint=True, binary=True,
:return: mask or binary mask; or None if internal call
"""
im_shape = img0.shape
hh = int(math.ceil(im_shape[0]/seg))
ww = int(math.ceil(im_shape[1]/seg))
hh = int(math.ceil(im_shape[0]/patch))
ww = int(math.ceil(im_shape[1]/patch))

img0 = np.pad(img0, 3, 'constant')
img1 = np.zeros((im_shape[0], im_shape[1]))
mask = np.zeros((im_shape[0], im_shape[1]))

if inpaint:
for i in tqdm(range(hh)):
for i in range(hh):
for j in range(ww):
img = img0[i * seg:(i + 1) * seg + 6, j * seg:(j + 1) * seg + 6]
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
mask_, clean_ = self.clean_(img, threshold=threshold, inpaint=True, binary=binary)
mask[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = mask_[3:-3, 3:-3]
img1[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = clean_[3:-3, 3:-3]
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
img1[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = clean_
return mask, img1

else:
for i in tqdm(range(hh)):
for i in range(hh):
for j in range(ww):
img = img0[i * seg:(i + 1) * seg + 6, j * seg:(j + 1) * seg + 6]
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
mask_ = self.clean_(img, threshold=threshold, inpaint=False, binary=binary)
mask[i*seg:(i+1)*seg, j*seg:(j+1)*seg] = mask_[3:-3, 3:-3]
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
return mask

def inpaint(self, img0, mask):

"""
inpaint parts of an image given an inpaint mask
:param img0 (np.ndarray): 2D input image
:param mask (np.ndarray): 2D input mask
:return: inpainted image
inpaint img0 under mask
:param img0: (np.ndarray) input image
:param mask: (np.ndarray) inpainting mask
:return: inpainted clean image
"""
img0 = img0.astype(np.float32) / 100
mask = mask.astype(np.float32)
Expand Down
28 changes: 6 additions & 22 deletions deepCR/test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def test_deepCR_serial():
def test_deepCR_parallel():

mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((512, 512))

out = mdl.clean(in_im, parallel=True, seg=256, n_jobs=-1)
in_im = np.ones((299, 299))
out = mdl.clean(in_im, parallel=True, seg=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)

# Is the serial runtime slower than the parallel runtime on a big image?
Expand All @@ -32,34 +31,19 @@ def test_deepCR_parallel():
if os.cpu_count() > 2:
in_im = np.ones((3096, 2000))
t0 = time.time()
out = mdl.clean(in_im, inpaint=False, parallel=True, seg=256)
out = mdl.clean(in_im, inpaint=False, parallel=True, seg=True)
par_runtime = time.time() - t0
assert out.shape == in_im.shape

t0 = time.time()
out = mdl.clean(in_im, inpaint=False, parallel=False, seg=256)
out = mdl.clean(in_im, inpaint=False, parallel=False, seg=True)
ser_runtime = time.time() - t0
# assert False, f"par={par_runtime}, ser={ser_runtime}"
assert par_runtime < ser_runtime


def test_seg():
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((4000, 4000))
out = mdl.clean(in_im, seg=256)
out = mdl.clean(in_im, seg=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
out = mdl.clean(in_im, inpaint=False, seg=256)
assert out.shape == in_im.shape

"""
def test_consistency():
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((1000, 1000))
out1 = mdl.clean(in_im, inpaint=False, binary=False)
out2 = mdl.clean(in_im, seg=256, inpaint=False, binary=False)
print(out1[0][:10], type(out1[0][0]))
print(out2[0][:10], type(out2[0][0]))
assert (out1 == out2).all()
"""
#if __name__ == '__main__':
# test_consistency()
assert out.shape == in_im.shape

0 comments on commit 1d9b24f

Please sign in to comment.