# TrenchRipper Master Notebook

## Introduction

This notebook contains the entire `TrenchRipper` pipline, divided into simple steps. This pipline is ideal for Mother <br>Machine image data where cells possess fluorescent segmentation markers. Segmentation on phase or brightfield data <br>is being developed, but is still an experimental feature.

The steps in this pipeline are as follows:
1. Extracting your Mother Machine data (.nd2) into hdf5 format
2. Identifying and cropping individual trenches into kymographs
3. Segmenting cells with a fluorescent marker
4. Determining lineages and object properties

In each step, the user will dynamically specify parameters using a series of interactive diagnostics on their dataset. <br>Following this, a parameter file will be written to disk and then used to deploy a parallel computation on the <br>dataset, either locally or on a SLURM cluster.


This is intended as an end-to-end solution to analyzing Mother Machine data. As such, **it is not trivial to plug data <br>directly into intermediate steps**, as it will lack the correct formatting and associated metadata. A notable <br>exception to this is using another program to segment data. The library references binary segmentation masks using <br>only metadata derived from their associated kymographs. As such, it is possible to generate segmentations on these <br>kymographs elsewhere and place them into the segmentation data path to have `TrenchRipper` act on those <br>segmentations instead. More on this in the segmentation section...

#### Imports

Run this section to import all relavent packages and libraries used in this notebook. You must run this everytime you open a new python kernel.

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

import warnings

warnings.filterwarnings(action="once")

import matplotlib

matplotlib.rcParams["figure.figsize"] = [20, 10]

#### Specify Paths

Begin by defining the directory in which all processing will be done, as well as the initial nd2 file we will be <br>processing. This line should be run everytime you open a new python kernel.

The format should be: `headpath = "/path/to/folder"` and `nd2file = "/path/to/file.nd2"`

For example:
```
headpath = "/n/scratch2/de64/2019-05-31_validation_data"
nd2file = "/n/scratch2/de64/2019-05-31_validation_data/Main_Experiment.nd2"
```

Ideally, these files should be placed in a storage location with relatively fast I/O

In [None]:
headpath = "/n/scratch2/de64/2020-03-02_plasmid_loss/"
nd2file = "/n/scratch2/de64/2020-03-02_plasmid_loss/Basilisk_SJC25x2_SJC28_Losses.nd2"

## Extract to hdf5 files

In this section, we will be extracting our image data. Currently this notebook only supports `.nd2` format; however <br>there are `.tiff` extractors in the TrenchRipper source files that are being added to `Master.ipynb` soon.

In the abstract, this step will take a single `.nd2` file and split it into a set of `.hdf5` files stored in <br>`headpath/hdf5`. Splitting the file up in this way will facilitate quick procesing in later steps. Each field of <br>view will be split into one or more `.hdf5` files, depending on the number of images per file requested (more on <br>this later). 

To keep track of which output files correspond to which FOVs, as well as to keep track of experiment metadata, the <br>extractor also outputs a `metadata.hdf5` file in the `headpath` folder. The data from this step is accessible in <br>that `metadata.hdf5` file under the `global` key. If you would like to look at this metadata, you may use the <br>`tr.utils.pandas_hdf5_handler` to read from this file. Later steps will add additional metadata under different <br>keys into the `metadata.hdf5` file.

#### Start Dask Workers

First, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=40,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Extraction

Now that we have our cluster scheduler spun up, it is time to convert files. This will be handled by the <br>`hdf5_extractor` object. This extractor will pull up each FOV and split it such that each derived `.hdf5` file <br>contains, at maximum, N timepoints of that FOV per file. The image data stored in these files takes the <br>form of `(N,Y,X)` arrays that are accessible using the desired channel name as a key. 

The arguments for this extractor are:

 - **nd2file** : The filepath to the `.nd2` file you intend to extract.
 
 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **tpts_per_file** : The maximum number of timepoints stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **ignore_fovmetadata** : Used when `.nd2` data is corrupted and does not possess records for stage positions or <br>timepoints. Only set `False` if the extractor throws errors on metadata handling.

 - **nd2reader_override** : Overrides values in metadata recovered using the `nd2reader`. Currently set to <br>`{"z_levels":[],"z_coordinates":[]}` by default to correct a known issue where z coordinates are mistakenly <br>interpreted as a z stack. See the [nd2reader](https://rbnvrw.github.io/nd2reader/) documentation for more info.

In [None]:
hdf5_extractor = tr.ndextract.hdf5_fov_extractor(
    nd2file,
    headpath,
    tpts_per_file=50,
    ignore_fovmetadata=False,
    nd2reader_override={
        "z_levels": [],
        "z_coordinates": [],
        "fields_of_view": list(range(71)),
        "num_frames": 275,
    },
)

##### Extraction Parameters

Here, you may set the time interval you want to extract. Useful for cropping data to the period exhibiting the dynamics of interest.

Optionally take notes to add to the `metadata.hdf5` file. Notes may also be taken directly in this notebook.

In [None]:
hdf5_extractor.inter_set_params()

##### Begin Extraction 

Running the following line will start the extraction process. This may be monitored by examining the `Dask Dashboard` <br> under the link displayed earlier. Once the computation is complete, move to the next line.

This step may take a long time, though it is possible to speed it up using additional workers.

In [None]:
hdf5_extractor.extract(dask_controller)

##### Shutdown Dask

Once extraction is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.shutdown()

## Kymographs

Now that you have extracted your data into a series of `.hdf5` files, we will now perform identification and cropping <br>of the individual trenches/growth channels present in the images. This algorithm assumes that your growth trenches <br>are vertically aligned and that they alternate in their orientation from top to bottom. See the example image for the <br>correct geometry:

![example_image](./resources/example_image.jpg)

The output of this step will be a set of `.hdf5` files stored in `headpath/kymograph`. The image data stored in these <br>files takes the form of `(K,T,Y,X)` arrays where K is the trench index, T is time, and Y,X are the crop dimensions. <br>These arrays are accessible using keys of the form `"[Trench Row Number]/[Image Channel]"`. For example, <br>looking up phase channel data of trenches in the topmost row of an image will require the key `"0/Phase"`

### Test Parameters



##### Initialize the interactive kymograph class

As a first step, initialize the `tr.interactive.kymograph_interactive` class that will be help us choose the <br>parameters we will use to generate kymographs. 

In [None]:
interactive_kymograph = tr.kymograph_interactive(headpath)

##### Examine Images

Here you can manually inspect images before beginning parameter tuning.

In [None]:
interactive_kymograph.view_image_interactive()

You will now want to select a few test FOVs to try out parameters on, the channel you want to detect trenches on, and <br>the time interval on which you will perform your processing.

The arguments for this step are:

- **seg_channel (string)** : The channel name that you would like to segment on.

- **invert (list)** : Whether or not you want to invert the image before detecting trenches. By default, it is assumed that <br>the trenches have a high pixel intensity relative to the background. This should be the case for Phase Contrast and <br>Fluorescence Imageing, but may not be the case for Brightfield Imaging, in which case you will want to invert the image.

- **fov_list (list)** : List of integers corresponding to the FOVs that you wish to make test kymographs of.

- **t_subsample_step (int)** : Step size to be used for subsampling input files in time, recommend that subsampling results in <br>between 5 and 10 timepoints for quick processing.

Hit the "Run Interact" button to lock in your parameters. The button will become transparent briefly and become solid again <br>when processing is complete. After that has occured, move on to the next step. 

In [None]:
interactive_kymograph.import_hdf5_interactive()

##### Tune "trench-row" detection hyperparameters

The kymograph code begins by detecting the positions of trench rows in the image as follows:

1. Reducing each 2D image to a 1D signal along the y-axis by computing the qth percentile of the data along the x-axis
2. Smooth this signal using a median kernel
3. Normalize the signal by linearly scaling 0. and 1. to the minimum and maximum, respectively
4. Use a set threshold to determine the trench row poisitons

The arguments for this step are:

 - **y_percentile (int)** : Percentile to use for step 1.

 - **smoothing_kernel_y_dim_0 (int)** : Median kernel size to use for step 2.

 - **y_percentile_threshold (float)** : Threshold to use in step 4.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line.

In [None]:
interactive_kymograph.preview_y_precentiles_interactive()

##### Tune "trench-row" cropping hyperparameters

Next, we will use the detected rows to perform cropping of the input image in the y-dimension:

1. Determine edges of trench rows based on threshold mask.
2. Filter out rows that are too small.
3. Perform cropping using the "end" of the row as reference (the end referring to the part of the trench farthest from <br>the feeding channel).

Step 3 performs a simple algorithm to determine the orientation of each trench:

```
row_orientations = [] # A list of row orientations, starting from the topmost row
if the number of detected rows == 'Number of Rows': 
    row_orientations.append('Orientation')
elif the number of detected rows < 'Number of Rows':
    row_orientations.append('Orientation when < expected rows')
for row in rows:
    if row_orientations[-1] == downward:
        row_orientations.append(upward)
    elif row_orientations[-1] == upward:
        row_orientations.append(downward)
```

The arguments for this step are:

 - **y_min_edge_dist (int)** : Minimum row length necessary for detection (filters out small detected objects).

 - **padding_y (int)** : Padding to add to the end of trench row when cropping in the y-dimension.

 - **trench_len_y (int)** : Length from the end of each trench row to the feeding channel side of the crop.

 - **Number of Rows (int)** : The number of rows to expect in your image. For instance, two in the example image.

 - **Orientation (int)** : The orientation of the top-most row where 0 corresponds to a trench with a downward-oriented trench <br>opening and 1 corresponds to a trench with an upward-oriented trench opening.

 - **Orientation when < expected rows(int)** : The orientation of the top-most row when the number of detected rows is less than <br>expected. Useful if your trenches drift out of your image in some FOVs.

 - **images_per_row(int)** : How many images to output per row for this widget.

Running the following widget will display y-cropped images for each fov and timepoint.

In [None]:
interactive_kymograph.preview_y_crop_interactive()

##### Tune trench detection hyperparameters

Next, we will detect the positions of trenchs in the y-cropped images as follows:

1. Reducing each 2D image to a 1D signal along the x-axis by computing the qth percentile of the data along the y-axis.
2. Determine the signal background by smoothing this signal using a large median kernel.
3. Subtract the background signal.
4. Smooth the resultant signal using a median kernel.
5. Use an [otsu threhsold](https://imagej.net/Auto_Threshold#Otsu) to determine the trench midpoint poisitons.

After this, x-dimension drift correction of our detected midpoints will be performed as follows:

6. Begin at t=1
7. For $m \in \{midpoints(t)\}$ assign $n \in \{midpoints(t-1)\}$ to m if n is the closest midpoint to m at time $t-1$,<br>
points that are not the closest midpoint to any midpoints in m will not be mapped.
8. Compute the translation of each midpoint at time.
9. Take the average of this value as the x-dimension drift from time t-1 to t.

The arguments for this step are:

**t (int)** : Timepoint to examine the percentiles and threshold in.

**x_percentile (int)** : Percentile to use for step 1.

**background_kernel_x (int)** : Median kernel size to use for step 2.

**smoothing_kernel_x (int)** : Median kernel size to use for step 4.

**otsu_scaling (float)** : Scaling factor to apply to the threshold determined by Otsu's method.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line. In addition, it will display the detected midpoints for each of your timepoints. <br>If there is too much sparsity, or discontinuity, your drift correction will not be accurate.

In [None]:
interactive_kymograph.preview_x_percentiles_interactive()

##### Tune trench cropping hyperparameters

Trench cropping simply uses the drift-corrected midpoints as a reference and crops out some fixed length around them <br>
to produce an output kymograph. **Note that the current implementation does not allow trench crops to overlap**. If your<br>
trench crops do overlap, the error will not be caught here, but will cause issues later in the pipeline. As such, try <br>
to crop your trenches as closely as possible. This issue will be fixed in a later update.

The arguments for this step are:

**trench_width_x (int)** : Trench width to use for cropping.

**trench_present_thr (float)** : Trenches that appear in less than this percent of FOVs will be eliminated from the dataset.<br>
If not removed, missing positions will be inferred from the image drift.

Running the following widget will display a random kymograph for each row in each fov and will also produce midpoint plots <br>showing retained midpoints

In [None]:
interactive_kymograph.preview_kymographs_interactive()

##### Export and save hyperparameters

Run the following line to register and display the parameters you have selected for kymograph creation.

In [None]:
interactive_kymograph.process_results()

If you are satisfied with the above parameters, run the following line to write these parameters to disk at `headpath/kymograph.par`<br>
This file will be used to perform kymograph creation in the next section.

In [None]:
interactive_kymograph.write_param_file()

### Generate Kymograph

##### Start Dask Workers

Again, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2 for kymograph creation. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    death_timeout=30.0,
    local=False,
    n_workers=100,
    memory="8GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Kymograph Cropping

Now that we have our cluster scheduler spun up, we will extract kymographs using the parameters stored in `headpath/kymograph.par`. <br>
This will be handled by the `kymograph_cluster` object. This will detect trenches in all of the files present in `headpath/hdf5` that <br>
you created in the first step. It will then crop these trenches and place the crops in a series of `.hdf5` files in `headpath/kymograph`. <br>
These files will store image data in the form of `(K,T,Y,X)` arrays where K is the trench index, T is time and Y,X are the image dimensions <br>
of the crop.

The arguments for this step are:

 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **trenches_per_file** : The maximum number of trenches stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **paramfile** : Set to true if you want to use parameters from `headpath/kymograph.par` Otherwise, you will have to specify <br>
 parameters as direct arguments to `kymograph_cluster`.

In [None]:
kymoclust = tr.kymograph.kymograph_cluster(
    headpath=headpath, trenches_per_file=25, paramfile=True
)

##### Begin Kymograph Cropping 

Running the following line will start the cropping process. This may be monitored by examining the `Dask Dashboard` <br>
under the link displayed earlier. Once the computation is complete, move to the next line.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.generate_kymographs(dask_controller)

In [None]:
kymoclust.cleanup_kymographs()

##### Post-process Images

After the above step, kymographs will have been created for each `.hdf5` input file. They will now need to be reorganized <br>
into a new set of files such that each file has, at most, `trenches_per_file` trenches in each file.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.post_process(dask_controller)

##### Check kymograph statistics

Run the next line to display some statistics from kymograph creation. The outputs are:

 - **fovs processed** : The number of FOVs successfully processed out of the total number of FOVs
 - **rows processed** : The number of rows of trenches processed out of the total number of rows
 - **trenches processed** : The number of trenches successfully processed
 - **row/fov** : The average number of rows successfully processed per FOV
 - **trenches/fov** : The average number of trenches successfully processed per FOV
 - **failed fovs** : A list of failed FOVs. Spot check these FOVs in the viewer to determine potential problems

In [None]:
kymoclust.kymo_report()

##### Shutdown Dask

Once cropping is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.shutdown()

## Fluorescence Segmentation

Now that you have copped your data into kymographs, we will now perform segmentation/cell detection <br>
on your kymographs. Currently, this pipeline only supports segmentation of fluorescence images; however, <br>
segmentation of transmitted light imaging techniques is in development.

The output of this step will be a set of `segmentation_[File #].hdf5` files stored in `headpath/fluorsegmentation`.<br>
The image data stored in these files takes the exact same form as the kymograph data, `(K,T,Y,X)` arrays <br>
where K is the trench index, T is time, and Y,X are the crop dimensions. These arrays are accessible using <br>
keys of the form `"[Trench Row Number]"`.

Since no metadata is generated by this step, it is possible to use another segmentation algorithm on the kymograph <br>
data. The output of segmentation must be split into `segmentation_[File #].hdf5` files, where `[File #]` agrees with the<br>
corresponding `kymograph_[File #].hdf5` file. Additionally, the `(K,T,Y,X)` arrays must be of the same shape as the <br>
kymograph arrays and accessible at the corresponding `"[Trench Row Number]"` key. These files must be placed into <br>
their own folder at `headpath/foldername`. This folder may then be used in later steps.

### Test Parameters

##### Initialize the interactive segmentation class

As a first step, initialize the `tr.fluo_segmentation_interactive` class that will be handling all steps of generating a segmentation. 

In [None]:
interactive_segmentation = tr.fluo_segmentation_interactive(headpath)

##### Choose channel to segment on

In [None]:
interactive_segmentation.choose_seg_channel_inter()

#### Import data

Fill in 

You will need to tune the following `args` and `kwargs` (in order):

**fov_idx (int)** :

**n_trenches (int)** :

**t_range (tuple)** :

**t_subsample_step (int)** :

In [None]:
interactive_segmentation.import_array_inter()

#### Scale data

Fill in 

You will need to tune the following `args` and `kwargs` (in order):

**scale (bool)** : Whether to scale the kymograph in time.

**scaling_percentile (int)** : Whole image intensity percentile to use to determine scaling constant. 

#### Apply Gaussian Filter

Fill in 

You will need to tune the following `args` and `kwargs` (in order):

**smooth_sigma (float)** : Standard deviation of gaussian kernel.

In [None]:
interactive_segmentation.plot_processed_inter()

#### Determine Cell Mask Envelope

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**cell_mask_method (str)** : Thresholding method, can be a local or global Otsu threshold.

**cell_otsu_scaling (float)** : Scaling factor applied to determined threshold.

**local_otsu_r (int)** : Radius of thresholding kernel used in the local otsu thresholding.

In [None]:
interactive_segmentation.plot_cell_mask_inter()

#### Display Edge Mask at Threshold Value

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**edge_threshold_scaling (float)** : Scaling factor applied to determined threshold.

In [None]:
interactive_segmentation.plot_threshold_result_inter()

#### Threshold Sampling and Convexity Calculation

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**edge_threshold_scaling (float)** : Scaling factor applied to determined threshold.

**threshold_step_perc (float)** : Threshold step size to be used for trying multiple thresholds.

**threshold_perc_num_steps (int)** : Number of steps to use when generating multiple thresholds.

In [None]:
interactive_segmentation.plot_scores_inter()

#### Convexity Thresholding

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**convex_threshold (float)** : Threshold to be used for convexity thresholding.

In [None]:
interactive_segmentation.plot_final_mask_inter()

In [None]:
interactive_segmentation.process_results()

In [None]:
interactive_segmentation.write_param_file()

### Generate Segmentation

#### Start Dask Workers

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="02:00:00",
    local=False,
    n_workers=200,
    memory="2GB",
    cores=1,
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
segment = tr.fluo_segmentation_cluster(headpath, paramfile=True)

In [None]:
segment.dask_segment(dask_controller)

In [None]:
dask_controller.daskclient.restart()

#### Stop Dask Workers

In [None]:
dask_controller.shutdown()

## Lineage Tracing

### Test Parameters

In [None]:
score_function = tr.tracking.scorefn(
    headpath,
    "fluorsegmentation",
    u_size=0.16,
    sig_size=0.07,
    u_pos=0.16,
    sig_pos=0.1,
    w_pos=0.3,
    w_size=1.0,
    w_merge=0.8,
)

In [None]:
score_function.interactive_scorefn()

In [None]:
Tracking_Solver = tr.tracking.tracking_solver(
    headpath,
    "fluorsegmentation",
    ScoreFn=score_function,
    edge_limit=2,
)
data, orientation = score_function.output.result

In [None]:
Tracking_Solver.interactive_tracking(data, orientation)

In [None]:
Tracking_Solver.save_params()

### Generate Lineage Traces

In [None]:
dask_controller = tr.dask_controller(
    walltime="02:00:00",
    local=False,
    n_workers=100,
    memory="8GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
Tracking_Solver = tr.tracking.tracking_solver(
    headpath, "fluorsegmentation", paramfile=True
)

In [None]:
Tracking_Solver.compute_all_lineages(dask_controller)

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.shutdown()

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr
import dask.dataframe as dd
import operator
import scipy as sp

In [None]:
from matplotlib import pyplot as plt
import numpy as np

In [None]:
import pandas as pd

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=10,
    cores=1,
    memory="8GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
meta = tr.pandas_hdf5_handler(headpath + "/metadata.hdf5")

In [None]:
microns = meta.read_df("global", read_metadata=True).metadata["pixel_microns"]

In [None]:
kymo = meta.read_df("kymograph")
kymo = kymo.loc[(slice(None), slice(0, 0)), :]

In [None]:
1392 * microns

In [None]:
fov_kymo = kymo[(kymo["fov"] == 0) & kymo["row"] == 1]

In [None]:
fov_kymo[fov_kymo["x (local)"] > 450]

In [None]:
plt.hist(
    kymo[(kymo["fov"] == 0) & kymo["row"] == 1]["x (local)"] / microns,
    range=(1300, 1500),
    bins=100,
)

In [None]:
fov = kymo[(kymo["fov"] == 0)]
output = fov[
    (fov["x (local)"] < (1395 * microns))
    & (fov["x (local)"] > (1380 * microns))
    & (fov["y (local)"] < (1100 * microns))
    & (fov["y (local)"] > (1000 * microns))
]

In [None]:
output

In [None]:
df = dd.read_parquet(headpath + "/lineage/output", engine="fastparquet")

In [None]:
def compute_del_area(series):
    ttl_t = len(series["area"])
    del_areas = []
    for t in range(ttl_t - 1):
        del_area = series["area"].values[t + 1] / series["area"].values[t]
        del_areas.append(del_area)
    del_areas.append(-1)
    max_discont = np.max(del_areas)
    return max_discont

In [None]:
df = dd.read_parquet(headpath + "/lineage/output", engine="fastparquet")
trenchid_group = df.groupby(["trenchid", "timepoints"])
max_cent_y = trenchid_group["Centroid Y"].idxmax()
max_cent_y_list = max_cent_y.compute().tolist()
mothers_df = df.loc[max_cent_y_list].persist()
area_groupby = mothers_df.groupby("trenchid")["area"]
peaks = area_groupby.apply(sp.signal.find_peaks, distance=3, prominence=3).compute()

first_peaks = []
ttl_peaks = []
for peak in peaks:
    if len(peak[0]) > 0:
        first_peaks.append(peak[0][0])
        ttl_peaks.append(sum(peak[0] > -1))
    else:
        first_peaks.append(-1)
        ttl_peaks.append(0)
first_peaks = pd.DataFrame({"lag time": first_peaks, "ttl peaks": ttl_peaks})
first_peaks.index = peaks.index

lag_df = mothers_df.join(first_peaks, on="trenchid").persist()
first_gen_df = (
    lag_df.groupby("trenchid")
    .apply(lambda g: g[g["timepoints"] <= g["lag time"].max()])
    .compute()
)
first_gen_df = first_gen_df.droplevel("trenchid")

max_discon = first_gen_df.groupby("trenchid").apply(compute_del_area)
max_discon = pd.DataFrame({"max discon": max_discon})
first_gen_df = first_gen_df.join(max_discon, on="trenchid")
first_gen_df = first_gen_df[first_gen_df["max discon"] < 2.0]

trenchid_groupby = first_gen_df.groupby("trenchid")
filtered_trenchids = trenchid_groupby["YFP mean_intensity"].mean() < 2500.0
filtered_trenchids = filtered_trenchids[filtered_trenchids].index.tolist()
lag_df_loss = first_gen_df[first_gen_df["trenchid"].isin(filtered_trenchids)]
loss_lag_times = lag_df_loss.groupby("trenchid")["lag time"].max()
lag_times = first_gen_df.groupby("trenchid")["lag time"].max()

file_indices = (
    lag_df_loss.groupby(["File Index", "File Trench Index"])["File Index"].min().values
)
trench_indices = (
    lag_df_loss.groupby(["File Index", "File Trench Index"])["File Trench Index"]
    .min()
    .values
)

In [None]:
%matplotlib inline
plt.hist(lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.hist(loss_lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.show()

In [None]:
df = dd.read_parquet(headpath + "/lineage/output", engine="fastparquet")

first_tpt_df = df[df["timepoints"] == 0]
original_cell_ids = first_tpt_df["Global CellID"].compute().tolist()
original_cell_df = df[df["Global CellID"].isin(original_cell_ids)]

In [None]:
max_discon = original_cell_df.groupby("Global CellID").apply(compute_del_area)
max_discon = pd.DataFrame({"max discon": max_discon})
original_cell_df = original_cell_df.join(max_discon, on="Global CellID", rsuffix="_moo")
original_cell_df = original_cell_df[
    (original_cell_df["max discon"] > 0.2) & (original_cell_df["max discon"] < 2.0)
]
original_cell_df_pd = original_cell_df.compute()

In [None]:
original_cell_df_pd

In [None]:
cellids_groupby = original_cell_df_pd.groupby(["Global CellID"])
filtered_cellids = cellids_groupby["YFP mean_intensity"].mean() < 2500.0
low_signal_cells = filtered_cellids[filtered_cellids].index.tolist()
low_signal_cells_df = original_cell_df_pd[
    original_cell_df_pd["Global CellID"].isin(low_signal_cells)
]

# grouped_signal_cells_df = low_signal_cells_df.groupby(['Global CellID'])
# low_norm_signal_cells = grouped_signal_cells_df.apply(lambda x: np.all((x['Normalized YFP'] < 0.2)))
# low_norm_signal_cells = low_norm_signal_cells[low_norm_signal_cells].index.tolist()
# low_norm_signal_cells_df = low_signal_cells_df[low_signal_cells_df["Global CellID"].isin(low_norm_signal_cells)]

grouped_signal_cells_df = low_signal_cells_df.groupby(["Global CellID"])
filtered_time = grouped_signal_cells_df["timepoints"].max() > 0
final_cellids = filtered_time[filtered_time].index.tolist()
out_df = low_signal_cells_df[low_signal_cells_df["Global CellID"].isin(final_cellids)]
out_df = out_df[out_df["Daughter CellID 1"] != -1]
out_df = out_df[out_df["Daughter CellID 2"] != -1]

In [None]:
max_discon

In [None]:
cellid_df = df.set_index("Global CellID")

In [None]:
daughter_1_df = cellid_df.loc[out_df["Daughter CellID 1"].values].compute()
daughter_2_df = cellid_df.loc[out_df["Daughter CellID 2"].values].compute()

In [None]:
init_tpt = daughter_1_df.groupby("Global CellID")["timepoints"].min()
init_tpt = pd.DataFrame({"init_tpt": init_tpt})
daughter_1_df = daughter_1_df.join(init_tpt, on="Global CellID")
daughter_1_df = daughter_1_df[daughter_1_df["timepoints"] == daughter_1_df["init_tpt"]]

In [None]:
init_tpt = daughter_2_df.groupby("Global CellID")["timepoints"].min()
init_tpt = pd.DataFrame({"init_tpt": init_tpt})
daughter_2_df = daughter_2_df.join(init_tpt, on="Global CellID")
daughter_2_df = daughter_2_df[daughter_2_df["timepoints"] == daughter_2_df["init_tpt"]]

In [None]:
ttl_daughter_area = daughter_1_df["area"].values + daughter_2_df["area"].values

In [None]:
out_df["ttl_daughter_area"] = ttl_daughter_area

In [None]:
init_tpt = out_df.groupby("Global CellID")["timepoints"].max()
init_tpt = pd.DataFrame({"init_tpt": init_tpt})
out_df = out_df.join(init_tpt, on="Global CellID")
out_df = out_df[out_df["timepoints"] == out_df["init_tpt"]]
out_df["area_ratio"] = (out_df["ttl_daughter_area"] / out_df["area"]).values

In [None]:
out_df

In [None]:
out_df = out_df[out_df["area_ratio"] < 2.0]

In [None]:
out_df

In [None]:
plt.hist(out_df["area_ratio"], bins=30, range=(0, 5))
plt.show()

In [None]:
lag_times = original_cell_df_pd.groupby(["Global CellID"])["timepoints"].max()
loss_lag_times = out_df.groupby(["Global CellID"])["timepoints"].max()

In [None]:
loss_lag_times

In [None]:
%matplotlib inline
plt.hist(lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.hist(loss_lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.show()

In [None]:
file_indices = (
    out_df.groupby(["File Index", "File Trench Index"])["File Index"].min().values
)
trench_indices = (
    out_df.groupby(["File Index", "File Trench Index"])["File Trench Index"]
    .min()
    .values
)

In [None]:
file_indices

In [None]:
trench_indices

In [None]:
test = df.loc[:100].compute()

In [None]:
test["Centroid Y"]

In [None]:
def filter_df(df, query_list, client=False, repartition=False):
    # filter_list must be in df.query format (see pandas docs)

    # returns persisted dataframe either in cluster or local

    compiled_query = " and ".join(query_list)
    out_df = df.query(compiled_query)
    if client:
        out_df = client.daskclient.persist(out_df)
    else:
        out_df = out_df.persist()

    if repartition:
        init_size = len(df)
        final_size = len(out_df)
        ratio = init_size // final_size
        out_df = out_df.repartition(npartitions=(df.npartitions // ratio) + 1)

        if client:
            out_df = client.daskclient.persist(out_df)
        else:
            out_df = out_df.persist()

    return out_df

In [None]:
trenchid_group = df.groupby(["trenchid", "timepoints"])

In [None]:
max_cent_y = trenchid_group["Centroid Y"].idxmax()

In [None]:
max_cent_y_list = max_cent_y.compute().tolist()

In [None]:
max_cent_y_list

In [None]:
mothers_df = df.loc[max_cent_y_list].persist()

In [None]:
area_groupby = mothers_df.groupby("trenchid")["area"]

In [None]:
peaks = area_groupby.apply(sp.signal.find_peaks, distance=3, prominence=3).compute()

In [None]:
len(ttl_peaks)

In [None]:
first_peaks = []
ttl_peaks = []
for peak in peaks:
    if len(peak[0]) > 0:
        first_peaks.append(peak[0][0])
        ttl_peaks.append(sum(peak[0] > -1))
    else:
        first_peaks.append(-1)
        ttl_peaks.append(0)
first_peaks = pd.DataFrame({"lag time": first_peaks, "ttl peaks": ttl_peaks})
first_peaks.index = peaks.index

In [None]:
%matplotlib inline
plt.hist(first_peaks["ttl peaks"], bins=30)
plt.show()

In [None]:
len(first_peaks)

In [None]:
trenchid_groupby["lag time"].min().compute()

In [None]:
temp_df = lag_df.set_index("trenchid").compute()

In [None]:
first_gen_df

In [None]:
lag_time_list = trenchid_groupby["lag time"].min().compute()

In [None]:
lag_time_list

In [None]:
first_gen_df = (
    lag_df.groupby("trenchid")
    .apply(lambda g: g[g["timepoints"] <= g["lag time"].max()])
    .compute()
)

In [None]:
test = trenchid_groupby.apply(
    lambda g: g[g["timepoints"] <= g["lag time"].max()]
).compute()

In [None]:
test[:30]

In [None]:
filtered_trenchids = trenchid_groupby["YFP mean_intensity"].mean() < 2500.0

In [None]:
for t in range(1):
    print(t)

In [None]:
def compute_del_area(series):
    ttl_t = len(series["area"])
    del_areas = []
    for t in range(ttl_t - 1):
        del_area = series["area"].values[t + 1] / series["area"].values[t]
        del_areas.append(del_area)
    del_areas.append(-1)
    max_discont = np.max(del_areas)
    return max_discont

In [None]:
first_peaks = []
ttl_peaks = []
for peak in peaks:
    if len(peak[0]) > 0:
        first_peaks.append(peak[0][0])
        ttl_peaks.append(sum(peak[0] > -1))
    else:
        first_peaks.append(-1)
        ttl_peaks.append(0)

first_peaks = pd.DataFrame({"lag time": first_peaks, "ttl peaks": ttl_peaks})
first_peaks.index = peaks.index

In [None]:
max_discont = first_gen_df.groupby("trenchid").apply(compute_del_area)
max_discont_filter = max_discont < 1.75

In [None]:
%matplotlib inline
plt.hist(max_discont, bins=30, range=(0, 2))
plt.show()

In [None]:
lag_df = mothers_df.join(first_peaks, on="trenchid").persist()
first_gen_df = (
    lag_df.groupby("trenchid")
    .apply(lambda g: g[g["timepoints"] <= g["lag time"].max()])
    .compute()
)
first_gen_df = first_gen_df.droplevel("trenchid")

max_discon = first_gen_df.groupby("trenchid").apply(compute_del_area)
max_discon = pd.DataFrame({"max discon": max_discon})
first_gen_df = first_gen_df.join(max_discon, on="trenchid")
first_gen_df = first_gen_df[first_gen_df["max discon"] < 2.0]

# lag_df = lag_df[(lag_df["ttl peaks"]>15)&(lag_df["ttl peaks"]<45)]
# trenchid_groupby = lag_df[lag_df["timepoints"]==0].groupby('trenchid')
# filtered_trenchids = trenchid_groupby.apply(lambda x: np.all((x['YFP mean_intensity'] < 2500.)))
# trenchid_groupby = lag_df.groupby('trenchid')
# filtered_trenchids = trenchid_groupby.apply(lambda x: np.all((x['YFP mean_intensity'] < 3000.)))
trenchid_groupby = first_gen_df.groupby("trenchid")
filtered_trenchids = trenchid_groupby["YFP mean_intensity"].mean() < 2500.0
filtered_trenchids = filtered_trenchids[filtered_trenchids].index.tolist()
lag_df_loss = first_gen_df[first_gen_df["trenchid"].isin(filtered_trenchids)]
loss_lag_times = lag_df_loss.groupby("trenchid")["lag time"].max()
lag_times = first_gen_df.groupby("trenchid")["lag time"].max()

In [None]:
filtered_trenchids

In [None]:
file_indices = (
    lag_df_loss.groupby(["File Index", "File Trench Index"])["File Index"].min().values
)
trench_indices = (
    lag_df_loss.groupby(["File Index", "File Trench Index"])["File Trench Index"]
    .min()
    .values
)

In [None]:
file_indices
trench_indices

In [None]:
trench_indices

In [None]:
%matplotlib inline
plt.hist(lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.hist(loss_lag_times, bins=15, range=(0, 100), density=True, alpha=0.7)
plt.show()

In [None]:
lag_times

In [None]:
lag_df.compute()

filtered_cellids = cellids.apply(lambda x: np.all((x["YFP mean_intensity"] < 5000.0)))
low_signal_cells = filtered_cellids[filtered_cellids].index.tolist()
low_signal_cells_df = original_cell_df_pd[
    original_cell_df_pd["Global CellID"].isin(low_signal_cells)
]

In [None]:
lag_df_loss_idx

In [None]:
lag_df_loss_idx = lag_df.apply(lambda x: np.all((x["YFP mean_intensity"] < 2500.0)))
lag_df_loss_idx = lag_df_loss_idx[lag_df_loss_idx].index.tolist()
lag_df_loss = lag_df_loss[lag_df_loss["Global CellID"].isin(lag_df_loss_idx)]

In [None]:
first_peaks

In [None]:
plt.hist(first_peaks)
plt.show()

In [None]:
sp.signal.find_peaks()

In [None]:
area_groupby.app

In [None]:
pivot_df = mothers_df.pivot(index="timepoints", columns="trenchid", values="area")
pivot_df.plot(c="c", legend=False)
plt.scatter(peaks.values[0][0], np.repeat(0, len(peaks.values[0][0])))

In [None]:
first_tpt_df = df[df["timepoints"] == 0]
original_cell_ids = first_tpt_df["Global CellID"].compute().tolist()
original_cell_df = df[df["Global CellID"].isin(original_cell_ids)]
original_cell_df_pd = original_cell_df.compute()

In [None]:
original_cell_df

In [None]:
original_cell_df["Normalized YFP"] = (
    original_cell_df["YFP mean_intensity"] / original_cell_df["mCherry mean_intensity"]
)

In [None]:
original_cell_df_pd = original_cell_df.compute()

In [None]:
original_cell_df_pd.to_pickle("./original_cell_df.pkl")

In [None]:
import pandas as pd

In [None]:
original_cell_df_pd = pd.read_pickle("./original_cell_df.pkl")
# original_cell_df = dd.from_pandas(original_cell_df_pd,npartitions=50)

In [None]:
from matplotlib import pyplot as plt

In [None]:
original_cell_df_pd.loc[:1000]

In [None]:
plt.hist(original_cell_df_pd["YFP mean_intensity"], bins=100, range=(0, 5000))
plt.show()

In [None]:
import numpy as np

In [None]:
# cellids = original_cell_df.groupby(['Global CellID'])

In [None]:
cellids = original_cell_df_pd.groupby(["Global CellID"])

In [None]:
len(cellids)

In [None]:
filtered_cellids = cellids.apply(lambda x: np.all((x["YFP mean_intensity"] < 5000.0)))
low_signal_cells = filtered_cellids[filtered_cellids].index.tolist()
low_signal_cells_df = original_cell_df_pd[
    original_cell_df_pd["Global CellID"].isin(low_signal_cells)
]

# grouped_signal_cells_df = low_signal_cells_df.groupby(['Global CellID'])
# low_norm_signal_cells = grouped_signal_cells_df.apply(lambda x: np.all((x['Normalized YFP'] < 0.2)))
# low_norm_signal_cells = low_norm_signal_cells[low_norm_signal_cells].index.tolist()
# low_norm_signal_cells_df = low_signal_cells_df[low_signal_cells_df["Global CellID"].isin(low_norm_signal_cells)]

grouped_signal_cells_df = low_signal_cells_df.groupby(["Global CellID"])
filtered_time = grouped_signal_cells_df["timepoints"].max() > 0
final_cellids = filtered_time[filtered_time].index.tolist()
out_df = low_signal_cells_df[low_signal_cells_df["Global CellID"].isin(final_cellids)]
out_df = out_df[out_df["Daughter CellID 1"] != -1]

In [None]:
mother_df = out_df[out_df["CellID"] == 0]

In [None]:
out_df[:50][["File Index", "File Trench Index", "trenchid"]]

In [None]:
max_tpts = mother_df.groupby(["Global CellID"])["timepoints"].max()

In [None]:
%matplotlib inline
plt.hist(max_tpts, bins=20, range=(10, 70))
plt.show()

In [None]:
len(out_df)

In [None]:
out_df[:50]

In [None]:
pivot_df = out_df.pivot(
    index="timepoints", columns="Global CellID", values="Normalized YFP"
)
time_pivot_df = out_df.pivot(
    index="Global CellID", columns="timepoints", values="Normalized YFP"
)
mean_in_time = time_pivot_df.mean()
pivot_df.plot(c="c", legend=False)
# plt.plot(mean_in_time,c="r")

In [None]:
original_cell_df_pd["pos x"] = (
    original_cell_df_pd["x (local)"] + original_cell_df_pd["centroid x"]
) / microns
original_cell_df_pd["pos y"] = (
    original_cell_df_pd["y (local)"] + original_cell_df_pd["centroid y"]
) / microns

In [None]:
original_cell_df_pd["fov"] == 0

In [None]:
original_cell_df_pd[
    (original_cell_df_pd["fov"] == 0)
    & (original_cell_df_pd["timepoints"] == 0)
    & (original_cell_df_pd["pos y"] > 1000)
    & (original_cell_df_pd["pos y"] < 1200)
    & (original_cell_df_pd["pos x"] > 1300)
    & (original_cell_df_pd["pos x"] < 1500)
]

In [None]:
from matplotlib import pyplot as plt

plt.hist(
    original_cell_df_pd[
        (original_cell_df_pd["fov"] == 0)
        & (original_cell_df_pd["timepoints"] == 0)
        & (original_cell_df_pd["pos y"] > 1000)
        & (original_cell_df_pd["pos y"] < 1200)
    ]["pos x"],
    bins=30,
    range=(1300, 1500),
)

##### Note

Try to get traces of normalized YFP for each cell and plot them...

In [None]:
from matplotlib import pyplot as plt

In [None]:
# yfp_groupby = original_cell_df.groupby(['Global CellID'])["Normalized YFP"]
# time_groupby = original_cell_df.groupby(['Global CellID'])["timepoints"]
cell_groupby = original_cell_df_pd.groupby(["Global CellID"])

In [None]:
plt.hist(original_cell_df_pd["Normalized YFP"], range=(0, 2), bins=50)
plt.show()

In [None]:
pivot_df = original_cell_df_pd.pivot(
    index="timepoints", columns="Global CellID", values="Normalized YFP"
)
time_pivot_df = original_cell_df_pd.pivot(
    index="Global CellID", columns="timepoints", values="Normalized YFP"
)
mean_in_time = time_pivot_df.mean()
pivot_df.plot(c="c", legend=False)
plt.plot(mean_in_time, c="r")

In [None]:
original_cell_df.groupby(["Global CellID"])[""]

In [None]:
first_tpt_idx = (
    original_cell_df.groupby(["Global CellID"])["timepoints"].idxmin().compute()
)

In [None]:
last_tpt_idx = (
    original_cell_df.groupby(["Global CellID"])["timepoints"].idxmax().compute()
)

In [None]:
cell_first_tpt_df = (
    original_cell_df.loc[first_tpt_idx.tolist()].set_index("Global CellID").persist()
)
cell_last_tpt_df = (
    original_cell_df.loc[last_tpt_idx.tolist()].set_index("Global CellID").persist()
)

In [None]:
cell_first_tpt_df

In [None]:
first_div_time = (
    original_cell_df.groupby(["Global CellID"])["timepoints"].max().compute()
)

In [None]:
from matplotlib import pyplot as plt

In [None]:
yfp_signal_first = cell_first_tpt_df["YFP mean_intensity"].compute()
yfp_signal_last = cell_last_tpt_df["YFP mean_intensity"].compute()
mcherry_signal_first = cell_first_tpt_df["mCherry mean_intensity"].compute()
mcherry_signal_last = cell_last_tpt_df["mCherry mean_intensity"].compute()
normalized_yfp_first = yfp_signal_first / mcherry_signal_first
normalized_yfp_last = yfp_signal_last / mcherry_signal_last
normalized_signal_ratio = normalized_yfp_last / normalized_yfp_first

In [None]:
plt.hist(normalized_yfp_first.values, range=(0, 2), bins=100)
plt.show()

In [None]:
plt.hist(normalized_yfp_last.values, range=(0, 2), bins=100)
plt.show()

In [None]:
plt.hist(normalized_signal_ratio.values, range=(0, 2), bins=100)
plt.show()

In [None]:
plt.hist(signal_ratio, range=(0, 0.5), bins=100)
plt.show()

In [None]:
max(signal_ratio)

In [None]:
cell_last_tpt_df["YFP mean_intensity"] < 250

In [None]:
signal_ratio = yfp_signal_last / yfp_signal_first

In [None]:
# plas_loss_last_tpt_df = cell_last_tpt_df[signal_ratio<0.3]
plas_loss_last_tpt_df = cell_last_tpt_df[normalized_signal_ratio < 0.5]
plas_loss_last_tpt_df = plas_loss_last_tpt_df[plas_loss_last_tpt_df["CellID"] == 0]

In [None]:
plas_loss_last_tpt_df.compute()

In [None]:
first_div_time_loss = plas_loss_last_tpt_df["timepoints"].compute()

In [None]:
%matplotlib inline

In [None]:
len(plas_loss_last_tpt_df)

In [None]:
matplotlib.rcParams["figure.figsize"] = [20, 10]

In [None]:
plt.hist(first_div_time.values, bins=50, range=(0, 100))
plt.show()

In [None]:
plt.hist(first_div_time_loss.values, bins=50, range=(0, 100))
plt.show()

In [None]:
first_div_time.values

In [None]:
mothers_daughters_df.groupby(["Global CellID"])[""]

In [None]:
original_cell_df = cell_max_tpt_df.loc[original_cell_ids]

In [None]:
cell_max_tpt_df = (
    reference_df.loc[cells_max_tpt.tolist()].set_index("Global CellID").persist()
)

In [None]:
cells_max_tpt = reference_df.groupby(["Global CellID"])["timepoints"].idxmax().compute()

In [None]:
mother_df[:50]

## Region Properties (No Lineage)

In [None]:
analyzer = tr.analysis.regionprops_extractor(
    headpath, "fluorsegmentation", intensity_channel_list=["mCherry", "YFP"]
)

In [None]:
analyzer.export_all_data()

In [None]:
idx = 0
print(file_indices[idx])
print(trench_indices[idx])

In [None]:
trench_indices

## Inspect Kymographs

In [None]:
%matplotlib widget

In [None]:
from ipywidgets import interactive, IntText, IntSlider

kyview = tr.analysis.kymograph_viewer(headpath, "YFP", "fluorsegmentation")

In [None]:
kyviewer = interactive(
    kyview.inspect_trench,
    {"manual": True},
    file_idx=IntText(value=0, description="File Index:", disabled=False),
    trench_idx=IntText(value=0, description="Trench Index:", disabled=False),
    x_size=IntSlider(
        value=32, description="X Size:", min=0, max=50, step=1, disabled=False
    ),
    y_size=IntSlider(
        value=15, description="Y Size:", min=0, max=30, step=1, disabled=False
    ),
)
display(kyviewer)

In [None]:
interact(
    interactive_kymograph.view_image,
    fov_idx=IntText(value=0, description="FOV number:", disabled=False),
    t=IntSlider(
        value=0, min=0, max=timepoints_len - 1, step=1, continuous_update=False
    ),
    channel=Dropdown(
        options=channels, value=channels[0], description="Channel:", disabled=False
    ),
    invert=Dropdown(options=[True, False], value=False),
)

## Phase Segmentation Training

### Data Preparation

In [None]:
dataloader = tr.unet.UNet_Training_DataLoader(
    nndatapath="/n/scratch2/de64/nntest7",
    experimentname="First NN",
    trainpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
    testpath="/n/scratch2/de64/2019-05-31_validation_data",
    valpath="/n/scratch2/de64/2019-05-31_validation_data",
)

In [None]:
dataloader = tr.unet.UNet_Training_DataLoader(
    nndatapath="/n/scratch2/de64/nntest8",
    experimentname="First NN",
    trainpath="/n/scratch2/de64/2019-05-31_validation_data",
    testpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
    valpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
)

In [None]:
dataloader = tr.unet.UNet_Training_DataLoader(
    nndatapath="/n/scratch2/de64/nntest9",
    experimentname="First NN",
    trainpath="/n/scratch2/de64/2019-05-31_validation_data",
    testpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
    valpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
)

In [None]:
dataloader = tr.unet.UNet_Training_DataLoader(
    nndatapath="/n/scratch2/de64/nntest10",
    experimentname="First NN",
    trainpath="/n/scratch2/de64/2019-06-18_DE85_training_data",
    testpath="/n/scratch2/de64/2019-05-31_validation_data",
    valpath="/n/scratch2/de64/2019-05-31_validation_data",
)

#### Training Set Selection

In [None]:
dataloader.inter_get_selection(dataloader.trainpath, "train")

#### Test Set Selection

In [None]:
dataloader.inter_get_selection(dataloader.testpath, "test")

#### Validation Set Selection

In [None]:
dataloader.inter_get_selection(dataloader.valpath, "val")

#### Weightmap Parameters

In [None]:
dataloader.display_grid()

In [None]:
dataloader.get_grid_params()

#### Export

In [None]:
dataloader.export_all_data(memory="6GB")

### Hyperparameter (Grid) Search

#### Set-up Search

In [None]:
grid = tr.unet.GridSearch("/n/scratch2/de64/nntest10", numepochs=15)

In [None]:
grid.display_grid()

In [None]:
grid.get_grid_params()

#### Run Search

In [None]:
grid.run_grid_search(gres="gpu:teslaK80:1")

#### Evaluate Results

In [None]:
%matplotlib ipympl
matplotlib.rcParams["figure.figsize"] = [12, 8]

import seaborn as sns

sns.set()
sns.set(font_scale=2)

In [None]:
vis = tr.unet.TrainingVisualizer(
    "/n/scratch2/de64/nntest10", "/n/groups/paulsson/Daniel/NNModels"
)

In [None]:
vis.inter_plot_loss("Val Loss")
vis.grid_widget.on("filter_changed", vis.handle_filter_changed)

In [None]:
vis.grid_widget

In [None]:
vis.inter_df_columns()

In [None]:
vis.model_widget

In [None]:
import matplotlib
from matplotlib import pyplot as plt

%matplotlib inline

plt.hist(vis.model_df["Val F1 Cell Scores"][0], bins=50)
plt.xlabel("F-Score")
plt.ylabel("Occurances")
plt.xticks(np.arange(0, 1.01, step=0.5))
plt.draw()

In [None]:
headpath = "/n/scratch2/de64/2019-07-08_bacillus_rodz_mut_expt_bmbm_ti4"
unetseg = tr.unet.UNet_Segmenter(
    headpath, "Phase", "/n/groups/paulsson/Daniel/NNModels", min_obj_size=20
)

In [None]:
choose_channel = interactive(
    unetseg.choose_seg_channel,
    {"manual": True},
    seg_channel=Dropdown(options=unetseg.all_channels, value=unetseg.all_channels[0]),
)
display(choose_channel)

In [None]:
unetseg.inter_df_columns()

In [None]:
import torch
import numpy as np
import h5py
import trenchripper as tr
from matplotlib import pyplot as plt

In [None]:
with h5py.File("/n/scratch2/de64/nntest7/test.hdf5", "r") as infile:
    img_arr = torch.Tensor(infile["img"][535:550])
    seg_arr = torch.Tensor(infile["seg"][100:200:10])
    weight_arr = infile["weight_(10.0, 4.0)"][0:300:10]

In [None]:
testunet = tr.unet.UNet(1, 2, layers=3, hidden_size=32, dropout=0.0, withsoftmax=True)
device = torch.device("cpu")
testunet.load_state_dict(
    torch.load("/n/scratch2/de64/nntest7/models/0.pt", map_location=device)
)

In [None]:
y = testunet.forward(img_arr).detach().numpy()[:, 1]
x = img_arr.detach().numpy().squeeze(1)

In [None]:
x.shape

In [None]:
plt.imshow(x[4])

In [None]:
plt.imshow(y[4])

In [None]:
img_kymo = tr.utils.kymo_handle()
img_kymo.import_wrap(x)
img = img_kymo.return_unwrap(padding=0)
plt.imshow(img)

In [None]:
seg_kymo = tr.utils.kymo_handle()
seg_kymo.import_wrap(y)
seg = seg_kymo.return_unwrap(padding=0)
plt.imshow(seg)

In [None]:
mask = seg > 0.6
plt.imshow(mask)

In [None]:
import skimage as sk

In [None]:
filtered_mask = sk.morphology.remove_small_objects(mask, min_size=30)

In [None]:
plt.imshow(filtered_mask)

## Other

#### Transfer files into the scratch folder

If you are working on the HMS O2 server, this is a convenience function to facilitate transfer of files onto the `/n/scratch2` folder.

In [None]:
sourcedir = "/n/files/SysBio/PAULSSON\ LAB/SILVIA/Ti4--Data/2020_03_28--PlasmidLosses_SJC25_SJC28/temp"
targetdir = "/n/scratch2/de64/2020-03-02_plasmid_loss"
tr.trcluster.transferjob(sourcedir, targetdir)

#### Dask Utilities

In [None]:
dask_controller.shutdown()

In [None]:
dask_controller.retry_failed()

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.retry_processing()