## Important Note

Due to the current policy of SLURM schedulers to charge the full allocated time to the account fairshare (regardless of timeout), it is a good idea to refurbish the slurm scheduler to adopt a "take and hold" approach to O2 resources since you get penalized for time that you request but do not use making allocation of resources later in the day (after running one or two steps) more difficult.

The primary piece of code that would need to be written for a take and hold model would need to include the following:

- More sophisticated memory clearing of the dask_controller object (total memory clearing without a restart that leads to crashes) THIS FIRST 
- Rule-based adjustments of chunk sizes to limit per-node memory use to make sure a flat memory request can be made at cluster initialization

# Introduction

This notebook contains the entire `TrenchRipper` pipline, divided into simple steps. This pipline is ideal for Mother <br>Machine image data where cells possess fluorescent segmentation markers. Segmentation on phase or brightfield data <br>is being developed, but is still an experimental feature.

The steps in this pipeline are as follows:
1. Extracting your Mother Machine data (.nd2) into hdf5 format
2. Identifying and cropping individual trenches into kymographs
3. Segmenting cells with a fluorescent marker
4. Determining lineages and object properties

In each step, the user will dynamically specify parameters using a series of interactive diagnostics on their dataset. <br>Following this, a parameter file will be written to disk and then used to deploy a parallel computation on the <br>dataset, either locally or on a SLURM cluster.


This is intended as an end-to-end solution to analyzing Mother Machine data. As such, **it is not trivial to plug data <br>directly into intermediate steps**, as it will lack the correct formatting and associated metadata. A notable <br>exception to this is using another program to segment data. The library references binary segmentation masks using <br>only metadata derived from their associated kymographs. As such, it is possible to generate segmentations on these <br>kymographs elsewhere and place them into the segmentation data path to have `TrenchRipper` act on those <br>segmentations instead. More on this in the segmentation section...

#### Imports

Run this section to import all relavent packages and libraries used in this notebook. You must run this everytime you open a new python kernel.

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

import warnings

warnings.filterwarnings(action="once")

import matplotlib

matplotlib.rcParams["figure.figsize"] = [20, 10]

In [None]:
# addition of active memory manager
import dask

dask.config.set({"distributed.scheduler.active-memory-manager.start": True})
dask.config.set({"distributed.scheduler.worker-ttl": "5m"})
dask.config.set({"distributed.scheduler.allowed-failures": 100})

# Part 1: Growth/Division

#### Specify Paths

Begin by defining the directory in which all processing will be done, as well as the initial nd2 file we will be <br>processing. This line should be run everytime you open a new python kernel.

The format should be: `headpath = "/path/to/folder"` and `nd2file = "/path/to/file.nd2"`

For example:
```
headpath = "/n/scratch2/de64/2019-05-31_validation_data"
nd2file = "/n/scratch2/de64/2019-05-31_validation_data/Main_Experiment.nd2"
```

Ideally, these files should be placed in a storage location with relatively fast I/O

In [None]:
headpath = "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/"
nd2file = "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/Experiment.nd2"

## Extract to hdf5 files

In this section, we will be extracting our image data. Currently this notebook only supports `.nd2` format; however <br>there are `.tiff` extractors in the TrenchRipper source files that are being added to `Master.ipynb` soon.

In the abstract, this step will take a single `.nd2` file and split it into a set of `.hdf5` files stored in <br>`headpath/hdf5`. Splitting the file up in this way will facilitate quick procesing in later steps. Each field of <br>view will be split into one or more `.hdf5` files, depending on the number of images per file requested (more on <br>this later). 

To keep track of which output files correspond to which FOVs, as well as to keep track of experiment metadata, the <br>extractor also outputs a `metadata.hdf5` file in the `headpath` folder. The data from this step is accessible in <br>that `metadata.hdf5` file under the `global` key. If you would like to look at this metadata, you may use the <br>`tr.utils.pandas_hdf5_handler` to read from this file. Later steps will add additional metadata under different <br>keys into the `metadata.hdf5` file.

#### Start Dask Workers

First, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
# dask_controller = tr.trcluster.dask_controller(
#     walltime="1:00:00",
#     local=False,
#     n_workers=100,
#     n_workers_min=20,
#     memory="4GB",
#     working_directory="/home/de64/scratch/de64/dask",
# )
# dask_controller.startdask()
dask_controller = tr.trcluster.dask_controller(
    walltime="2:00:00",
    local=False,
    n_workers=400,
    n_workers_min=50,
    memory="1GB",
    working_directory="/home/de64/scratch/de64/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.reset_worker_memory()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Extraction

Now that we have our cluster scheduler spun up, it is time to convert files. This will be handled by the <br>`hdf5_extractor` object. This extractor will pull up each FOV and split it such that each derived `.hdf5` file <br>contains, at maximum, N timepoints of that FOV per file. The image data stored in these files takes the <br>form of `(N,Y,X)` arrays that are accessible using the desired channel name as a key. 

The arguments for this extractor are:

 - **nd2file** : The filepath to the `.nd2` file you intend to extract.
 
 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **tpts_per_file** : The maximum number of timepoints stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **ignore_fovmetadata** : Used when `.nd2` data is corrupted and does not possess records for stage positions or <br>timepoints. Only set `False` if the extractor throws errors on metadata handling.

 - **nd2reader_override** : Overrides values in metadata recovered using the `nd2reader`. Currently set to <br>`{"z_levels":[],"z_coordinates":[]}` by default to correct a known issue where z coordinates are mistakenly <br>interpreted as a z stack. See the [nd2reader](https://rbnvrw.github.io/nd2reader/) documentation for more info.

In [None]:
hdf5_extractor = tr.ndextract.hdf5_fov_extractor(
    nd2file,
    headpath,
    tpts_per_file=50,
    ignore_fovmetadata=False,
    nd2reader_override={"z_levels": [], "z_coordinates": []},
)

##### Extraction Parameters

Here, you may set the time interval you want to extract. Useful for cropping data to the period exhibiting the dynamics of interest.

Optionally take notes to add to the `metadata.hdf5` file. Notes may also be taken directly in this notebook.

In [None]:
hdf5_extractor.inter_set_params()

In [None]:
hdf5_extractor.inter_set_flatfieldpaths()

##### Begin Extraction 

Running the following line will start the extraction process. This may be monitored by examining the `Dask Dashboard` <br> under the link displayed earlier. Once the computation is complete, move to the next line.

This step may take a long time, though it is possible to speed it up using additional workers.

In [None]:
hdf5_extractor.extract(dask_controller)

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.shutdown()

## Kymographs

Now that you have extracted your data into a series of `.hdf5` files, we will now perform identification and cropping <br>of the individual trenches/growth channels present in the images. This algorithm assumes that your growth trenches <br>are vertically aligned and that they alternate in their orientation from top to bottom. See the example image for the <br>correct geometry:

![example_image](./resources/example_image.jpg)

The output of this step will be a set of `.hdf5` files stored in `headpath/kymograph`. The image data stored in these <br>files takes the form of `(K,T,Y,X)` arrays where K is the trench index, T is time, and Y,X are the crop dimensions. <br>These arrays are accessible using keys of the form `"[Image Channel]"`. For example, looking up phase channel <br>data of trenches in the topmost row of an image will require the key `"Phase"`

[ '/n/scratch3/users/d/de64/190917_20x_phase_gfp_segmentation002',
 '/n/scratch3/users/d/de64/190922_20x_phase_gfp_segmentation',
 '/n/scratch3/users/d/de64/190925_20x_phase_yfp_segmentation',
 '/n/scratch3/users/d/de64/ezrdm_training_sb7',
 '/n/scratch3/users/d/de64/mbm_training_sb7',
 '/n/scratch3/users/d/de64/Sb7_L35',
 '/n/scratch3/users/d/de64/MM_DVCvecto_TOP_1_9',
 '/n/scratch3/users/d/de64/Vibrio_2_1_TOP',
 '/n/scratch3/users/d/de64/Vibrio_A_B_VZRDM--04--RUN_80ms',
 '/n/scratch3/users/d/de64/RpoSOutliers_WT_hipQ_100X',
 '/n/scratch3/users/d/de64/Main_Experiment',
 '/n/scratch3/users/d/de64/bde17_gotime']

### Test Parameters



##### Initialize the interactive kymograph class

As a first step, initialize the `tr.interactive.kymograph_interactive` class that will be help us choose the <br>parameters we will use to generate kymographs. 

In [None]:
interactive_kymograph = tr.kymograph_interactive(headpath)

In [None]:
viewer = tr.hdf5_viewer(headpath, persist_data=False)

##### Examine Images

Here you can manually inspect images before beginning parameter tuning.

In [None]:
viewer.view(width=1200)

You will now want to select a few test FOVs to try out parameters on, the channel you want to detect trenches on, and <br>the time interval on which you will perform your processing.

The arguments for this step are:

- **seg_channel (string)** : The channel name that you would like to segment on.

- **invert (list)** : Whether or not you want to invert the image before detecting trenches. By default, it is assumed that <br>the trenches have a high pixel intensity relative to the background. This should be the case for Phase Contrast and <br>Fluorescence Imageing, but may not be the case for Brightfield Imaging, in which case you will want to invert the image.

- **fov_list (list)** : List of integers corresponding to the FOVs that you wish to make test kymographs of.

- **t_subsample_step (int)** : Step size to be used for subsampling input files in time, recommend that subsampling results in <br>between 5 and 10 timepoints for quick processing.

Hit the "Run Interact" button to lock in your parameters. The button will become transparent briefly and become solid again <br>when processing is complete. After that has occured, move on to the next step. 

In [None]:
interactive_kymograph.import_hdf5_interactive()

##### Tune "trench-row" detection hyperparameters

The kymograph code begins by detecting the positions of trench rows in the image as follows:

1. Reducing each 2D image to a 1D signal along the y-axis by computing the qth percentile of the data along the x-axis
2. Smooth this signal using a median kernel
3. Normalize the signal by linearly scaling 0. and 1. to the minimum and maximum, respectively
4. Use a set threshold to determine the trench row poisitons

The arguments for this step are:

 - **y_percentile (int)** : Percentile to use for step 1.

 - **smoothing_kernel_y_dim_0 (int)** : Median kernel size to use for step 2.

 - **y_percentile_threshold (float)** : Threshold to use in step 4.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line.

In [None]:
interactive_kymograph.preview_y_precentiles_interactive()

##### Tune "trench-row" cropping hyperparameters

Next, we will use the detected rows to perform cropping of the input image in the y-dimension:

1. Determine edges of trench rows based on threshold mask.
2. Filter out rows that are too small.
3. Use the remaining rows to compute the drift in y in each image.
4. Apply the drift to the initally detected rows to get rows in all timepoints.
5. Perform cropping using the "end" of the row as reference (the end referring to the part of the trench farthest from <br>the feeding channel).

Step 5 performs a simple algorithm to determine the orientation of each trench:

```
row_orientations = [] # A list of row orientations, starting from the topmost row
if the number of detected rows == 'Number of Rows': 
    row_orientations.append('Orientation')
elif the number of detected rows < 'Number of Rows':
    row_orientations.append('Orientation when < expected rows')
for row in rows:
    if row_orientations[-1] == downward:
        row_orientations.append(upward)
    elif row_orientations[-1] == upward:
        row_orientations.append(downward)
```

Additionally, if the device tranches face a single direction, alternation of row orientation may be turned off by setting the<br> `Alternate Orientation?` argument to False. The `Use Median Drift?` argument, when set to True, will use the<br> median drift in y across all FOVs for drift correction, instead of doing drift correction independently for all FOVs. <br>This can be useful if there are a large fraction of FOVs which are failing drift correction. Note that `Use Median Drift?` <br>sets this behavior for both y and x drift correction.

The arguments for this step are:

 - **y_min_edge_dist (int)** : Minimum row length necessary for detection (filters out small detected objects).

 - **padding_y (int)** : Padding to add to the end of trench row when cropping in the y-dimension.

 - **trench_len_y (int)** : Length from the end of each trench row to the feeding channel side of the crop.

 - **Number of Rows (int)** : The number of rows to expect in your image. For instance, two in the example image.
 
 - **Alternate Orientation? (bool)** : Whether or not to alternate the orientation of consecutive rows.

 - **Orientation (int)** : The orientation of the top-most row where 0 corresponds to a trench with a downward-oriented trench <br>opening and 1 corresponds to a trench with an upward-oriented trench opening.

 - **Orientation when < expected rows(int)** : The orientation of the top-most row when the number of detected rows is less than <br>expected. Useful if your trenches drift out of your image in some FOVs.
 
 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.

 - **images_per_row(int)** : How many images to output per row for this widget.

Running the following widget will display y-cropped images for each fov and timepoint.

In [None]:
interactive_kymograph.preview_y_precentiles_consensus_interactive()

In [None]:
interactive_kymograph.preview_y_crop_interactive()

##### Tune trench detection hyperparameters

Next, we will detect the positions of trenchs in the y-cropped images as follows:

1. Reducing each 2D image to a 1D signal along the x-axis by computing the qth percentile of the data along the y-axis.
2. Determine the signal background by smoothing this signal using a large median kernel.
3. Subtract the background signal.
4. Smooth the resultant signal using a median kernel.
5. Use an [otsu threhsold](https://imagej.net/Auto_Threshold#Otsu) to determine the trench midpoint poisitons.

After this, x-dimension drift correction of our detected midpoints will be performed as follows:

6. Begin at t=1
7. For $m \in \{midpoints(t)\}$ assign $n \in \{midpoints(t-1)\}$ to m if n is the closest midpoint to m at time $t-1$,<br>
points that are not the closest midpoint to any midpoints in m will not be mapped.
8. Compute the translation of each midpoint at time.
9. Take the average of this value as the x-dimension drift from time t-1 to t.

The arguments for this step are:

 - **t (int)** : Timepoint to examine the percentiles and threshold in.

 - **x_percentile (int)** : Percentile to use for step 1.

 - **background_kernel_x (int)** : Median kernel size to use for step 2.

 - **smoothing_kernel_x (int)** : Median kernel size to use for step 4.

 - **otsu_scaling (float)** : Scaling factor to apply to the threshold determined by Otsu's method.

Running the following widget will display the smoothed 1-D signal for each of your timepoints. In addition, the threshold <br>value for each fov will be displayed as a red line. In addition, it will display the detected midpoints for each of your timepoints. <br>If there is too much sparsity, or discontinuity, your drift correction will not be accurate.

In [None]:
interactive_kymograph.preview_x_percentiles_interactive()

##### Tune trench cropping hyperparameters

Trench cropping simply uses the drift-corrected midpoints as a reference and crops out some fixed length around them <br>
to produce an output kymograph. **Note that the current implementation does not allow trench crops to overlap**. If your<br>
trench crops do overlap, the error will not be caught here, but will cause issues later in the pipeline. As such, try <br>
to crop your trenches as closely as possible. This issue will be fixed in a later update.

The arguments for this step are:

 - **trench_width_x (int)** : Trench width to use for cropping.

 - **trench_present_thr (float)** : Trenches that appear in less than this percent of FOVs will be eliminated from the dataset.<br>
If not removed, missing positions will be inferred from the image drift.

 - **Use Median Drift? (bool)** : Whether to use the median detected drift across all FOVs, instead of the drift detected in each FOV individually.


Running the following widget will display a random kymograph for each row in each fov and will also produce midpoint plots <br>showing retained midpoints

In [None]:
interactive_kymograph.preview_kymographs_interactive()

##### Export and save hyperparameters

Run the following line to register and display the parameters you have selected for kymograph creation.

In [None]:
interactive_kymograph.process_results()

If you are satisfied with the above parameters, run the following line to write these parameters to disk at `headpath/kymograph.par`<br>
This file will be used to perform kymograph creation in the next section.

In [None]:
interactive_kymograph.write_param_file()

## Notes on improved scheduling

Make a utility that will be able to run analysis steps without intervention from the ipynb, so things can be run as "fire and forget"

I will start by adapting this to the simple single step case of the lineage trace. In order to fire off an independent job that will deploy, maintain and close the cluster dynamically for a single application I need the following information from the process:

- The estimated time for the process to complete so that the scheduler head knows how long it needs to be queued up
- The estimated minimum memory in the pool to execute the task successfully

I should be able to get this information as a method from the process.

In addition, this scheduler should attempt to minimally impact my fairshare. Features that would promote this would be:

- Don't overschedule the time usage of the process. Queue up workers with short wall times (30 mins) and just maintain a constant target size. Currently implementing this...
- Setting some kind of minimum worker number to ensure progression requires the proper amount of memory to be in place. Try setting the adaptive minimum based on memory considerations. Forget this for now, the adaptive scheduler might be able to handle this...

Implement headless scheduler with fixed resources and task before integrating with the requirement of the job itself...

Headless scheduler doesn't seem worth it....continue using with the adaptive scheduling adjustment and see if that helps...

Couldnt find an easy way to dynamically adapt the memory per worker

### Generate Kymograph

##### Start Dask Workers

Again, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2 for kymograph creation. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

##### Perform Kymograph Cropping

Now that we have our cluster scheduler spun up, we will extract kymographs using the parameters stored in `headpath/kymograph.par`. <br>
This will be handled by the `kymograph_cluster` object. This will detect trenches in all of the files present in `headpath/hdf5` that <br>
you created in the first step. It will then crop these trenches and place the crops in a series of `.hdf5` files in `headpath/kymograph`. <br>
These files will store image data in the form of `(K,T,Y,X)` arrays where K is the trench index, T is time and Y,X are the image dimensions <br>
of the crop.

The arguments for this step are:

 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **trenches_per_file** : The maximum number of trenches stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **paramfile** : Set to true if you want to use parameters from `headpath/kymograph.par` Otherwise, you will have to specify <br>
 parameters as direct arguments to `kymograph_cluster`.

In [None]:
kymoclust = tr.kymograph.kymograph_cluster(headpath=headpath, paramfile=True)

##### Begin Kymograph Cropping 

Running the following line will start the cropping process. This may be monitored by examining the `Dask Dashboard` <br>
under the link displayed earlier. Once the computation is complete, move to the next line.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.generate_kymographs(dask_controller)

In [None]:
ff = tr.focus_filter(headpath)

In [None]:
ff.choose_filter_channel_inter()

In [None]:
ff.plot_histograms(intensity_range=(0, 1500))

In [None]:
ff.plot_focus_threshold_inter()

In [None]:
ff.write_param_file()

##### Post-process Images

After the above step, kymographs will have been created for each `.hdf5` input file. They will now need to be reorganized <br>
into a new set of files such that each file has, at most, `trenches_per_file` trenches in each file.

**Do not move on until all tasks are displayed as 'in memory' in Dask.**

In [None]:
kymoclust.post_process(dask_controller, trench_timepoints_per_file=1000)

##### Check kymograph statistics

Run the next line to display some statistics from kymograph creation. The outputs are:

 - **fovs processed** : The number of FOVs successfully processed out of the total number of FOVs
 - **rows processed** : The number of rows of trenches processed out of the total number of rows
 - **trenches processed** : The number of trenches successfully processed
 - **row/fov** : The average number of rows successfully processed per FOV
 - **trenches/fov** : The average number of trenches successfully processed per FOV
 - **failed fovs** : A list of failed FOVs. Spot check these FOVs in the viewer to determine potential problems

In [None]:
kymoclust.kymo_report()

##### Shutdown Dask

Once cropping is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.daskclient.restart()

In [None]:
dask_controller.shutdown()

## Fluorescence Segmentation

Now that you have copped your data into kymographs, we will now perform segmentation/cell detection <br>
on your kymographs. Currently, this pipeline only supports segmentation of fluorescence images; however, <br>
segmentation of transmitted light imaging techniques is in development.

The output of this step will be a set of `segmentation_[File #].hdf5` files stored in `headpath/fluorsegmentation`.<br>
The image data stored in these files takes the exact same form as the kymograph data, `(K,T,Y,X)` arrays <br>
where K is the trench index, T is time, and Y,X are the crop dimensions. These arrays are accessible using <br>
keys of the form `"[Trench Row Number]"`.

Since no metadata is generated by this step, it is possible to use another segmentation algorithm on the kymograph <br>
data. The output of segmentation must be split into `segmentation_[File #].hdf5` files, where `[File #]` agrees with the<br>
corresponding `kymograph_[File #].hdf5` file. Additionally, the `(K,T,Y,X)` arrays must be of the same shape as the <br>
kymograph arrays and accessible at the corresponding `"[Trench Row Number]"` key. These files must be placed into <br>
their own folder at `headpath/foldername`. This folder may then be used in later steps.

### Test Parameters

##### Initialize the interactive segmentation class

As a first step, initialize the `tr.fluo_segmentation_interactive` class that will be handling all steps of generating a segmentation. 

In [None]:
interactive_segmentation = tr.fluo_segmentation_interactive(headpath)

##### Choose channel to segment on

In [None]:
interactive_segmentation.choose_seg_channel_inter()

#### Import data

Fill in 

You will need to tune the following `args` and `kwargs` (in order):

**fov_idx (int)** :

**n_trenches (int)** :

**t_range (tuple)** :

**t_subsample_step (int)** :

In [None]:
interactive_segmentation.import_array_inter()

##### Process data

In [None]:
interactive_segmentation.plot_processed_inter()

#### Determine Cell Mask Envelope

Fill in.

You will need to tune the following `args` and `kwargs` (in order):

**cell_mask_method (str)** : Thresholding method, can be a local or global Otsu threshold.

**cell_otsu_scaling (float)** : Scaling factor applied to determined threshold.

**local_otsu_r (int)** : Radius of thresholding kernel used in the local otsu thresholding.

In [None]:
Local Threshold Method: otsu
Background Threshold Method: triangle
Global Threshold: 20
Local Window Size: 15
Otsu Scaling: 1.0
Niblack K: 0.2
Background Threshold Scaling: 1.0
Minimum Object Size: 20

In [None]:
interactive_segmentation.plot_cell_mask_inter()

In [None]:
interactive_segmentation.plot_eig_mask_inter()

In [None]:
interactive_segmentation.plot_dist_mask_inter()

In [None]:
interactive_segmentation.plot_marker_mask_inter()

In [None]:
interactive_segmentation.process_results()

In [None]:
interactive_segmentation.write_param_file()

### Generate Segmentation

In [None]:
segment = tr.segment.fluo_segmentation_cluster(headpath, paramfile=True)

In [None]:
segment.dask_segment(dask_controller)

## Lineage Tracing

### Test Parameters

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
10

In [None]:
score_function = tr.tracking.scorefn(
    headpath,
    "fluorsegmentation",
    u_size=0.04,
    sig_size=0.02,
    u_pos=0.04,
    sig_pos=0.02,
    w_merge=0.0,
)

In [None]:
score_function.interactive_scorefn()

In [None]:
Tracking_Solver = tr.tracking.tracking_solver(
    headpath,
    "fluorsegmentation",
    ScoreFn=score_function,
    edge_limit=2,
)
data, orientation = score_function.output.result

In [None]:
Tracking_Solver.interactive_tracking(data, orientation)

In [None]:
Tracking_Solver.save_params()

In [None]:
dask_controller.shutdown()

### Generate Lineage Traces

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="1:00:00",
    local=False,
    n_workers=400,
    n_workers_min=20,
    memory="2GB",
    working_directory="/home/de64/scratch/de64/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
Tracking_Solver = tr.tracking.tracking_solver(
    headpath,
    "fluorsegmentation",
    paramfile=True,
    volume_estimation=True,
    props_list=["area", "major_axis_length", "minor_axis_length"],
    props_to_unpack={},
    pixel_scaling_factors={"area": 2, "major_axis_length": 1, "minor_axis_length": 1},
    intensity_props_list=["mean_intensity"],
)

In [None]:
Tracking_Solver.compute_all_lineages(dask_controller, entries_per_partition=100000)

In [None]:
dask_controller.shutdown()

## Experimenting with ypet-dnaN analysis

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import scipy as sp
import sklearn as skl
import dask.dataframe as dd
import dask.array as da

import scipy.stats
from sklearn.linear_model import LinearRegression

from matplotlib import pyplot as plt
import holoviews as hv

hv.extension("bokeh")

### Import Lineage

In [None]:
lineage_df = dd.read_parquet(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/lineage"
)

##temp fix
lineage_df["CellID"] = lineage_df["CellID"].astype(int)
lineage_df["Global CellID"] = lineage_df["Global CellID"].astype(int)

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="3:00:00",
    local=False,
    n_workers=50,
    n_workers_min=20,
    memory="8GB",
    working_directory="/home/de64/scratch/de64/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

### Optimizing Growth Quantification

In [None]:
def filter_df(df, query_list, client=False, repartition=False, persist=False):
    # filter_list must be in df.query format (see pandas docs)

    # returns persisted dataframe either in cluster or local

    compiled_query = " and ".join(query_list)
    out_df = df.query(compiled_query)
    if persist:
        if client:
            out_df = client.daskclient.persist(out_df)
        else:
            out_df = out_df.persist()

    if repartition:
        init_size = len(df)
        final_size = len(out_df)
        ratio = init_size // final_size
        out_df = out_df.repartition(npartitions=(df.npartitions // ratio) + 1)
        if persist:
            if client:
                out_df = client.daskclient.persist(out_df)
            else:
                out_df = out_df.persist()

    return out_df


def get_first_cell_timepoint(df):
    min_tpts = df.groupby(["Global CellID"])["timepoints"].idxmin().tolist()
    init_cells = df.loc[min_tpts]
    return init_cells


def get_last_cell_timepoint(df):
    max_tpts = df.groupby(["Global CellID"])["timepoints"].idxmax().tolist()
    fin_cells = df.loc[max_tpts]
    return fin_cells


def get_growth_and_division_stats(
    lineage_df,
    kymo_df_path,
    trench_score_thr=-75,
    absolute_time=True,
    delta_t_min=None,
    size_metrics=[
        "area",
        "major_axis_length",
        "minor_axis_length",
        "Volume",
        "Surface Area",
    ],
):

    kymo_df = dd.read_parquet(kymo_df_path)
    kymo_idx_list = lineage_df["Kymograph FOV Parquet Index"].tolist()

    if not absolute_time:
        kymo_df["time (s)"] = kymo_df["timepoints"] * delta_t_min * 60.0

    kymo_time_series = (
        kymo_df["time (s)"].loc[kymo_idx_list].compute(scheduler="threads")
    )
    kymo_time_series.index = lineage_df.index
    lineage_df["time (s)"] = kymo_time_series

    reference = filter_df(lineage_df, ["`Trench Score` < " + str(trench_score_thr)])
    query = filter_df(
        lineage_df,
        [
            "`Mother CellID` != -1",
            "`Daughter CellID 1` != -1",
            "`Daughter CellID 2` != -1",
            "`Sister CellID` != -1",
            "`Trench Score` < " + str(trench_score_thr),
        ],
    )

    init_cells = (
        get_first_cell_timepoint(query)
        .reset_index()
        .set_index("Global CellID")
        .sort_index()
    )
    fin_cells = (
        get_last_cell_timepoint(query)
        .reset_index()
        .set_index("Global CellID")
        .sort_index()
    )

    cell_min_tpt_df = (
        get_first_cell_timepoint(reference)
        .reset_index()
        .set_index("Global CellID")
        .sort_index()
    )
    cell_max_tpt_df = (
        get_last_cell_timepoint(reference)
        .reset_index()
        .set_index("Global CellID")
        .sort_index()
    )

    mother_df = cell_max_tpt_df.loc[init_cells["Mother CellID"].tolist()]
    sister_df = cell_min_tpt_df.loc[init_cells["Sister CellID"].tolist()]
    daughter_1_df = cell_min_tpt_df.loc[fin_cells["Daughter CellID 1"].tolist()]
    daughter_2_df = cell_min_tpt_df.loc[fin_cells["Daughter CellID 2"].tolist()]

    for metric in size_metrics:

        if metric == "minor_axis_length":

            init_cells["Birth: " + metric] = init_cells[metric].values
            init_cells["Division: " + metric] = fin_cells[metric].values
            init_cells["Delta: " + metric] = (
                fin_cells[metric].values - init_cells[metric].values
            )

        else:

            interp_mother_final_size = (
                (init_cells[metric].values + sister_df[metric].values)
                * mother_df[metric].values
            ) ** (1 / 2)
            sister_frac = init_cells[metric].values / (
                sister_df[metric].values + init_cells[metric].values
            )
            init_cells["Birth: " + metric] = sister_frac * interp_mother_final_size

            init_cells["Division: " + metric] = (
                (daughter_1_df[metric].values + daughter_2_df[metric].values)
                * fin_cells[metric].values
            ) ** (1 / 2)

            init_cells["Delta: " + metric] = (
                init_cells["Division: " + metric].values
                - init_cells["Birth: " + metric].values
            )

    init_cells["Final timepoints"] = daughter_1_df[
        "timepoints"
    ].values  # counting a timepoint in which a division occurs as a full timepoint, hacky
    init_cells["Delta Timepoints"] = (
        init_cells["Final timepoints"] - init_cells["timepoints"]
    )

    # if absolute_time:
    interpolated_final_time = (
        fin_cells["time (s)"].values + daughter_1_df["time (s)"].values
    ) / 2  # interpolating under the same assumptions as the size quantification
    interpolated_init_time = (
        init_cells["time (s)"].values + mother_df["time (s)"].values
    ) / 2
    init_cells["Final time (s)"] = interpolated_final_time
    init_cells["Delta time (s)"] = interpolated_final_time - interpolated_init_time

    query = (
        query.reset_index()
        .set_index(["Global CellID", "timepoints"])
        .sort_index()
        .reset_index(level=1)
    )

    # if absolute_time:

    delta_t_series = query.groupby(level=0, sort=False)["time (s)"].apply(
        lambda x: ((x[1:].values - x[:-1].values))
    )

    init_time_gap = init_cells["time (s)"].values - interpolated_init_time
    final_time_gap = interpolated_final_time - fin_cells["time (s)"].values

    for size_metric in size_metrics:  # Havn't decided between mean and median
        init_size = query.groupby(level=0, sort=False)[size_metric].apply(
            lambda x: x.iloc[0]
        )
        final_size = query.groupby(level=0, sort=False)[size_metric].apply(
            lambda x: x.iloc[-1]
        )

        init_linear_gr = init_size - (init_cells["Birth: " + size_metric].values)
        init_linear_gr = init_linear_gr / init_time_gap
        final_linear_gr = (init_cells["Division: " + size_metric].values) - final_size
        final_linear_gr = final_linear_gr / final_time_gap

        init_exp_gr = 2 * (
            (init_size - (init_cells["Birth: " + size_metric].values))
            / (init_size + (init_cells["Birth: " + size_metric].values))
        )
        init_exp_gr = init_exp_gr / init_time_gap
        final_exp_gr = 2 * (
            ((init_cells["Division: " + size_metric].values) - final_size)
            / ((init_cells["Division: " + size_metric].values) + final_size)
        )
        final_exp_gr = final_exp_gr / final_time_gap

        all_linear_gr = query.groupby(level=0, sort=False)[size_metric].apply(
            lambda x: x[1:].values - x[:-1].values
        )  ##needs to interpolate last growth rate
        all_linear_gr = all_linear_gr / delta_t_series
        all_linear_gr = all_linear_gr.apply(lambda x: x.tolist())
        all_linear_gr = all_linear_gr.to_frame()
        all_linear_gr = all_linear_gr.rename(columns={0: "Main List"})
        all_linear_gr["Start"] = init_linear_gr
        all_linear_gr["End"] = final_linear_gr
        all_linear_gr["Appended"] = all_linear_gr.apply(
            lambda x: [x["Start"]] + x["Main List"] + [x["End"]], axis=1
        )
        mean_linear_gr = all_linear_gr["Appended"].apply(lambda x: np.nanmean(x))
        del all_linear_gr
        mean_linear_gr = mean_linear_gr * 3600  # size unit per hr

        all_exp_gr = query.groupby(level=0, sort=False)[size_metric].apply(
            lambda x: 2
            * ((x[1:].values - x[:-1].values) / (x[1:].values + x[:-1].values))
        )  ##needs to interpolate last growth rate
        all_exp_gr = all_exp_gr / delta_t_series
        all_exp_gr = all_exp_gr.apply(lambda x: x.tolist())
        all_exp_gr = all_exp_gr.to_frame()
        all_exp_gr = all_exp_gr.rename(columns={0: "Main List"})
        all_exp_gr["Start"] = init_exp_gr
        all_exp_gr["End"] = final_exp_gr
        all_exp_gr["Appended"] = all_exp_gr.apply(
            lambda x: [x["Start"]] + x["Main List"] + [x["End"]], axis=1
        )
        mean_exp_gr = all_exp_gr["Appended"].apply(lambda x: np.nanmean(x))
        del all_exp_gr
        mean_exp_gr = mean_exp_gr * 3600  # size unit per hr

        mean_cell_size_metric = query.groupby(level=0, sort=False)[size_metric].apply(
            lambda x: np.nanmean(x.values)
        )

        init_cells["Mean: " + size_metric] = mean_cell_size_metric
        init_cells["Mean Linear Growth Rate: " + size_metric] = mean_linear_gr
        init_cells["Mean Exponential Growth Rate: " + size_metric] = mean_exp_gr
    #     else:
    #         for size_metric in size_metrics: # Havn't decided between mean and median
    #             mean_linear_gr = query.groupby(level=0,sort=False)[size_metric].apply(lambda x: np.nanmean(x[1:].values - x[:-1].values))
    #             mean_linear_gr = (mean_linear_gr/delta_t_min)*60 #size unit per hr
    #             mean_exp_gr = query.groupby(level=0,sort=False)[size_metric].apply(lambda x: np.nanmean((2*(x[1:].values - x[:-1].values))/(x[1:].values + x[:-1].values)))
    #             mean_exp_gr = (mean_exp_gr/delta_t_min)*60 #exponential size unit per hr
    #             mean_cell_size_metric = query.groupby(level=0,sort=False)[size_metric].apply(lambda x: np.nanmean(x.values))

    #             init_cells["Mean: " + size_metric] = mean_cell_size_metric
    #             init_cells["Mean Linear Growth Rate: " + size_metric] = mean_linear_gr
    #             init_cells["Mean Exponential Growth Rate: " + size_metric] = mean_exp_gr

    median_mchy_intensity = query.groupby("Global CellID")[
        "mCherry mean_intensity"
    ].apply(lambda x: np.nanmean(x.values))
    init_cells["Mean: mCherry Intensity"] = median_mchy_intensity

    init_cells = init_cells.rename(columns={"timepoints": "initial timepoints"})

    return init_cells


def get_all_growth_and_division_stats(
    lineage_df,
    kymo_df_path,
    trench_score_thr=-75,
    absolute_time=True,
    delta_t_min=None,
    size_metrics=[
        "area",
        "major_axis_length",
        "minor_axis_length",
        "Volume",
        "Surface Area",
    ],
):
    test_partition = lineage_df.get_partition(0).compute()
    test_partition = get_growth_and_division_stats(
        test_partition,
        kymo_df_path,
        trench_score_thr=trench_score_thr,
        absolute_time=absolute_time,
        delta_t_min=delta_t_min,
        size_metrics=size_metrics,
    )

    growth_div_df = dd.map_partitions(
        get_growth_and_division_stats,
        lineage_df,
        kymo_df_path,
        trench_score_thr=trench_score_thr,
        absolute_time=absolute_time,
        delta_t_min=delta_t_min,
        size_metrics=size_metrics,
        meta=test_partition,
    )

    return growth_div_df

In [None]:
test = lineage_df["Trench Score"].compute()

In [None]:
moo = (
    lineage_df.groupby("trenchid").apply(lambda x: x.iloc[0]["Trench Score"]).compute()
)

In [None]:
growth_div_df = get_all_growth_and_division_stats(
    lineage_df,
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/metadata",
    absolute_time=False,
    delta_t_min=5.0,
    trench_score_thr=-85,
).persist()

In [None]:
growth_div_df.to_parquet(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/2022-02-24_growth_division_df",
    engine="pyarrow",
    overwrite=True,
)

In [None]:
growth_div_df = dd.read_parquet(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/2022-02-24_growth_division_df"
)
growth_div_df = growth_div_df.reset_index().set_index("trenchid")

In [None]:
mean_yfp = (
    growth_div_df.groupby("trenchid")
    .apply(lambda x: np.mean(x["YFP mean_intensity"]))
    .compute()
)
bright_trenchids = mean_yfp[mean_yfp > 1450].index.tolist()

In [None]:
growth_div_df = growth_div_df.loc[bright_trenchids].persist()

In [None]:
plt.hist(mean_yfp, bins=50)
plt.show()

In [None]:
def remove_early_outliers(
    final_output_df_lineage,
    trenchid_name="phenotype trenchid",
    early_timepoint_cutoff=12,
    gaussian_subsample=0.2,
    percentile_threshold=10,
    filter_params=[
        "Mean Linear Growth Rate: Volume",
        "Mean Exponential Growth Rate: Volume",
        "Division: major_axis_length",
        "Mean: minor_axis_length",
        "Mean: mCherry Intensity",
        "Delta time (s)",
    ],
    plot_values_names=[
        "Volume Growth Rate (linear)",
        "Volume Growth Rate (ratio)",
        "Division Length",
        "Minor Axis Length",
        "Mean mCherry Intensity",
        "Interdivision Time",
    ],
):

    final_output_df_trench_groupby = final_output_df_lineage.groupby(
        trenchid_name, sort=False
    )
    early_tpt_df = final_output_df_trench_groupby.apply(
        lambda x: x[x["Final timepoints"] < early_timepoint_cutoff].reset_index(
            drop=True
        )
    ).persist()
    for filter_param in filter_params:
        early_param_series = early_tpt_df[filter_param]
        all_param_values = (
            early_param_series.sample(frac=gaussian_subsample).compute().tolist()
        )
        gaussian_fit = sp.stats.norm.fit(all_param_values)
        gaussian_fit = sp.stats.norm(loc=gaussian_fit[0], scale=gaussian_fit[1])

        early_param_series = dd.from_pandas(
            early_param_series.compute().droplevel(1), npartitions=50
        )
        trench_probability = early_param_series.groupby(trenchid_name).apply(
            lambda x: np.exp(np.sum(gaussian_fit.logpdf(x)) / len(x)), meta=float
        )

        final_output_df_lineage[
            filter_param + ": Probability"
        ] = trench_probability.persist()

    final_output_df_onetrench = (
        final_output_df_lineage.groupby(trenchid_name)
        .apply(lambda x: x.iloc[0])
        .compute()
    )

    plt.figure(figsize=(22, 16))
    query_list = []
    for i, filter_param in enumerate(filter_params):
        prob_threshold = np.nanpercentile(
            final_output_df_onetrench[filter_param + ": Probability"].tolist(),
            percentile_threshold,
        )
        query = "`" + filter_param + ": Probability` > " + str(prob_threshold)
        query_list.append(query)

        min_v, max_v = (
            np.min(final_output_df_onetrench[filter_param + ": Probability"]),
            np.max(final_output_df_onetrench[filter_param + ": Probability"]),
        )

        plt.subplot(2, 3, i + 1)
        plt.title(plot_values_names[i], fontsize=22)
        plt.xlabel("Unnormalized Likelihood", fontsize=18)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.hist(
            final_output_df_onetrench[
                final_output_df_onetrench[filter_param + ": Probability"]
                < prob_threshold
            ][filter_param + ": Probability"].tolist(),
            bins=50,
            range=(min_v, max_v),
        )
        plt.hist(
            final_output_df_onetrench[
                final_output_df_onetrench[filter_param + ": Probability"]
                >= prob_threshold
            ][filter_param + ": Probability"].tolist(),
            bins=50,
            range=(min_v, max_v),
        )

    compiled_query = " and ".join(query_list)
    final_output_df_onetrench_filtered = final_output_df_onetrench.query(compiled_query)
    final_output_df_filtered = final_output_df_lineage.loc[
        final_output_df_onetrench_filtered.index.tolist()
    ].persist()

    return final_output_df_filtered

In [None]:
final_output_df_filtered = remove_early_outliers(
    growth_div_df,
    trenchid_name="trenchid",
    early_timepoint_cutoff=30,
    gaussian_subsample=0.5,
    percentile_threshold=30,
)
plt.savefig("Prob_threshold.png", dpi=500)

In [None]:
from statsmodels.nonparametric import kernel_regression
from scipy.stats import iqr
from statsmodels.nonparametric.smoothers_lowess import lowess
import sklearn as skl
from tslearn.clustering import TimeSeriesKMeans
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
import copy


def timeseries_kernel_reg(df, y_label, min_tpt, max_tpt, kernel_bins, kernel_bandwidth):
    def kernel_reg(
        x_arr,
        y_arr,
        start=min_tpt,
        end=max_tpt,
        kernel_bins=kernel_bins,
        kernel_bandwidth=kernel_bandwidth,
    ):
        intervals = np.linspace(start, end, num=kernel_bins, dtype=float)
        w = kernel_regression.KernelReg(
            y_arr,
            x_arr,
            "c",
            reg_type="lc",
            bw=np.array([kernel_bandwidth]),
            ckertype="gaussian",
        ).fit(intervals)[0]
        reg_x, reg_y = (intervals, w)
        return reg_x, reg_y

    kernel_result = df.groupby("trenchid").apply(
        lambda x: kernel_reg(
            (x["Final time (s)"].values - (x["Delta time (s)"].values / 2)),
            x[y_label].values,
        )[1],
        meta=float,
    )
    return kernel_result


def get_all_kernel_regs(
    df, y_label_list, min_tpt=0, max_tpt=36000, kernel_bins=20, kernel_bandwidth=2700
):
    out_df = copy.copy(df)

    for y_label in y_label_list:
        kernel_result = timeseries_kernel_reg(
            out_df, y_label, min_tpt, max_tpt, kernel_bins, kernel_bandwidth
        )
        out_df["Kernel Trace: " + y_label] = kernel_result.persist()

    return out_df

In [None]:
params_to_trace = [
    "Mean Linear Growth Rate: Volume",
    "Mean Exponential Growth Rate: Volume",
    "Birth: major_axis_length",
    "Division: major_axis_length",
    "Birth: Volume",
    "Division: Volume",
    "Birth: Surface Area",
    "Division: Surface Area",
    "Mean: minor_axis_length",
    "Mean: mCherry Intensity",
    "Delta time (s)",
]

In [None]:
final_output_df_filtered

In [None]:
trenchiddf = get_all_kernel_regs(
    final_output_df_filtered,
    params_to_trace,
    min_tpt=0,
    max_tpt=50000,
    kernel_bins=40,
    kernel_bandwidth=2700,
)

In [None]:
50000 - 12500

In [None]:
div_volume_df = (
    trenchiddf["Kernel Trace: Division: Volume"]
    .groupby("trenchid")
    .apply(lambda x: x.iloc[0])
    .compute()
)

In [None]:
div_size_arr = np.stack(div_volume_df.tolist())

In [None]:
plt.plot(np.mean(div_size_arr, axis=0))

In [None]:
trenchiddf_early = trenchiddf[trenchiddf["time (s)"] < 12500]
trenchiddf_late = trenchiddf[trenchiddf["time (s)"] > 37500]

In [None]:
trenchiddf_late.compute()

In [None]:
remaining_file_idx_lookup = (
    trenchiddf.groupby("File Index")
    .apply(lambda x: sorted(x["File Trench Index"].unique().tolist()))
    .compute()
    .sort_index()
    .to_dict()
)

In [None]:
import os
import lmfit
from lmfit.lineshapes import gaussian2d
import h5py
import skimage as sk

In [None]:
with h5py.File(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_400.hdf5",
    "r",
) as infile:
    yfp_data = infile["YFP"][:]
with h5py.File(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_400.hdf5",
    "r",
) as infile:
    seg_data = infile["data"][:]

idx = 3
test = yfp_data[idx][0] / 65535

test = sk.filters.gaussian(test, sigma=2.0)

min_sig = 2
max_sig = 8

blobs_log = sk.feature.blob_log(
    test,
    min_sigma=min_sig,
    max_sigma=max_sig,
    num_sigma=max_sig - min_sig,
    threshold=0.001,
)

fig, axes = plt.subplots(1, 1, figsize=(18, 10))
axes.imshow(test)
for blob in blobs_log:
    y, x, r = blob
    c = plt.Circle((x, y), r, color="r", linewidth=2, fill=False)
    axes.add_patch(c)
plt.show()

In [None]:
with h5py.File(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_10.hdf5",
    "r",
) as infile:
    yfp_data = infile["YFP"][:]
with h5py.File(
    "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_10.hdf5",
    "r",
) as infile:
    seg_data = infile["data"][:]

In [None]:
idx = 2

yfp_handle = tr.kymo_handle()
yfp_handle.import_wrap(yfp_data[idx])

seg_handle = tr.kymo_handle()
seg_handle.import_wrap(seg_data[idx])

In [None]:
plt.imshow(yfp_handle.return_unwrap())

In [None]:
plt.imshow(seg_handle.return_unwrap())

In [None]:
## get time intervals from length dynamics

### a) Making fork plots (basic)


orientationfloat
Angle between the 0th axis (rows) and the major axis of the ellipse that has the same second moments as the region, ranging from -pi/2 to pi/2 counter-clockwise.

skimage.transform.rotate(image, angle, resize=False, center=None, order=None, mode='constant', cval=0, clip=True, preserve_range=False)

In [None]:
import os

In [None]:
yfp_data.shape

In [None]:
file_indices = np.sort(
    [
        int(item.split("_")[1].split(".")[0])
        for item in os.listdir(
            "/home/de64/scratch/de64/sync_folder/2022-01-23_DE511_test/kymograph/"
        )
        if "kymograph" in item
    ]
)

In [None]:
for i in range(yfp_data.shape[0]):
    yfp_data[i] = sk.filters.gaussian(yfp_data[i], sigma=2.0, preserve_range=True)

In [None]:
nan_thr = 3
percentile = 99
t_min = 0
t_max = 60

percentile_traces = []

file_indices = np.sort(
    [
        int(item.split("_")[1].split(".")[0])
        for item in os.listdir(
            "/home/de64/scratch/de64/sync_folder/2022-01-23_DE511_test/kymograph/"
        )
        if "kymograph" in item
    ]
)

for file_idx in file_indices:
    if file_idx not in remaining_file_idx_lookup.keys():
        continue
    trench_indices = remaining_file_idx_lookup[file_idx]
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        yfp_data = infile["YFP"][:]
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        seg_data = infile["data"][:]

    for idx in trench_indices:

        for t in range(t_min, t_max):
            yfp_data[idx, t] = sk.filters.gaussian(
                yfp_data[idx, t], sigma=2.0, preserve_range=True
            )

            rps = sk.measure.regionprops(seg_data[idx][t], yfp_data[idx][t])
            degrees = [(rp.orientation / (2 * np.pi)) * 360 for rp in rps]
            centroids = [rp.centroid for rp in rps]

            for i, degree in enumerate(degrees):
                centroid = centroids[i]
                masked_img = (seg_data[idx][t] == (i + 1)) * yfp_data[idx][t]
                rotated_img = sk.transform.rotate(
                    masked_img, degree, center=centroid[::-1], preserve_range=True
                )
                rotated_img[rotated_img == 0.0] = np.NaN

                n_not_nan = np.sum(~np.isnan(rotated_img), axis=1)
                nan_mask = n_not_nan > nan_thr
                nan_masked_rotated_img = rotated_img[nan_mask]

                percentile_trace = np.nanpercentile(
                    nan_masked_rotated_img, percentile, axis=1
                )
                percentile_traces.append(percentile_trace)
percentile_traces = np.array(percentile_traces)
trace_lens = np.array([len(trace) for trace in percentile_traces])
trace_bins = np.sort(np.unique(trace_lens))
med_traces = []
for trace_len in trace_bins:
    trace_group = np.stack(percentile_traces[trace_lens == trace_len])
    med_trace = np.nanmedian(trace_group, axis=0)
    med_traces.append(med_trace)
max_trace_len = np.max(trace_bins)
trace_dif = max_trace_len - trace_bins
padded_traces = np.array(
    [
        np.pad(
            trace,
            ((trace_dif[i] // 2) + trace_dif[i] % 2, trace_dif[i] // 2),
            constant_values=np.NaN,
        )
        for i, trace in enumerate(med_traces)
    ]
)

In [None]:
trace_lens = np.array([len(trace) for trace in percentile_traces])
trace_bins = np.sort(np.unique(trace_lens))[:100]
med_traces = []
for trace_len in trace_bins:
    trace_group = np.stack(percentile_traces[trace_lens == trace_len])
    med_trace = np.nanmedian(trace_group, axis=0)
    med_traces.append(med_trace)
max_trace_len = np.max(trace_bins)
trace_dif = max_trace_len - trace_bins
padded_traces = np.array(
    [
        np.pad(
            trace,
            ((trace_dif[i] // 2) + trace_dif[i] % 2, trace_dif[i] // 2),
            constant_values=np.NaN,
        )
        for i, trace in enumerate(med_traces)
    ]
)

In [None]:
plt.plot(percentile_traces[1])
plt.show()

In [None]:
plt.imshow(padded_traces[:80], vmin=1350, vmax=1800)

In [None]:
nan_thr = 3
percentile = 99
t_min = 150
t_max = 200

percentile_traces = []

file_indices = np.sort(
    [
        int(item.split("_")[1].split(".")[0])
        for item in os.listdir(
            "/home/de64/scratch/de64/sync_folder/2022-01-23_DE511_test/kymograph/"
        )
        if "kymograph" in item
    ]
)

for file_idx in file_indices:
    if file_idx not in remaining_file_idx_lookup.keys():
        continue
    trench_indices = remaining_file_idx_lookup[file_idx]
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        yfp_data = infile["YFP"][:]
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        seg_data = infile["data"][:]

    for idx in trench_indices:

        for t in range(t_min, t_max):
            yfp_data[idx, t] = sk.filters.gaussian(
                yfp_data[idx, t], sigma=2.0, preserve_range=True
            )

            rps = sk.measure.regionprops(seg_data[idx][t], yfp_data[idx][t])
            degrees = [(rp.orientation / (2 * np.pi)) * 360 for rp in rps]
            centroids = [rp.centroid for rp in rps]

            for i, degree in enumerate(degrees):
                centroid = centroids[i]
                masked_img = (seg_data[idx][t] == (i + 1)) * yfp_data[idx][t]
                rotated_img = sk.transform.rotate(
                    masked_img, degree, center=centroid[::-1], preserve_range=True
                )
                rotated_img[rotated_img == 0.0] = np.NaN

                n_not_nan = np.sum(~np.isnan(rotated_img), axis=1)
                nan_mask = n_not_nan > nan_thr
                nan_masked_rotated_img = rotated_img[nan_mask]

                percentile_trace = np.nanpercentile(
                    nan_masked_rotated_img, percentile, axis=1
                )
                percentile_traces.append(percentile_trace)
percentile_traces = np.array(percentile_traces)

In [None]:
trace_lens = np.array([len(trace) for trace in percentile_traces])
trace_bins = np.sort(np.unique(trace_lens))[:100]
med_traces = []
for trace_len in trace_bins:
    trace_group = np.stack(percentile_traces[trace_lens == trace_len])
    med_trace = np.nanmedian(trace_group, axis=0)
    med_traces.append(med_trace)
max_trace_len = np.max(trace_bins)
trace_dif = max_trace_len - trace_bins
padded_traces = np.array(
    [
        np.pad(
            trace,
            ((trace_dif[i] // 2) + trace_dif[i] % 2, trace_dif[i] // 2),
            constant_values=np.NaN,
        )
        for i, trace in enumerate(med_traces)
    ]
)

In [None]:
trace_bins[:80]

In [None]:
plt.imshow(padded_traces[30:100], vmin=1350, vmax=1800)

In [None]:
plt.hist(padded_traces.flatten())

### b) Making fork plots (Suckjun localization)

In [None]:
import os
import lmfit
from lmfit.lineshapes import gaussian2d

In [None]:
idx = 11
test = yfp_data[idx][0] / 65535

In [None]:
idx = 10
test = yfp_data[idx][0] / 65535

min_sig = 2
max_sig = 8

blobs_log = sk.feature.blob_log(
    test,
    min_sigma=min_sig,
    max_sigma=max_sig,
    num_sigma=max_sig - min_sig,
    threshold=0.001,
)

fig, axes = plt.subplots(1, 1, figsize=(18, 10))
axes.imshow(test)
for blob in blobs_log:
    y, x, r = blob
    c = plt.Circle((x, y), r, color="r", linewidth=2, fill=False)
    axes.add_patch(c)
plt.show()

In [None]:
blob_pad = 3

# def get_foci_list(img,blobs_log,blob_pad = 3): ## too slow, abondened for now...
#     xy_foci = []
#     for blob in blobs_log:
#         y_blob, x_blob, r_blob = blob
#         y_blob, x_blob, r_blob = int(y_blob), int(x_blob), int(r_blob)
#         top_left = (y_blob-r_blob,x_blob-r_blob)
#         blob_patch = img[top_left[0]-blob_pad:(top_left[0]+r_blob+blob_pad+1),top_left[1]-blob_pad:(top_left[1]+r_blob+blob_pad+1)]
#         x_mesh,y_mesh = np.meshgrid(range(blob_patch.shape[1]),range(blob_patch.shape[0]))
#         model = lmfit.models.Gaussian2dModel()
#         params = model.guess(blob_patch.flatten(), x_mesh.flatten(), y_mesh.flatten())
#         result = model.fit(blob_patch.flatten(), x=x_mesh.flatten(), y=y_mesh.flatten(), params=params)
#         fit = model.func(x_mesh, y_mesh, **result.best_values)
#         x_fit = top_left[1]+result.best_values['centerx']-blob_pad
#         y_fit = top_left[0]+result.best_values['centery']-blob_pad
#         xy_focus = [x_fit,y_fit]
#         xy_foci.append(xy_focus)
#     return xy_foci

# try taking the blob region, getting regionprops and extracting the weighted centroid, should be basically similar and much faster
def get_foci_list(img, blobs_log):
    disk_img = np.zeros(img.shape, dtype=bool)
    for blob in blobs_log:
        y_blob, x_blob, r_blob = blob
        y_blob, x_blob, r_blob = int(y_blob), int(x_blob), int(r_blob)

        disk_mask = sk.morphology.disk(r_blob, dtype=bool)

        disk_left = x_blob - r_blob
        disk_right = x_blob + r_blob + 1
        disk_top = y_blob - r_blob
        disk_bottom = y_blob + r_blob + 1

        cropped_disk_left = disk_left >= 0
        cropped_disk_right = disk_right <= disk_img.shape[1]
        cropped_disk_top = disk_top >= 0
        cropped_disk_bottom = disk_bottom <= disk_img.shape[0]

        if (
            cropped_disk_left
            and cropped_disk_right
            and cropped_disk_top
            and cropped_disk_bottom
        ):
            disk_img[disk_top:disk_bottom, disk_left:disk_right] = disk_mask[
                : (r_blob * 2) + 1, : (r_blob * 2) + 1
            ]

    disk_img = sk.measure.label(disk_img)
    rps = sk.measure.regionprops(disk_img, img)
    weighted_centroids = [rp.centroid_weighted for rp in rps]

    return weighted_centroids

    #     blob_patch = img[top_left[0]-blob_pad:(top_left[0]+r_blob+blob_pad+1),top_left[1]-blob_pad:(top_left[1]+r_blob+blob_pad+1)]
    #     x_mesh,y_mesh = np.meshgrid(range(blob_patch.shape[1]),range(blob_patch.shape[0]))
    #     model = lmfit.models.Gaussian2dModel()
    #     params = model.guess(blob_patch.flatten(), x_mesh.flatten(), y_mesh.flatten())
    #     result = model.fit(blob_patch.flatten(), x=x_mesh.flatten(), y=y_mesh.flatten(), params=params)
    #     fit = model.func(x_mesh, y_mesh, **result.best_values)
    #     x_fit = top_left[1]+result.best_values['centerx']-blob_pad
    #     y_fit = top_left[0]+result.best_values['centery']-blob_pad
    #     xy_focus = [x_fit,y_fit]
    #     xy_foci.append(xy_focus)
    # return xy_foci

In [None]:
sk.morphology.disk(4, dtype=bool)

In [None]:
plt.imshow(test)
plt.scatter(xy_foci[:, 0], xy_foci[:, 1], color="r")

In [None]:
disk_img = np.zeros(norm_yfp_data.shape, dtype=bool)

In [None]:
plt.imshow(disk_img)

In [None]:
y_blob

In [None]:
x_blob

In [None]:
r_blob

In [None]:
disk_mask[:, :]

In [None]:
disk_img.shape

In [None]:
disk_left = x_blob - r_blob
disk_right = x_blob + r_blob
disk_top = y_blob - r_blob
disk_bottom = y_blob + r_blob

cropped_disk_left = max(0, disk_left)
cropped_disk_right = min(disk_img.shape[1], disk_right)
cropped_disk_top = max(0, disk_top)
cropped_disk_bottom = min(disk_img.shape[0], disk_bottom)

del_left = cropped_disk_left - disk_left
del_right = disk_right - cropped_disk_right
del_top = cropped_disk_top - disk_top
del_bottom = disk_bottom - cropped_disk_bottom

In [None]:
disk_mask = disk_mask[
    cropped_disk_top : cropped_disk_bottom + 1,
    cropped_disk_left : cropped_disk_right + 1,
]

In [None]:
disk_img = disk_img[
    cropped_disk_top : cropped_disk_bottom + 1,
    cropped_disk_left : cropped_disk_right + 1,
]

In [None]:
disk_mask.shape

In [None]:
disk_img[
    cropped_disk_top : cropped_disk_bottom + 1,
    cropped_disk_left : cropped_disk_right + 1,
] = disk_mask[
    cropped_disk_top - disk_top : (r_blob * 2) + del_bottom + 1,
    cropped_disk_left - disk_left : (r_blob * 2) + del_right + 1,
]

In [None]:
disk_mask.shape

In [None]:
disk_right - cropped_disk_right

In [None]:
disk_img.shape[1]

In [None]:
for blob in blobs_log:
    y_blob, x_blob, r_blob = blob
    y_blob, x_blob, r_blob = int(y_blob), int(x_blob), int(r_blob)

    disk_mask = sk.morphology.disk(r_blob, dtype=bool)

    disk_left = x_blob - r_blob
    disk_right = x_blob + r_blob
    disk_top = y_blob - r_blob
    disk_bottom = y_blob + r_blob

    cropped_disk_left = max(0, disk_left)
    cropped_disk_right = min(disk_img.shape[1], disk_right)
    cropped_disk_top = max(0, disk_top)
    cropped_disk_bottom = min(disk_img.shape[0], disk_bottom)

    del_left = cropped_disk_left - disk_left
    del_right = disk_right - cropped_disk_right
    del_top = cropped_disk_top - disk_top
    del_bottom = disk_bottom - cropped_disk_bottom

    disk_img[
        cropped_disk_top : cropped_disk_bottom + 1,
        cropped_disk_left : cropped_disk_right + 1,
    ] = disk_mask[
        cropped_disk_top - disk_top : (r_blob * 2) + del_bottom + 1,
        cropped_disk_left - disk_left : (r_blob * 2) + del_right + 1,
    ]

In [None]:
file_indices = np.sort([int(item.split("_")[1].split(".")[0]) for item in os.listdir("/home/de64/scratch/de64/sync_folder/2022-01-23_DE511_test/kymograph/") if 'kymograph' in item])

for file_idx in file_indices:
    if file_idx not in remaining_file_idx_lookup.keys():
        continue
    trench_indices = remaining_file_idx_lookup[file_idx]
    with h5py.File("/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_" + str(file_idx) + ".hdf5", "r") as infile:
        yfp_data = infile["YFP"][:]
    with h5py.File("/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_" + str(file_idx) + ".hdf5", "r") as infile:

In [None]:
nan_thr = 3
min_sig = 3
max_sig = 8
brightness_threshold = 0.01

t_min = 0
t_max = 30

ttl_imgs = 0
bright_imgs = 0

rel_xy_foci_in_seg_list = []

file_indices = np.sort(
    [
        int(item.split("_")[1].split(".")[0])
        for item in os.listdir(
            "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/"
        )
        if "kymograph" in item
    ]
)
file_indices = file_indices[:500]

for file_idx in file_indices:
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/kymograph/kymograph_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        yfp_data = infile["YFP"][:]
    with h5py.File(
        "/home/de64/scratch/de64/sync_folder/2022-02-11_DE524_rne/fluorsegmentation/segmentation_"
        + str(file_idx)
        + ".hdf5",
        "r",
    ) as infile:
        seg_data = infile["data"][:]

    for idx in range(seg_data.shape[0]):
        for t in range(t_min, t_max):
            rps = sk.measure.regionprops(seg_data[idx][t], yfp_data[idx][t])
            radians = [rp.orientation for rp in rps]
            coords = [rp.coords for rp in rps]
            centroids = [rp.centroid for rp in rps]
            cell_lens = [rp.axis_major_length for rp in rps]  # temp, do better later

            theta = np.array(radians)
            R_all = np.array(
                [[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]
            )

            norm_yfp_data = yfp_data[idx][t] / 65535

            foreground_data = norm_yfp_data[
                norm_yfp_data > sk.filters.threshold_otsu(norm_yfp_data)
            ].flatten()
            mean_foreground = np.mean(foreground_data)

            ttl_imgs += 1
            if mean_foreground < brightness_threshold:
                continue
            bright_imgs += 1

            blobs_log = sk.feature.blob_log(
                norm_yfp_data,
                min_sigma=min_sig,
                max_sigma=max_sig,
                num_sigma=max_sig - min_sig,
                threshold=0.001,
            )
            xy_foci = get_foci_list(norm_yfp_data, blobs_log)
            xy_foci_arr = np.array(xy_foci)

            for i, radian in enumerate(radians):
                coord = coords[i]
                centroid = centroids[i]
                cell_len = cell_lens[i]

                coord_tuples = [(xy_coord[0], xy_coord[1]) for xy_coord in coord]
                xy_foci_coord_list = [
                    (item[0], item[1])
                    for item in np.round(np.array(xy_foci)).astype(int).tolist()
                ]
                foci_in_seg = [
                    f
                    for f, item in enumerate(xy_foci_coord_list)
                    if item in coord_tuples
                ]
                xy_foci_in_seg = xy_foci_arr[foci_in_seg]
                # plt.imshow(norm_yfp_data)
                # plt.scatter(xy_foci_in_seg[:,1],xy_foci_in_seg[:,0],c='r')
                # plt.imshow(seg_data[idx][t],alpha=0.4)
                # plt.show()
                if len(xy_foci_in_seg) != 0:
                    rel_xy_foci_in_seg = xy_foci_in_seg - np.array(centroid)
                    rel_xy_foci_in_seg = (
                        R_all[:, :, i] @ (rel_xy_foci_in_seg[:, ::-1].T)
                    ).T[:, ::-1]
                    rel_xy_foci_in_seg = np.concatenate(
                        [
                            rel_xy_foci_in_seg,
                            np.array(
                                [[cell_len for m in range(rel_xy_foci_in_seg.shape[0])]]
                            ).T,
                        ],
                        axis=1,
                    )
                    rel_xy_foci_in_seg_list.append(rel_xy_foci_in_seg)
rel_xy_foci_in_seg_arr = np.concatenate(rel_xy_foci_in_seg_list)
# percentile_traces = np.array(percentile_traces)
# trace_lens = np.array([len(trace) for trace in percentile_traces])
# trace_bins = np.sort(np.unique(trace_lens))
# med_traces = []
# for trace_len in trace_bins:
#     trace_group = np.stack(percentile_traces[trace_lens==trace_len])
#     med_trace = np.nanmedian(trace_group,axis=0)
#     med_traces.append(med_trace)
# max_trace_len = np.max(trace_bins)
# trace_dif = max_trace_len-trace_bins
# padded_traces = np.array([np.pad(trace,((trace_dif[i]//2)+trace_dif[i]%2,trace_dif[i]//2),constant_values=np.NaN) for i,trace in enumerate(med_traces)])

In [None]:
print(str(bright_imgs) + "/" + str(ttl_imgs))

In [None]:
pixel_scale = 0.215

scaled_foci_pos = rel_xy_foci_in_seg_arr * pixel_scale

In [None]:
pixel_scale = 0.215

y_min, y_max = -40, 40

bins = np.linspace(30, 60, num=40)

hist_list = []
for i in range(len(bins) - 1):
    bin_interval = bins[i : i + 2]
    padding = int(np.round(np.mean(bin_interval))) // 2
    under_thr = rel_xy_foci_in_seg_arr[:, 2] <= bin_interval[1]
    over_thr = bin_interval[0] < rel_xy_foci_in_seg_arr[:, 2]
    filtered_xy_foci = rel_xy_foci_in_seg_arr[under_thr & over_thr]
    y_coords = filtered_xy_foci[:, 0]
    hist_out = np.histogram(y_coords, bins=y_max - y_min, range=(y_min, y_max))[0]
    hist_out = hist_out / np.max(hist_out)
    hist_out[: y_max - padding] = np.NaN
    hist_out[y_max + padding :] = np.NaN
    hist_list.append(hist_out)
hist_arr = np.array(hist_list)

In [None]:
y_ticks = np.linspace(30 * pixel_scale, 60 * pixel_scale, num=40)
x_ticks = np.linspace(y_min * pixel_scale, y_max * pixel_scale, num=y_max - y_min)

In [None]:
y_ticks

In [None]:
padding = int(np.round(np.mean(bin_interval))) // 2

In [None]:
fig, ax1 = plt.subplots(1, 1)
ax1.imshow(hist_arr, cmap="jet")
ax1.set_xticklabels(x_ticks)
ax1.set_yticklabels(y_ticks)

In [None]:
ax1.imshow(hist_arr, cmap="jet")

In [None]:
plt.imshow(hist_out)

In [None]:
plt.hist(y_coords, bins=60, range=(-30, 30))[0]

In [None]:
xy_foci_in_seg - np.array(centroid)

In [None]:
rel_xy_foci_in_seg_list

In [None]:
xy_foci_coord_list

In [None]:
xy_foci_arr.shape

In [None]:
coord_tuples

In [None]:
rel_xy_foci_in_seg.shape

In [None]:
np.concatenate(
    [rel_xy_foci_in_seg, np.array([[0 for m in range(rel_xy_foci_in_seg.shape[0])]]).T],
    axis=1,
)

In [None]:
np.array([[0 for m in range(rel_xy_foci_in_seg.shape[0])]]).T.shape

In [None]:
xy_foci_in_seg[:, ::-1]

In [None]:
np.array(centroid)

In [None]:
rel_xy_foci_in_seg

In [None]:
rel_xy_foci_in_seg

In [None]:
theta = degrees
R = np.array([[np.cos(degrees), -np.sin(degrees)], [np.sin(degrees), np.cos(degrees)]])

In [None]:
rel_xy_foci_in_seg = (R @ (rel_xy_foci_in_seg[:, ::-1].T)).T[:, ::-1]

In [None]:
rel_xy_foci_in_seg

In [None]:
R.shape

In [None]:
rel_xy_foci_in_seg

In [None]:
R

In [None]:
rot_test = sk.transform.rotate(
    seg_data[idx][0], 5, center=centroid[::-1], preserve_range=True
).astype(int)
rps = sk.measure.regionprops(rot_test)
degrees = [(rp.orientation / (2 * np.pi)) * 360 for rp in rps]
coords = [rp.coords for rp in rps]
centroids = [rp.centroid for rp in rps]

In [None]:
degrees

In [None]:
plt.imshow(rot_test)

In [None]:
theta = np.array(degrees)
R = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]])

In [None]:
R.shape