diff --git a/caliban_toolbox/pipeline.py b/caliban_toolbox/pipeline.py index e910e18..c3f7547 100644 --- a/caliban_toolbox/pipeline.py +++ b/caliban_toolbox/pipeline.py @@ -55,8 +55,8 @@ def save_stitched_npzs(stitched_channels, stitched_labels, save_dir): """Takes corrected labels and channels and saves them into NPZ format Args: - stitched_channels: original channel data - stitched_labels: stitched labels + stitched_channels: xarray containing original channel data + stitched_labels: xarray containing stitched labels """ for i in range(stitched_channels.shape[0]): @@ -67,23 +67,6 @@ def save_stitched_npzs(stitched_channels, stitched_labels, save_dir): np.savez(save_path, X=X, y=y) -def process_stitched_data(base_dir): - """Takes stitched output and creates folder of NPZs for review - - Args: - base_dir: directory to read from - """ - - stitched_labels = xr.load_dataarray(os.path.join(base_dir, 'output', 'stitched_labels.xr')) - channel_data = xr.load_dataarray(os.path.join(base_dir, 'channel_data.xr')) - - stitched_folder = os.path.join(base_dir, 'stitched_npzs') - os.makedirs(stitched_folder) - - save_stitched_npzs(stitched_channels=channel_data, stitched_labels=stitched_labels, - save_dir=stitched_folder) - - def concatenate_npz_files(npz_list): """Takes a list of NPZ files and combines the X and y keys of each together diff --git a/caliban_toolbox/pipeline_test.py b/caliban_toolbox/pipeline_test.py index dd1aeda..c41f078 100644 --- a/caliban_toolbox/pipeline_test.py +++ b/caliban_toolbox/pipeline_test.py @@ -75,24 +75,6 @@ def test_save_stitched_npzs(): assert np.all(np.isin(npzs, npz_names)) -def test_process_stitched_data(): - channels = np.zeros((4, 100, 100, 2)) - labels = np.zeros((4, 100, 100, 1)) - - coords_labels = [['fov1', 'fov2', 'fov3', 'fov4'], range(100), range(100), [0]] - coords_channels = [['fov1', 'fov2', 'fov3', 'fov4'], range(100), range(100), range(2)] - dims = ['fovs', 'rows', 'cols', 'channels'] - labels = xr.DataArray(labels, coords=coords_labels, dims=dims) - channels = xr.DataArray(channels, coords=coords_channels, dims=dims) - - with tempfile.TemporaryDirectory() as temp_dir: - os.makedirs(os.path.join(temp_dir, 'output')) - labels.to_netcdf(os.path.join(temp_dir, 'output', 'stitched_labels.xr')) - channels.to_netcdf(os.path.join(temp_dir, 'channel_data.xr')) - - pipeline.process_stitched_data(temp_dir) - - def test_concatenate_npz_files(): npz_num = 5 npz_list = _make_npz_files(npz_num=npz_num) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 77cb33b..23fe5b1 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -193,6 +193,9 @@ def reconstruct_image_stack(crop_dir, verbose=True): Args: crop_dir: full path to directory with cropped images verbose: flag to control print statements + + Returns: + stitched_images: xarray containing the stitched image stack """ # sanitize inputs @@ -225,4 +228,4 @@ def reconstruct_image_stack(crop_dir, verbose=True): stitched_xr = xr.DataArray(data=image_stack, coords=coordinate_labels, dims=dimension_labels) - stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.xr')) + return stitched_xr diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 9149abe..669057a 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -178,9 +178,7 @@ def test_reconstruct_image_stack(): io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=log_data, save_dir=temp_dir) - reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - - stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + stitched_imgs = reshape_data.reconstruct_image_stack(crop_dir=temp_dir) # dims are the same assert np.all(stitched_imgs.shape == y_data.shape) @@ -219,8 +217,7 @@ def test_reconstruct_image_stack(): blank_labels="include", save_format="npz", verbose=False) - reshape_data.reconstruct_image_stack(temp_dir) - stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir) assert np.all(stitched_imgs.shape == y_data.shape) assert np.all(np.equal(stitched_imgs[0, :, 0, 0, 0, 0, 0], tags)) @@ -272,8 +269,7 @@ def test_reconstruct_image_stack(): blank_labels="include", save_format="npz", verbose=False) - reshape_data.reconstruct_image_stack(temp_dir) - stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir) assert np.all(stitched_imgs.shape == y_data.shape)