Skip to content

Commit

Permalink
combining NPZ files for labeled ontology
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Jun 7, 2020
1 parent 80f6fe8 commit 053bc5a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
59 changes: 53 additions & 6 deletions caliban_toolbox/pipeline.py
Expand Up @@ -32,11 +32,13 @@

def find_sparse_images(labeled_data, cutoff=100):
"""Gets coordinates of images that have very few cells
Args:
labeled_data: predictions used for counting number of cells
labeled_data: segmentation labels
cutoff: minimum number of cells per image
Returns:
numpy.array: index used to remove sparse images
numpy.array: index of images above the threshold
"""

unique_counts = []
Expand All @@ -50,7 +52,8 @@ def find_sparse_images(labeled_data, cutoff=100):


def save_stitched_npzs(stitched_channels, stitched_labels, save_dir):
"""Takes corrected labels and channels and saves to NPZ for caliban round 2 checking
"""Takes corrected labels and channels and saves them into NPZ format
Args:
stitched_channels: original channel data
stitched_labels: stitched labels
Expand All @@ -66,15 +69,59 @@ def save_stitched_npzs(stitched_channels, stitched_labels, save_dir):

def process_stitched_data(base_dir):
"""Takes stitched output and creates folder of NPZs for review
Args:
base_dir: directory to read from
"""

stitched_labels = xr.load_dataarray(os.path.join(base_dir, 'output', 'stitched_labels.xr'))
channel_data = xr.load_dataarray(os.path.join(base_dir, 'channel_data.xr'))

correction_folder = os.path.join(base_dir, 'ready_to_correct')
os.makedirs(correction_folder)
stitched_folder = os.path.join(base_dir, 'stitched_npzs')
os.makedirs(stitched_folder)

save_stitched_npzs(stitched_channels=channel_data, stitched_labels=stitched_labels,
save_dir=correction_folder)
save_dir=stitched_folder)


def concatenate_npz_files(npz_list):
"""Takes a list of NPZ files and combines the X and y keys of each together
Args:
npz_list: list of NPZ files
Returns:
tuple: concatenated X data and y data
"""

X_data = []
y_data = []
for npz in npz_list:
X_data.append(npz['X'])
y_data.append(npz['y'])
X_data = np.concatenate(X_data, axis=0)
y_data = np.concatenate(y_data, axis=0)
return X_data, y_data


def create_combined_npz(npz_dir, save_name):
"""Takes folder of corrected NPZs and combines together into single NPZ file
Args:
npz_dir: directory containing NPZ files
save_name: name for combined NPZ file
Raises: ValueError if invalid directory name
"""

if not os.path.isdir(npz_dir):
raise ValueError("Invalid directory name")

npz_filenames = os.listdir(npz_dir)
npz_filenames = [file for file in npz_filenames if '.npz' in file]

npz_list = [np.load(os.path.join(npz_dir, npz)) for npz in npz_filenames]

X_concat, y_concat = concatenate_npz_files(npz_list=npz_list)

np.savez(os.path.join(npz_dir, save_name), X=X_concat, y=y_concat)
35 changes: 35 additions & 0 deletions caliban_toolbox/pipeline_test.py
Expand Up @@ -32,6 +32,16 @@
from caliban_toolbox import pipeline


def _make_npz_files(npz_num):
npz_list = []
for i in range(npz_num):
X_data = np.zeros((1, 256, 256, 2))
y_data = np.zeros((1, 256, 256, 1))
npz_list.append({'X': X_data, 'y': y_data})

return npz_list


def test_find_sparse_images():
images = np.zeros((10, 30, 30, 1))
sparse_indices = np.random.choice(range(10), 5, replace=False)
Expand Down Expand Up @@ -81,3 +91,28 @@ def test_process_stitched_data():
channels.to_netcdf(os.path.join(temp_dir, 'channel_data.xr'))

pipeline.process_stitched_data(temp_dir)


def test_concatenate_npz_files():
npz_num = 5
npz_list = _make_npz_files(npz_num=npz_num)
X_concat, y_concat = pipeline.concatenate_npz_files(npz_list)

assert X_concat.shape == (npz_num, 256, 256, 2)
assert y_concat.shape == (npz_num, 256, 256, 1)


def test_create_combined_npz():
npz_list = _make_npz_files(5)

with tempfile.TemporaryDirectory() as temp_dir:
for idx, npz in enumerate(npz_list):
np.savez(os.path.join(temp_dir, 'test_npz_{}.npz'.format(idx)),
X=npz['X'], y=npz['y'])

pipeline.create_combined_npz(npz_dir=temp_dir, save_name='combined.npz')

combined = np.load(os.path.join(temp_dir, 'combined.npz'))

assert combined['X'].shape == (5, 256, 256, 2)
assert combined['y'].shape == (5, 256, 256, 1)

0 comments on commit 053bc5a

Please sign in to comment.