# TrenchRipper Master Notebook

## Introduction

This notebook contains the entire `TrenchRipper` pipline, divided into simple steps. This pipline is ideal for Mother <br>Machine image data where cells possess fluorescent segmentation markers. Segmentation on phase or brightfield data <br>is being developed, but is still an experimental feature.

The steps in this pipeline are as follows:
1. Extracting your Mother Machine data (.nd2) into hdf5 format
2. Identifying and cropping individual trenches into kymographs
3. Segmenting cells with a fluorescent marker
4. Determining lineages and object properties

In each step, the user will dynamically specify parameters using a series of interactive diagnostics on their dataset. <br>Following this, a parameter file will be written to disk and then used to deploy a parallel computation on the <br>dataset, either locally or on a SLURM cluster.


This is intended as an end-to-end solution to analyzing Mother Machine data. As such, **it is not trivial to plug data <br>directly into intermediate steps**, as it will lack the correct formatting and associated metadata. A notable <br>exception to this is using another program to segment data. The library references binary segmentation masks using <br>only metadata derived from their associated kymographs. As such, it is possible to generate segmentations on these <br>kymographs elsewhere and place them into the segmentation data path to have `TrenchRipper` act on those <br>segmentations instead. More on this in the segmentation section...

#### Imports

Run this section to import all relavent packages and libraries used in this notebook. You must run this everytime you open a new python kernel.

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

import warnings

warnings.filterwarnings(action="once")

import matplotlib

matplotlib.rcParams["figure.figsize"] = [20, 10]

In [None]:
sourcedir = "/n/files/SysBio/PAULSSON LAB/Personal Folders/Daniel/Nanopore_Data/2020-11-23_lDE13/no_sample/20201124_0006_MN35044_FAO84917_7fb03513"
targetdir = "/home/de64/scratch/de64/2020-11-23_lDE13_sequencing"
tr.trcluster.transferjob(sourcedir, targetdir)

# mVenus

#### Specify Paths

Begin by defining the directory in which all processing will be done, as well as the initial nd2 file we will be <br>processing. This line should be run everytime you open a new python kernel.

The format should be: `headpath = "/path/to/folder"` and `nd2file = "/path/to/file.nd2"`

For example:
```
headpath = "/n/scratch2/de64/2019-05-31_validation_data"
nd2file = "/n/scratch2/de64/2019-05-31_validation_data/Main_Experiment.nd2"
```

Ideally, these files should be placed in a storage location with relatively fast I/O

In [None]:
headpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/mVenus_headpath"
# hdf5inputpath = "/home/de64/scratch/de64/2020-11-07_lDE11/run"
nd2file = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/induction_real.nd2"

## Extract to hdf5 files

In this section, we will be extracting our image data. Currently this notebook only supports `.nd2` format; however <br>there are `.tiff` extractors in the TrenchRipper source files that are being added to `Master.ipynb` soon.

In the abstract, this step will take a single `.nd2` file and split it into a set of `.hdf5` files stored in <br>`headpath/hdf5`. Splitting the file up in this way will facilitate quick procesing in later steps. Each field of <br>view will be split into one or more `.hdf5` files, depending on the number of images per file requested (more on <br>this later). 

To keep track of which output files correspond to which FOVs, as well as to keep track of experiment metadata, the <br>extractor also outputs a `metadata.hdf5` file in the `headpath` folder. The data from this step is accessible in <br>that `metadata.hdf5` file under the `global` key. If you would like to look at this metadata, you may use the <br>`tr.utils.pandas_hdf5_handler` to read from this file. Later steps will add additional metadata under different <br>keys into the `metadata.hdf5` file.

#### Start Dask Workers

First, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=20,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Extraction

Now that we have our cluster scheduler spun up, it is time to convert files. This will be handled by the <br>`hdf5_extractor` object. This extractor will pull up each FOV and split it such that each derived `.hdf5` file <br>contains, at maximum, N timepoints of that FOV per file. The image data stored in these files takes the <br>form of `(N,Y,X)` arrays that are accessible using the desired channel name as a key. 

The arguments for this extractor are:

 - **nd2file** : The filepath to the `.nd2` file you intend to extract.
 
 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **tpts_per_file** : The maximum number of timepoints stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **ignore_fovmetadata** : Used when `.nd2` data is corrupted and does not possess records for stage positions or <br>timepoints. Only set `False` if the extractor throws errors on metadata handling.

 - **nd2reader_override** : Overrides values in metadata recovered using the `nd2reader`. Currently set to <br>`{"z_levels":[],"z_coordinates":[]}` by default to correct a known issue where z coordinates are mistakenly <br>interpreted as a z stack. See the [nd2reader](https://rbnvrw.github.io/nd2reader/) documentation for more info.

In [None]:
# hdf5_extractor = marlin_extractor(hdf5inputpath, headpath)

In [None]:
hdf5_extractor = tr.ndextract.hdf5_fov_extractor(
    nd2file,
    headpath,
    tpts_per_file=50,
    ignore_fovmetadata=False,
    nd2reader_override={"z_levels": [], "z_coordinates": []},
)

In [None]:
# hdf5_extractor = tr.ndextract.tiff_extractor(
#     tiffpath,
#     headpath,
#     ["Phase","YFP"],tpts_per_file=50
# )

##### Extraction Parameters

Here, you may set the time interval you want to extract. Useful for cropping data to the period exhibiting the dynamics of interest.

Optionally take notes to add to the `metadata.hdf5` file. Notes may also be taken directly in this notebook.

In [None]:
hdf5_extractor.inter_set_params()

##### Begin Extraction 

Running the following line will start the extraction process. This may be monitored by examining the `Dask Dashboard` <br> under the link displayed earlier. Once the computation is complete, move to the next line.

This step may take a long time, though it is possible to speed it up using additional workers.

In [None]:
hdf5_extractor.extract(dask_controller)

##### Shutdown Dask

Once extraction is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.shutdown()

## Kymographs

Now that you have extracted your data into a series of `.hdf5` files, we will now perform identification and cropping <br>of the individual trenches/growth channels present in the images. This algorithm assumes that your growth trenches <br>are vertically aligned and that they alternate in their orientation from top to bottom. See the example image for the <br>correct geometry:

![example_image](./resources/example_image.jpg)

The output of this step will be a set of `.hdf5` files stored in `headpath/kymograph`. The image data stored in these <br>files takes the form of `(K,T,Y,X)` arrays where K is the trench index, T is time, and Y,X are the crop dimensions. <br>These arrays are accessible using keys of the form `"[Image Channel]"`. For example, looking up phase channel <br>data of trenches in the topmost row of an image will require the key `"Phase"`

[ '/n/scratch3/users/d/de64/190917_20x_phase_gfp_segmentation002',
 '/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation',
 '/n/scratch3/users/d/de64/190925_20x_phase_yfp_segmentation',
 '/n/scratch3/users/d/de64/ezrdm_training_sb7',
 '/n/scratch3/users/d/de64/mbm_training_sb7',
 '/n/scratch3/users/d/de64/Sb7_L35',
 '/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9',
 '/n/scratch3/users/d/de64/Vibrio_2_1_TOP',
 '/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms',
 '/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X',
 '/n/scratch3/users/d/de64/Main_Experiment',
 '/n/scratch3/users/d/de64/bde17_gotime']

### Test Parameters



##### Initialize the interactive kymograph class

As a first step, initialize the `tr.interactive.kymograph_interactive` class that will be help us choose the <br>parameters we will use to generate kymographs. 

In [None]:
interactive_kymograph = tr.kymograph_interactive(headpath)

In [None]:
viewer = tr.hdf5_viewer(headpath)

##### Examine Images

Here you can manually inspect images before beginning parameter tuning.

In [None]:
viewer.view(width=1200)

You will now want to select a few test FOVs to try out parameters on, the channel you want to detect trenches on, and <br>the time interval on which you will perform your processing.

The arguments for this step are:

- **seg_channel (string)** : The channel name that you would like to segment on.

- **invert (list)** : Whether or not you want to invert the image before detecting trenches. By default, it is assumed that <br>the trenches have a high pixel intensity relative to the background. This should be the case for Phase Contrast and <br>Fluorescence Imageing, but may not be the case for Brightfield Imaging, in which case you will want to invert the image.

- **fov_list (list)** : List of integers corresponding to the FOVs that you wish to make test kymographs of.

- **t_subsample_step (int)** : Step size to be used for subsampling input files in time, recommend that subsampling results in <br>between 5 and 10 timepoints for quick processing.

Hit the "Run Interact" button to lock in your parameters. The button will become transparent briefly and become solid again <br>when processing is complete. After that has occured, move on to the next step. 

In [None]:
interactive_kymograph.import_hdf5_interactive()

##### Tune "trench-row" detection hyperparameters

The kymograph code begins by detecting the positions of trench rows in the image as follows:

1. Reducing each 2D image to a 1D signal along the y-axis by computing the qth percentile of the data along the x-axis
2. Smooth this signal using a median kernel
3. Normalize the signal by linearly scaling 0. and 1. to the minimum and maximum, respectively
4. Use a set threshold to determine the trench row poisitons

The arguments for this step are:

 - **y_percentile (int)** : Percentile to use for step 1.

 - **smoothing_kernel_y_dim_0 (int)** : Median kernel size to use for step 2.

 - **y_percentile_threshold (float)** : Threshold to use in step 4.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line.

In [None]:
interactive_kymograph.preview_y_precentiles_interactive()

##### Tune "trench-row" cropping hyperparameters

Next, we will use the detected rows to perform cropping of the input image in the y-dimension:

1. Determine edges of trench rows based on threshold mask.
2. Filter out rows that are too small.
3. Use the remaining rows to compute the drift in y in each image.
4. Apply the drift to the initally detected rows to get rows in all timepoints.
5. Perform cropping using the "end" of the row as reference (the end referring to the part of the trench farthest from <br>the feeding channel).

Step 5 performs a simple algorithm to determine the orientation of each trench:

```
row_orientations = [] # A list of row orientations, starting from the topmost row
if the number of detected rows == 'Number of Rows': 
    row_orientations.append('Orientation')
elif the number of detected rows < 'Number of Rows':
    row_orientations.append('Orientation when < expected rows')
for row in rows:
    if row_orientations[-1] == downward:
        row_orientations.append(upward)
    elif row_orientations[-1] == upward:
        row_orientations.append(downward)
```

Additionally, if the device tranches face a single direction, alternation of row orientation may be turned off by setting the<br> `Alternate Orientation?` argument to False. The `Use Median Drift?` argument, when set to True, will use the<br> median drift in y across all FOVs for drift correction, instead of doing drift correction independently for all FOVs. <br>This can be useful if there are a large fraction of FOVs which are failing drift correction. Note that `Use Median Drift?` <br>sets this behavior for both y and x drift correction.

The arguments for this step are:

 - **y_min_edge_dist (int)** : Minimum row length necessary for detection (filters out small detected objects).

 - **padding_y (int)** : Padding to add to the end of trench row when cropping in the y-dimension.

 - **trench_len_y (int)** : Length from the end of each trench row to the feeding channel side of the crop.

 - **Number of Rows (int)** : The number of rows to expect in your image. For instance, two in the example image.
 
 - **Alternate Orientation? (bool)** : Whether or not to alternate the orientation of consecutive rows.

 - **Orientation (int)** : The orientation of the top-most row where 0 corresponds to a trench with a downward-oriented trench <br>opening and 1 corresponds to a trench with an upward-oriented trench opening.

 - **Orientation when < expected rows(int)** : The orientation of the top-most row when the number of detected rows is less than <br>expected. Useful if your trenches drift out of your image in some FOVs.
 
 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.

 - **images_per_row(int)** : How many images to output per row for this widget.

Running the following widget will display y-cropped images for each fov and timepoint.

In [None]:
interactive_kymograph.preview_y_crop_interactive()

##### Tune trench detection hyperparameters

Next, we will detect the positions of trenchs in the y-cropped images as follows:

1. Reducing each 2D image to a 1D signal along the x-axis by computing the qth percentile of the data along the y-axis.
2. Determine the signal background by smoothing this signal using a large median kernel.
3. Subtract the background signal.
4. Smooth the resultant signal using a median kernel.
5. Use an [otsu threhsold](https://imagej.net/Auto_Threshold#Otsu) to determine the trench midpoint poisitons.

After this, x-dimension drift correction of our detected midpoints will be performed as follows:

6. Begin at t=1
7. For $m \in \{midpoints(t)\}$ assign $n \in \{midpoints(t-1)\}$ to m if n is the closest midpoint to m at time $t-1$,<br>
points that are not the closest midpoint to any midpoints in m will not be mapped.
8. Compute the translation of each midpoint at time.
9. Take the average of this value as the x-dimension drift from time t-1 to t.

The arguments for this step are:

 - **t (int)** : Timepoint to examine the percentiles and threshold in.

 - **x_percentile (int)** : Percentile to use for step 1.

 - **background_kernel_x (int)** : Median kernel size to use for step 2.

 - **smoothing_kernel_x (int)** : Median kernel size to use for step 4.

 - **otsu_scaling (float)** : Scaling factor to apply to the threshold determined by Otsu's method.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line. In addition, it will display the detected midpoints for each of your timepoints. <br>If there is too much sparsity, or discontinuity, your drift correction will not be accurate.

In [None]:
interactive_kymograph.preview_x_percentiles_interactive()

##### Tune trench cropping hyperparameters

Trench cropping simply uses the drift-corrected midpoints as a reference and crops out some fixed length around them <br>
to produce an output kymograph. **Note that the current implementation does not allow trench crops to overlap**. If your<br>
trench crops do overlap, the error will not be caught here, but will cause issues later in the pipeline. As such, try <br>
to crop your trenches as closely as possible. This issue will be fixed in a later update.

The arguments for this step are:

 - **trench_width_x (int)** : Trench width to use for cropping.

 - **trench_present_thr (float)** : Trenches that appear in less than this percent of FOVs will be eliminated from the dataset.<br>
If not removed, missing positions will be inferred from the image drift.

 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.


Running the following widget will display a random kymograph for each row in each fov and will also produce midpoint plots <br>showing retained midpoints

In [None]:
interactive_kymograph.preview_kymographs_interactive()

##### Export and save hyperparameters

Run the following line to register and display the parameters you have selected for kymograph creation.

In [None]:
interactive_kymograph.process_results()

If you are satisfied with the above parameters, run the following line to write these parameters to disk at `headpath/kymograph.par`<br>
This file will be used to perform kymograph creation in the next section.

In [None]:
interactive_kymograph.write_param_file()

### Generate Kymograph

##### Start Dask Workers

Again, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2 for kymograph creation. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=20,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Kymograph Cropping

Now that we have our cluster scheduler spun up, we will extract kymographs using the parameters stored in `headpath/kymograph.par`. <br>
This will be handled by the `kymograph_cluster` object. This will detect trenches in all of the files present in `headpath/hdf5` that <br>
you created in the first step. It will then crop these trenches and place the crops in a series of `.hdf5` files in `headpath/kymograph`. <br>
These files will store image data in the form of `(K,T,Y,X)` arrays where K is the trench index, T is time and Y,X are the image dimensions <br>
of the crop.

The arguments for this step are:

 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **trenches_per_file** : The maximum number of trenches stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **paramfile** : Set to true if you want to use parameters from `headpath/kymograph.par` Otherwise, you will have to specify <br>
 parameters as direct arguments to `kymograph_cluster`.

In [None]:
kymoclust = tr.kymograph.kymograph_cluster(
    headpath=headpath, trenches_per_file=200, paramfile=True
)

##### Begin Kymograph Cropping 

Running the following line will start the cropping process. This may be monitored by examining the `Dask Dashboard` <br>
under the link displayed earlier. Once the computation is complete, move to the next line.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.generate_kymographs(dask_controller)

In [None]:
ff = tr.focus_filter(headpath)

In [None]:
ff.choose_filter_channel_inter()

In [None]:
ff.plot_histograms()

In [None]:
ff.plot_focus_threshold_inter()

In [None]:
ff.write_param_file()

##### Post-process Images

After the above step, kymographs will have been created for each `.hdf5` input file. They will now need to be reorganized <br>
into a new set of files such that each file has, at most, `trenches_per_file` trenches in each file.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.post_process(dask_controller)

##### Check kymograph statistics

Run the next line to display some statistics from kymograph creation. The outputs are:

 - **fovs processed** : The number of FOVs successfully processed out of the total number of FOVs
 - **rows processed** : The number of rows of trenches processed out of the total number of rows
 - **trenches processed** : The number of trenches successfully processed
 - **row/fov** : The average number of rows successfully processed per FOV
 - **trenches/fov** : The average number of trenches successfully processed per FOV
 - **failed fovs** : A list of failed FOVs. Spot check these FOVs in the viewer to determine potential problems

In [None]:
kymoclust.kymo_report()

##### Shutdown Dask

Once cropping is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.shutdown()

#### Output

At this point you may want to use your output. The output of this step is a set of `.hdf5` files stored in <br>`headpath/kymograph`. The image data stored in these files takes the form of `(K,T,Y,X)` arrays <br>where K is the trench index, T is time, and Y,X are the crop dimensions.

These arrays are accessible using keys of the form `"[Trench Row Number]/[Image Channel]"`. <br>For example, looking up phase channel data of trenches in the topmost row of an image will require <br>the key `"0/Phase"` The metadata associated with these files is a large pandas dataframe relating <br>crops to original FOVs, accessible using the "kymograph" key on `headpath/metadata.hdf5`

To assist in accessing this file, you may use the `trenchripper.pandas_hdf5_handler` object to <br>interface with this file as follows:

In [None]:
import dask.dataframe as dd
import dask.delayed as delayed
from distributed.client import futures_of
import numpy as np
import pandas as pd
import h5py
import scipy.signal
import skimage as sk
from time import sleep
from matplotlib import pyplot as plt

In [None]:
def get_image_measurements(
    kymographpath, channels, file_idx, output_name, img_fn, *args, **kwargs
):
    df = dd.read_parquet(kymographpath + "/metadata")

    working_dfs = []

    proc_file_path = kymographpath + "/kymograph_" + str(file_idx) + ".hdf5"
    with h5py.File(proc_file_path, "r") as infile:
        working_filedf = df[df["File Index"] == file_idx].compute()
        trench_idx_list = working_filedf["File Trench Index"].unique().tolist()
        for trench_idx in trench_idx_list:
            trench_df = working_filedf[
                working_filedf["File Trench Index"] == trench_idx
            ]

            for channel in channels:
                kymo_arr = infile[channel][trench_idx]
                fn_out = [
                    img_fn(kymo_arr[i], *args, **kwargs)
                    for i in range(kymo_arr.shape[0])
                ]
                trench_df[channel + " " + output_name] = fn_out
            working_dfs.append(trench_df)

    out_df = pd.concat(working_dfs)
    return out_df


def get_all_image_measurements(
    headpath, output_path, channels, output_name, img_fn, *args, **kwargs
):
    kymographpath = headpath + "/kymograph"
    df = dd.read_parquet(kymographpath + "/metadata")

    file_list = df["File Index"].unique().compute().tolist()

    delayed_list = []

    for file_idx in file_list:
        df_delayed = delayed(get_image_measurements)(
            kymographpath, channels, file_idx, output_name, img_fn, *args, **kwargs
        )
        delayed_list.append(df_delayed.persist())

    ## filtering out non-failed dataframes ##
    all_delayed_futures = []
    for item in delayed_list:
        all_delayed_futures += futures_of(item)
    while any(future.status == "pending" for future in all_delayed_futures):
        sleep(0.1)

    good_delayed = []
    for item in delayed_list:
        if all([future.status == "finished" for future in futures_of(item)]):
            good_delayed.append(item)

    ## compiling output dataframe ##
    df_out = dd.from_delayed(good_delayed).persist()
    df_out["FOV Parquet Index"] = df_out.index
    df_out = df_out.set_index("FOV Parquet Index", drop=True, sorted=False)
    df_out = df_out.repartition(partition_size="25MB").persist()

    dd.to_parquet(
        df_out,
        output_path,
        engine="fastparquet",
        compression="gzip",
        write_metadata_file=True,
    )

In [None]:
headpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/mVenus_headpath"

In [None]:
kymograph_metadata = dd.read_parquet(headpath + "/kymograph/metadata")

# kymograph_metadata = pd.read_parquet(headpath + "/kymograph/metadata")
# kymograph_metadata = kymograph_metadata[kymograph_metadata["fov"]>125]

In [None]:
get_all_image_measurements(
    headpath,
    headpath + "/percentiles",
    ["mCherry", "YFP"],
    "90th Percentile",
    np.percentile,
    90,
)

In [None]:
kymograph_metadata = pd.read_parquet(headpath + "/percentiles")
kymograph_metadata = kymograph_metadata.set_index(["trenchid", "timepoints"], drop=True)
kymograph_metadata = kymograph_metadata.loc[(slice(None), range(10, 20)), :]

In [None]:
trenchid_groupby = kymograph_metadata.groupby(["trenchid"])
yfp_ratio_mean = trenchid_groupby.apply(
    lambda x: np.mean(x["YFP 90th Percentile"] / x["mCherry 90th Percentile"])
)

In [None]:
plt.hist(yfp_ratio_mean, bins=100)
plt.show()

## Fluorescence Segmentation

Now that you have copped your data into kymographs, we will now perform segmentation/cell detection <br>
on your kymographs. Currently, this pipeline only supports segmentation of fluorescence images; however, <br>
segmentation of transmitted light imaging techniques is in development.

The output of this step will be a set of `segmentation_[File #].hdf5` files stored in `headpath/fluorsegmentation`.<br>
The image data stored in these files takes the exact same form as the kymograph data, `(K,T,Y,X)` arrays <br>
where K is the trench index, T is time, and Y,X are the crop dimensions. These arrays are accessible using <br>
keys of the form `"[Trench Row Number]"`.

Since no metadata is generated by this step, it is possible to use another segmentation algorithm on the kymograph <br>
data. The output of segmentation must be split into `segmentation_[File #].hdf5` files, where `[File #]` agrees with the<br>
corresponding `kymograph_[File #].hdf5` file. Additionally, the `(K,T,Y,X)` arrays must be of the same shape as the <br>
kymograph arrays and accessible at the corresponding `"[Trench Row Number]"` key. These files must be placed into <br>
their own folder at `headpath/foldername`. This folder may then be used in later steps.

### Test Parameters

##### Initialize the interactive segmentation class

As a first step, initialize the `tr.fluo_segmentation_interactive` class that will be handling all steps of generating a segmentation. 

In [None]:
interactive_segmentation = tr.fluo_segmentation_interactive(headpath)

##### Choose channel to segment on

In [None]:
interactive_segmentation.choose_seg_channel_inter()

#### Import data

Fill in 

You will need to tune the following `args` and `kwargs` (in order):

**fov_idx (int)** :

**n_trenches (int)** :

**t_range (tuple)** :

**t_subsample_step (int)** :

In [None]:
interactive_segmentation.import_array_inter()

##### Process data

In [None]:
interactive_segmentation.plot_processed_inter()

#### Determine Cell Mask Envelope

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**cell_mask_method (str)** : Thresholding method, can be a local or global Otsu threshold.

**cell_otsu_scaling (float)** : Scaling factor applied to determined threshold.

**local_otsu_r (int)** : Radius of thresholding kernel used in the local otsu thresholding.

In [None]:
interactive_segmentation.plot_cell_mask_inter()

In [None]:
interactive_segmentation.plot_eig_mask_inter()

In [None]:
interactive_segmentation.plot_dist_mask_inter()

In [None]:
interactive_segmentation.plot_marker_mask_inter()

In [None]:
interactive_segmentation.process_results()

In [None]:
interactive_segmentation.write_param_file()

### Generate Segmentation

#### Start Dask Workers

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="01:00:00",
    local=False,
    n_workers=50,
    memory="1GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
segment = tr.segment.fluo_segmentation_cluster(headpath, paramfile=True)

In [None]:
segment.dask_segment(dask_controller)

#### Stop Dask Workers

In [None]:
dask_controller.shutdown()

## Region Properties (No Lineage)

Note this does not require a dask client

In [None]:
analyzer = tr.analysis.regionprops_extractor(
    headpath, "fluorsegmentation", intensity_channel_list=["mCherry", "YFP"]
)

In [None]:
analyzer.export_all_data()

In [None]:
import dask.dataframe as dd
import dask.delayed as delayed
from distributed.client import futures_of
import numpy as np
import pandas as pd
import h5py
import seaborn as sns
import scipy.signal
import skimage as sk
from time import sleep
from matplotlib import pyplot as plt

In [None]:
meta_handle = tr.pandas_hdf5_handler(headpath + "/metadata.hdf5")
meta_handle = meta_handle.read_df("global", read_metadata=True).metadata

In [None]:
region_props = pd.read_pickle(headpath + "/analysis.pkl").loc[
    (slice(None), slice(None), list(range(10, 20)), slice(None))
]
region_props = region_props.reset_index()
region_props = region_props.set_index(
    ["trenchid", "timepoints", "Intensity Channel", "Objectid"], drop=True
)
region_props = region_props.sort_index()

In [None]:
region_props[:5]

In [None]:
mchy_df = region_props.loc[(slice(None), slice(None), ["mCherry"]), :].droplevel(
    "Intensity Channel"
)
yfp_df = region_props.loc[(slice(None), slice(None), ["YFP"]), :].droplevel(
    "Intensity Channel"
)

In [None]:
ratio_df = yfp_df["mean_intensity"] / mchy_df["mean_intensity"]

scaled_yfp = (yfp_df["mean_intensity"] - yfp_df["mean_intensity"].min()) / (
    yfp_df["mean_intensity"].max() - yfp_df["mean_intensity"].min()
)
scaled_mchy = (mchy_df["mean_intensity"] - mchy_df["mean_intensity"].min()) / (
    mchy_df["mean_intensity"].max() - mchy_df["mean_intensity"].min()
)
scaled_ratio_df = scaled_yfp / scaled_mchy

In [None]:
mean_df = ratio_df.groupby("trenchid").mean()

In [None]:
plt.hist(mchy_df["mean_intensity"], bins=200, range=(0, 3000))
plt.show()

In [None]:
plt.hist(yfp_df["mean_intensity"], bins=200, range=(0, 15000))
plt.show()

In [None]:
plt.hist(ratio_df, bins=200, range=(0, 8))
plt.show()

In [None]:
plt.hist(scaled_ratio_df, bins=200, range=(0, 5))
plt.show()

In [None]:
plt.hist(mean_df, bins=200)
plt.show()

# Barcodes

In [None]:
# import h5py
# import pandas as pd
# import os

# hdf5inputpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/run"

# metadata_files = []
# for root, _, files in os.walk(hdf5inputpath):
#             metadata_files.extend(
#                 [
#                     os.path.join(root, f)
#                     for f in files
#                     if "metadata" in os.path.splitext(f)[0]
#                 ]
#             )

# for i in range(1,9):
#     indf = pd.read_hdf(metadata_files[0],key='data/' + str(i))
#     indf.to_hdf(metadata_files[0][:-5] + "_t=" + str(i) + ".hdf5", "data" ,"w")

In [None]:
import h5py
import os
import shutil
import copy
import h5py_cache
import tifffile
import pickle as pkl
import numpy as np
import pandas as pd
import ipywidgets as ipyw

from nd2reader import ND2Reader
from tifffile import imsave, imread
from paulssonlab.deaton.trenchripper.trenchripper.utils import (
    pandas_hdf5_handler,
    writedir,
)
from parse import compile


class marlin_extractor:
    def __init__(
        self,
        hdf5inputpath,
        headpath,
        tpts_per_file=100,
        parsestr="fov={fov:d}_config={channel}_t={timepoints:d}.hdf5",
        metaparsestr="metadata_t={timepoint:d}.hdf5",
        zero_base_keys=["timepoints"],
    ):  # note this chunk size has a large role in downstream steps...make sure is less than 1 MB
        """Utility to import hdf5 format files from MARLIN Runs.

        Attributes:
            headpath (str): base directory for data analysis
            tiffpath (str): directory where tiff files are located
            metapath (str): metadata path
            hdf5path (str): where to store hdf5 data
            tpts_per_file (int): number of timepoints to put in each hdf5 file
            parsestr (str): format of filenames from which to extract metadata (using parse library)
        """
        self.hdf5inputpath = hdf5inputpath
        self.headpath = headpath
        self.metapath = self.headpath + "/metadata.hdf5"
        self.hdf5path = self.headpath + "/hdf5"
        self.tpts_per_file = tpts_per_file
        self.parsestr = parsestr
        self.metaparsestr = metaparsestr
        self.zero_base_keys = zero_base_keys

        self.organism = ""
        self.microscope = ""
        self.notes = ""

    def get_metadata(
        self,
        hdf5inputpath,
        parsestr="fov={fov:d}_config={channel}_t={timepoints:d}.hdf5",
        metaparsestr="metadata_t={timepoint:d}.hdf5",
        zero_base_keys=["timepoints"],
    ):
        parser = compile(parsestr)
        parse_keys = [
            item.split("}")[0].split(":")[0] for item in parsestr.split("{")[1:]
        ] + ["image_paths"]

        exp_metadata = {}
        fov_metadata = {key: [] for key in parse_keys}

        hdf5_files = []
        metadata_files = []
        for root, _, files in os.walk(hdf5inputpath):
            hdf5_files.extend(
                [
                    os.path.join(root, f)
                    for f in files
                    if "config" in os.path.splitext(f)[0]
                ]
            )
            metadata_files.extend(
                [
                    os.path.join(root, f)
                    for f in files
                    if "metadata" in os.path.splitext(f)[0]
                ]
            )

        with h5py.File(hdf5_files[0], "r") as infile:
            hdf5_shape = infile["data"].shape
        exp_metadata["height"] = hdf5_shape[0]
        exp_metadata["width"] = hdf5_shape[1]
        #     exp_metadata['pixel_microns'] = tags['65326']

        for f in hdf5_files:
            match = parser.search(f)
            # ignore any files that don't match the regex
            if match is not None:
                # Add to dictionary
                fov_frame_dict = match.named
                for key, value in fov_frame_dict.items():
                    fov_metadata[key].append(value)
                fov_metadata["image_paths"].append(f)

        for zero_base_key in zero_base_keys:
            if 0 not in fov_metadata[zero_base_key]:
                fov_metadata[zero_base_key] = [
                    item - 1 for item in fov_metadata[zero_base_key]
                ]

        channels = list(set(fov_metadata["channel"]))
        exp_metadata["channels"] = channels
        exp_metadata["num_fovs"] = len(set(fov_metadata["fov"]))
        exp_metadata["frames"] = list(set(fov_metadata["timepoints"]))
        exp_metadata["num_frames"] = len(exp_metadata["frames"])
        exp_metadata["pixel_microns"] = 0.16136255757596  ##hack assuming ti5 40x
        fov_metadata = pd.DataFrame(fov_metadata)
        fov_metadata = fov_metadata.set_index(["fov", "timepoints"]).sort_index()

        output_fov_metadata = []
        step = len(channels)
        for i in range(0, len(fov_metadata), step):
            rows = fov_metadata[i : i + step]
            channel_path_entry = dict(zip(rows["channel"], rows["image_paths"]))
            fov_entry = rows.index.get_level_values("fov").unique()[0]
            timepoint_entry = rows.index.get_level_values("timepoints").unique()[0]
            fov_metadata_entry = {
                "fov": fov_entry,
                "timepoints": timepoint_entry,
                "channel_paths": channel_path_entry,
            }
            output_fov_metadata.append(fov_metadata_entry)
        fov_metadata = pd.DataFrame(output_fov_metadata).set_index(
            ["fov", "timepoints"]
        )

        metaparser = compile(metaparsestr)
        meta_df_out = []
        for metadata_file in metadata_files:
            match = metaparser.search(metadata_file)
            if match is not None:
                timepoint = match.named["timepoint"]
                meta_df = pd.read_hdf(metadata_file)
                meta_df["timepoints"] = timepoint
                meta_df_out.append(meta_df)
        meta_df_out = pd.concat(meta_df_out)
        if 0 not in meta_df_out["timepoints"].unique().tolist():
            meta_df_out["timepoints"] = meta_df_out["timepoints"] - 1
        meta_df_out = meta_df_out.groupby(["fov", "timepoints"], as_index=False)
        meta_df_out = meta_df_out.apply(lambda x: x[0:1])
        meta_df_out = meta_df_out.set_index(["fov", "timepoints"], drop=True)
        fov_metadata = fov_metadata.join(meta_df_out)

        return exp_metadata, fov_metadata

    def assignidx(self, fov_metadata):
        numfovs = len(fov_metadata.index.get_level_values("fov").unique().tolist())
        timepoints_per_fov = len(
            fov_metadata.index.get_level_values("timepoints").unique().tolist()
        )

        files_per_fov = (timepoints_per_fov // self.tpts_per_file) + 1
        remainder = timepoints_per_fov % self.tpts_per_file
        ttlfiles = numfovs * files_per_fov
        fov_file_idx = np.repeat(list(range(files_per_fov)), self.tpts_per_file)[
            : -(self.tpts_per_file - remainder)
        ]
        file_idx = np.concatenate(
            [fov_file_idx + (fov_idx * files_per_fov) for fov_idx in range(numfovs)]
        )
        fov_img_idx = np.repeat(
            np.array(list(range(self.tpts_per_file)))[np.newaxis, :],
            files_per_fov,
            axis=0,
        )
        fov_img_idx = fov_img_idx.flatten()[: -(self.tpts_per_file - remainder)]
        img_idx = np.concatenate([fov_img_idx for fov_idx in range(numfovs)])

        fov_idx = np.repeat(list(range(numfovs)), timepoints_per_fov)
        timepoint_idx = np.repeat(
            np.array(list(range(timepoints_per_fov)))[np.newaxis, :], numfovs, axis=0
        ).flatten()

        outdf = copy.deepcopy(fov_metadata)
        outdf["File Index"] = file_idx
        outdf["Image Index"] = img_idx
        return outdf

    def writemetadata(self, t_range=None, fov_list=None):

        exp_metadata, fov_metadata = self.get_metadata(
            self.hdf5inputpath,
            parsestr=self.parsestr,
            zero_base_keys=self.zero_base_keys,
        )

        if t_range is not None:
            exp_metadata["frames"] = exp_metadata["frames"][t_range[0] : t_range[1] + 1]
            exp_metadata["num_frames"] = len(exp_metadata["frames"])
            fov_metadata = fov_metadata.loc[
                pd.IndexSlice[:, slice(t_range[0], t_range[1])], :
            ]  # 4 -> 70

        if fov_list is not None:
            fov_metadata = fov_metadata.loc[list(fov_list)]
            exp_metadata["fields_of_view"] = list(fov_list)

        self.chunk_shape = (1, exp_metadata["height"], exp_metadata["width"])
        chunk_bytes = 2 * np.multiply.accumulate(np.array(self.chunk_shape))[-1]
        self.chunk_cache_mem_size = 2 * chunk_bytes
        exp_metadata["chunk_shape"], exp_metadata["chunk_cache_mem_size"] = (
            self.chunk_shape,
            self.chunk_cache_mem_size,
        )
        exp_metadata["Organism"], exp_metadata["Microscope"], exp_metadata["Notes"] = (
            self.organism,
            self.microscope,
            self.notes,
        )
        self.meta_handle = pandas_hdf5_handler(self.metapath)

        assignment_metadata = self.assignidx(fov_metadata)
        assignment_metadata.astype({"File Index": int, "Image Index": int})

        self.meta_handle.write_df("global", assignment_metadata, metadata=exp_metadata)

    def read_metadata(self):
        writedir(self.hdf5path, overwrite=True)
        self.writemetadata()
        metadf = self.meta_handle.read_df("global", read_metadata=True)
        self.metadata = metadf.metadata
        metadf = metadf.reset_index(inplace=False)
        metadf = metadf.set_index(
            ["File Index", "Image Index"], drop=True, append=False, inplace=False
        )
        self.metadf = metadf.sort_index()

    def set_params(self, fov_list, t_range, organism, microscope, notes):
        self.fov_list = fov_list
        self.t_range = t_range
        self.organism = organism
        self.microscope = microscope
        self.notes = notes

    def inter_set_params(self):
        self.read_metadata()
        t0, tf = (self.metadata["frames"][0], self.metadata["frames"][-1])
        available_fov_list = self.metadf["fov"].unique().tolist()
        selection = ipyw.interactive(
            self.set_params,
            {"manual": True},
            fov_list=ipyw.SelectMultiple(options=available_fov_list),
            t_range=ipyw.IntRangeSlider(
                value=[t0, tf],
                min=t0,
                max=tf,
                step=1,
                description="Time Range:",
                disabled=False,
            ),
            organism=ipyw.Textarea(
                value="",
                placeholder="Organism imaged in this experiment.",
                description="Organism:",
                disabled=False,
            ),
            microscope=ipyw.Textarea(
                value="",
                placeholder="Microscope used in this experiment.",
                description="Microscope:",
                disabled=False,
            ),
            notes=ipyw.Textarea(
                value="",
                placeholder="General experiment notes.",
                description="Notes:",
                disabled=False,
            ),
        )
        display(selection)

    def extract(self, dask_controller, retries=1):
        dask_controller.futures = {}

        self.writemetadata(t_range=self.t_range, fov_list=self.fov_list)
        metadf = self.meta_handle.read_df("global", read_metadata=True)
        self.metadata = metadf.metadata
        metadf = metadf.reset_index(inplace=False)
        metadf = metadf.set_index(
            ["File Index", "Image Index"], drop=True, append=False, inplace=False
        )
        self.metadf = metadf.sort_index()

        def writehdf5(fovnum, num_entries, timepoint_list, file_idx):
            y_dim = self.metadata["height"]
            x_dim = self.metadata["width"]
            filedf = self.metadf.loc[file_idx].reset_index(inplace=False)
            filedf = filedf.set_index(
                ["timepoints"], drop=True, append=False, inplace=False
            )
            filedf = filedf.sort_index()

            with h5py_cache.File(
                self.hdf5path + "/hdf5_" + str(file_idx) + ".hdf5",
                "w",
                chunk_cache_mem_size=self.chunk_cache_mem_size,
            ) as h5pyfile:
                for i, channel in enumerate(self.metadata["channels"]):
                    hdf5_dataset = h5pyfile.create_dataset(
                        str(channel),
                        (num_entries, y_dim, x_dim),
                        chunks=self.chunk_shape,
                        dtype="uint16",
                    )
                    for j in range(len(timepoint_list)):
                        frame = timepoint_list[j]
                        entry = filedf.loc[frame]["channel_paths"]
                        file_path = entry[channel]
                        with h5py_cache.File(file_path, "r") as infile:
                            img = infile["data"][:]
                        hdf5_dataset[j, :, :] = img
            return "Done."

        file_list = self.metadf.index.get_level_values("File Index").unique().values
        num_jobs = len(file_list)
        random_priorities = np.random.uniform(size=(num_jobs,))

        for k, file_idx in enumerate(file_list):
            priority = random_priorities[k]
            filedf = self.metadf.loc[file_idx]

            fovnum = filedf[0:1]["fov"].values[0]
            num_entries = len(filedf.index.get_level_values("Image Index").values)
            timepoint_list = filedf["timepoints"].tolist()

            future = dask_controller.daskclient.submit(
                writehdf5,
                fovnum,
                num_entries,
                timepoint_list,
                file_idx,
                retries=retries,
                priority=priority,
            )
            dask_controller.futures["extract file: " + str(file_idx)] = future

        extracted_futures = [
            dask_controller.futures["extract file: " + str(file_idx)]
            for file_idx in file_list
        ]
        pause_for_extract = dask_controller.daskclient.gather(
            extracted_futures, errors="skip"
        )

        futures_name_list = ["extract file: " + str(file_idx) for file_idx in file_list]
        failed_files = [
            futures_name_list[k]
            for k, item in enumerate(extracted_futures)
            if item.status != "finished"
        ]
        failed_file_idx = [int(item.split(":")[1]) for item in failed_files]
        outdf = self.meta_handle.read_df("global", read_metadata=False)

        tempmeta = outdf.reset_index(inplace=False)
        tempmeta = tempmeta.set_index(
            ["File Index", "Image Index"], drop=True, append=False, inplace=False
        )
        failed_fovs = tempmeta.loc[failed_file_idx]["fov"].unique().tolist()

        outdf = outdf.drop(failed_fovs)

        if self.t_range != None:
            outdf = outdf.reset_index(inplace=False)
            outdf["timepoints"] = outdf["timepoints"] - self.t_range[0]
            outdf = outdf.set_index(
                ["fov", "timepoints"], drop=True, append=False, inplace=False
            )

        self.meta_handle.write_df("global", outdf, metadata=self.metadata)

#### Specify Paths

Begin by defining the directory in which all processing will be done, as well as the initial nd2 file we will be <br>processing. This line should be run everytime you open a new python kernel.

The format should be: `headpath = "/path/to/folder"` and `nd2file = "/path/to/file.nd2"`

For example:
```
headpath = "/n/scratch2/de64/2019-05-31_validation_data"
nd2file = "/n/scratch2/de64/2019-05-31_validation_data/Main_Experiment.nd2"
```

Ideally, these files should be placed in a storage location with relatively fast I/O

In [None]:
headpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/FISH_headpath"
hdf5inputpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/run"

## Extract to hdf5 files

In this section, we will be extracting our image data. Currently this notebook only supports `.nd2` format; however <br>there are `.tiff` extractors in the TrenchRipper source files that are being added to `Master.ipynb` soon.

In the abstract, this step will take a single `.nd2` file and split it into a set of `.hdf5` files stored in <br>`headpath/hdf5`. Splitting the file up in this way will facilitate quick procesing in later steps. Each field of <br>view will be split into one or more `.hdf5` files, depending on the number of images per file requested (more on <br>this later). 

To keep track of which output files correspond to which FOVs, as well as to keep track of experiment metadata, the <br>extractor also outputs a `metadata.hdf5` file in the `headpath` folder. The data from this step is accessible in <br>that `metadata.hdf5` file under the `global` key. If you would like to look at this metadata, you may use the <br>`tr.utils.pandas_hdf5_handler` to read from this file. Later steps will add additional metadata under different <br>keys into the `metadata.hdf5` file.

#### Start Dask Workers

First, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=20,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Extraction

Now that we have our cluster scheduler spun up, it is time to convert files. This will be handled by the <br>`hdf5_extractor` object. This extractor will pull up each FOV and split it such that each derived `.hdf5` file <br>contains, at maximum, N timepoints of that FOV per file. The image data stored in these files takes the <br>form of `(N,Y,X)` arrays that are accessible using the desired channel name as a key. 

The arguments for this extractor are:

 - **nd2file** : The filepath to the `.nd2` file you intend to extract.
 
 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **tpts_per_file** : The maximum number of timepoints stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **ignore_fovmetadata** : Used when `.nd2` data is corrupted and does not possess records for stage positions or <br>timepoints. Only set `False` if the extractor throws errors on metadata handling.

 - **nd2reader_override** : Overrides values in metadata recovered using the `nd2reader`. Currently set to <br>`{"z_levels":[],"z_coordinates":[]}` by default to correct a known issue where z coordinates are mistakenly <br>interpreted as a z stack. See the [nd2reader](https://rbnvrw.github.io/nd2reader/) documentation for more info.

In [None]:
hdf5_extractor = marlin_extractor(hdf5inputpath, headpath)

##### Extraction Parameters

Here, you may set the time interval you want to extract. Useful for cropping data to the period exhibiting the dynamics of interest.

Optionally take notes to add to the `metadata.hdf5` file. Notes may also be taken directly in this notebook.

In [None]:
hdf5_extractor.inter_set_params()

##### Begin Extraction 

Running the following line will start the extraction process. This may be monitored by examining the `Dask Dashboard` <br> under the link displayed earlier. Once the computation is complete, move to the next line.

This step may take a long time, though it is possible to speed it up using additional workers.

In [None]:
hdf5_extractor.extract(dask_controller)

##### Shutdown Dask

Once extraction is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.shutdown()

## Kymographs

Now that you have extracted your data into a series of `.hdf5` files, we will now perform identification and cropping <br>of the individual trenches/growth channels present in the images. This algorithm assumes that your growth trenches <br>are vertically aligned and that they alternate in their orientation from top to bottom. See the example image for the <br>correct geometry:

![example_image](./resources/example_image.jpg)

The output of this step will be a set of `.hdf5` files stored in `headpath/kymograph`. The image data stored in these <br>files takes the form of `(K,T,Y,X)` arrays where K is the trench index, T is time, and Y,X are the crop dimensions. <br>These arrays are accessible using keys of the form `"[Image Channel]"`. For example, looking up phase channel <br>data of trenches in the topmost row of an image will require the key `"Phase"`

### Test Parameters



##### Initialize the interactive kymograph class

As a first step, initialize the `tr.interactive.kymograph_interactive` class that will be help us choose the <br>parameters we will use to generate kymographs. 

In [None]:
interactive_kymograph = tr.kymograph_interactive(headpath)

You will now want to select a few test FOVs to try out parameters on, the channel you want to detect trenches on, and <br>the time interval on which you will perform your processing.

The arguments for this step are:

- **seg_channel (string)** : The channel name that you would like to segment on.

- **invert (list)** : Whether or not you want to invert the image before detecting trenches. By default, it is assumed that <br>the trenches have a high pixel intensity relative to the background. This should be the case for Phase Contrast and <br>Fluorescence Imageing, but may not be the case for Brightfield Imaging, in which case you will want to invert the image.

- **fov_list (list)** : List of integers corresponding to the FOVs that you wish to make test kymographs of.

- **t_subsample_step (int)** : Step size to be used for subsampling input files in time, recommend that subsampling results in <br>between 5 and 10 timepoints for quick processing.

Hit the "Run Interact" button to lock in your parameters. The button will become transparent briefly and become solid again <br>when processing is complete. After that has occured, move on to the next step. 

In [None]:
interactive_kymograph.import_hdf5_interactive()

##### Tune "trench-row" detection hyperparameters

The kymograph code begins by detecting the positions of trench rows in the image as follows:

1. Reducing each 2D image to a 1D signal along the y-axis by computing the qth percentile of the data along the x-axis
2. Smooth this signal using a median kernel
3. Normalize the signal by linearly scaling 0. and 1. to the minimum and maximum, respectively
4. Use a set threshold to determine the trench row poisitons

The arguments for this step are:

 - **y_percentile (int)** : Percentile to use for step 1.

 - **smoothing_kernel_y_dim_0 (int)** : Median kernel size to use for step 2.

 - **y_percentile_threshold (float)** : Threshold to use in step 4.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line.

In [None]:
interactive_kymograph.preview_y_precentiles_interactive()

##### Tune "trench-row" cropping hyperparameters

Next, we will use the detected rows to perform cropping of the input image in the y-dimension:

1. Determine edges of trench rows based on threshold mask.
2. Filter out rows that are too small.
3. Use the remaining rows to compute the drift in y in each image.
4. Apply the drift to the initally detected rows to get rows in all timepoints.
5. Perform cropping using the "end" of the row as reference (the end referring to the part of the trench farthest from <br>the feeding channel).

Step 5 performs a simple algorithm to determine the orientation of each trench:

```
row_orientations = [] # A list of row orientations, starting from the topmost row
if the number of detected rows == 'Number of Rows': 
    row_orientations.append('Orientation')
elif the number of detected rows < 'Number of Rows':
    row_orientations.append('Orientation when < expected rows')
for row in rows:
    if row_orientations[-1] == downward:
        row_orientations.append(upward)
    elif row_orientations[-1] == upward:
        row_orientations.append(downward)
```

Additionally, if the device tranches face a single direction, alternation of row orientation may be turned off by setting the<br> `Alternate Orientation?` argument to False. The `Use Median Drift?` argument, when set to True, will use the<br> median drift in y across all FOVs for drift correction, instead of doing drift correction independently for all FOVs. <br>This can be useful if there are a large fraction of FOVs which are failing drift correction. Note that `Use Median Drift?` <br>sets this behavior for both y and x drift correction.

The arguments for this step are:

 - **y_min_edge_dist (int)** : Minimum row length necessary for detection (filters out small detected objects).

 - **padding_y (int)** : Padding to add to the end of trench row when cropping in the y-dimension.

 - **trench_len_y (int)** : Length from the end of each trench row to the feeding channel side of the crop.

 - **Number of Rows (int)** : The number of rows to expect in your image. For instance, two in the example image.
 
 - **Alternate Orientation? (bool)** : Whether or not to alternate the orientation of consecutive rows.

 - **Orientation (int)** : The orientation of the top-most row where 0 corresponds to a trench with a downward-oriented trench <br>opening and 1 corresponds to a trench with an upward-oriented trench opening.

 - **Orientation when < expected rows(int)** : The orientation of the top-most row when the number of detected rows is less than <br>expected. Useful if your trenches drift out of your image in some FOVs.
 
 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.

 - **images_per_row(int)** : How many images to output per row for this widget.

Running the following widget will display y-cropped images for each fov and timepoint.

In [None]:
interactive_kymograph.preview_y_crop_interactive()

##### Tune trench detection hyperparameters

Next, we will detect the positions of trenchs in the y-cropped images as follows:

1. Reducing each 2D image to a 1D signal along the x-axis by computing the qth percentile of the data along the y-axis.
2. Determine the signal background by smoothing this signal using a large median kernel.
3. Subtract the background signal.
4. Smooth the resultant signal using a median kernel.
5. Use an [otsu threhsold](https://imagej.net/Auto_Threshold#Otsu) to determine the trench midpoint poisitons.

After this, x-dimension drift correction of our detected midpoints will be performed as follows:

6. Begin at t=1
7. For $m \in \{midpoints(t)\}$ assign $n \in \{midpoints(t-1)\}$ to m if n is the closest midpoint to m at time $t-1$,<br>
points that are not the closest midpoint to any midpoints in m will not be mapped.
8. Compute the translation of each midpoint at time.
9. Take the average of this value as the x-dimension drift from time t-1 to t.

The arguments for this step are:

 - **t (int)** : Timepoint to examine the percentiles and threshold in.

 - **x_percentile (int)** : Percentile to use for step 1.

 - **background_kernel_x (int)** : Median kernel size to use for step 2.

 - **smoothing_kernel_x (int)** : Median kernel size to use for step 4.

 - **otsu_scaling (float)** : Scaling factor to apply to the threshold determined by Otsu's method.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line. In addition, it will display the detected midpoints for each of your timepoints. <br>If there is too much sparsity, or discontinuity, your drift correction will not be accurate.

In [None]:
interactive_kymograph.preview_x_percentiles_interactive()

##### Tune trench cropping hyperparameters

Trench cropping simply uses the drift-corrected midpoints as a reference and crops out some fixed length around them <br>
to produce an output kymograph. **Note that the current implementation does not allow trench crops to overlap**. If your<br>
trench crops do overlap, the error will not be caught here, but will cause issues later in the pipeline. As such, try <br>
to crop your trenches as closely as possible. This issue will be fixed in a later update.

The arguments for this step are:

 - **trench_width_x (int)** : Trench width to use for cropping.

 - **trench_present_thr (float)** : Trenches that appear in less than this percent of FOVs will be eliminated from the dataset.<br>
If not removed, missing positions will be inferred from the image drift.

 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.


Running the following widget will display a random kymograph for each row in each fov and will also produce midpoint plots <br>showing retained midpoints

In [None]:
interactive_kymograph.preview_kymographs_interactive()

##### Export and save hyperparameters

Run the following line to register and display the parameters you have selected for kymograph creation.

In [None]:
interactive_kymograph.process_results()

If you are satisfied with the above parameters, run the following line to write these parameters to disk at `headpath/kymograph.par`<br>
This file will be used to perform kymograph creation in the next section.

In [None]:
interactive_kymograph.write_param_file()

### Generate Kymograph

##### Start Dask Workers

Again, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2 for kymograph creation. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=20,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Kymograph Cropping

Now that we have our cluster scheduler spun up, we will extract kymographs using the parameters stored in `headpath/kymograph.par`. <br>
This will be handled by the `kymograph_cluster` object. This will detect trenches in all of the files present in `headpath/hdf5` that <br>
you created in the first step. It will then crop these trenches and place the crops in a series of `.hdf5` files in `headpath/kymograph`. <br>
These files will store image data in the form of `(K,T,Y,X)` arrays where K is the trench index, T is time and Y,X are the image dimensions <br>
of the crop.

The arguments for this step are:

 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **trenches_per_file** : The maximum number of trenches stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **paramfile** : Set to true if you want to use parameters from `headpath/kymograph.par` Otherwise, you will have to specify <br>
 parameters as direct arguments to `kymograph_cluster`.

In [None]:
kymoclust = tr.kymograph.kymograph_cluster(
    headpath=headpath, trenches_per_file=200, paramfile=True
)

##### Begin Kymograph Cropping 

Running the following line will start the cropping process. This may be monitored by examining the `Dask Dashboard` <br>
under the link displayed earlier. Once the computation is complete, move to the next line.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.generate_kymographs(dask_controller)

In [None]:
ff = tr.focus_filter(headpath)

In [None]:
ff.choose_filter_channel_inter()

In [None]:
ff.plot_histograms()

In [None]:
ff.plot_focus_threshold_inter()

In [None]:
ff.write_param_file()

##### Post-process Images

After the above step, kymographs will have been created for each `.hdf5` input file. They will now need to be reorganized <br>
into a new set of files such that each file has, at most, `trenches_per_file` trenches in each file.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.post_process(dask_controller)

##### Check kymograph statistics

Run the next line to display some statistics from kymograph creation. The outputs are:

 - **fovs processed** : The number of FOVs successfully processed out of the total number of FOVs
 - **rows processed** : The number of rows of trenches processed out of the total number of rows
 - **trenches processed** : The number of trenches successfully processed
 - **row/fov** : The average number of rows successfully processed per FOV
 - **trenches/fov** : The average number of trenches successfully processed per FOV
 - **failed fovs** : A list of failed FOVs. Spot check these FOVs in the viewer to determine potential problems

In [None]:
kymoclust.kymo_report()

##### Shutdown Dask

Once cropping is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.shutdown()

#### Output

At this point you may want to use your output. The output of this step is a set of `.hdf5` files stored in <br>`headpath/kymograph`. The image data stored in these files takes the form of `(K,T,Y,X)` arrays <br>where K is the trench index, T is time, and Y,X are the crop dimensions.

These arrays are accessible using keys of the form `"[Trench Row Number]/[Image Channel]"`. <br>For example, looking up phase channel data of trenches in the topmost row of an image will require <br>the key `"0/Phase"` The metadata associated with these files is a large pandas dataframe relating <br>crops to original FOVs, accessible using the "kymograph" key on `headpath/metadata.hdf5`

To assist in accessing this file, you may use the `trenchripper.pandas_hdf5_handler` object to <br>interface with this file as follows:

In [None]:
import dask.dataframe as dd
import dask.delayed as delayed
from distributed.client import futures_of
import numpy as np
import pandas as pd
import h5py
import scipy.signal
import skimage as sk
from time import sleep

In [None]:
def get_image_measurements(
    kymographpath, channels, file_idx, output_name, img_fn, *args, **kwargs
):
    df = dd.read_parquet(kymographpath + "/metadata")

    working_dfs = []

    proc_file_path = kymographpath + "/kymograph_" + str(file_idx) + ".hdf5"
    with h5py.File(proc_file_path, "r") as infile:
        working_filedf = df[df["File Index"] == file_idx].compute()
        trench_idx_list = working_filedf["File Trench Index"].unique().tolist()
        for trench_idx in trench_idx_list:
            trench_df = working_filedf[
                working_filedf["File Trench Index"] == trench_idx
            ]

            for channel in channels:
                kymo_arr = infile[channel][trench_idx]
                fn_out = [
                    img_fn(kymo_arr[i], *args, **kwargs)
                    for i in range(kymo_arr.shape[0])
                ]
                trench_df[channel + " " + output_name] = fn_out
            working_dfs.append(trench_df)

    out_df = pd.concat(working_dfs)
    return out_df


def get_all_image_measurements(
    headpath, output_path, channels, output_name, img_fn, *args, **kwargs
):
    kymographpath = headpath + "/kymograph"
    df = dd.read_parquet(kymographpath + "/metadata")

    file_list = df["File Index"].unique().compute().tolist()

    delayed_list = []

    for file_idx in file_list:
        df_delayed = delayed(get_image_measurements)(
            kymographpath, channels, file_idx, output_name, img_fn, *args, **kwargs
        )
        delayed_list.append(df_delayed.persist())

    ## filtering out non-failed dataframes ##
    all_delayed_futures = []
    for item in delayed_list:
        all_delayed_futures += futures_of(item)
    while any(future.status == "pending" for future in all_delayed_futures):
        sleep(0.1)

    good_delayed = []
    for item in delayed_list:
        if all([future.status == "finished" for future in futures_of(item)]):
            good_delayed.append(item)

    ## compiling output dataframe ##
    df_out = dd.from_delayed(good_delayed).persist()
    df_out["FOV Parquet Index"] = df_out.index
    df_out = df_out.set_index("FOV Parquet Index", drop=True, sorted=False)
    df_out = df_out.repartition(partition_size="25MB").persist()

    dd.to_parquet(
        df_out,
        output_path,
        engine="fastparquet",
        compression="gzip",
        write_metadata_file=True,
    )

In [None]:
headpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/FISH_headpath"

In [None]:
kymograph_metadata = dd.read_parquet(headpath + "/kymograph/metadata")

In [None]:
get_all_image_measurements(
    headpath,
    headpath + "/percentiles",
    ["RFP", "Cy5", "Cy7"],
    "98th Percentile",
    np.percentile,
    98,
)

In [None]:
import numpy as np
import pandas as pd
import skimage as sk
import scipy.signal
from matplotlib import pyplot as plt
import scipy as sp
import sklearn as skl
import sklearn.mixture


def get_background_dist(values):
    hist_range = (0, np.percentile(values, 60))
    freq, val = np.histogram(values, bins=100, range=hist_range)
    mu_n = val[np.argmax(freq)]
    lower_tail = values[values < mu_n]
    std_n = sp.stats.halfnorm.fit(-lower_tail)[1]
    return mu_n, std_n


def get_background_dist_peak(values):
    hist_range = (0, np.percentile(values, 90))
    hist_count, hist_vals = np.histogram(values, bins=100, range=hist_range)
    peaks = sp.signal.find_peaks(hist_count, distance=20)[0]
    mu_n = hist_vals[peaks[0]]
    lower_tail = values[values < mu_n]
    std_n = sp.stats.halfnorm.fit(-lower_tail)[1]
    return mu_n, std_n


def get_background_thr(values, background_fn, background_scaling=5.0):
    mu_n, std_n = background_fn(values)
    back_thr = mu_n + background_scaling * std_n
    return back_thr


def get_signal_sum(df):
    trench_group = df.groupby(["trenchid"])
    barcodes = trench_group.apply(
        lambda x: np.array(
            x["RFP 98th Percentile"].tolist()
            + x["Cy5 98th Percentile"].tolist()
            + x["Cy7 98th Percentile"].tolist()
        ).astype(float)
    )

    short = barcodes.apply(lambda x: len(x) != 24)
    for idx in np.where(short)[0]:
        barcodes[idx] = np.array([0.0 for i in range(24)])

    barcodes_arr = np.array(barcodes.to_list())
    barcodes_arr_no_short = np.array(barcodes[~short].to_list())

    signal_sum = np.sum(barcodes_arr, axis=1)
    signal_sum_no_short = np.sum(barcodes_arr_no_short, axis=1)

    signal_filter_thr = get_background_thr(
        signal_sum_no_short, get_background_dist_peak
    )
    print("Signal Threshold: " + str(signal_filter_thr))
    high_signal_mask = signal_sum > signal_filter_thr
    high_signal_barcodes = barcodes[high_signal_mask]

    trenchid_list = high_signal_barcodes.index.get_level_values(
        "trenchid"
    ).values.tolist()
    high_signal_metadata = df.loc[trenchid_list]

    return signal_sum, high_signal_barcodes


def get_gmm_params(values):
    gmm = skl.mixture.GaussianMixture(n_components=2, n_init=10)
    gmm.fit(values.reshape(-1, 1))
    #     probs = gmm.predict_proba(values.reshape(-1,1))
    return gmm.means_, (gmm.covariances_) ** (1 / 2)


def get_gmm_hard_assign(values):
    gmm = skl.mixture.GaussianMixture(n_components=2, n_init=10)
    gmm.fit(values.reshape(-1, 1))
    lower_mean_idx = np.argmin(gmm.means_)
    assign = gmm.predict(values.reshape(-1, 1))
    if lower_mean_idx == 1:
        assign = (-assign) + 1
    return assign


def get_gmm_probs(values):
    gmm = skl.mixture.GaussianMixture(n_components=2, n_init=10)
    gmm.fit(values.reshape(-1, 1))
    probs = gmm.predict_proba(values.reshape(-1, 1))
    return probs


def str_to_bool(string):
    code = {"1": True, "0": False}
    conv_str = np.array(list(map(lambda x: code[x], string)))
    return conv_str


def bool_to_str(integer):
    rev_code = {True: "1", False: "0"}
    conv_int = "".join(list(map(lambda x: rev_code[x], integer)))
    return conv_int

In [None]:
kymograph_metadata = pd.read_parquet(headpath + "/percentiles")
kymograph_metadata = kymograph_metadata.set_index(["trenchid", "timepoints"], drop=True)
signal_sum, high_signal_barcodes_series = get_signal_sum(kymograph_metadata)
high_signal_barcodes = np.array([item for item in high_signal_barcodes_series])

In [None]:
high_signal_barcodes

In [None]:
plt.hist(signal_sum, bins=100, range=(0, 200000))
plt.title("Sum of all barcode signal")
plt.xlabel("Summed Intensity")
plt.show()

In [None]:
assign_list = []
for i in range(24):
    #     if i == 1:
    #         assign = high_signal_barcodes[:, i]>1100
    #     else:
    assign = get_gmm_hard_assign(high_signal_barcodes[:, i])
    assign_list.append(assign)
assign_arr = np.array(assign_list, dtype=bool)

In [None]:
fig, axes = plt.subplots(3, 8, figsize=(24, 8))
colors = ["salmon", "violet", "grey"]

for i in range(high_signal_barcodes.shape[1]):
    row_idx = i // 8
    column_idx = i % 8
    color = colors[row_idx]
    max_val = np.percentile(high_signal_barcodes[:, i].flatten(), 99)

    bins = np.linspace(0, max_val, num=50)
    on_arr = high_signal_barcodes[:, i][assign_arr[i]]
    off_arr = high_signal_barcodes[:, i][~assign_arr[i]]

    on_frq, on_edges = np.histogram(on_arr, bins)
    off_frq, off_edges = np.histogram(off_arr, bins)
    ttl_frq = np.sum(on_frq) + np.sum(off_frq)
    on_frq, off_frq = (on_frq / ttl_frq, off_frq / ttl_frq)

    axes[row_idx, column_idx].bar(
        off_edges[:-1], off_frq, width=np.diff(off_edges), align="edge", color="black"
    )
    axes[row_idx, column_idx].bar(
        on_edges[:-1], on_frq, width=np.diff(on_edges), align="edge", color=color
    )

    fig.tight_layout()

fig.text(0.5, -0.04, "Trench Intensity", ha="center", size=26)
fig.text(-0.01, 0.5, "# Trenches", va="center", rotation="vertical", size=26)

plt.tight_layout()
# plt.savefig("./2020-10-10_lDE11_figure_1.png",dpi=300,bbox_inches="tight")
plt.show()

In [None]:
def barcode_to_FISH(barcodestr):
    cycleorder = [
        0,
        1,
        6,
        7,
        12,
        13,
        18,
        19,
        2,
        3,
        8,
        9,
        14,
        15,
        20,
        21,
        4,
        5,
        10,
        11,
        16,
        17,
        22,
        23,
    ]

    barcode = [bool(int(item)) for item in list(barcodestr)]
    FISH_barcode = np.array([barcode[i] for i in cycleorder])
    FISH_barcode = "".join(FISH_barcode.astype(int).astype(str))

    return FISH_barcode

In [None]:
lDE13_nanopore.iloc[0]

In [None]:
lDE13_nanopore = pd.read_csv("./lDE13_final_df.tsv", delimiter="\t", index_col=0)
lDE13_lookup = {}
for _, row in lDE13_nanopore.iterrows():
    lDE13_lookup[barcode_to_FISH(row["24bit_barcode"])] = row
lDE13_lookup_df = pd.DataFrame(lDE13_lookup).T

In [None]:
assign_strs = np.apply_along_axis(
    lambda x: "".join(x.astype(int).astype(str)), 0, assign_arr
)

In [None]:
barcode_df = pd.DataFrame(high_signal_barcodes_series, columns=["Barcode Signal"])
barcode_df["Barcode"] = assign_strs
barcode_df = barcode_df.reset_index(drop=False)

In [None]:
lDE13_lookup_df[0:2]

In [None]:
mergeddf = []
for _, row in barcode_df.iterrows():
    try:
        lDE_row = lDE13_lookup_df.loc[row["Barcode"]]
        entry = pd.concat([row, lDE_row])
        mergeddf.append(entry)
    except:
        pass
mergeddf = pd.DataFrame(mergeddf)

In [None]:
len(barcode_df)

In [None]:
len(mergeddf)

In [None]:
percent_called = len(mergeddf) / len(barcode_df)

In [None]:
percent_called

In [None]:
mergeddf.to_csv("./trench_sgrna_map.tsv", sep="\t")

In [None]:
mergeddf = pd.read_csv("./trench_sgrna_map.tsv", delimiter="\t")

In [None]:
headpath = "/home/de64/scratch/de64/2020-11-23_mVenus_KO_library/mVenus_headpath"

region_props = pd.read_pickle(headpath + "/analysis.pkl").loc[
    (slice(None), slice(None), list(range(2, 5)), slice(None))
]
region_props = region_props.reset_index()
region_props = region_props.set_index(
    ["trenchid", "timepoints", "Intensity Channel", "Objectid"], drop=True
)
region_props = region_props.sort_index()

In [None]:
kymo_first_tpt = kymograph_metadata.loc[(slice(None), 0), :]

In [None]:
def get_trenchid_map(kymodf1, kymodf2):
    trenchid_map = {}
    for fov in kymo_first_tpt["fov"].unique().tolist():
        df1_chunk = kymodf1[kymodf1["fov"] == fov]
        df2_chunk = kymodf2[kymodf2["fov"] == fov]

        df1_xy = df1_chunk[["y (local)", "x (local)"]].values
        df2_xy = df2_chunk[["y (local)", "x (local)"]].values

        ymat = np.subtract.outer(df1_xy[:, 0], df2_xy[:, 0])
        xmat = np.subtract.outer(df1_xy[:, 1], df2_xy[:, 1])
        distmat = (ymat**2 + xmat**2) ** (1 / 2)
        mapping = np.argmin(distmat, axis=1)

        df1_trenchids = df1_chunk.index.get_level_values("trenchid").tolist()
        df2_trenchids = df2_chunk.index.get_level_values("trenchid").tolist()

        trenchid_map.update(
            {
                trenchid: df2_trenchids[mapping[i]]
                for i, trenchid in enumerate(df1_trenchids)
            }
        )
    return trenchid_map

In [None]:
trenchid_map = get_trenchid_map(kymo_first_tpt, region_props)

In [None]:
trenchid_map

In [None]:
plt.hist(mchy_df["mean_intensity"], range=(0, 3000), bins=100)
plt.show()

In [None]:
mchy_df = region_props.loc[(slice(None), slice(None), ["mCherry"]), :].droplevel(
    "Intensity Channel"
)
yfp_df = region_props.loc[(slice(None), slice(None), ["YFP"]), :].droplevel(
    "Intensity Channel"
)

In [None]:
import copy

mean_mchy = mchy_df.groupby("trenchid").mean()
mean_yfp = yfp_df.groupby("trenchid").mean()

output_df = copy.deepcopy(mergeddf)
output_df["mean_mchy"] = mergeddf.apply(
    lambda x: mean_mchy.loc[trenchid_map[x["trenchid"]]]["mean_intensity"], axis=1
)
output_df["mean_yfp"] = mergeddf.apply(
    lambda x: mean_yfp.loc[trenchid_map[x["trenchid"]]]["mean_intensity"], axis=1
)

In [None]:
output_df.to_csv("./final_output_df.tsv", sep="\t")

In [None]:
import seaborn as sns

output_df = pd.read_csv("./final_output_df.tsv", delimiter="\t")
lDE13_nanopore = pd.read_csv("./lDE13_final_df.tsv", delimiter="\t", index_col=0)

In [None]:
len(output_df["barcodeid"].unique())

In [None]:
len(output_df["sgrnaid"].unique())

In [None]:
len(output_df)

In [None]:
# targetid = 41
targetid = 53
example_target = output_df[output_df["targetid"] == targetid]
nanopore_ex = (
    lDE13_nanopore[lDE13_nanopore["targetid"] == targetid]
    .groupby("sgrnaid")
    .apply(lambda x: x[0:1])
)

In [None]:
example_target[example_target["num_mismatch"] == 0].iloc[0]

In [None]:
fig = plt.figure(figsize=(12, 8))
g = sns.swarmplot(
    x=20 - example_target["num_mismatch"],
    y=example_target["mean_yfp"],
    palette="cividis_r",
    size=6,
)
plt.xlabel("Number Matched", fontsize=30)
plt.xticks(fontsize=30)
plt.ylim(0, 14000)
plt.ylabel("Mean mVenus Intensity", fontsize=30)
plt.yticks(fontsize=30)
plt.tight_layout()

sns.despine()

plt.savefig("2020-12-05_DAC_fig_5A.png", dpi=150)
plt.show()

In [None]:
g = sns.swarmplot(
    x=example_target["sgrnaid"].max() - example_target["sgrnaid"],
    y=example_target["mean_yfp"],
    palette="cividis_r",
)
plt.xlabel("sgRNAid", fontsize=30)
# plt.xticks(fontsize=30)
plt.ylabel("Mean YFP Intensity", fontsize=30)
plt.yticks(fontsize=30)
# plt.ylim(0,1)
for i in range(3, 27, 6):
    g.axvspan(i - 0.5, i + 2.5, color="C0", alpha=0.1, lw=0)
g.axvspan(27 - 0.5, 27 + 1.5, color="C0", alpha=0.1, lw=0)
plt.show()

In [None]:
g = sns.swarmplot(
    x=20 - example_target["num_mismatch"],
    y=example_target["mean_yfp"],
    palette="cividis_r",
)
plt.xlabel("Number Matched", fontsize=30)
plt.xticks(fontsize=30)
plt.ylabel("Mean YFP Intensity", fontsize=30)
plt.yticks(fontsize=30)
# plt.ylim(0,1)
plt.show()

In [None]:
example_target

In [None]:
mchy_bkd = 150.0
max_percentile = np.percentile(ratio_df, 99)

ratio_df = (yfp_df["mean_intensity"] - yfp_df["mean_intensity"].min()) / (
    mchy_df["mean_intensity"] - mchy_bkd
)
ratio_df = ratio_df.groupby("trenchid").median()

# low_val = np.percentile(ratio_df,1.)
# high_val = np.percentile(ratio_df,95.)
# ratio_df = (ratio_df-low_val)/(high_val-low_val)
# ratio_df = pd.DataFrame(ratio_df)

In [None]:
plt.hist(ratio_df, bins=100, range=(0.0, 8.0))
plt.show()

In [None]:
plt.hist(ratio_df["mean_intensity"], bins=100, range=(0.0, 1.0))
plt.show()

In [None]:
import copy

output_df = copy.deepcopy(mergeddf)
output_df["mean_intensity"] = mergeddf.apply(
    lambda x: ratio_df.loc[trenchid_map[x["trenchid"]]]["mean_intensity"].values, axis=1
)

In [None]:
output_df.to_csv("./final_output_df.tsv", sep="\t")

In [None]:
import seaborn as sns
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt

In [None]:
output_df = pd.read_csv("./final_output_df.tsv", delimiter="\t")

sgrnaids, ncount = np.unique(output_df["sgrnaid"], return_counts=True)
well_sampled_sgrnaids = sgrnaids[ncount > 10]

output_df = output_df[output_df["sgrnaid"].isin(well_sampled_sgrnaids)]

In [None]:
sgrnaid_groupby = output_df.groupby("sgrnaid")
yfp = sgrnaid_groupby.apply(lambda x: (x["mean_yfp"] / x["mean_mchy"]))
mean_yfp = yfp.groupby("sgrnaid").apply(lambda x: np.mean(x))
mean_yfp_argsort = np.argsort(mean_yfp).values[::-1].tolist()
mean_yfp_argsort = mean_yfp.index.values[mean_yfp_argsort]
yfp_sorted = yfp.loc[mean_yfp_argsort, :].to_frame()
yfp_sorted.columns = ["YFP Ratio"]
yfp_sorted = yfp_sorted.droplevel(1)
yfp_sorted = yfp_sorted.reset_index()
sgrna_to_index = dict(
    zip(yfp_sorted["sgrnaid"].unique(), range(len(yfp_sorted["sgrnaid"].unique())))
)
yfp_sorted["Order"] = yfp_sorted["sgrnaid"].apply(lambda x: sgrna_to_index[x])

yfp_sorted = yfp_sorted[yfp_sorted["Order"] > 20]

min_yfp = np.min(yfp_sorted.groupby("Order")["YFP Ratio"].mean())
max_yfp = np.max(yfp_sorted.groupby("Order")["YFP Ratio"].mean())
yfp_sorted["Relative YFP Ratio"] = (yfp_sorted["YFP Ratio"] - min_yfp) / (
    max_yfp - min_yfp
)

In [None]:
fig = plt.figure(figsize=(12, 8))
g = sns.lineplot(
    data=yfp_sorted, x="Order", y="Relative YFP Ratio", n_boot=100, linewidth=5.0
)
# g.set_yscale("log")
# g.set_yticks([0.01,0.1,1.])
plt.xlabel("Rank", fontsize=30)
plt.xticks(fontsize=30)
plt.ylabel("Mean Normalized Intensity", fontsize=30)
plt.yticks(fontsize=30)
plt.ylim(0, 1)
plt.tight_layout()
plt.savefig("2021-03-31_Qbio_fig_1.png", dpi=150)
plt.show()

In [None]:
example_target = output_df[output_df["targetid"] == 53]
# example_target = output_df[output_df["targetid"]==54]

In [None]:
example_target.iloc[0]

In [None]:
row["num_mismatch"]

In [None]:
output_data = []
for _, row in example_target.iterrows():
    for item in row["mean_intensity"]:
        output_data.append([row["num_mismatch"], item])
output_data = np.array(output_data)[::10]

In [None]:
output_data[::100, 0]

In [None]:
for mismatch in range(0, 11, 3):
    mismatch_mask = output_data[:, 0] == mismatch
    masked_output = output_data[mismatch_mask]
    g = sns.distplot(
        masked_output[:, 1], kde=True, norm_hist=False, hist=False, bins=100
    )
#     g.set_xscale("log")
#     g.set_xticks([-2,-1,0.])
# plt.xlabel("Number Mismatch",fontsize=30)
# plt.xticks(fontsize=30)
# plt.xlim(-1.8,0.5)
# plt.ylabel("Mean Insensity Ratio",fontsize=30)
# plt.yticks(fontsize=30)
# plt.ylim(0,1)
plt.show()

# g = sns.swarmplot(x=20-example_target["num_mismatch"], y=example_target["mean_intensity"], palette="cividis");
# g.set_yscale("log")
# g.set_yticks([0.01,0.1,1.])

In [None]:
# plt.scatter(example_target["num_mismatch"],example_target["mean_intensity"])
sns.swarmplot(
    x="num_mismatch", y="mean_intensity", data=example_target, palette="cividis"
)
plt.xlabel("Number Mismatch", fontsize=30)
plt.xticks(fontsize=30)
plt.ylabel("Mean Insensity Ratio", fontsize=30)
plt.yticks(fontsize=30)
plt.ylim(0, 1)
plt.show()

In [None]:
g = sns.swarmplot(
    x=20 - example_target["num_mismatch"],
    y=example_target["mean_intensity"],
    palette="cividis",
)
g.set_yscale("log")
g.set_yticks([0.01, 0.1, 1.0])
plt.xlabel("Number Mismatch", fontsize=30)
plt.xticks(fontsize=30)
plt.ylabel("Mean Insensity Ratio", fontsize=30)
plt.yticks(fontsize=30)
plt.ylim(0, 1)
plt.show()

In [None]:
output_df.groupby("targetid").nunique()

In [None]:
sgrna_list = [992, 1001, 1014, 1020]

In [None]:
for sgrna in sgrna_list:
    sns.distplot(
        example_target[example_target["sgrnaid"] == sgrna]["mean_intensity"],
        norm_hist=True,
        hist=False,
    )
plt.xlabel("Number Mismatch", fontsize=30)
plt.xticks(fontsize=30)
plt.ylabel("Mean Insensity Ratio", fontsize=30)
plt.yticks(fontsize=30)
# plt.ylim(0,1)
plt.show()

In [None]:
nunique_df = example_target.groupby("sgrnaid").nunique()

In [None]:
sgrna_list = [992, 994, 1001, 1014, 1020]

In [None]:
example_target.groupby("sgrnaid").apply(lambda x: np.min(x["num_mismatch"]))[
    nunique_df["trenchid"] > 10
]

In [None]:
nunique_df["trenchid"] > 20

In [None]:
sorted(example_target["sgrnaid"].unique())

In [None]:
on_cds_mask = (
    (output_df["target_strand"] == 1)
    & (output_df["category"] == "Target")
    & (output_df["end"] <= 808475)
    & (output_df["end"] >= 807758)
)

promoter_mask = (output_df["category"] == "Target") & (output_df["end"] > 808475)

cds_antisense_mask = (
    (output_df["target_strand"] == -1)
    & (output_df["category"] == "Target")
    & (output_df["end"] <= 808475)
    & (output_df["end"] >= 807758)
)

dummy_mask = output_df["category"] == "Dummy"

In [None]:
on_cds_df = output_df[on_cds_mask]
promoter_df = output_df[promoter_mask]
antisense_df = output_df[cds_antisense_mask]
dummy_df = output_df[dummy_mask]

In [None]:
sorted(on_cds_df["targetid"].unique())

In [None]:
for targetid in on_cds_df["targetid"].unique():
    subdf = on_cds_df[on_cds_df["targetid"] == targetid]
    sns.lineplot(subdf["num_mismatch"], subdf["mean_yfp"], ci=None, color="C0")
for targetid in promoter_df["targetid"].unique():
    subdf = promoter_df[promoter_df["targetid"] == targetid]
    sns.lineplot(subdf["num_mismatch"], subdf["mean_yfp"], ci=None, color="C1")
for targetid in antisense_df["targetid"].unique():
    subdf = antisense_df[antisense_df["targetid"] == targetid]
    sns.lineplot(subdf["num_mismatch"], subdf["mean_yfp"], ci=None, color="grey")
# plt.ylim(0.,1.)
plt.plot()

In [None]:
fig = plt.figure(figsize=(12, 8))
g = sns.swarmplot(
    x=20 - example_target["num_mismatch"],
    y=example_target["mean_yfp"],
    palette="cividis_r",
    size=6,
)
plt.xlabel("Number Matched", fontsize=30)
plt.xticks(fontsize=30)
plt.ylim(0, 14000)
plt.ylabel("Mean YFP Intensity", fontsize=30)
plt.yticks(fontsize=30)
plt.tight_layout()

sns.despine()

plt.savefig("2020-12-05_DAC_fig_5A.png", dpi=150)
plt.show()

In [None]:
20 - subdf["num_mismatch"]

In [None]:
fig = plt.figure(figsize=(12, 8))
for targetid in on_cds_df["targetid"].unique():
    subdf = on_cds_df[on_cds_df["targetid"] == targetid]
    #     plt.plot(20-subdf["num_mismatch"],subdf["mean_yfp"],color="C0")
    sns.lineplot(20 - subdf["num_mismatch"], subdf["mean_yfp"], ci=None, color="C0")
plt.xlabel("Number Matched", fontsize=30)
plt.xticks(fontsize=30)
plt.ylim(0, 14000)
plt.ylabel("Mean mVenus Intensity", fontsize=30)
plt.yticks(fontsize=30)
plt.tight_layout()

sns.despine()

plt.savefig("2020-12-05_DAC_fig_5B.png", dpi=150)
plt.plot()

In [None]:
plt.hist(on_cds_df["mean_intensity"], bins=30)
plt.show()

In [None]:
plt.hist(dummy_df["mean_intensity"], bins=30)
plt.show()