In [None]:
import matplotlib.pyplot as plt
import numpy as np
import skimage as sk
import h5py
import pickle

from ipywidgets import (
    interact,
    interactive,
    fixed,
    interact_manual,
    FloatSlider,
    IntSlider,
    Dropdown,
    IntText,
    SelectMultiple,
    Select,
    IntRangeSlider,
    FloatRangeSlider,
)
from skimage import filters, transform
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import PolyCollection
from paulssonlab.deaton.trenchripper.trenchripper.kymograph import kymograph_cluster
from paulssonlab.deaton.trenchripper.trenchripper.segment import fluo_segmentation
from paulssonlab.deaton.trenchripper.trenchripper.utils import (
    kymo_handle,
    pandas_hdf5_handler,
)

In [None]:
import pandas as pd

In [None]:
meta_handle = pandas_hdf5_handler(
    "/n/scratch2/de64/2019-05-31_validation_data/metadata.hdf5"
)
test = meta_handle.read_df("global", read_metadata=True)

In [None]:
test.loc[pd.IndexSlice[[2, 36], 0:5], :]
# df.loc[idx[:,[3,4]],:]

In [None]:
class kymograph_interactive(kymograph_cluster):
    def __init__(self, headpath):
        """The kymograph class is used to generate and visualize kymographs.
        The central function of this class is the method 'generate_kymograph',
        which takes an hdf5 file of images from a single fov and outputs an
        hdf5 file containing kymographs from all detected trenches.

        NOTE: I need to revisit the row detection, must ensure there can be no overlap...

        Args:
        """
        # break all_channels,fov_list,t_subsample_step=t_subsample_step
        super(kymograph_interactive, self).__init__(headpath=headpath)

        self.metadf = self.meta_handle.read_df("global", read_metadata=True)
        self.metadata = self.metadf.metadata
        self.fov_list = self.metadf.index.get_level_values("fov").unique().values
        self.channels = self.metadata["channels"]
        self.timepoints_len = self.metadata["num_frames"]

        self.final_params = {}

    def view_image(self, fov_idx, t, channel, invert):
        img_entry = self.metadf.loc[fov_idx, t]
        file_idx = int(img_entry["File Index"])
        img_idx = int(img_entry["Image Index"])

        with h5py.File(
            self.headpath + "/hdf5/hdf5_" + str(file_idx) + ".hdf5", "r"
        ) as infile:
            img_arr = infile[channel][img_idx, :, :]
        if invert:
            img_arr = sk.util.invert(img_arr)
        plt.imshow(img_arr, cmap="Greys_r")

    def view_image_interactive(self):

        interact(
            self.view_image,
            fov_idx=Select(description="FOV number:", options=self.fov_list),
            t=IntSlider(
                value=0,
                min=0,
                max=self.timepoints_len - 1,
                step=1,
                continuous_update=False,
            ),
            channel=Dropdown(
                options=self.channels,
                value=self.channels[0],
                description="Channel:",
                disabled=False,
            ),
            invert=Dropdown(options=[True, False], value=False),
        )

    #     def import_hdf5(self,i):
    #         """Performs initial import of the hdf5 file to be processed. Converts
    #         the input hdf5 file's "channel" datasets into the first dimension of
    #         the array, ordered as specified by 'self.all_channels'. Outputs a numpy
    #         array.

    #         Args:
    #             i (int): Specifies the current fov index.

    #         Returns:
    #             array: A numpy array containing the hdf5 file image data.
    #         """
    #         fov = self.fov_list[i]
    #         fovdf = self.metadf.loc[fov]
    #         last_idx = fovdf.index.get_level_values(0).unique().tolist()[-1]
    #         fovdf = fovdf.loc[slice(0,last_idx,self.t_subsample_step),:]
    #         file_indices = fovdf["File Index"].unique().tolist()

    #         channel_list = []
    #         for channel in self.all_channels:
    #             file_list = []
    #             for j,file_idx in enumerate(file_indices):
    #                 filedf = fovdf[fovdf["File Index"]==file_idx]
    #                 img_indices = filedf["Image Index"].unique().tolist()
    #                 with h5py.File(self.headpath + "/hdf5/hdf5_" + str(file_idx) + ".hdf5", "r") as infile:
    #                     file_list += [infile[channel][idx][:,:,np.newaxis] for idx in img_indices]
    #             channel_list.append(np.concatenate(file_list,axis=2))
    #         channel_array = np.array(channel_list)
    #         if self.invert:
    #             channel_array = sk.util.invert(channel_array)
    #         return channel_array

    #         writedir(self.kymographpath,overwrite=True)
    #         ### smoothed y percentiles ###

    #         self.fovdf = self.meta_handle.read_df("global",read_metadata=True)
    #         self.metadata = self.fovdf.metadata
    #         self.filedf = self.fovdf.reset_index(inplace=False)
    #         self.filedf = self.filedf.set_index(["File Index","Image Index"], drop=True, append=False, inplace=False)
    #         self.filedf = self.filedf.sort_index()
    #         self.file_list = self.filedf.index.get_level_values("File Index").unique().values
    #         self.fov_list = self.fovdf.index.get_level_values("fov").unique().values

    def import_hdf5_files(
        self, all_channels, seg_channel, invert, fov_list, t_subsample_step
    ):
        seg_channel_idx = all_channels.index(seg_channel)
        all_channels.insert(0, all_channels.pop(seg_channel_idx))
        self.all_channels = all_channels
        self.seg_channel = all_channels[0]
        self.fov_list = fov_list
        self.t_subsample_step = t_subsample_step
        self.invert = invert

        self.fovdf = self.meta_handle.read_df("global", read_metadata=True)
        self.fovdf = self.fovdf.loc[
            pd.IndexSlice[self.fov_list, :: self.t_subsample_step], :
        ]

        self.filedf = self.fovdf.reset_index(inplace=False)
        self.filedf = self.filedf.set_index(
            ["File Index", "Image Index"], drop=True, append=False, inplace=False
        )
        self.filedf = self.filedf.sort_index()
        self.file_list = (
            self.filedf.index.get_level_values("File Index").unique().values
        )

    def import_hdf5_interactive(self):
        import_hdf5 = interactive(
            self.import_hdf5_files,
            {"manual": True},
            all_channels=fixed(self.channels),
            seg_channel=Dropdown(options=self.channels, value=self.channels[0]),
            invert=Dropdown(options=[True, False], value=False),
            fov_list=SelectMultiple(options=self.fov_list),
            t_subsample_step=IntSlider(value=10, min=0, max=200, step=1),
        )
        display(import_hdf5)

    #         for k,file_idx in enumerate(file_list):
    #             future = dask_controller.daskclient.submit(self.get_smoothed_y_percentiles,file_idx,\
    #                                         self.y_percentile,self.smoothing_kernel_y,retries=1)
    #             dask_controller.futures["Smoothed Y Percentiles: " + str(file_idx)] = future

    #             def get_smoothed_y_percentiles(self,file_idx,y_percentile,smoothing_kernel_y):
    #         """For each imported array, computes the percentile along the x-axis of
    #         the segmentation channel, generating a (y,t) array. Then performs
    #         median filtering of this array for smoothing.

    #         Args:
    #             imported_hdf5_handle (h5py.File): Hdf5 file handle corresponding to the input hdf5 dataset
    #             "data" of shape (channel,y,x,t).
    #             y_percentile (int): Percentile to apply along the x-axis.
    #             smoothing_kernel_y (tuple): Kernel to use for median filtering.

    #         Returns:
    #             h5py.File: Hdf5 file handle corresponding to the output hdf5 dataset "data", a smoothed
    #             percentile array of shape (y,t).
    #         """
    #         with h5py_cache.File(self.hdf5path+"/hdf5_"+str(file_idx)+".hdf5","r",chunk_cache_mem_size=self.metadata["chunk_cache_mem_size"]) as imported_hdf5_handle:
    #             img_arr = imported_hdf5_handle[self.seg_channel][:] #t x y
    #             if self.invert:
    #                 img_arr = sk.util.invert(img_arr)
    #             perc_arr = np.percentile(img_arr,y_percentile,axis=2,interpolation='lower')
    #             y_percentiles_smoothed = self.median_filter_2d(perc_arr,smoothing_kernel_y)

    #             min_qth_percentile = y_percentiles_smoothed.min(axis=1)[:, np.newaxis]
    #             max_qth_percentile = y_percentiles_smoothed.max(axis=1)[:, np.newaxis]
    #             y_percentiles_smoothed = (y_percentiles_smoothed - min_qth_percentile)/(max_qth_percentile - min_qth_percentile)

    #         return y_percentiles_smoothed

    def preview_y_precentiles(
        self, y_percentile, smoothing_kernel_y_dim_0, y_percentile_threshold
    ):

        self.final_params["Y Percentile"] = y_percentile
        self.final_params["Y Smoothing Kernel"] = smoothing_kernel_y_dim_0
        self.final_params["Y Percentile Threshold"] = y_percentile_threshold

        y_percentiles_smoothed_list = []
        for i, file_idx in enumerate(self.file_list):
            y_percentiles_smoothed_list.append(
                self.get_smoothed_y_percentiles(
                    file_idx, y_percentile, smoothing_kernel_y_dim_0
                )
            )

        y_percentiles_smoothed_list = self.map_to_fovs(
            self.get_smoothed_y_percentiles,
            imported_array_list,
            y_percentile,
            (smoothing_kernel_y_dim_0, 1),
        )

        self.plot_y_precentiles(
            y_percentiles_smoothed_list, self.fov_list, y_percentile_threshold
        )

        self.y_percentiles_smoothed_list = y_percentiles_smoothed_list

        return y_percentiles_smoothed_list

    def preview_y_precentiles_interactive(self):
        row_detection = interactive(
            self.preview_y_precentiles,
            {"manual": True},
            imported_array_list=fixed(self.imported_array_list),
            y_percentile=IntSlider(value=99, min=0, max=100, step=1),
            smoothing_kernel_y_dim_0=IntSlider(value=29, min=1, max=200, step=2),
            y_percentile_threshold=FloatSlider(value=0.2, min=0.0, max=1.0, step=0.01),
        )
        display(row_detection)

    def plot_y_precentiles(
        self, y_percentiles_smoothed_list, fov_list, y_percentile_threshold
    ):
        fig = plt.figure()

        ### Subplot dimensions of plot
        root_list_len = np.ceil(np.sqrt(len(y_percentiles_smoothed_list)))

        ### Looping through each fov
        idx = 0
        for j, y_percentiles_smoothed in enumerate(y_percentiles_smoothed_list):
            ### Managing Subplots
            idx += 1
            ax = fig.add_subplot(root_list_len, root_list_len, idx, projection="3d")

            ### Making list of vertices (tuples) for use with PolyCollection
            vert_arr = np.array(
                [
                    np.add.accumulate(
                        np.ones(y_percentiles_smoothed.shape, dtype=int), axis=0
                    ),
                    y_percentiles_smoothed,
                ]
            )
            verts = []
            for t in range(vert_arr.shape[2]):
                w_vert = vert_arr[:, :, t]
                verts.append(
                    [
                        (w_vert[0, i], w_vert[1, i])
                        for i in range(0, w_vert.shape[1], 10)
                    ]
                )

            ### Making counting array for y position
            zs = np.add.accumulate(np.ones(len(verts)))

            ### Creating PolyCollection and add to plot
            poly = PolyCollection(verts, facecolors=["b"])
            poly.set_alpha(0.5)
            ax.add_collection3d(poly, zs=zs, zdir="y")

            ### Depecting thresholds as straight lines
            x_len = y_percentiles_smoothed.shape[0]
            y_len = y_percentiles_smoothed.shape[1]
            thr_x = np.repeat(
                np.add.accumulate(np.ones(x_len, dtype=int))[:, np.newaxis],
                y_len,
                axis=1,
            ).T.flatten()
            thr_y = np.repeat(np.add.accumulate(np.ones(y_len, dtype=int)), x_len)
            thr_z = np.repeat(y_percentile_threshold, x_len * y_len)

            for i in range(0, x_len * y_len, x_len):
                ax.plot(
                    thr_x[i : i + x_len],
                    thr_y[i : i + x_len],
                    thr_z[i : i + x_len],
                    c="r",
                )

            ### Plot lebels
            ax.set_title("FOV: " + str(fov_list[j]))
            ax.set_xlabel("y position")
            ax.set_xlim3d(0, vert_arr[0, -1, 0])
            ax.set_ylabel("time (s)")
            ax.set_ylim3d(0, len(verts))
            ax.set_zlabel("intensity")
            ax.set_zlim3d(0, np.max(vert_arr[1]))

        plt.show()

    def preview_y_crop(
        self,
        y_percentiles_smoothed_list,
        imported_array_list,
        y_min_edge_dist,
        padding_y,
        trench_len_y,
        expected_num_rows,
        alternate_orientation,
        orientation_detection,
        orientation_on_fail,
        images_per_row,
    ):

        self.final_params["Minimum Trench Length"] = y_min_edge_dist
        self.final_params["Y Padding"] = padding_y
        self.final_params["Trench Length"] = trench_len_y
        self.final_params["Orientation Detection Method"] = orientation_detection
        self.final_params[
            "Expected Number of Rows (Manual Orientation Detection)"
        ] = expected_num_rows
        self.final_params["Alternate Orientation"] = alternate_orientation
        self.final_params[
            "Top Orientation when Row Drifts Out (Manual Orientation Detection)"
        ] = orientation_on_fail

        y_percentile_threshold = self.final_params["Y Percentile Threshold"]

        get_trench_edges_y_output = self.map_to_fovs(
            self.get_trench_edges_y, y_percentiles_smoothed_list, y_percentile_threshold
        )
        trench_edges_y_lists = [item[0] for item in get_trench_edges_y_output]
        start_above_lists = [item[1] for item in get_trench_edges_y_output]
        end_above_lists = [item[2] for item in get_trench_edges_y_output]

        get_manual_orientations_output = self.map_to_fovs(
            self.get_manual_orientations,
            trench_edges_y_lists,
            start_above_lists,
            end_above_lists,
            alternate_orientation,
            expected_num_rows,
            orientation_detection,
            orientation_on_fail,
            y_min_edge_dist,
        )

        orientations_list = [item[0] for item in get_manual_orientations_output]
        drop_first_row_list = [item[1] for item in get_manual_orientations_output]
        drop_last_row_list = [item[2] for item in get_manual_orientations_output]

        y_ends_lists = self.map_to_fovs(
            self.get_trench_ends,
            trench_edges_y_lists,
            start_above_lists,
            end_above_lists,
            orientations_list,
            drop_first_row_list,
            drop_last_row_list,
            y_min_edge_dist,
        )
        y_drift_list = self.map_to_fovs(self.get_y_drift, y_ends_lists)

        keep_in_frame_kernels_output = self.map_to_fovs(
            self.keep_in_frame_kernels,
            y_ends_lists,
            y_drift_list,
            imported_array_list,
            orientations_list,
            padding_y,
            trench_len_y,
        )
        valid_y_ends_list = [item[0] for item in keep_in_frame_kernels_output]
        valid_orientations_list = [item[1] for item in keep_in_frame_kernels_output]
        cropped_in_y_list = self.map_to_fovs(
            self.crop_y,
            imported_array_list,
            y_drift_list,
            valid_y_ends_list,
            valid_orientations_list,
            padding_y,
            trench_len_y,
        )

        self.plot_y_crop(
            cropped_in_y_list,
            imported_array_list,
            self.fov_list,
            valid_orientations_list,
            images_per_row,
        )

        self.cropped_in_y_list = cropped_in_y_list

        return cropped_in_y_list

    def preview_y_crop_interactive(self):

        y_cropping = interactive(
            self.preview_y_crop,
            {"manual": True},
            y_percentiles_smoothed_list=fixed(self.y_percentiles_smoothed_list),
            imported_array_list=fixed(self.imported_array_list),
            y_min_edge_dist=IntSlider(value=50, min=5, max=1000, step=5),
            padding_y=IntSlider(value=20, min=0, max=500, step=5),
            trench_len_y=IntSlider(value=270, min=0, max=1000, step=5),
            expected_num_rows=IntText(
                value=2, description="Number of Rows:", disabled=False
            ),
            alternate_orientation=Dropdown(
                options=[True, False],
                value=True,
                description="Alternate Orientation?:",
                disabled=False,
            ),
            orientation_detection=Dropdown(
                options=[0, 1, "phase"],
                value=0,
                description="Orientation:",
                disabled=False,
            ),
            orientation_on_fail=Dropdown(
                options=[None, 0, 1],
                value=0,
                description="Orientation when < expected rows:",
                disabled=False,
            ),
            images_per_row=IntSlider(value=3, min=1, max=10, step=1),
        )

        display(y_cropping)

    def plot_y_crop(
        self,
        cropped_in_y_list,
        imported_array_list,
        fov_list,
        valid_orientations_list,
        images_per_row,
    ):

        time_list = range(1, imported_array_list[0].shape[3] + 1)
        time_per_img = len(time_list)
        ttl_lanes = np.sum([len(item) for item in valid_orientations_list])
        ttl_imgs = ttl_lanes * time_per_img

        remaining_imgs = time_per_img % images_per_row
        if remaining_imgs == 0:
            rows_per_lane = time_per_img // images_per_row
        else:
            rows_per_lane = (time_per_img // images_per_row) + 1

        nrows = rows_per_lane * ttl_lanes
        ncols = images_per_row

        fig, _ = plt.subplots(figsize=(20, 10))

        idx = 0
        for i, cropped_in_y in enumerate(cropped_in_y_list):
            num_rows = len(valid_orientations_list[i])
            for j in range(num_rows):
                for k, t in enumerate(time_list):
                    idx += 1
                    ax = plt.subplot(nrows, ncols, idx)
                    ax.axis("off")
                    ax.set_title(
                        "row=" + str(j) + ",fov=" + str(fov_list[i]) + ",t=" + str(t)
                    )
                    ax.imshow(cropped_in_y[j, 0, :, :, k], cmap="Greys_r")
                if remaining_imgs != 0:
                    for t in range(0, (images_per_row - remaining_imgs)):
                        idx += 1

        fig.tight_layout()
        fig.show()

    def preview_x_percentiles(
        self,
        cropped_in_y_list,
        t,
        x_percentile,
        background_kernel_x,
        smoothing_kernel_x,
        otsu_scaling,
        min_threshold,
    ):

        self.final_params["X Percentile"] = x_percentile
        self.final_params["X Background Kernel"] = background_kernel_x
        self.final_params["X Smoothing Kernel"] = smoothing_kernel_x
        self.final_params["Otsu Threshold Scaling"] = otsu_scaling
        self.final_params["Minimum X Threshold"] = min_threshold

        smoothed_x_percentiles_list = self.map_to_fovs(
            self.get_smoothed_x_percentiles,
            cropped_in_y_list,
            x_percentile,
            (background_kernel_x, 1),
            (smoothing_kernel_x, 1),
        )
        thresholds = []
        for smoothed_x_percentiles_row in smoothed_x_percentiles_list:
            for smoothed_x_percentiles in smoothed_x_percentiles_row:
                x_percentiles_t = smoothed_x_percentiles[:, t]
                thresholds.append(
                    self.get_midpoints(x_percentiles_t, otsu_scaling, min_threshold)[1]
                )
        self.plot_x_percentiles(
            smoothed_x_percentiles_list, self.fov_list, t, thresholds
        )

        self.smoothed_x_percentiles_list = smoothed_x_percentiles_list
        all_midpoints_list, x_drift_list = self.preview_midpoints(
            self.smoothed_x_percentiles_list
        )

        return smoothed_x_percentiles_list, all_midpoints_list, x_drift_list

    def preview_midpoints(self, smoothed_x_percentiles_list):
        otsu_scaling = self.final_params["Otsu Threshold Scaling"]
        min_threshold = self.final_params["Minimum X Threshold"]

        all_midpoints_list = self.map_to_fovs(
            self.get_all_midpoints,
            self.smoothed_x_percentiles_list,
            otsu_scaling,
            min_threshold,
        )
        self.plot_midpoints(all_midpoints_list, self.fov_list)
        x_drift_list = self.map_to_fovs(self.get_x_drift, all_midpoints_list)

        self.all_midpoints_list, self.x_drift_list = (all_midpoints_list, x_drift_list)

        return all_midpoints_list, x_drift_list

    def preview_x_percentiles_interactive(self):
        trench_detection = interactive(
            self.preview_x_percentiles,
            {"manual": True},
            cropped_in_y_list=fixed(self.cropped_in_y_list),
            t=IntSlider(
                value=0, min=0, max=self.cropped_in_y_list[0].shape[4] - 1, step=1
            ),
            x_percentile=IntSlider(value=85, min=50, max=100, step=1),
            background_kernel_x=IntSlider(value=21, min=1, max=601, step=20),
            smoothing_kernel_x=IntSlider(value=9, min=1, max=31, step=2),
            otsu_scaling=FloatSlider(value=0.25, min=0.0, max=2.0, step=0.01),
            min_threshold=IntSlider(value=0, min=0.0, max=65535, step=1),
        )

        display(trench_detection)

    def plot_x_percentiles(self, smoothed_x_percentiles_list, fov_list, t, thresholds):
        fig = plt.figure()
        nrow = len(self.cropped_in_y_list)  # fovs
        ncol = (sum([len(item) for item in self.cropped_in_y_list]) // nrow) + 1

        idx = 0
        for i, smoothed_x_percentiles_lanes in enumerate(smoothed_x_percentiles_list):
            for j, smoothed_x_percentiles in enumerate(smoothed_x_percentiles_lanes):
                idx += 1
                data = smoothed_x_percentiles[:, t]
                ax = fig.add_subplot(ncol, nrow, idx)
                ax.plot(data)

                current_threshold = thresholds[idx - 1]
                threshold_data = np.repeat(current_threshold, len(data))
                ax.plot(threshold_data, c="r")
                ax.set_title("FOV: " + str(fov_list[i]) + " Lane: " + str(j))
                ax.set_xlabel("x position")
                ax.set_ylabel("intensity")

        plt.show()

    def plot_midpoints(self, all_midpoints_list, fov_list):
        fig = plt.figure()
        ax = fig.gca()

        nrows = 2 * len(fov_list)
        ncols = 2

        idx = 0
        for i, top_bottom_list in enumerate(all_midpoints_list):
            for j, all_midpoints in enumerate(top_bottom_list):
                idx += 1
                ax = plt.subplot(nrows, ncols, idx)
                ax.set_title("row=" + str(j) + ",fov=" + str(fov_list[i]))
                data = np.concatenate(
                    [
                        np.array([item, np.ones(item.shape, dtype=int) * k]).T
                        for k, item in enumerate(all_midpoints)
                    ]
                )
                ax.scatter(data[:, 0], data[:, 1], alpha=0.7)
                ax.set_xlabel("x position")
                ax.set_ylabel("time")

        plt.tight_layout()
        plt.show()

    def preview_kymographs(
        self,
        cropped_in_y_list,
        all_midpoints_list,
        x_drift_list,
        trench_width_x,
        trench_present_thr,
    ):
        self.final_params["Trench Width"] = trench_width_x
        self.final_params["Trench Presence Threshold"] = trench_present_thr

        cropped_in_x_list = self.map_to_fovs(
            self.get_crop_in_x,
            cropped_in_y_list,
            all_midpoints_list,
            x_drift_list,
            trench_width_x,
            trench_present_thr,
        )
        corrected_midpoints_list = self.map_to_fovs(
            self.get_corrected_midpoints,
            all_midpoints_list,
            x_drift_list,
            trench_width_x,
            trench_present_thr,
        )

        self.plot_kymographs(cropped_in_x_list, self.fov_list)
        self.plot_midpoints(corrected_midpoints_list, self.fov_list)

    def preview_kymographs_interactive(self):
        interact_manual(
            self.preview_kymographs,
            cropped_in_y_list=fixed(self.cropped_in_y_list),
            all_midpoints_list=fixed(self.all_midpoints_list),
            x_drift_list=fixed(self.x_drift_list),
            trench_width_x=IntSlider(value=30, min=2, max=1000, step=2),
            trench_present_thr=FloatSlider(value=0.0, min=0.0, max=1.0, step=0.05),
        )

    def plot_kymographs(self, cropped_in_x_list, fov_list, num_rows=2):
        plt.figure()
        idx = 0
        ncol = num_rows
        nrow = len(fov_list) * num_rows

        for i, row_list in enumerate(cropped_in_x_list):
            for j, channel in enumerate(row_list):
                seg_channel = channel[0]
                idx += 1
                rand_k = np.random.randint(0, seg_channel.shape[0])
                ax = plt.subplot(ncol, nrow, idx)
                ex_kymo = seg_channel[rand_k]
                self.plot_kymograph(ax, ex_kymo)
                ax.set_title(
                    "row="
                    + str(j)
                    + ",fov="
                    + str(fov_list[i])
                    + ",trench="
                    + str(rand_k)
                )

        plt.tight_layout()
        plt.show()

    def plot_kymograph(self, ax, kymograph):
        """Helper function for plotting kymographs. Takes a kymograph array of
        shape (y_dim,x_dim,t_dim).

        Args:
            kymograph (array): kymograph array of shape (y_dim,x_dim,t_dim).
        """
        list_in_t = [kymograph[:, :, t] for t in range(kymograph.shape[2])]
        img_arr = np.concatenate(list_in_t, axis=1)
        ax.imshow(img_arr, cmap="Greys_r")

    def process_results(self):
        self.final_params["All Channels"] = self.all_channels
        self.final_params["Invert"] = self.invert

        for key, value in self.final_params.items():
            print(key + " " + str(value))

    def write_param_file(self):
        with open(self.headpath + "/kymograph.par", "wb") as outfile:
            pickle.dump(self.final_params, outfile)