diff --git a/caliban_toolbox/pipeline.py b/caliban_toolbox/pipeline.py index 380ac01..e910e18 100644 --- a/caliban_toolbox/pipeline.py +++ b/caliban_toolbox/pipeline.py @@ -32,11 +32,13 @@ def find_sparse_images(labeled_data, cutoff=100): """Gets coordinates of images that have very few cells + Args: - labeled_data: predictions used for counting number of cells + labeled_data: segmentation labels cutoff: minimum number of cells per image + Returns: - numpy.array: index used to remove sparse images + numpy.array: index of images above the threshold """ unique_counts = [] @@ -50,7 +52,8 @@ def find_sparse_images(labeled_data, cutoff=100): def save_stitched_npzs(stitched_channels, stitched_labels, save_dir): - """Takes corrected labels and channels and saves to NPZ for caliban round 2 checking + """Takes corrected labels and channels and saves them into NPZ format + Args: stitched_channels: original channel data stitched_labels: stitched labels @@ -66,6 +69,7 @@ def save_stitched_npzs(stitched_channels, stitched_labels, save_dir): def process_stitched_data(base_dir): """Takes stitched output and creates folder of NPZs for review + Args: base_dir: directory to read from """ @@ -73,8 +77,51 @@ def process_stitched_data(base_dir): stitched_labels = xr.load_dataarray(os.path.join(base_dir, 'output', 'stitched_labels.xr')) channel_data = xr.load_dataarray(os.path.join(base_dir, 'channel_data.xr')) - correction_folder = os.path.join(base_dir, 'ready_to_correct') - os.makedirs(correction_folder) + stitched_folder = os.path.join(base_dir, 'stitched_npzs') + os.makedirs(stitched_folder) save_stitched_npzs(stitched_channels=channel_data, stitched_labels=stitched_labels, - save_dir=correction_folder) + save_dir=stitched_folder) + + +def concatenate_npz_files(npz_list): + """Takes a list of NPZ files and combines the X and y keys of each together + + Args: + npz_list: list of NPZ files + + Returns: + tuple: concatenated X data and y data + """ + + X_data = [] + y_data = [] + for npz in npz_list: + X_data.append(npz['X']) + y_data.append(npz['y']) + X_data = np.concatenate(X_data, axis=0) + y_data = np.concatenate(y_data, axis=0) + return X_data, y_data + + +def create_combined_npz(npz_dir, save_name): + """Takes folder of corrected NPZs and combines together into single NPZ file + + Args: + npz_dir: directory containing NPZ files + save_name: name for combined NPZ file + + Raises: ValueError if invalid directory name + """ + + if not os.path.isdir(npz_dir): + raise ValueError("Invalid directory name") + + npz_filenames = os.listdir(npz_dir) + npz_filenames = [file for file in npz_filenames if '.npz' in file] + + npz_list = [np.load(os.path.join(npz_dir, npz)) for npz in npz_filenames] + + X_concat, y_concat = concatenate_npz_files(npz_list=npz_list) + + np.savez(os.path.join(npz_dir, save_name), X=X_concat, y=y_concat) diff --git a/caliban_toolbox/pipeline_test.py b/caliban_toolbox/pipeline_test.py index bec5089..dd1aeda 100644 --- a/caliban_toolbox/pipeline_test.py +++ b/caliban_toolbox/pipeline_test.py @@ -32,6 +32,16 @@ from caliban_toolbox import pipeline +def _make_npz_files(npz_num): + npz_list = [] + for i in range(npz_num): + X_data = np.zeros((1, 256, 256, 2)) + y_data = np.zeros((1, 256, 256, 1)) + npz_list.append({'X': X_data, 'y': y_data}) + + return npz_list + + def test_find_sparse_images(): images = np.zeros((10, 30, 30, 1)) sparse_indices = np.random.choice(range(10), 5, replace=False) @@ -81,3 +91,28 @@ def test_process_stitched_data(): channels.to_netcdf(os.path.join(temp_dir, 'channel_data.xr')) pipeline.process_stitched_data(temp_dir) + + +def test_concatenate_npz_files(): + npz_num = 5 + npz_list = _make_npz_files(npz_num=npz_num) + X_concat, y_concat = pipeline.concatenate_npz_files(npz_list) + + assert X_concat.shape == (npz_num, 256, 256, 2) + assert y_concat.shape == (npz_num, 256, 256, 1) + + +def test_create_combined_npz(): + npz_list = _make_npz_files(5) + + with tempfile.TemporaryDirectory() as temp_dir: + for idx, npz in enumerate(npz_list): + np.savez(os.path.join(temp_dir, 'test_npz_{}.npz'.format(idx)), + X=npz['X'], y=npz['y']) + + pipeline.create_combined_npz(npz_dir=temp_dir, save_name='combined.npz') + + combined = np.load(os.path.join(temp_dir, 'combined.npz')) + + assert combined['X'].shape == (5, 256, 256, 2) + assert combined['y'].shape == (5, 256, 256, 1)