In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from skimage.io import imread, imsave
from skimage.registration import phase_cross_correlation
from scipy.fftpack import fft2, fftshift, ifft2, ifftshift

In [None]:
import h5py_cache
import paulssonlab.deaton.trenchripper.trenchripper as tr
from paulssonlab.deaton.trenchripper.trenchripper import pandas_hdf5_handler
from paulssonlab.deaton.trenchripper.trenchripper import writedir

In [None]:
headpath = "/home/de64/scratch/de64/sync_folder/2021-01-28_lDE14/gfp"

In [None]:
class fft_drift_correction:
    def __init__(self, vert_cross_w=-1, horiz_cross_w=-1):

        self.vert_cross_w = vert_cross_w
        self.horiz_cross_w = horiz_cross_w

    def make_mask_cross(self, imgs_h, imgs_w, wy=-1, wx=-1, preview=False):
        """
        Args:
            imgs_h (int) Height of images
            imgs_w (int) Width of images
            wy (int) Height of cross
            wx (int) Width of cross
            preview (bool) If want to plot image
        Returns
            (imgs_h x imgs_w array)
        """

        x0 = imgs_w // 2
        y0 = imgs_h // 2

        if wy == -1:
            wy = imgs_h // 6

        if wx == -1:
            wx = imgs_w // 6

        img_mask = np.zeros((imgs_h, imgs_w), dtype=bool)
        img_mask[y0 - wy // 2 : y0 + wy // 2, :] = True
        img_mask[:, x0 - wx // 2 : x0 + wx // 2] = True

        if preview:
            plt.imshow(img_mask)

        return img_mask

    def filter_fourier(self, imgs, img_mask):
        """
        Args:
            imgs: (T, Y, X)
            imgs_mask
        Returns
            Filtered image stack
        """
        imgs_filt = np.zeros(imgs.shape, dtype="uint16")

        for i in range(imgs.shape[0]):
            img_fft = fftshift(fft2(imgs[i]))
            img_fft[img_mask] = 0.0
            img_fft = ifftshift(img_fft)
            imgs_filt[i] = np.abs(ifft2(img_fft))
        #         if i%20 == 0:
        #             print(i)
        return imgs_filt

    def drift_correction_find_shifts(self, imgs, img_ref):
        """
        Finds shifts for drift corarection

        Returns:
            shifts (array): shifts for each image in the stack ordered as row_shift (y-shift), column_shift (x-shift)
        """

        shifts = np.zeros(
            (imgs.shape[0], 2), dtype=int
        )  # Stores the shifts, the first one is the shift of the second image (index 1) with respect to the first one (index 0),
        # the second one is the shift of the third (index 2) with respect to the second (index 1)

        for i in range(imgs.shape[0]):

            # Compute shifts for a single image with respect to reference
            shifts[i, :], _, _ = phase_cross_correlation(img_ref, imgs[i])

        return shifts

    def drift_correct_images(self, imgs, shifts):
        """
        Applies shifts to return and save drift-corrected images. Returns and saves the corrected images
        Parameters:
            imgs: Image stack to correct (T, Y, X)
            shifts (array): shifts for each image in the stack ordered as row_shift (y-shift), column_shift (x-shift)
        Returns:
            imgs_reg (array): Drift-corrected image stack
        """
        # Calculate max abs value of shifts to define padding
        max_shifts = np.max(np.abs(shifts), axis=0).astype(int)

        imgs_median = np.median(imgs)  # TODO Does it have to be for each frame?
        imgs_padded = np.pad(
            imgs,
            ((0, 0), (max_shifts[0], max_shifts[0]), (max_shifts[1], max_shifts[1])),
            constant_values=(
                (imgs_median, imgs_median),
                (imgs_median, imgs_median),
                (imgs_median, imgs_median),
            ),
        )
        imgs_reg = np.zeros(imgs_padded.shape, dtype="uint16")
        for i in range(imgs.shape[0]):
            img_reg = np.roll(imgs_padded[i], shift=shifts[i, :], axis=(0, 1))
            imgs_reg[i] = img_reg

        # Crops the edges
        imgs_reg = imgs_reg[
            :,
            max_shifts[0] : imgs_reg.shape[1] - max_shifts[0],
            max_shifts[1] : imgs_reg.shape[2] - max_shifts[1],
        ]

        return imgs_reg


class fft_drift_correction_cluster(fft_drift_correction):
    def __init__(
        self, headpath, outputfolder, driftchannel, vert_cross_w=-1, horiz_cross_w=-1
    ):
        super(fft_drift_correction_cluster, self).__init__(
            vert_cross_w=vert_cross_w, horiz_cross_w=horiz_cross_w
        )

        self.headpath = headpath
        self.metapath = headpath + "/metadata.hdf5"
        self.hdf5path = headpath + "/hdf5"
        self.outputpath = headpath + "/" + outputfolder
        self.driftchannel = driftchannel

        self.meta_handle = pandas_hdf5_handler(self.metapath)
        self.metadata = self.meta_handle.read_df("global", read_metadata=True).metadata

        self.chunk_shape = (1, self.metadata["height"], self.metadata["width"])
        chunk_bytes = 2 * np.multiply.accumulate(np.array(self.chunk_shape))[-1]
        self.chunk_cache_mem_size = 2 * chunk_bytes

    def drift_correct_file(self, file_idx, ref_file_idx, ref_img_idx):
        with h5py.File(
            self.hdf5path + "/hdf5_" + str(ref_file_idx) + ".hdf5", "r"
        ) as input_file:
            ref_img = input_file[self.driftchannel][
                ref_img_idx : ref_img_idx + 1
            ]  # 1,y,x

        with h5py.File(
            self.hdf5path + "/hdf5_" + str(file_idx) + ".hdf5", "r"
        ) as input_file:
            imgs = input_file[self.driftchannel][:]  # t,y,x

        img_mask = self.make_mask_cross(
            imgs_h=imgs.shape[1],
            imgs_w=imgs.shape[2],
            wy=self.vert_cross_w,
            wx=self.horiz_cross_w,
        )

        ref_imgs_filt = self.filter_fourier(ref_img, img_mask)
        imgs_filt = self.filter_fourier(imgs, img_mask)

        shifts = self.drift_correction_find_shifts(imgs_filt, ref_imgs_filt[0])

        with h5py_cache.File(
            self.outputpath + "/hdf5_" + str(file_idx) + ".hdf5",
            "w",
            chunk_cache_mem_size=self.chunk_cache_mem_size,
        ) as output_file:
            with h5py.File(
                self.hdf5path + "/hdf5_" + str(file_idx) + ".hdf5", "r"
            ) as input_file:
                for channel in input_file.keys():
                    imgs_reg = self.drift_correct_images(input_file[channel], shifts)
                    hdf5_dataset = output_file.create_dataset(
                        channel, data=imgs_reg, dtype="uint16"
                    )

        return file_idx

    def dask_segment(self, dask_controller):
        writedir(self.outputpath, overwrite=True)
        dask_controller.futures = {}

        global_df = self.meta_handle.read_df("global")
        file_list = global_df["File Index"].unique().tolist()

        file_to_fov_dict = (
            global_df.groupby("File Index")
            .apply(lambda x: x.index.get_level_values("fov").unique()[0])
            .to_dict()
        )
        template_file_index_dict = (
            global_df.groupby("fov").apply(lambda x: np.min(x["File Index"])).to_dict()
        )
        template_image_index_dict = (
            global_df.groupby("fov").apply(lambda x: np.min(x["Image Index"])).to_dict()
        )

        num_file_jobs = len(file_list)

        random_priorities = np.random.uniform(size=(num_file_jobs,))
        for k, file_idx in enumerate(file_list):
            priority = random_priorities[k]

            fov_idx = file_to_fov_dict[file_idx]
            ref_file_idx = template_file_index_dict[fov_idx]
            ref_img_idx = template_image_index_dict[fov_idx]

            future = dask_controller.daskclient.submit(
                self.drift_correct_file,
                file_idx,
                ref_file_idx,
                ref_img_idx,
                retries=0,
                priority=priority,
            )
            dask_controller.futures["Drift Correction: " + str(file_idx)] = future
        #         for k,file_idx in enumerate(file_list):
        #             priority = random_priorities[k]

        #             future = dask_controller.daskclient.submit(self.segmentation_completed,dask_controller.futures["Segmentation: " + str(file_idx)],retries=0,priority=priority)
        #             dask_controller.futures["Segmentation Completed: " + str(file_idx)] = future
        gathered_tasks = dask_controller.daskclient.gather(
            [
                dask_controller.futures["Drift Correction: " + str(file_idx)]
                for file_idx in file_list
            ],
            errors="skip",
        )

In [None]:
fft_drift_clust = fft_drift_correction_cluster(headpath, "luistest", "RFP-Penta")

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=100,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.daskclient

In [None]:
dask_controller.shutdown()