In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pymeshfix
import skimage
from matplotlib.colors import LightSource
from skimage.io import imread, imsave
from skimage.measure import marching_cubes
from skimage.transform import rescale
from tqdm import tqdm
from trimesh import Trimesh
from trimesh.smoothing import filter_taubin

In [None]:
def set_axes_equal(ax: plt.Axes):
    ax.set_box_aspect([1, 1, 1])
    limits = np.array(
        [
            ax.get_xlim3d(),
            ax.get_ylim3d(),
            ax.get_zlim3d(),
        ]
    )
    x, y, z = np.mean(limits, axis=1)
    radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
    ax.set_xlim3d([x - radius, x + radius])
    ax.set_ylim3d([y - radius, y + radius])
    ax.set_zlim3d([z - radius, z + radius])

In [None]:
def create_surface_plot(lumen_mask_label, organoid_mask, file_name, pad=0):
    # Create surface mask
    fig = plt.figure(figsize=(10, 10))
    fig.subplots_adjust(top=1, bottom=0, left=0, right=1, wspace=0)
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect([1, 1, 1])
    # ax.set_box_aspect([1,1,1])
    max_size = (np.array(lumen_mask_label.shape) * 4 * 0.347).max() + pad
    ax.set_xlim3d([0, max_size])
    ax.set_ylim3d([0, max_size])
    ax.set_zlim3d([0, max_size])
    shapes = lumen_mask_label.shape
    ax.view_init(15, 15)
    ax.tick_params(axis="both", pad=5)
    ls = LightSource(azdeg=0, altdeg=-65)
    print(np.unique(lumen_mask_label))
    for label in tqdm(range(1, len(np.unique(lumen_mask_label)))):
        vertices, faces, _, _ = marching_cubes(
            (lumen_mask_label == label).astype(int), 0, step_size=1
        )
        vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
        vertices_clean = np.round(vertices_clean * 4 * 0.347).astype(int)

        cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
        cell_mesh = filter_taubin(cell_mesh, iterations=50)

        # Create trisurf plot
        ax.plot_trisurf(
            cell_mesh.vertices[:, 0],
            cell_mesh.vertices[:, 1],
            triangles=cell_mesh.faces,
            Z=cell_mesh.vertices[:, 2],
            lightsource=ls,
            alpha=1.0,
        )
    # plt.axis('off')
    vertices, faces, _, _ = marching_cubes(
        (organoid_mask == 1).astype(int), 0, step_size=2
    )
    vertices_clean, faces_clean = pymeshfix.clean_from_arrays(vertices, faces)
    vertices_clean = np.round(vertices_clean * 4 * 0.347).astype(int)

    cell_mesh = Trimesh(vertices=vertices_clean, faces=faces_clean)
    cell_mesh = filter_taubin(cell_mesh, iterations=50)

    # Create trisurf plot
    ax.plot_trisurf(
        cell_mesh.vertices[:, 0],
        cell_mesh.vertices[:, 1],
        triangles=cell_mesh.faces,
        Z=cell_mesh.vertices[:, 2],
        lightsource=ls,
        alpha=0.3,
        cmap="gray",
    )

    ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0)
    xs = [0, 200, 400, 600]
    ax.set_xticks(xs)

    ax.set_xticklabels(ax.get_xticks(), va="bottom")

    # ax.set_zlim(0, max(mask.shape))
    # bbox = fig.bbox_inches.from_bounds(1, 1, 8, 8)
    plt.savefig(file_name, pad_inches=0, dpi=300)
    plt.close()
    # plt.show()

In [None]:
position = 7
time_point = 152

In [None]:
# Multimo Define Start, Max num lumen and max lumen vol hours
# times=np.array([60])*2
times = np.array([21, 40, 60]) * 2
# Get masks + create 3D plots for time points + positions
positions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
# positions=[15,16]

# Save in extended_figure folder
for position in tqdm(positions):
    for time_point in times:
        # convert hour to time_point
        time_point += 1
        mask_dir = f"/20210503_201032_6_lines_mosaic_HB4_D4_processed/Position_{position}_Settings_1_Processed/lumen_masks_v19_05_2023/"
        # Load lumen mask
        combined_masks = imread(
            mask_dir + f"{time_point:04}" + "_lumen_organoid_mask_processed.tif"
        ).copy()
        combined_masks = np.rot90(combined_masks, 3, axes=(1, 2))

        organoid_mask = combined_masks >= 2
        lumen_mask = skimage.morphology.remove_small_objects(
            (combined_masks == 3), min_size=20000 / (0.347 * 0.347 * 2 * 4 * 4)
        ).astype(bool)
        lumen_mask_label = skimage.measure.label(lumen_mask)
        lumen_mask_label = rescale(
            lumen_mask_label,
            [2 / (4 * 0.347), 1, 1],
            order=0,
            anti_aliasing=False,
            preserve_range=True,
        ).astype(np.uint16)
        organoid_mask = rescale(
            organoid_mask,
            [2 / (4 * 0.347), 1, 1],
            order=0,
            anti_aliasing=False,
            preserve_range=True,
        ).astype(np.uint16)
        file_name = f"extended_figures/Multimo_3D_min_lumen_most_lumen_peak_lumen/3D_plot_lumen_hour_{((time_point-1)/2)}_position_{position}.png"
        create_surface_plot(lumen_mask_label, organoid_mask, file_name)

In [None]:
# Multimo max lumen vol hours
times = np.array([21, 40, 60])
# Get masks + create 3D plots for time points + positions
positions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
image_folder = f"extended_figures/Multimo_3D_min_lumen_most_lumen_peak_lumen/"

# Save in extended_figure folder
all_image_3d_lum = []
for time_point in times:
    time_point += 1
    all_image_3d_lum = []
    for position in tqdm(positions):
        image_3d_lum = imread(
            image_folder
            + f"3D_plot_lumen_hour_{float(time_point)-1}_position_{position}.png"
        ).copy()
        image_3d_lum = image_3d_lum[250:-100, 150:-200]
        all_image_3d_lum.append(image_3d_lum)

    all_image_3d_lum = np.array(all_image_3d_lum)
    stack_vert = []
    for i in range(0, len(all_image_3d_lum), 4):
        stack_vert.append(np.hstack(all_image_3d_lum[i : (i + 4)]))
    stack_vert = np.array(stack_vert)
    stack_vert = np.vstack(stack_vert)
    # stack_vert=np.hstack(stack_vert)
    imsave(image_folder + f"3D_grid_lumen_{time_point-1}.png", stack_vert)