In [1]:
import cupy as cp
import h5py
import numpy as np
import skimage.io
from cupyx.scipy.ndimage import gaussian_filter, rotate
from matplotlib.colors import ListedColormap
from skimage import exposure
from skimage.transform import rescale
from tifffile import imsave
from tqdm import tqdm

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()

# implementaion of royerlab DEXP


def attenuation_filter(
    image, attenuation_min_density, attenuation, attenuation_filtering
):
    if attenuation_filtering > 0:
        image_for_attenuation = gaussian_filter(image, sigma=attenuation_filtering)
    else:
        image_for_attenuation = image

    cum_density = cp.cumsum(
        attenuation_min_density + (1 - attenuation_min_density) * image_for_attenuation,
        axis=0,
    )

    image *= cp.exp(-attenuation * cum_density)
    return image


def create_colored_image(im_proj, lower_percentile=0.5, upper_percentile=99.5):
    green_map = [[0, i / 255, 0] for i in range(256)]
    green_matplotlib_map = ListedColormap(green_map, "Green")
    magenta_map = [[i / 255, 0, i / 255] for i in range(256)]
    magenta_matplotlib_map = ListedColormap(magenta_map, "Magenta")
    im_proj = skimage.exposure.rescale_intensity(im_proj, out_range=(0, 1))
    vmin_green, vmax_green = np.percentile(
        im_proj[0, :, :], q=(lower_percentile, upper_percentile)
    )
    clipped_green = exposure.rescale_intensity(
        im_proj[0, :, :], in_range=(vmin_green, vmax_green), out_range=np.float32
    )

    vmin_magenta, vmax_magenta = np.percentile(
        im_proj[1, :, :], q=(lower_percentile, upper_percentile)
    )
    clipped_magenta = exposure.rescale_intensity(
        im_proj[1, :, :], in_range=(vmin_magenta, vmax_magenta), out_range=np.float32
    )

    channel1 = green_matplotlib_map(clipped_green)
    channel2 = magenta_matplotlib_map(clipped_magenta)
    assembled = np.stack((channel1, channel2), axis=3)
    newim = np.max(assembled, axis=3)
    return newim


def read_tiff_stacks(input_dir, time_point, scale_factor, pad_size=32):
    stack_file = h5py.File(
        f"/cluster/project/treutlein/DATA/imaging/viventis/20210503_201032_6_lines_mosaic_HB4_D4_processed/Position_{str(position)}_Settings_1_Processed/denoised_registered_processed.h5",
        "r",
    )

    if scale_factor != [1.0, 1.0, 1.0]:
        stack_mcherry_downscaled = rescale(
            stack_file[f"t{time_point:05}"]["s01"]["0"]["cells"][()].copy(),
            scale_factor,
            anti_aliasing=True,
        )
        stack_mcherry_downscaled = np.pad(
            stack_mcherry_downscaled, ((pad_size, pad_size)), "constant"
        )

        stack_gfp_downscaled = rescale(
            stack_file[f"t{time_point:05}"]["s00"]["0"]["cells"][()].copy(),
            scale_factor,
            anti_aliasing=True,
        )
        stack_gfp_downscaled = np.pad(
            stack_gfp_downscaled, ((pad_size, pad_size)), "constant"
        )
    elif scale_factor == [1.0, 1.0, 1.0]:
        stack_mcherry_downscaled = stack_file[f"t{time_point:05}"]["s01"]["0"]["cells"][
            ()
        ].copy()
        stack_gfp_downscaled = stack_file[f"t{time_point:05}"]["s00"]["0"]["cells"][
            ()
        ].copy()

    return stack_gfp_downscaled, stack_mcherry_downscaled


def create_movie(
    input_dir,
    time_point_start=1,
    time_point_stop=1,
    start_angle=1,
    stop_angle=360,
    n_frames=213,
    scale=0.25,
    voxel_sizes=[2, 0.347, 0.347],
    attenuation=True,
    MIP=True,
    stack_slice=None,
    run_through_slice=False,
    rotation_axes=(0, 2),
    attenuation_filtering=4,
    attenuation_min_density=0.002,
    attenuation_strength=0.01,
    pad_size=32,
):
    scale_factor = [
        scale * (voxel_sizes[0] / np.min(voxel_sizes)),
        scale * (voxel_sizes[1] / np.min(voxel_sizes)),
        scale * (voxel_sizes[2] / np.min(voxel_sizes)),
    ]
    print(scale_factor)

    if time_point_start == time_point_stop:
        stack_gfp_downscaled, stack_mcherry_downscaled = read_tiff_stacks(
            input_dir, time_point_start, scale_factor, pad_size=pad_size
        )
        if run_through_slice == True:
            time_range = np.arange(0, len(stack_gfp_downscaled)).astype(int)
            angle_range = np.linspace(start_angle, stop_angle, len(time_range))

    if run_through_slice == False:
        angle_range = np.linspace(start_angle, stop_angle, n_frames)
        time_range = np.linspace(time_point_start, time_point_stop, n_frames).astype(
            int
        )

    print(angle_range)

    print(time_range)

    assert len(angle_range) == len(time_range)

    ims = []
    tp_old = -1
    for angle, time_point in tqdm(zip(angle_range, time_range)):
        # print(time_point)
        if time_point_start != time_point_stop:
            if time_point != tp_old:
                stack_gfp_downscaled, stack_mcherry_downscaled = read_tiff_stacks(
                    input_dir, time_point, scale_factor, pad_size=pad_size
                )
            # print("loaded image")
        if start_angle != stop_angle:
            stack_mcherry_rotated = (
                np.nan_to_num(
                    rotate(
                        cp.asarray(stack_mcherry_downscaled),
                        np.abs(angle - 180),
                        mode="constant",
                        axes=rotation_axes,
                        reshape=False,
                    )
                )
                * 1000
            )
            stack_gfp_rotated = (
                np.nan_to_num(
                    rotate(
                        cp.asarray(stack_gfp_downscaled),
                        np.abs(angle - 180),
                        mode="constant",
                        axes=rotation_axes,
                        reshape=False,
                    )
                )
                * 1000
            )
            # print("rotated image")
        else:
            image = np.stack([stack_gfp_downscaled, stack_mcherry_downscaled], axis=0)
            # print("stacked images", image.shape)

        if attenuation == True:
            stack_mcherry_rotated = attenuation_filter(
                stack_mcherry_rotated,
                attenuation_min_density,
                attenuation_strength,
                attenuation_filtering,
            ).get()
            stack_gfp_rotated = attenuation_filter(
                stack_gfp_rotated,
                attenuation_min_density,
                attenuation_strength,
                attenuation_filtering,
            ).get()
            image = np.stack([stack_gfp_rotated, stack_mcherry_rotated], axis=0)
            # print("attenuated image")
        if MIP == True:
            im_proj = np.max(image, axis=1)

        ims.append(im_proj)
    return ims

In [2]:
position = 10
ims = create_movie(
    input_dir="",
    time_point_start=188,
    time_point_stop=188,
    start_angle=1,
    stop_angle=360,
    n_frames=360,
    scale=0.25,
    voxel_sizes=[2, 0.347, 0.347],
    attenuation=True,
    MIP=True,
    stack_slice=None,
    run_through_slice=False,
)

[1.4409221902017293, 0.25, 0.25]
[  1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.
  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  26.  27.  28.
  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.  41.  42.
  43.  44.  45.  46.  47.  48.  49.  50.  51.  52.  53.  54.  55.  56.
  57.  58.  59.  60.  61.  62.  63.  64.  65.  66.  67.  68.  69.  70.
  71.  72.  73.  74.  75.  76.  77.  78.  79.  80.  81.  82.  83.  84.
  85.  86.  87.  88.  89.  90.  91.  92.  93.  94.  95.  96.  97.  98.
  99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112.
 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126.
 127. 128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140.
 141. 142. 143. 144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154.
 155. 156. 157. 158. 159. 160. 161. 162. 163. 164. 165. 166. 167. 168.
 169. 170. 171. 172. 173. 174. 175. 176. 177. 178. 179. 180. 181. 182.
 183. 184. 185. 186. 187. 188. 189. 190. 191

360it [26:11,  4.36s/it]


In [3]:
np.array(ims).shape

(360, 2, 638, 638)

In [4]:
imsave(
    f"rotating_movie.tiff",
    np.array(ims).astype(np.float32),
    imagej=True,
    resolution=(1.0 / (4 * 0.347), 1.0 / (4 * 0.347)),
    metadata={"unit": "um", "axes": "TCYX"},
    compression="zlib",
)