Skip to content

Commit

Permalink
reverted model convention
Browse files Browse the repository at this point in the history
  • Loading branch information
kmzzhang committed Jul 25, 2019
1 parent e9d0195 commit b9c49de
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion deepCR/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class deepCR():

def __init__(self, mask='ACS-WFC-2-32', inpaint='ACS-WFC-2-32', device='CPU',
def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU',
model_dir=default_model_path):

"""
Expand Down
6 changes: 3 additions & 3 deletions deepCR/test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def test_deepCR_serial():

mdl = model.deepCR(mask='ACS-WFC-2-32', device='CPU')
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((299, 299))
out = mdl.clean(in_im)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
Expand All @@ -18,7 +18,7 @@ def test_deepCR_serial():

def test_deepCR_parallel():

mdl = model.deepCR(mask='ACS-WFC-2-32', device='CPU')
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((299, 299))
out = mdl.clean(in_im, parallel=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
Expand All @@ -39,7 +39,7 @@ def test_deepCR_parallel():
assert par_runtime < ser_runtime

def test_seg():
mdl = model.deepCR(mask='ACS-WFC-2-32', inpaint='ACS-WFC-2-32', device='CPU')
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((500, 1000))
out = mdl.clean(in_im, segment=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
Expand Down
11 changes: 3 additions & 8 deletions learned_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@

__all__ = ('mask_dict', 'inpaint_dict')

mask_dict = {'ACS-WFC-2-32': [UNet2Sigmoid, (1, 1, 32), 100],
'ACS-WFC-2-4': [UNet2Sigmoid, (1, 1, 4), 100],
'ACS-WFC-F606W-2-32': [UNet2Sigmoid, (1, 1, 32), 100],
mask_dict = {'ACS-WFC-F606W-2-32': [UNet2Sigmoid, (1, 1, 32), 100],
'ACS-WFC-F606W-2-4': [UNet2Sigmoid, (1, 1, 4), 100],
'example_model': [UNet2Sigmoid, (1, 1, 32), 100]}

inpaint_dict = {'ACS-WFC-3-32': [UNet3, (2, 1, 32)],
'ACS-WFC-2-32': [UNet2, (2, 1, 32)],
'ACS-WFC-F606W-3-32': [UNet3, (2, 1, 32)],
'ACS-WFC-F606W-2-32': [UNet2, (2, 1, 32)]
}
inpaint_dict = {'ACS-WFC-F606W-3-32': [UNet3, (2, 1, 32)],
'ACS-WFC-F606W-2-32': [UNet2, (2, 1, 32)]}

default_model_path = path.join(path.dirname(__file__))
Binary file removed learned_models/inpaint/ACS-WFC-2-32.pth
Binary file not shown.
Binary file removed learned_models/inpaint/ACS-WFC-3-32.pth
Binary file not shown.
Binary file removed learned_models/mask/ACS-WFC-2-32.pth
Binary file not shown.
Binary file removed learned_models/mask/ACS-WFC-2-4.pth
Binary file not shown.

0 comments on commit b9c49de

Please sign in to comment.