-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from profjsb/train
Added model training and evaluation functionality
- Loading branch information
Showing
12 changed files
with
476 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
class dataset(Dataset): | ||
def __init__(self, image, mask, ignore=None, sky=None, aug_sky=[0, 0], part='train', f_val=0.1, seed=1): | ||
""" custom pytorch dataset class to load deepCR-mask training data | ||
:param image: image with CR | ||
:param mask: CR mask | ||
:param ignore: loss mask, e.g., bad pixel, saturation, etc. | ||
:param sky: (np.ndarray) [N,] sky background level | ||
:param aug_sky: [negative number, positive number]. Add sky background by aug_sky[0] * sky to aug_sky[1] * sky. | ||
:param part: either 'train' or 'val'. split by 0.8, 0.2 | ||
:param f_val: percentage of dataset reserved as validation set. | ||
:param seed: fix numpy random seed to seed, for reproducibility. | ||
""" | ||
|
||
np.random.seed(seed) | ||
len = image.shape[0] | ||
assert f_val < 1 | ||
f_train = 1 - f_val | ||
if sky is None: | ||
sky = np.zeros_like(image) | ||
if ignore is None: | ||
ignore = np.zeros_like(image) | ||
if part == 'train': | ||
self.image = image[np.s_[:int(len * f_train)]] | ||
self.mask = mask[np.s_[:int(len * f_train)]] | ||
self.ignore = ignore[np.s_[:int(len * f_train)]] | ||
self.sky = sky[np.s_[:int(len * f_train)]] | ||
elif part == 'val': | ||
self.image = image[np.s_[int(len * f_train):]] | ||
self.mask = mask[np.s_[int(len * f_train):]] | ||
self.ignore = ignore[np.s_[int(len * f_train):]] | ||
self.sky = sky[np.s_[int(len * f_train):]] | ||
else: | ||
self.image = image | ||
self.mask = mask | ||
self.ignore = ignore | ||
self.sky = sky | ||
self.aug_sky = aug_sky | ||
|
||
def __len__(self): | ||
return self.image.shape[0] | ||
|
||
def __getitem__(self, i): | ||
a = (self.aug_sky[0] + np.random.rand() * (self.aug_sky[1] - self.aug_sky[0])) * self.sky[i] | ||
return self.image[i] + a, self.mask[i], self.ignore[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
from tqdm import tqdm as tqdm | ||
|
||
from deepCR.util import maskMetric | ||
from deepCR.dataset import dataset | ||
|
||
|
||
__all__ = 'roc' | ||
|
||
|
||
def _roc(model, data, thresholds): | ||
""" internal function called by roc | ||
:param model: | ||
:param data: deepCR.dataset object | ||
:param thresholds: | ||
:return: tpr, fpr | ||
""" | ||
nROC = thresholds.size | ||
metric = np.zeros((nROC, 4)) | ||
for t in tqdm(range(len(data))): | ||
dat = data[t] | ||
pdt_mask = model.clean(dat[0], inpaint=False, binary=False) | ||
msk = dat[1] | ||
ignore = dat[2] | ||
for i in range(nROC): | ||
binary_mask = np.array(pdt_mask > thresholds[i]) * (1 - ignore) | ||
metric[i] += maskMetric(binary_mask, msk * (1 - ignore)) | ||
TP, TN, FP, FN = metric[:, 0], metric[:, 1], metric[:, 2], metric[:, 3] | ||
tpr = TP / (TP + FN) | ||
fpr = FP / (FP + TN) | ||
return tpr * 100, fpr * 100 | ||
|
||
|
||
def roc(model, image, mask, ignore=None, thresholds=np.linspace(0.001, 0.999, 500)): | ||
""" evaluate model on test set with the ROC curve | ||
:param model: deepCR object | ||
:param image: np.ndarray((N, W, H)) image array | ||
:param mask: np.ndarray((N, W, H)) CR mask array | ||
:param ignore: np.ndarray((N, W, H)) bad pixel array incl. saturation, etc. | ||
:param thresholds: np.ndarray(N) FPR grid on which to evaluate ROC curves | ||
:return: np.ndarray(N), np.ndarray(N): TPR and FPR | ||
""" | ||
data = dataset(image=image, mask=mask, ignore=ignore) | ||
tpr, fpr = _roc(model, data, thresholds=thresholds) | ||
return tpr, fpr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from .. import dataset | ||
|
||
|
||
def test_dataset(): | ||
inputs = np.random.rand(10,32,32); sky = np.random.rand(10) | ||
data = dataset.dataset(image=inputs, mask=inputs, ignore=inputs, sky=sky, part='train', aug_sky=[1, 1], f_val=0.1) | ||
data0 = data[0] | ||
assert len(data) == 9 | ||
assert (data0[0] == inputs[0] + sky[0]).all() | ||
assert (data0[1] == inputs[0]).all() | ||
assert (data0[2] == inputs[0]).all() | ||
|
||
|
||
if __name__ == '__main__': | ||
test_dataset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import deepCR.evaluate as evaluate | ||
from deepCR.model import deepCR | ||
|
||
|
||
def test_eval(): | ||
mdl = deepCR() | ||
var = np.zeros((10,24,24)) | ||
tpr, fpr = evaluate.roc(mdl, image=var, mask=var, thresholds=np.linspace(0,1,10)) | ||
assert tpr.shape == (10,) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_eval() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from deepCR.training import train | ||
|
||
|
||
def test_train(): | ||
inputs = np.zeros((6, 64, 64)) | ||
sky = np.ones(6) | ||
trainer = train(image=inputs, mask=inputs, sky=sky, aug_sky=[-0.9, 10], epoch=2, verbose=False) | ||
trainer.train() | ||
trainer.save() | ||
|
||
|
||
if __name__ == '__main__': | ||
test_train() |
Oops, something went wrong.