In [None]:
# import stuff
import random

random.seed(0)

import pickle

import networkx as nx
import numpy as np

np.random.seed(0)

import os
import pickle

import matplotlib
import matplotlib.pyplot as plt
import motile
import networkx as nx
import numpy as np
import pandas as pd
import pymeshfix
import seaborn as sns
import skimage
import zarr
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from matplotlib.colors import LightSource
from motile.plot import draw_solution, draw_track_graph
from skimage.io import imread, imsave
from skimage.measure import marching_cubes
from skimage.transform import rescale
from tqdm import tqdm
from trimesh import Trimesh
from trimesh.smoothing import filter_taubin

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
import warnings

warnings.filterwarnings("ignore")
import imageio

In [None]:
max_time_points = 10
cmap_gray_yellow = matplotlib.colors.LinearSegmentedColormap.from_list(
    "", ["#fff200", "#5e5e5e"], N=max_time_points + 1
)

norm = matplotlib.colors.Normalize(vmin=0, vmax=max_time_points + 1)

import matplotlib.colors as colors
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 1))
plt.axis("off")
# Create the colorbar
cbar = plt.colorbar(
    plt.cm.ScalarMappable(norm=norm, cmap=cmap_gray_yellow),
    cax=ax,
    orientation="horizontal",
    label=r"Hours from the last fusion event",
)

# Set the number of ticks to match n_bins

plt.tight_layout()
plt.savefig("updated_figures_figure_2/colorbar_fusion.png", bbox_inches="tight")

plt.show()

In [None]:
# positions=[4]
channel = "GFP"
positions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
matrigel = [1, 2, 3, 4]
no_matrix = [5, 6, 7, 8, 13, 14, 15, 16]
agar = [9, 10, 11, 12]
all_dfs = []
smootheness = 5
label = f"lumen_masks_smooth_{smootheness}_processed"
for position in positions:
    try:
        track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
        track_graph = pickle.load(
            open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
        )
        nx_graph = pickle.load(
            open(track_path + f"{channel}_{label}_track.pickle", "rb")
        )
        adjacency_network = nx.to_scipy_sparse_array(nx_graph)
        df = []
        for t in range(125):
            nodes = track_graph.nodes_by_frame(t)
            n_nodes = len(nodes)
            if n_nodes == 0:
                n_nodes = 1
            node_loc = np.where(
                np.array([(node in nodes) for node in list(nx_graph.nodes)])
            )[0]
            n_divisions = np.count_nonzero(
                adjacency_network[node_loc].sum(1) == 2
            )  # /n_nodes
            n_fusions = np.count_nonzero(
                adjacency_network[:, node_loc].sum(0) == 2
            )  # /n_nodes
            df.append([t, n_divisions, n_fusions, len(nodes), position])
        df = pd.DataFrame(df)
        df.columns = ["Hour", "N divisions", "N fusions", "N lumen", "position"]
        df["N divisions RM"] = df["N divisions"].rolling(10, min_periods=1).mean()
        df["Smoothened number of fusions"] = (
            df["N fusions"].rolling(10, min_periods=1).mean()
        )
        df["Fusions per lumen"] = df["N fusions"] / df["N lumen"]
        df["Smoothened fusions per lumen"] = (
            df["Fusions per lumen"].rolling(25, min_periods=1).mean()
        )

        df["Number of lumen"] = df["N lumen"]
        if position in matrigel:
            df["Condition"] = "Matrigel"
        if position in no_matrix:
            df["Condition"] = "No Matrix"
        if position in agar:
            df["Condition"] = "Agarose"

        all_dfs.append(df)
    except:
        print(f"Position {position} still running")
all_dfs = pd.concat(all_dfs)
all_dfs.index = np.arange(len(all_dfs))
all_dfs["Day"] = (all_dfs["Hour"] / 24) + 4
pallette = {"Matrigel": "#17ad97", "No Matrix": "#4d4d4d", "Agarose": "#98d9d1"}

In [None]:
def extract_filtered(z, label, t_kernel_length, zarr_level):
    zarr_shape = input_movie[0]["labels"][label][zarr_level][1, :, :].shape
    updated_movie = np.zeros(
        (len(time_points_label),) + (1,) + zarr_shape, dtype=np.uint8
    )
    for t in time_points_label:
        updated_movie[t, :, :, :] = input_movie[t]["labels"][label][zarr_level][z, :, :]
    updated_movie = updated_movie == 3
    filtered = scipy.ndimage.median_filter(
        updated_movie, size=(t_kernel_length, 1, 1, 1)
    )
    return filtered


def concatenate_overlapping_arrays(arrays):
    def overlap(arr1, arr2):
        return np.any(np.in1d(arr1, arr2))

    groups = []

    for arr in arrays:
        merged = False
        for group in groups:
            if any(overlap(arr, g) for g in group):
                group.append(arr)
                merged = True
                break
        if not merged:
            groups.append([arr])

    # Concatenate arrays within each group
    result = [np.concatenate(group) for group in groups]

    # Remove duplicates within each result array
    result = [np.unique(r) for r in result]

    return result

In [None]:
def lumen_track_movie(t, level, axis_off=True, rotate=False):
    # Read in movies
    lumen_mask_label = input_movie[t]["labels"][detection_name][level][:]
    organoid_mask = input_movie[t]["labels"]["tissue_mask"][level][:] >= 2

    # Correct orientation to match slices
    if rotate:
        lumen_mask_label = np.rot90(lumen_mask_label, 3, axes=(1, 2))
        organoid_mask = np.rot90(organoid_mask, 3, axes=(1, 2))

    # Calculate annisotropy facto + rescale
    anniso_factor = 2 / (2 * (2**level) * 0.347)
    lumen_mask_label = rescale(
        lumen_mask_label,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    organoid_mask = rescale(
        organoid_mask,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    # Create surface mask
    fig = plt.figure(figsize=(6, 6))
    fig.subplots_adjust(top=1, bottom=0, left=0, right=1, wspace=0)
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect([1, 1, 1])
    max_size = (np.array(lumen_mask_label.shape) * 2 ** (level) * 2 * 0.347).max()
    ax.set_xlim3d([0, max_size])
    ax.set_ylim3d([0, max_size])
    ax.set_zlim3d([0, max_size])
    shapes = lumen_mask_label.shape
    ax.view_init(15, 15)
    ax.tick_params(axis="both", pad=5)
    ls = LightSource(azdeg=0, altdeg=-65)

    # Plot colored lumen
    for label in np.unique(lumen_mask_label)[1:]:
        which_track = np.array([(label - 1 in track) for track in tracks])
        if which_track.any():
            vertices, faces, _, _ = marching_cubes(
                (lumen_mask_label == label).astype(int), 0, step_size=1
            )
            vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
            vertices_clean = np.round(
                vertices_clean * 2 ** (level * 1) * 2 * 0.347
            ).astype(int)

            cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
            cell_mesh = filter_taubin(cell_mesh, iterations=50)
            track = np.where(which_track)[0][0]
            ax.plot_trisurf(
                cell_mesh.vertices[:, 0],
                cell_mesh.vertices[:, 1],
                triangles=cell_mesh.faces,
                Z=cell_mesh.vertices[:, 2],
                lightsource=ls,
                alpha=1.0,
                color=cmap_tracks[track],
            )
    # Plot sem transparent
    vertices, faces, _, _ = marching_cubes(
        (organoid_mask == 1).astype(int), 0, step_size=2
    )
    vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
    vertices_clean = np.round(vertices_clean * 2 ** (level * 1) * 2 * 0.347).astype(int)

    cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
    cell_mesh = filter_taubin(cell_mesh, iterations=50)

    # Create trisurf plot
    ax.plot_trisurf(
        cell_mesh.vertices[:, 0],
        cell_mesh.vertices[:, 1],
        triangles=cell_mesh.faces,
        Z=cell_mesh.vertices[:, 2],
        lightsource=ls,
        alpha=0.2,
        cmap="gray",
    )
    if axis_off == True:
        plt.axis("off")
    else:
        ax.grid(False)

    canvas = plt.gca().figure.canvas
    canvas.draw()
    data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    image = data.reshape(canvas.get_width_height()[::-1] + (3,))
    plt.close()
    return image


def lumen_track_one(t, level, track_num, axis_off=True):
    # Read in movies
    lumen_mask_label = input_movie[t]["labels"][detection_name][level][:]
    organoid_mask = input_movie[t]["labels"]["tissue_mask"][level][:] >= 2

    # Correct orientation to match slices
    lumen_mask_label = np.rot90(lumen_mask_label, 3, axes=(1, 2))
    organoid_mask = np.rot90(organoid_mask, 3, axes=(1, 2))

    anniso_factor = 2 / (2 * (2**level) * 0.347)
    lumen_mask_label = rescale(
        lumen_mask_label,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    organoid_mask = rescale(
        organoid_mask,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    # Create surface mask
    fig = plt.figure(figsize=(6, 6))
    fig.subplots_adjust(top=1, bottom=0, left=0, right=1, wspace=0)
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect([1, 1, 1])
    # ax.set_box_aspect([1,1,1])
    plt.grid(b=None)
    max_size = (np.array(lumen_mask_label.shape) * 2 ** (level) * 2 * 0.347).max()
    ax.set_xlim3d([0, max_size])
    ax.set_ylim3d([0, max_size])
    ax.set_zlim3d([0, max_size])
    shapes = lumen_mask_label.shape
    ax.view_init(15, 15)
    ax.tick_params(axis="both", pad=5)
    ls = LightSource(azdeg=0, altdeg=-65)

    for label in np.unique(lumen_mask_label)[1:]:
        # print(label)
        which_track = np.array((label - 1 in tracks[track_num]))
        if which_track:
            which_track = np.array([(label - 1 in track) for track in tracks])
            vertices, faces, _, _ = marching_cubes(
                (lumen_mask_label == label).astype(int), 0, step_size=1
            )
            vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
            vertices_clean = np.round(
                vertices_clean * 2 ** (level * 1) * 2 * 0.347
            ).astype(int)

            cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
            # cell_mesh = filter_taubin(cell_mesh, iterations=10)

            # Create trisurf plot
            # print(np.where(which_track))
            track = np.where(which_track)[0][0]
            ax.plot_trisurf(
                cell_mesh.vertices[:, 0],
                cell_mesh.vertices[:, 1],
                triangles=cell_mesh.faces,
                Z=cell_mesh.vertices[:, 2],
                lightsource=ls,
                alpha=1.0,
                color=cmap_tracks[track],
            )
    # Plot sem transparent
    vertices, faces, _, _ = marching_cubes(
        (organoid_mask == 1).astype(int), 0, step_size=2
    )
    vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
    vertices_clean = np.round(vertices_clean * 2 ** (level * 1) * 2 * 0.347).astype(int)

    cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
    cell_mesh = filter_taubin(cell_mesh, iterations=50)

    # Create trisurf plot
    ax.plot_trisurf(
        cell_mesh.vertices[:, 0],
        cell_mesh.vertices[:, 1],
        triangles=cell_mesh.faces,
        Z=cell_mesh.vertices[:, 2],
        lightsource=ls,
        alpha=0.2,
        cmap="gray",
    )
    if axis_off == True:
        plt.axis("off")
    else:
        ax.grid(False)

    canvas = plt.gca().figure.canvas
    canvas.draw()
    data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    image = data.reshape(canvas.get_width_height()[::-1] + (3,))
    plt.close()
    return image


def thunder_plot(t, level, plot_organoid=True, rotate=True):
    lumen_mask_label = input_movie[t]["labels"][detection_name][level][:]
    organoid_mask = input_movie[t]["labels"]["tissue_mask"][level][:] == 2

    if rotate:
        lumen_mask_label = np.rot90(lumen_mask_label, 3, axes=(1, 2))
        organoid_mask = np.rot90(organoid_mask, 3, axes=(1, 2))

    anniso_factor = 2 / (2 * (2**level) * 0.347)
    lumen_mask_label = rescale(
        lumen_mask_label,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    organoid_mask = rescale(
        organoid_mask,
        [anniso_factor, 1, 1],
        order=0,
        anti_aliasing=False,
        preserve_range=True,
    ).astype(np.uint16)

    # Create surface mask
    fig = plt.figure(figsize=(6, 6))
    fig.subplots_adjust(top=1, bottom=0, left=0, right=1, wspace=0)
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect([1, 1, 1])
    # ax.set_box_aspect([1,1,1])
    max_size = (np.array(lumen_mask_label.shape) * 2 ** (level) * 2 * 0.347).max()
    ax.set_xlim3d([0, max_size])
    ax.set_ylim3d([0, max_size])
    ax.set_zlim3d([0, max_size])
    shapes = lumen_mask_label.shape
    # ax.set_box_aspect((np.ptp([0,shapes[1]*4*0.347]), np.ptp([0,shapes[2]*4*0.347]), np.ptp([0,shapes[2]*4*0.347])))
    ax.view_init(15, 15)
    ax.tick_params(axis="both", pad=5)
    ls = LightSource(azdeg=0, altdeg=-65)
    labels_plot = np.unique(lumen_mask_label)[1:]

    for label in labels_plot:
        vertices, faces, _, _ = marching_cubes(
            (lumen_mask_label == label).astype(int), 0, step_size=1
        )
        vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
        vertices_clean = np.round(vertices_clean * 2 ** (level * 1) * 2 * 0.347).astype(
            int
        )

        cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
        cell_mesh = filter_taubin(cell_mesh, iterations=50)
        if label in fusion_values:
            ax.plot_trisurf(
                cell_mesh.vertices[:, 0],
                cell_mesh.vertices[:, 1],
                triangles=cell_mesh.faces,
                Z=cell_mesh.vertices[:, 2],
                lightsource=ls,
                alpha=1.0,
                color=matplotlib.colors.to_hex(
                    cmap_gray_yellow(int(fusion_values[label]))
                ),
            )
        else:
            ax.plot_trisurf(
                cell_mesh.vertices[:, 0],
                cell_mesh.vertices[:, 1],
                triangles=cell_mesh.faces,
                Z=cell_mesh.vertices[:, 2],
                lightsource=ls,
                alpha=1.0,
                color=matplotlib.colors.to_hex(cmap_gray_yellow(0)),
            )
    if plot_organoid == True:
        # Plot sem transparent
        vertices, faces, _, _ = marching_cubes(
            (organoid_mask == 1).astype(int), 0, step_size=2
        )
        vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
        vertices_clean = np.round(vertices_clean * 2 ** (level * 1) * 2 * 0.347).astype(
            int
        )

        cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
        cell_mesh = filter_taubin(cell_mesh, iterations=50)

        # Create trisurf plot
        ax.plot_trisurf(
            cell_mesh.vertices[:, 0],
            cell_mesh.vertices[:, 1],
            triangles=cell_mesh.faces,
            Z=cell_mesh.vertices[:, 2],
            lightsource=ls,
            alpha=0.2,
            cmap="gray",
        )

    plt.axis("off")
    canvas = plt.gca().figure.canvas
    canvas.draw()
    data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    image = data.reshape(canvas.get_width_height()[::-1] + (3,))
    plt.close()
    return image

In [None]:
for position in [3, 12, 13]:
    label = f"lumen_masks_smooth_5_processed"
    zarr_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/Position_{str(position)}_Settings_1_Processed_registered.zarr"
    zarr_array = zarr.open(zarr_path, mode="r")
    input_movie = zarr_array[channel]
    input_shape = input_movie["0"]["labels"][label]["0"].shape
    detection_name = "lumen_masks_smooth_5_processed_detection"
    max_time_points = 10
    track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
    track_graph = pickle.load(
        open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
    )
    nx_graph = pickle.load(open(track_path + f"{channel}_{label}_track.pickle", "rb"))
    adjacency_network = nx.to_scipy_sparse_array(nx_graph)
    fused_nodes = np.array(nx_graph.nodes)[adjacency_network.sum(0) == 2]
    cmap_gray_yellow = matplotlib.colors.LinearSegmentedColormap.from_list(
        "", ["#5e5e5e", "#fff200"], N=max_time_points + 1
    )

    # Initialize a dictionary to store the fusion values for each node
    fusion_values = {node + 1: 0 for node in nx_graph.nodes()}

    # Assign value 1 to nodes that have just fused
    for node in fused_nodes:
        fusion_values[node + 1] = 10
    nx_graph = nx_graph.reverse()
    # Iterate through all nodes
    for node in tqdm(nx_graph.nodes()):
        if node not in fused_nodes:
            # Find the shortest path to any fused node
            shortest_distances = []
            for fused_node in fused_nodes:
                try:
                    distance = nx.shortest_path_length(nx_graph, node, fused_node)
                    shortest_distances.append(distance)
                except nx.NetworkXNoPath:
                    pass

            if shortest_distances:
                # Calculate the fusion value based on the shortest distance
                min_distance = min(shortest_distances)
                if min_distance < max_time_points:
                    fusion_values[node + 1] = np.round(
                        max_time_points * (1 - (min_distance / max_time_points))
                    )

    all_one_lumen = Parallel(n_jobs=3, backend="multiprocessing", verbose=5)(
        delayed(thunder_plot)(time_point, 1, plot_organoid=False, rotate=False)
        for time_point in [48, 72, 96, 120]
    )

    for image, time in zip(all_one_lumen, [48, 72, 96, 120]):
        imsave(
            f"updated_figures_figure_2/thunder_plot_v3/colored_post_fusion/thunder_plot_{position}_day_{4+(time/24)}.png",
            image,
        )

In [None]:
# Create movies
# 1. Same track movies
# for position in [6]:
positions = [1, 2, 3, 4, 5, 6, 7, 8, 13, 14, 15, 16, 9, 10, 11, 12]
all_lumen_track = []
for position, time_point in tqdm(
    zip(positions, (60, 60, 60, 60, 96, 96, 96, 96, 96, 96, 96, 96, 97, 97, 97, 97))
):
    label = f"lumen_masks_smooth_5_processed"
    zarr_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/Position_{str(position)}_Settings_1_Processed_registered.zarr"
    zarr_array = zarr.open(zarr_path, mode="r")
    input_movie = zarr_array[channel]
    input_shape = input_movie["0"]["labels"][label]["0"].shape
    detection_name = "lumen_masks_smooth_5_processed_detection"

    track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
    track_graph = pickle.load(
        open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
    )
    nx_graph = pickle.load(open(track_path + f"{channel}_{label}_track.pickle", "rb"))
    adjacency_network = nx.to_scipy_sparse_array(nx_graph)
    # Extract tracks
    roots = []
    for component in nx.weakly_connected_components(nx_graph):
        G_sub = nx_graph.subgraph(component)
        roots.extend([n for n, d in G_sub.in_degree() if d == 0])
    min_track_length = 5
    tracks = []
    for root in roots:
        out_root_len_old = -1
        out_root_len_new = 1
        out_root = [root]
        while out_root_len_old != out_root_len_new:
            out_root_len_old = len(out_root)
            out_root = np.unique(out_root)
            out_root = list(nx_graph.out_edges(out_root))
            out_root = np.unique(out_root)
            out_root_len_new = len(out_root)
        tracks.append(out_root)

    # concat overlapping arrays
    tracks = concatenate_overlapping_arrays(tracks)
    cmap_tracks = sns.color_palette("husl", len(tracks))
    tracks_len = [len(track) for track in tracks]
    all_lumen_track.append(lumen_track_movie(time_point, 1, axis_off=False))
imsave(
    f"updated_figures_figure_2/track_colored_16_no_grid.png",
    skimage.util.montage(np.array(all_lumen_track), channel_axis=-1),
)

In [None]:
# Create movies
# 1. Same track movies
# for position in [6]:
for position in range(1, 17):
    label = f"lumen_masks_smooth_5_processed"
    zarr_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/Position_{str(position)}_Settings_1_Processed_registered.zarr"
    zarr_array = zarr.open(zarr_path, mode="r")
    input_movie = zarr_array[channel]
    input_shape = input_movie["0"]["labels"][label]["0"].shape
    detection_name = "lumen_masks_smooth_5_processed_detection"

    track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
    track_graph = pickle.load(
        open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
    )
    nx_graph = pickle.load(open(track_path + f"{channel}_{label}_track.pickle", "rb"))
    adjacency_network = nx.to_scipy_sparse_array(nx_graph)
    # Extract tracks
    roots = []
    for component in nx.weakly_connected_components(nx_graph):
        G_sub = nx_graph.subgraph(component)
        roots.extend([n for n, d in G_sub.in_degree() if d == 0])
    min_track_length = 5
    tracks = []
    for root in roots:
        out_root_len_old = -1
        out_root_len_new = 1
        out_root = [root]
        while out_root_len_old != out_root_len_new:
            out_root_len_old = len(out_root)
            out_root = np.unique(out_root)
            out_root = list(nx_graph.out_edges(out_root))
            out_root = np.unique(out_root)
            out_root_len_new = len(out_root)
        tracks.append(out_root)

    # concat overlapping arrays
    tracks = concatenate_overlapping_arrays(tracks)
    cmap_tracks = sns.color_palette("husl", len(tracks))
    tracks_len = [len(track) for track in tracks]
    all_one_lumen = Parallel(n_jobs=16, backend="multiprocessing", verbose=5)(
        delayed(lumen_track_movie)(time_point, 1) for time_point in range(125)
    )
    all_one_lumen = np.array(all_one_lumen)

    writer = imageio.get_writer(
        f"movies/same_track_color_movies/mp4/colored_by_track_{position}.mp4", fps=20
    )

    for im in all_one_lumen:
        writer.append_data(im)
    writer.close()

    imsave(
        f"movies/same_track_color_movies/tiff/colored_by_track_{position}.tiff",
        np.moveaxis(all_one_lumen, -1, 1),
        imagej=True,
        metadata={"axes": "TCYX"},
        compression="zlib",
    )

In [None]:
for position in range(1, 17):
    label = f"lumen_masks_smooth_5_processed"
    zarr_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/Position_{str(position)}_Settings_1_Processed_registered.zarr"
    zarr_array = zarr.open(zarr_path, mode="r")
    input_movie = zarr_array[channel]
    input_shape = input_movie["0"]["labels"][label]["0"].shape
    detection_name = "lumen_masks_smooth_5_processed_detection"
    max_time_points = 10
    track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
    track_graph = pickle.load(
        open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
    )
    nx_graph = pickle.load(open(track_path + f"{channel}_{label}_track.pickle", "rb"))
    adjacency_network = nx.to_scipy_sparse_array(nx_graph)
    fused_nodes = np.array(nx_graph.nodes)[adjacency_network.sum(0) == 2]
    cmap_gray_yellow = matplotlib.colors.LinearSegmentedColormap.from_list(
        "", ["#5e5e5e", "#fff200"], N=max_time_points + 1
    )

    # Initialize a dictionary to store the fusion values for each node
    fusion_values = {node + 1: 0 for node in nx_graph.nodes()}

    # Assign value 1 to nodes that have just fused
    for node in fused_nodes:
        fusion_values[node + 1] = 10
    nx_graph = nx_graph.reverse()
    # Iterate through all nodes
    for node in tqdm(nx_graph.nodes()):
        if node not in fused_nodes:
            # Find the shortest path to any fused node
            shortest_distances = []
            for fused_node in fused_nodes:
                try:
                    distance = nx.shortest_path_length(nx_graph, node, fused_node)
                    shortest_distances.append(distance)
                except nx.NetworkXNoPath:
                    pass

            if shortest_distances:
                # Calculate the fusion value based on the shortest distance
                min_distance = min(shortest_distances)
                if min_distance < max_time_points:
                    fusion_values[node + 1] = np.round(
                        max_time_points * (1 - (min_distance / max_time_points))
                    )

    all_one_lumen = Parallel(n_jobs=16, backend="multiprocessing", verbose=5)(
        delayed(thunder_plot)(time_point, 1) for time_point in range(125)
    )
    all_one_lumen = np.array(all_one_lumen)

    writer = imageio.get_writer(
        f"movies/same_track_color_movies/mp4/thunder_plot_{position}.mp4", fps=20
    )

    for im in all_one_lumen:
        writer.append_data(im)
    writer.close()

    imsave(
        f"movies/same_track_color_movies/tiff/thunder_plot_{position}.tiff",
        np.moveaxis(all_one_lumen, -1, 1),
        imagej=True,
        metadata={"axes": "TCYX"},
        compression="zlib",
    )

In [None]:
for position in range(1, 17):
    label = f"lumen_masks_smooth_5_processed"
    zarr_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/Position_{str(position)}_Settings_1_Processed_registered.zarr"
    zarr_array = zarr.open(zarr_path, mode="r")
    input_movie = zarr_array[channel]
    input_shape = input_movie["0"]["labels"][label]["0"].shape
    detection_name = "lumen_masks_smooth_5_processed_detection"
    max_time_points = 10
    track_path = f"/Morphodynamics_of_human_early_brain_organoid_development/tracking/lumen_tracking/tracks/Position_{position}_Settings_1_Processed/"
    track_graph = pickle.load(
        open(track_path + f"{channel}_{label}_track_motile_graph.pickle", "rb")
    )
    nx_graph = pickle.load(open(track_path + f"{channel}_{label}_track.pickle", "rb"))
    adjacency_network = nx.to_scipy_sparse_array(nx_graph)
    fused_nodes = np.array(nx_graph.nodes)[adjacency_network.sum(0) == 2]
    cmap_gray_yellow = matplotlib.colors.LinearSegmentedColormap.from_list(
        "", ["#5e5e5e", "#fff200"], N=max_time_points + 1
    )

    # Initialize a dictionary to store the fusion values for each node
    fusion_values = {node + 1: 0 for node in nx_graph.nodes()}

    # Assign value 1 to nodes that have just fused
    for node in fused_nodes:
        fusion_values[node + 1] = 10
    # Iterate through all nodes
    for node in tqdm(nx_graph.nodes()):
        if node not in fused_nodes:
            # Find the shortest path to any fused node
            shortest_distances = []
            for fused_node in fused_nodes:
                try:
                    distance = nx.shortest_path_length(nx_graph, node, fused_node)
                    shortest_distances.append(distance)
                except nx.NetworkXNoPath:
                    pass

            if shortest_distances:
                # Calculate the fusion value based on the shortest distance
                min_distance = min(shortest_distances)
                if min_distance < max_time_points:
                    fusion_values[node + 1] = np.round(
                        max_time_points * (1 - (min_distance / max_time_points))
                    )

    all_one_lumen = Parallel(n_jobs=16, backend="multiprocessing", verbose=5)(
        delayed(thunder_plot)(time_point, 1) for time_point in range(125)
    )
    all_one_lumen = np.array(all_one_lumen)

    writer = imageio.get_writer(
        f"movies/same_track_color_movies/mp4/reversed_thunder_plot_{position}.mp4",
        fps=20,
    )

    for im in all_one_lumen:
        writer.append_data(im)
    writer.close()

    imsave(
        f"movies/same_track_color_movies/tiff/reversed_thunder_plot_{position}.tiff",
        np.moveaxis(all_one_lumen, -1, 1),
        imagej=True,
        metadata={"axes": "TCYX"},
        compression="zlib",
    )