Skip to content

Commit

Permalink
Merge pull request #5 from profjsb/train
Browse files Browse the repository at this point in the history
Added model training and evaluation functionality
  • Loading branch information
profjsb committed Aug 2, 2019
2 parents b9c49de + 306a518 commit 817ee4d
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 55 deletions.
6 changes: 4 additions & 2 deletions deepCR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
"""
from deepCR.model import deepCR
from deepCR.training import train
from deepCR.evaluate import roc

__all__ = model.__all__
__all__ = [deepCR, train, roc]

__version__ = '0.1.4'
__version__ = '0.1.5'
47 changes: 47 additions & 0 deletions deepCR/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from torch.utils.data import Dataset

class dataset(Dataset):
def __init__(self, image, mask, ignore=None, sky=None, aug_sky=[0, 0], part='train', f_val=0.1, seed=1):
""" custom pytorch dataset class to load deepCR-mask training data
:param image: image with CR
:param mask: CR mask
:param ignore: loss mask, e.g., bad pixel, saturation, etc.
:param sky: (np.ndarray) [N,] sky background level
:param aug_sky: [negative number, positive number]. Add sky background by aug_sky[0] * sky to aug_sky[1] * sky.
:param part: either 'train' or 'val'. split by 0.8, 0.2
:param f_val: percentage of dataset reserved as validation set.
:param seed: fix numpy random seed to seed, for reproducibility.
"""

np.random.seed(seed)
len = image.shape[0]
assert f_val < 1
f_train = 1 - f_val
if sky is None:
sky = np.zeros_like(image)
if ignore is None:
ignore = np.zeros_like(image)
if part == 'train':
self.image = image[np.s_[:int(len * f_train)]]
self.mask = mask[np.s_[:int(len * f_train)]]
self.ignore = ignore[np.s_[:int(len * f_train)]]
self.sky = sky[np.s_[:int(len * f_train)]]
elif part == 'val':
self.image = image[np.s_[int(len * f_train):]]
self.mask = mask[np.s_[int(len * f_train):]]
self.ignore = ignore[np.s_[int(len * f_train):]]
self.sky = sky[np.s_[int(len * f_train):]]
else:
self.image = image
self.mask = mask
self.ignore = ignore
self.sky = sky
self.aug_sky = aug_sky

def __len__(self):
return self.image.shape[0]

def __getitem__(self, i):
a = (self.aug_sky[0] + np.random.rand() * (self.aug_sky[1] - self.aug_sky[0])) * self.sky[i]
return self.image[i] + a, self.mask[i], self.ignore[i]
47 changes: 47 additions & 0 deletions deepCR/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from tqdm import tqdm as tqdm

from deepCR.util import maskMetric
from deepCR.dataset import dataset


__all__ = 'roc'


def _roc(model, data, thresholds):
""" internal function called by roc
:param model:
:param data: deepCR.dataset object
:param thresholds:
:return: tpr, fpr
"""
nROC = thresholds.size
metric = np.zeros((nROC, 4))
for t in tqdm(range(len(data))):
dat = data[t]
pdt_mask = model.clean(dat[0], inpaint=False, binary=False)
msk = dat[1]
ignore = dat[2]
for i in range(nROC):
binary_mask = np.array(pdt_mask > thresholds[i]) * (1 - ignore)
metric[i] += maskMetric(binary_mask, msk * (1 - ignore))
TP, TN, FP, FN = metric[:, 0], metric[:, 1], metric[:, 2], metric[:, 3]
tpr = TP / (TP + FN)
fpr = FP / (FP + TN)
return tpr * 100, fpr * 100


def roc(model, image, mask, ignore=None, thresholds=np.linspace(0.001, 0.999, 500)):
""" evaluate model on test set with the ROC curve
:param model: deepCR object
:param image: np.ndarray((N, W, H)) image array
:param mask: np.ndarray((N, W, H)) CR mask array
:param ignore: np.ndarray((N, W, H)) bad pixel array incl. saturation, etc.
:param thresholds: np.ndarray(N) FPR grid on which to evaluate ROC curves
:return: np.ndarray(N), np.ndarray(N): TPR and FPR
"""
data = dataset(image=image, mask=mask, ignore=ignore)
tpr, fpr = _roc(model, data, thresholds=thresholds)
return tpr, fpr
83 changes: 41 additions & 42 deletions deepCR/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,39 @@
from os import path, mkdir
import math
import numpy as np
import joblib

import torch
import torch.nn as nn
from torch import from_numpy
from joblib import Parallel, delayed
from joblib import dump, load
from joblib import wrap_non_picklable_objects

from tqdm import tqdm


from deepCR.unet import WrappedModel
from deepCR.unet import WrappedModel, UNet2Sigmoid
from deepCR.util import medmask
from learned_models import mask_dict, inpaint_dict, default_model_path

__all__ = ('deepCR', 'mask_dict', 'inpaint_dict', 'default_model_path')
__all__ = 'deepCR'


class deepCR():

def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU',
model_dir=default_model_path):
def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint=None, device='CPU', hidden=32):

"""
Instantiation of deepCR with specified model configurations
Parameters
----------
mask : str
Name of deepCR-mask model to use.
inpaint : str
Name of the inpainting model to use. It can also be `medmask` which will then
use a simple 5x5 median mask sampling for inpainting
Either name of existing deepCR-mask model, or file path of your own model (incl. '.pth')
inpaint : (optional) str
Name of existing inpainting model to use. If left as None then by default use a simple 5x5 median mask
sampling for inpainting
device : str
One of 'CPU' or 'GPU'
model_dir : str
The location of the model directory with the mask/ and inpaint/
subdirectories. This defaults to where the pre-shipped
models live (in `learned_models/`)
hidden : int
Number of hidden channel for first deepCR-mask layer. Specify only if using custom deepCR-mask model.
Returns
-------
None
Expand All @@ -55,47 +48,52 @@ def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', devi
self.dtype = torch.FloatTensor
self.dint = torch.ByteTensor
wrapper = WrappedModel

if mask is not None:
if mask in mask_dict.keys():
self.scale = mask_dict[mask][2]
mask_path = default_model_path + '/mask/' + mask + '.pth'
self.maskNet = wrapper(mask_dict[mask][0](*mask_dict[mask][1]))
self.maskNet.type(self.dtype)
if device != 'GPU':
self.maskNet.load_state_dict(torch.load(model_dir + '/mask/' + mask + '.pth',
map_location='cpu'))
else:
self.maskNet.load_state_dict(torch.load(model_dir + '/mask/' + mask + '.pth'))

self.maskNet.eval()
for p in self.maskNet.parameters():
p.required_grad = False

if inpaint == 'medmask':
self.inpaintNet = None
else:
self.scale = 1
mask_path = mask
self.maskNet = wrapper(UNet2Sigmoid(1, 1, hidden))
self.maskNet.type(self.dtype)
if device != 'GPU':
self.maskNet.load_state_dict(torch.load(mask_path, map_location='cpu'))
else:
self.maskNet.load_state_dict(torch.load(mask_path))
self.maskNet.eval()
for p in self.maskNet.parameters():
p.required_grad = False

if inpaint is not None:
inpaint_path = default_model_path + '/inpaint/' + inpaint + '.pth'
self.inpaintNet = wrapper(inpaint_dict[inpaint][0](*inpaint_dict[inpaint][1])).type(self.dtype)
if device != 'GPU':
self.inpaintNet.load_state_dict(torch.load(model_dir+'/inpaint/' + inpaint+'.pth',
map_location='cpu'))
self.inpaintNet.load_state_dict(torch.load(inpaint_path, map_location='cpu'))
else:
self.inpaintNet.load_state_dict(torch.load(model_dir+'/inpaint/' + inpaint+'.pth'))
self.inpaintNet.load_state_dict(torch.load(inpaint_path))
self.inpaintNet.eval()
for p in self.inpaintNet.parameters():
p.required_grad = False

self.scale = mask_dict[mask][2]
else:
self.inpaintNet = None

def clean(self, img0, threshold=0.5, inpaint=True, binary=True, segment=False,
patch=256, parallel=False, n_jobs=-1):
"""
Identify cosmic rays in an input image, and (optionally) inpaint with the predicted cosmic ray mask
:param img0: (np.ndarray) 2D input image conforming to model requirements. For HST ACS/WFC, must be from _flc.fits and in units of electrons in native resolution.
:param img0: (np.ndarray) 2D input image conforming to model requirements. For HST ACS/WFC, must be from
_flc.fits and in units of electrons in native resolution.
:param threshold: (float; [0, 1]) applied to probabilistic mask to generate binary mask
:param inpaint: (bool) return clean, inpainted image only if True
:param binary: return binary CR mask if True. probabilistic mask if False
:param segment: (bool) if True, segment input image into chunks of patch * patch before performing CR rejection. Used for memory control.
:param patch: (int) Use 256 unless otherwise required. if segment==True, segment image into chunks of patch * patch.
:param segment: (bool) if True, segment input image into chunks of patch * patch before performing CR rejection.
Used for memory control.
:param patch: (int) Use 256 unless otherwise required. if segment==True, segment image into chunks of
patch * patch.
:param parallel: (bool) run in parallel if True and segment==True
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for larger n_jobs.
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for
larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""

Expand Down Expand Up @@ -183,7 +181,8 @@ def clean_large_parallel(self, img0, threshold=0.5, inpaint=True, binary=True,
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:param patch: (int) Use 256 unless otherwise required. patch size to run deepCR on.
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for larger n_jobs.
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for
larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""
folder = './joblib_memmap'
Expand Down
2 changes: 0 additions & 2 deletions deepCR/parts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
Expand Down
18 changes: 18 additions & 0 deletions deepCR/test/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import pytest

from .. import dataset


def test_dataset():
inputs = np.random.rand(10,32,32); sky = np.random.rand(10)
data = dataset.dataset(image=inputs, mask=inputs, ignore=inputs, sky=sky, part='train', aug_sky=[1, 1], f_val=0.1)
data0 = data[0]
assert len(data) == 9
assert (data0[0] == inputs[0] + sky[0]).all()
assert (data0[1] == inputs[0]).all()
assert (data0[2] == inputs[0]).all()


if __name__ == '__main__':
test_dataset()
16 changes: 16 additions & 0 deletions deepCR/test/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import pytest

import deepCR.evaluate as evaluate
from deepCR.model import deepCR


def test_eval():
mdl = deepCR()
var = np.zeros((10,24,24))
tpr, fpr = evaluate.roc(mdl, image=var, mask=var, thresholds=np.linspace(0,1,10))
assert tpr.shape == (10,)


if __name__ == '__main__':
test_eval()
18 changes: 10 additions & 8 deletions deepCR/test/test_model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import os
import time

import numpy as np
import pytest

from .. import model
from deepCR.model import deepCR


def test_deepCR_serial():

mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
mdl = deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((299, 299))
out = mdl.clean(in_im)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)

out = mdl.clean(in_im, inpaint=False)
assert out.shape == in_im.shape


def test_deepCR_parallel():

mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
mdl = deepCR(mask='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((299, 299))
out = mdl.clean(in_im, parallel=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
Expand All @@ -38,15 +38,17 @@ def test_deepCR_parallel():
ser_runtime = time.time() - t0
assert par_runtime < ser_runtime


def test_seg():
mdl = model.deepCR(mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((500, 1000))
mdl = deepCR(mask='ACS-WFC-F606W-2-32', inpaint='ACS-WFC-F606W-2-32', device='CPU')
in_im = np.ones((300, 500))
out = mdl.clean(in_im, segment=True)
assert (out[0].shape, out[1].shape) == (in_im.shape, in_im.shape)
out = mdl.clean(in_im, inpaint=False, segment=True)
assert out.shape == in_im.shape


if __name__ == '__main__':
test_seg()
test_deepCR_parallel()
test_deepCR_serial()
test_deepCR_parallel()
18 changes: 18 additions & 0 deletions deepCR/test/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

import numpy as np
import pytest

from deepCR.training import train


def test_train():
inputs = np.zeros((6, 64, 64))
sky = np.ones(6)
trainer = train(image=inputs, mask=inputs, sky=sky, aug_sky=[-0.9, 10], epoch=2, verbose=False)
trainer.train()
trainer.save()


if __name__ == '__main__':
test_train()

0 comments on commit 817ee4d

Please sign in to comment.