In [None]:
# Imports
import os
import time
import logging

import tifffile as tf
import numpy as np

import scipy.ndimage as ndi

import line_utils
import image_utils
import file_utils

logger = logging.getLogger('pseudotime')
logging.basicConfig(
    filename='pseudotime_run.log',
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.DEBUG,
    datefmt='%Y-%m-%d %H:%M:%S')
#logger.addHandler(logging.StreamHandler())

In [None]:
targets = file_utils.load_targets('targets.yaml')

# IN THE CASE OF USING THE AGGEGRATE TABLE LOADER (CellPose PCA), USE THIS INSTEAD OF TARGETS
file_path = "/Users/zachcm/Documents/Documents - Beyonce/Projects/ExM_Nadja/Final_Feature_Excel_File.xlsx"

# Order of stages
# time_key = "Stage"
time_key = "Frame"
# time_order = ["CF", "RC", "CS", "RS", "SM", "BA", "A"]

# time_order = ["CF", "RC", "CS", "RS", "SM", "SM1", "SM2", "SM3", "BA", "BA1", "BA2", "BA3", "BA4", "A"]
time_order = list(range(6)) # should be the range [0, n_time_bins] from step1_find_pseduotime_bins.ipynb

# Channels per image (TODO: Auto detect)
n_ch = 4

# wavelengths to be found in the file names
# Sublists are grouped. First element of the sublist is a group name.
# NOTE: First element must be a number!
wvls = [488,[568, "orange"],[646,647,657]]

# desired channel order, specified by keys in targets
# must include MTs, septin and DAPI
# desired_channel_order = ["MTs", "septin", "DAPI", "CellMask", "MKLP1", "RacGAP1", "PRC1", "Cit-K", "anillin", "myoIIA", "myoIIB", "actin", "Septin7", "Septin11", "Septin9", "BORG4", "Tsg101", "Tsg101-ab83m", "ALIXrb", "ALIXm", "IST1", "CHMP4B"]


# had to remove anillin from list. PCA didn't write to file
desired_channel_order = ["MTs", "septin", "DAPI", "CellMask", "MKLP1", "RacGAP1", "PRC1", "Cit-K",  "myoIIA", "myoIIB", "actin", "Septin7", "Septin11", "Septin9", "BORG4", "Tsg101", "Tsg101-ab83m", "ALIXrb", "ALIXm", "IST1", "CHMP4B"]
# desired_channel_order = ["MTs", "septin", "DAPI", "CellMask", "MKLP1", "RacGAP1", "PRC1", "Cit-K", "anillin", "myoIIA", "myoIIB", "actin", "Septin7", "Septin11", "Septin9", "BORG4", "Tsg101", "ALIXrb", "ALIXm", "IST1", "CHMP4B"]
# desired_channel_order = ["MTs", "septin", "DAPI", "MKLP1", "RacGAP1"]#, "Cit-K", "PRC1"]

# desired_channel_order = ["MTs", "septin", "DAPI", "MKLP1", "RacGAP1", "anillin", "myoIIA", "myoIIB", "Cit-K", "CellMask", "PRC1", "actin"]
# desired_channel_order = ["MTs", "septin", "DAPI", "MKLP1", "RacGAP1", "anillin", "myoIIA", "Cit-K", "CellMask", "PRC1", "actin"]
# Length of cropped pseudotime region (should be roughly the line length)
length = 500
# length = 1000

# Choose the mode, one of "z-stack", "mean-proj", "mean-proj-individual" or "radial-proj"
# They are mostly self-explanatory, but mean-proj-individual produces one slice per mean projection
# Note: we can only do one at a time at the moment
mode = "mean-proj"
# mode = "mean-proj-individual"
# mode = "radial-proj"
# mode = "z-stack"

# In the worst case (one tubule is sandwiched at the top of the stack and the other
# at the bottom), this must be 2*<max stack length>-1 
num_planes = 100

# pixel sizes (we assume they are constant)
dx, dy, dz = 0.09, 0.09, 1

In [None]:
# Before we do anything, let's make sure all of our targets exist
for key in desired_channel_order:
    try:
        targets[key]
    except KeyError:
        raise KeyError(f"Element {key} does not exist in targets dictionary!")
    
# ...and let's make sure our mode is supported
assert mode in ["z-stack", "mean-proj", "mean-proj-individual", "radial-proj"]

In [None]:
# Alternative loader for aggregate metrics
# WARNING: IF YOU RUN THIS, DO NOT RUN THE CELL BELOW.
import pandas as pd 
metrics = pd.read_excel(file_path, sheet_name="processed")
metrics = metrics[metrics[time_key].isin(time_order)]

In [None]:
# We only work with processed data in this notebook
targets_processed = targets
for k, v in targets.items():
    try:
        targets_processed[k]["workbook_sheet_name"] = f"{v['workbook_sheet_name']}_processed"
        targets_processed[k]["workbook_header_row"] = 0
    except KeyError:
        continue

# Load data from the workbooks
metrics = file_utils.load_workbooks(targets_processed, desired_channel_order)
metrics = metrics[metrics[time_key].isin(time_order)]

In [None]:
# Set this to true if we want to return a z-stack
z_stack = mode == "z-stack"
radial_proj = mode == "radial-proj"
mean_proj_individual = mode == "mean-proj-individual"

groups = metrics.groupby(time_key)

plot_stack = None
n_groups = len(groups)
l2 = length // 2
rot_width = int((np.sqrt(2) * length) + 1)

if mean_proj_individual:
    num_planes = metrics.value_counts(time_key).max()
elif not z_stack:
    num_planes = 1

if radial_proj:
    group_img = np.zeros((n_groups, len(desired_channel_order), num_planes, length//8, length)).squeeze()
elif num_planes == 1:
    group_img = np.zeros((n_groups, len(desired_channel_order), length, length))
else:
    group_img = np.zeros((n_groups, len(desired_channel_order), num_planes, length, length))

for group, tup in enumerate(groups):
    name, entries = tup
    n_group = len(entries)
    logger.info(f"{name}: {n_group} averaged")
    im_proj = {}

    # Now compute the average distance
    mean_dX2 = entries['dx_septin'].mean()

    # In our second pass, average these images
    for t, tup2 in enumerate(entries.groupby("target")):
        name2, entries2 = tup2
        n_target = len(entries2)
        logger.info(f"  {name2}: {n_target} averaged")
        if mean_proj_individual:
            k = 0
        for i, ml in entries2.iterrows():
            logger.info(f"  Analyzing {os.path.basename(ml['filename'])}")

            start = time.time()

            # Get the image associated with this row
            im = image_utils.NDImage(ml["filename"], load_sorted=True)

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to load image: {duration:.2f} s")
            start = time.time()

            channel_order, group_channel_order, mt_channel = image_utils.get_channel_orders(ml["filename"], 
                                                                                            wvls,
                                                                                            n_ch, 
                                                                                            targets, 
                                                                                            desired_channel_order)

            # get x, y, angle for this row
            x, y, angle = ml[["X", "Y", "Angle"]]

            # Rotate the image  # CYX
            im_rot = image_utils.pad_rot_and_trans_im(im[:], angle, x, y, crop_length=rot_width)

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to rotate image: {duration:.2f} s")
            start = time.time()

            if z_stack or radial_proj:
                im_min = im_rot.min(-1).min(-1).min(-1)
                im_rot = (im_rot - im_min[:, None, None, None]) / (im_rot.max(-1).max(-1).max(-1) - im_min)[:, None, None, None]
            else:
                im_rot = im_rot.mean(1).squeeze()

                im_min = im_rot.min(-1).min(-1)
                im_rot = (im_rot - im_min[:, None, None]) / (im_rot.max(-1).max(-1) - im_min)[:, None, None]

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to normalize image: {duration:.2f} s")
            start = time.time()

            # Grab coordinates
            xc, yc = im_rot.shape[-1]//2, im_rot.shape[-2]//2

            if z_stack or radial_proj:
                z_coord = int(round(line_utils.find_central_pos(im_rot[...,(yc-12):(yc+13),:].sum(2).squeeze(), xc, ch=mt_channel)))
                metrics.loc[i, "z_coord"] = z_coord
                logger.info(f"  im.shape: {im_rot.shape} projection shape: {im_rot[:].max(2).squeeze().shape} z_coord: {z_coord}")

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to get centroid: {duration:.2f} s")
            start = time.time()

            logger.info(f"  xc: {xc}  yc: {yc}  length: {length}")

            # Crop the image
            im_crop = im_rot[...,(yc-l2):(yc+l2),(xc-l2):(xc+l2)]

            logger.debug(f"  im_crop shape: {im_crop.shape}")

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to crop image: {duration:.2f} s")
            start = time.time()

            # rescale the image
            # if np.isnan(ml["dX (pxl)"]):
            if np.isnan(ml["dx_septin"]):
                mag = 1
                # im_zoom = im_crop
            else:
                # mag = ml["dX (pxl)"]/mean_dX
                mag = ml["dx_septin"]/mean_dX2
                if mag < 0.707 or mag > 1.414:
                    # misfit
                    mag = 1

            if z_stack or radial_proj:
                zmag = dz // dx if radial_proj else 1
                im_zoom = ndi.zoom(im_crop, (1,zmag,1,mag), order=0)
            else:
                im_zoom = ndi.zoom(im_crop, (1,1,mag), order=0)

            stop = time.time()
            duration = stop-start
            logger.debug(f"  time to zoom image: {duration:.2f} s zoom shape: {im_zoom.shape}")
            start = time.time()

            # Crop the image again
            xc, yc = im_zoom.shape[-1]//2, im_zoom.shape[-2]//2
            im_crop2 = im_zoom[...,(yc-l2):(yc+l2),(xc-l2):(xc+l2)]
            logger.debug(f"  im_crop2 shape: {im_crop2.shape}")

            if radial_proj:
                num_planes2_crop = im_crop.shape[-3] // 2
                num_planes2 = im_crop2.shape[-3] // 2
                new_z_coord = num_planes2 - (num_planes2_crop - z_coord)*zmag
                im_crop2 = image_utils.radial_projection(im_crop2, length//8, 1,
                                                         l2, l2, 0, 
                                                         new_z_coord, dx=1, dy=1, dz=1)
                
                stop = time.time()
                duration = stop-start
                logger.debug(f"  time to get radial proj image: {duration:.2f} s")

            # Add the image with a weighting 1/length of the group 
            if z_stack:
                z_length = im_crop2.shape[-3]
                num_planes2 = num_planes // 2
                zl, zu = num_planes2 - z_coord - 1, num_planes2 + z_length - z_coord - 1

                group_img[group,group_channel_order,zl:zu,...] += (im_crop2[channel_order]/np.array([n_group, n_group, n_group, n_target])[:,None,None,None])
            elif mean_proj_individual:
                logger.debug(f"  group_img shape: {group_img[group,group_channel_order,k,...].shape}")
                group_img[group,group_channel_order,k,...] += (im_crop2[channel_order]/np.array([n_group, n_group, n_group, n_target])[:,None,None])
                k += 1
            else:
                group_img[group,group_channel_order,...] += (im_crop2[channel_order]/np.array([n_group, n_group, n_group, n_target])[:,None,None])
        

In [None]:
group_order = list(groups[time_key].unique().keys())
group_img_sorted = [group_order.index(g) for g in time_order if g in group_order]
logger.debug(group_img_sorted)

In [None]:
stack_fn = f'pseudotime_images_{mode}_{"_".join(["".join([x[0:2],x[-1]]) for x in desired_channel_order])}.ome.tif'
if z_stack or mean_proj_individual:
    tf.imwrite(stack_fn, group_img[group_img_sorted,...], metadata={'axes': 'TCZYX'}, dtype=group_img.dtype)
else:
    tf.imwrite(stack_fn, group_img[group_img_sorted,...], metadata={'axes': 'TCYX'}, dtype=group_img.dtype)