In [2]:
from tifffile import imread, imwrite
import os
import matplotlib.pyplot as plt
import cv2
import networkx as nx
import numpy as np
import skan
import skimage
from skimage.measure import profile_line
from numpy.linalg import norm
from scipy.spatial import distance
from scipy import interpolate, sparse
from scipy.sparse import csgraph
from scipy.spatial import distance
import matplotlib
from skimage.util import img_as_ubyte, img_as_uint
matplotlib.rcParams['figure.figsize'] = [10, 10]


import numpy as np
from scipy.interpolate import interp1d
from scipy.ndimage import map_coordinates
from skimage.transform import resize

images_dir = "/mnt/external.data/TowbinLab/spsalmon/straightening_test/videos/"
masks_dir = "/mnt/external.data/TowbinLab/spsalmon/straightening_test/videos_skels/"

output_dir = "/mnt/external.data/TowbinLab/spsalmon/straightening_test/videos_str_python/"

os.makedirs(output_dir, exist_ok=True)

images = [os.path.join(images_dir, x) for x in os.listdir(images_dir)]
masks = [os.path.join(masks_dir, x) for x in os.listdir(masks_dir)]

idx = 0

image = imread(images[idx])
mask = imread(masks[idx])

In [8]:
def main_branch_paths(skeleton: skan.Skeleton):
    branch_data = skan.summarize(skeleton, find_main_branch=True)
    # get the largest subskeleton
    subskeletons = branch_data.groupby(by="skeleton-id", as_index=True)
    largest_subskeleton = subskeletons.get_group(
        subskeletons["branch-distance"].sum().idxmax()
    )
    # get the nodes/edges for skeleton's branch-graph
    main_branch_edges = largest_subskeleton.loc[
        largest_subskeleton.main == True, ["node-id-src", "node-id-dst"]
    ].to_numpy()
    main_branch_nodes, counts = np.unique(
        main_branch_edges.flatten(), return_counts=True
    )

    main_branch_graph = nx.Graph()
    main_branch_graph.add_edges_from(main_branch_edges)
    end1, end2 = main_branch_nodes[np.nonzero(counts == 1)]
    # order the nodes
    main_branch_nodes = np.asarray(
        nx.shortest_path(main_branch_graph, end1, end2))

    PATH_AXIS = 0
    NODE_AXIS = 1
    ALL_PATHS = np.s_[:]
    # main_branch_mask has True if path is part of main branch
    # path is part of main branch if both of its nodes are in main_branch_nodes, hence check for sum(NODE_AXIS) == 2
    main_branch_mask = (
        skeleton.paths[ALL_PATHS, main_branch_nodes].sum(axis=NODE_AXIS) == 2
    )
    main_branch_paths = np.flatnonzero(main_branch_mask)
    # subset skeleton.paths with unordered main_branch_paths and the ordered main_branch_nodes to give something like the following:
    # 0 1 1 0 0
    # 1 1 0 0 0
    # 0 0 1 1 0
    # 0 0 0 1 1
    # argmax for each row will give the col-index for the first "1" in that row
    # argsorting the col-index gives the path order
    # i.e. read the 1s in matrix from left to right, recording the order of the rows
    # for the example above, the desired result is [1, 0, 2, 3], which is the row order that would reorder the matrix to:
    # 1 1 0 0 0
    # 0 1 1 0 0
    # 0 0 1 1 0
    # 0 0 0 1 1
    path_order = (
        skeleton.paths[np.ix_(main_branch_paths, main_branch_nodes)]
        .argmax(axis=NODE_AXIS)
        .argsort(axis=PATH_AXIS)
        .A1  # A1 converts np.matrix to 1D np.array
    )
    main_branch_paths = main_branch_paths[path_order]
    return main_branch_paths


def paths_to_coordinates(skeleton: skan.Skeleton, paths):
    # axis 0 -> points; axis 1 -> xy coords
    POINTS = 0

    path_coords = [skeleton.path_coordinates(path) for path in paths]
    ordered_path_coords = []
    for path1_coords, path2_coords in zip(path_coords[:-1], path_coords[1:]):
        path1_end = path1_coords[-1]
        path2_tip1, path2_tip2 = path2_coords[[0, -1]]
        # path1_end is indeed the end
        if np.all(path1_end == path2_tip1) or np.all(path1_end == path2_tip2):
            ordered_path_coords.append(path1_coords)
        # path1_end is actually the start, so flip
        else:
            flipped = np.flip(path1_coords, axis=POINTS)
            ordered_path_coords.append(flipped)

    last_path = path_coords[-1]
    last_path_start = last_path[0]
    if len(ordered_path_coords) == 0:
        ordered_path_coords.append(last_path)
    else:
        last_ordered_coord = ordered_path_coords[-1][-1]
        if np.all(last_ordered_coord == last_path_start):
            # if end and start align, append as is
            ordered_path_coords.append(last_path)
        else:
            # if end and start don't align, flip first then append
            ordered_path_coords.append(np.flip(last_path, axis=POINTS))

    ordered_path_coords = np.concatenate(ordered_path_coords, axis=POINTS)
    # coords need to be unique because splprep will throw an error otherwise
    _, uniq_indices = np.unique(
        ordered_path_coords, return_index=True, axis=POINTS)
    uniq_indices.sort()
    ordered_path_coords = ordered_path_coords[uniq_indices]
    return ordered_path_coords


def generate_parametrisation(backbone, max_length):
    backbone_inter_pixel_dist = np.linalg.norm(
        backbone[:-1] - backbone[1:], axis=1)
    length = backbone_inter_pixel_dist.sum()
    u_per_pixel = 1 / max_length

    backbone_inter_u_dist = backbone_inter_pixel_dist * u_per_pixel
    start = (1 - length * u_per_pixel) / 2
    us = np.zeros(len(backbone))
    us[0] = start
    us[1:] = np.cumsum(backbone_inter_u_dist) + start
    return us


def calculate_worm_length(backbone):
    return np.linalg.norm(backbone[:-1] - backbone[1:], axis=1).sum()


def spline_from_backbone(backbone, upsampling : int):
    backbone_lengths = np.asarray(calculate_worm_length(backbone))
    max_length = backbone_lengths.max()
    ys, xs = backbone.T
    us = generate_parametrisation(backbone, max_length)
    tck, u = interpolate.splprep([ys, xs], u=us, ub=us[0], ue=us[-1], s=75)
    t, c, k = tck
    c = np.asarray(c).T
    spline = interpolate.BSpline(t, c, k, extrapolate=False)
    return_us = np.linspace(us.min(), us.max(), len(us)*2)
    return spline(return_us)


def straighten_worm(image, mask, width=60):
    skeleton = skimage.morphology.medial_axis(mask)

    skeleton = skimage.morphology.skeletonize(mask)

    skeleton = skan.Skeleton(skeleton)
    main_branch = main_branch_paths(skeleton)

    backbone = paths_to_coordinates(skeleton, main_branch)

    spline, derivative_spline = spline_from_backbone(backbone)

    normal_vectors = derivative_spline[:, ::-1]
    normal_vectors[:, 0] *= -1

    normal_vectors_norm = norm(normal_vectors, axis=1)

    normal_vectors = normal_vectors / \
        np.repeat(normal_vectors_norm[:, np.newaxis], 2, axis=1)

    widths = np.stack([spline+width*normal_vectors,
                      spline-width*normal_vectors], axis=1)

    X, Y = np.meshgrid(np.arange(image.shape[0]), np.arange(image.shape[1]))
    # print(X)
    # kymo = nearest_neighbour(X, Y, image, width)

    kymo = [profile_line(image, pts[0], pts[1], linewidth=1,
                         order=0, mode='constant') for pts in widths]
    tmp = [np.interp(np.arange(-width, width),
                     np.arange(-len(ky)/2, len(ky)/2), ky) for ky in kymo]

    straightened = np.array(tmp).T
    plt.imshow(straightened)
    plt.show()


def straighten(image, spline, normal_vectors, normal_vectors_norm, strwidth):


    ypoints = spline[:, 0]
    xpoints = spline[:, 1]

    y_normal_vectors = normal_vectors[:, 0]
    x_normal_vectors = normal_vectors[:, 1]

    tempinterpim = np.zeros((strwidth, len(xpoints)))

    pos = np.zeros(len(xpoints))
    pos[0] = 0
    for i in range(0, len(xpoints)):

        x1 = xpoints[i]
        y1 = ypoints[i]

        dx = x_normal_vectors[i]
        dy = y_normal_vectors[i]
        norm_normal_vector = normal_vectors_norm[i]

        # print(norm_normal_vector)

        if dx == 0:
            xeval = np.full((strwidth,), x1)
        else:
            xeval = np.linspace(x1 - (dx * strwidth / 2),
                                x1 + (dx * strwidth / 2) + dx, strwidth+1)

        if dy == 0:
            yeval = np.full((strwidth,), y1)
        else:
            yeval = np.linspace(y1 - (dy * strwidth / 2),
                                y1 + (dy * strwidth / 2) + dy, strwidth+1)

        coords = np.vstack((yeval, xeval))
        zeval = map_coordinates(image, coords, order=1)

        tempinterpim[:, i] = zeval
        if i > 0:
            pos[i] = pos[i-1] + norm_normal_vector

    # print(pos)

    worm = np.zeros((tempinterpim.shape[0], len(xpoints)))

    for i in range(tempinterpim.shape[0]):
        interp_func = interp1d(np.arange(len(
            xpoints)), tempinterpim[i, :], kind='linear', bounds_error=False, fill_value=0)
        temp = interp_func(np.arange(1, worm.shape[1] + 1))
        worm[i, :] = temp

    return worm.astype(int)


def are_tips_flipped(tip_pair1, tip_pair2):
    pair_wise_dists = distance.cdist(tip_pair1, tip_pair2, metric="euclidean")
    unflipped_dist = np.sum(pair_wise_dists[[0, 1], [0, 1]])
    flipped_dist = np.sum(pair_wise_dists[[0, 1], [1, 0]])
    return flipped_dist < unflipped_dist


def test_straighten(image, spline, strwidth):

    # Adjust strwidth
    strwidth = strwidth-1

    ypoints = spline[:, 0]
    xpoints = spline[:, 1]

    x2 = xpoints[0]
    y2 = ypoints[0]

    tempinterpim = np.zeros((strwidth+1, len(xpoints)))

    pos = np.zeros(len(xpoints))
    pos[0] = 0
    for i in range(1, len(xpoints)):
        x1 = x2
        y1 = y2
        x2 = xpoints[i]
        y2 = ypoints[i]

        dlx = x2 - x1
        dly = y1 - y2
        le = np.sqrt(dlx * dlx + dly * dly)
        dx = dly / le
        dy = dlx / le

        if dx == 0:
            xeval = np.full((strwidth+1,), x1)
        else:
            xeval = np.linspace(x1 - (dx * strwidth / 2),
                                x1 + (dx * strwidth / 2) + dx, strwidth+1)

        if dy == 0:
            yeval = np.full((strwidth+1,), y1)
        else:
            yeval = np.linspace(y1 - (dy * strwidth / 2),
                                y1 + (dy * strwidth / 2) + dy, strwidth+1)

        coords = np.vstack((yeval, xeval))
        zeval = map_coordinates(image, coords, order=1)

        tempinterpim[:, i] = zeval
        if i > 0:
            pos[i] = pos[i-1] + le

    worm = np.zeros((tempinterpim.shape[0], int(np.ceil(pos[-1]))))

    # for i in range(tempinterpim.shape[0]):
    #     interp_func = interp1d(
    #         pos, tempinterpim[i, :], kind='linear', bounds_error=False, fill_value=0)
    #     temp = interp_func(np.arange(1, worm.shape[1] + 1))
    #     worm[i, :] = temp

    interp_func = interp1d(
        pos, tempinterpim, kind='linear', axis=1, bounds_error=False, fill_value=0)
    temp = interp_func(np.arange(1, worm.shape[1] + 1))
    worm[:, :] = temp

    return worm.astype(int)


def test_straighten_worm_from_skel(image, skel, width=60):

    skeleton = skimage.morphology.medial_axis(skel)

    skeleton = skimage.morphology.skeletonize(mask)

    skeleton = skan.Skeleton(skeleton)
    main_branch = main_branch_paths(skeleton)

    backbone = paths_to_coordinates(skeleton, main_branch)

    spline, spline_derivative = spline_from_backbone(backbone)

    normal_vectors = spline_derivative[:, ::-1]
    normal_vectors[:, 0] *= -1

    normal_vectors_norm = norm(normal_vectors, axis=1, ord=2)
    print(normal_vectors_norm)

    normal_vectors = normal_vectors / \
        np.repeat(normal_vectors_norm[:, np.newaxis], 2, axis=1)

    test = test_straighten(image, spline, 141)

    plt.imshow(test)
    plt.show()

    print(test.shape)

    test = img_as_uint(test)

    imwrite(f'{output_dir}zhzhzhz.tiff', test, compression="zlib")


def straighten_video_from_skel(source_video, skel_video, width):
    straightened_video = []
    previous_backbone_tips = None
    for source_frame, skel_frame in zip(source_video, skel_video):
        skeleton = skan.Skeleton(skel_frame)
        main_branch = main_branch_paths(skeleton)
        backbone = paths_to_coordinates(skeleton, main_branch)

        backbone_tips = backbone[[0, -1]]
        if previous_backbone_tips is not None:
            flipped = are_tips_flipped(previous_backbone_tips, backbone_tips)
            if flipped:
                backbone = np.flip(backbone, axis=0)
                backbone_tips = backbone[[0, -1]]
            previous_backbone_tips = backbone_tips
        else:
            previous_backbone_tips = backbone_tips
        spline = spline_from_backbone(backbone, upsampling=2)
        straightened_frame = test_straighten(source_frame, spline, width)
        straightened_video.append(straightened_frame)

    max_x = max(straightened_video, key=lambda x: x.shape[0]).shape[0]
    max_y = max(straightened_video, key=lambda x: x.shape[1]).shape[1]

    video_shape = (len(straightened_video), max_x, max_y)
    print(video_shape)
    video = np.zeros(video_shape, 'uint16')

    video_name = f'{output_dir}test_unflipopo.tiff'
    print(video_name)
    for i, frame in enumerate(straightened_video):
        video[i, 0:frame.shape[0], 0:frame.shape[1]] = frame
    imwrite(video_name, video, compression="zlib")


straighten_video_from_skel(image, mask, 141)

(300, 141, 342)
/mnt/external.data/TowbinLab/spsalmon/straightening_test/videos_str_python/test_unflipopo.tiff
