In [1]:
import sys
from os import path
sys.path.append(path.abspath("."))
print(sys.path)

['/Users/fukai/projects/napari-travali/src/napari_travali', '/Users/fukai/kpzutils/lib', '/Users/fukai/exputils/lib', '/Users/fukai/projects/napari-travali/src/napari_travali', '/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python310.zip', '/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10', '/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/lib-dynload', '', '/Users/fukai/.local/lib/python3.10/site-packages', '/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages', '/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/PyQt5_sip-12.11.0-py3.10-macosx-10.9-x86_64.egg', '/Users/fukai/projects/image_analysis/napari-travali/src', '/Users/fukai/projects/napari-travali/src/napari_travali']


In [104]:
import io
import logging
import os
from os import path

import click
import dask.array as da
import napari
import numpy as np
import pandas as pd
from transitions import Machine
import zarr
from tqdm import tqdm

from _settings._transitions import transitions
from _utils._logging import log_error
from _utils._logging import logger
from _settings._consts import DF_DIVISIONS_COLUMNS
from _settings._consts import DF_TRACKS_COLUMNS
from _settings._consts import LOGGING_PATH
from _utils._logging import logger


In [127]:
from enum import Enum
from typing import Dict
from typing import List

import dask.array as da
import networkx as nx
import numpy as np
import pandas as pd
from napari.layers._multiscale_data import MultiScaleData
import zarr
from _settings._consts import DF_DIVISIONS_COLUMNS
from _settings._consts import DF_TRACKS_COLUMNS
from _settings._consts import NEW_LABEL_VALUE
from _settings._consts import NOSEL_VALUE
from _utils._gui_utils import ask_draw_label
from _utils._gui_utils import ask_ok_or_not
from _utils._gui_utils import choose_direction_by_mbox
from _utils._gui_utils import choose_division_by_mbox
from _utils._gui_utils import get_annotation_of_track_end
from _utils._logging import log_error
from _utils._logging import logger
from skimage.util import map_array


class ViewerState(Enum):
    ALL_LABEL = 1
    LABEL_SELECTED = 2
    LABEL_REDRAW = 3
    LABEL_SWITCH = 4
    DAUGHTER_SWITCH = 5
    DAUGHTER_DRAW = 6
    DAUGHTER_CHOOSE_MODE = 7


VIEWER_STATE_VISIBILITY = {
    ViewerState.ALL_LABEL: [True, False, False, True],
    ViewerState.LABEL_SELECTED: [True, True, False, False],
    ViewerState.LABEL_REDRAW: [False, False, True, False],
    ViewerState.LABEL_SWITCH: [True, False, False, True],
    ViewerState.DAUGHTER_SWITCH: [True, False, False, True],
    ViewerState.DAUGHTER_DRAW: [False, False, True, False],
    ViewerState.DAUGHTER_CHOOSE_MODE: [False, True, False, False],
}


class ViewerModel:
    """The model responsible for updating the viewer state"""

    def __init__(
        self,
        viewer,
        target_Ts,
        label_layer,
        redraw_label_layer,
        sel_label_layer,
        finalized_label_layer,
        df_tracks: pd.DataFrame,
        df_divisions: pd.DataFrame,
        *,
        new_track_id: int,
        new_label_value: int,
        finalized_track_ids: List[int],
        candidate_track_ids: List[int],
        termination_annotations: Dict[int, str],
    ):
        """Initialize the model

        Parameters
        ----------
        travali_viewer : TravaliViewer
            the base TravaliViewer object
        df_tracks : pd.DataFrame
            a dataframe containing the segment information
        df_divisions : pd.DataFrame
            a dataframe containing the division information
        new_track_id : int
            the id of the new segment
        new_label_value : int
            the value of the new label
        finalized_track_ids : List[int]
            the list of segment ids that have been finalized
        candidate_track_ids : List[int]
            the list of segment ids that are candidates for annotation
        termination_annotations : Dict[int,str]
            the dict for association between track_id and termination annotation
        """
        self.selected_label = None
        self.track_id = None
        self.frame_childs = None
        self.label_childs = None
        self.segment_labels = None
        self.label_edited = None

        self.target_Ts = list(target_Ts)
        self.viewer = viewer
        self.label_layer = label_layer
        self.redraw_label_layer = redraw_label_layer
        self.sel_label_layer = sel_label_layer
        self.finalized_label_layer = finalized_label_layer

        self.termination_annotation = ""
        self.shape = self.label_layer.data.shape
        self.sizeT = self.label_layer.data.shape[0]

        self.df_tracks = df_tracks
        self.df_divisions = df_divisions
        self.new_track_id = new_track_id
        self.new_label_value = new_label_value
        self.finalized_track_ids = finalized_track_ids
        self.candidate_track_ids = candidate_track_ids
        self.termination_annotations = termination_annotations
        
        self.finalized_label_layer.data = MultiScaleData([d.map_blocks(
                self.__label_to_finalized_label, 
                dtype=np.uint8) for d in self.label_layer.data])

        self.viewer_state_active = {
            ViewerState.ALL_LABEL: self.label_layer,
            ViewerState.LABEL_SELECTED: self.sel_label_layer,
            ViewerState.LABEL_REDRAW: self.redraw_label_layer,
            ViewerState.LABEL_SWITCH: self.label_layer,
            ViewerState.DAUGHTER_SWITCH: self.label_layer,
            ViewerState.DAUGHTER_DRAW: self.redraw_label_layer,
            ViewerState.DAUGHTER_CHOOSE_MODE: self.sel_label_layer,
        }
        self.layers = [
            self.label_layer,
            self.sel_label_layer,
            self.redraw_label_layer,
            self.finalized_label_layer,
        ]

    @log_error
    def update_layer_status(self, *_):
        """Update the layer status according to the current viewer state."""
        visibles = VIEWER_STATE_VISIBILITY[self.state]
        assert len(visibles) == len(self.layers)
        for i in range(len(self.layers)):
            try:
                self.layers[i].visible = visibles[i]
            except ValueError:
                pass

        self.viewer.layers.selection.clear()
        self.viewer.layers.selection.add(self.viewer_state_active[self.state])

    @log_error
    def refresh_redraw_label_layer(self):
        """Refresh the redraw_label_layer to blank."""
        self.redraw_label_layer.data = np.zeros_like(self.redraw_label_layer.data)
        self.redraw_label_layer.mode = "paint"

    ############# functions to map label #############
    @log_error
    def __label_to_selected_label_image(self, block, block_info=None):
        """Convert a block of label image to the selected label image."""
        #        print("block_info",block_info[0]['array-location'])
        assert not self.segment_labels is None
        assert not self.frame_childs is None
        assert not self.label_childs is None
        if block_info is None or len(block_info) == 0:
            return None
        block_location = block_info[0]["array-location"]
        iT = block_location[0][0]
        sel_label = (block == self.segment_labels[iT]).astype(np.uint8)
        # reading from df_tracks2
        for j, (frame, label) in enumerate(zip(self.frame_childs, self.label_childs)):
            if iT == frame:
                if np.isscalar(label):
                    sel_label[block == label] = j + 2
                else:
                    indices = [slice(loc[0], loc[1]) for loc in block_location]
                    sub_label = label[tuple(indices)[2:]]
                    sel_label[0, 0][sub_label] = j + 2
        return sel_label

    @log_error
    def __label_to_finalized_label(self, block, block_info=None):
        """Convert a block of label image to the finalize label image."""
        #        print("block_info",block_info[0]['array-location'])
        if block_info is None or len(block_info) == 0:
            return None
        location = block_info[0]["array-location"]
        frame = location[0][0]
        try:
            segments_at_frame = self.df_tracks.loc[frame]
        except KeyError:
            return np.zeros_like(block, dtype=np.uint8)

        finalized_labels_at_frame = (
            segments_at_frame[
                segments_at_frame["track_id"].isin(self.finalized_track_ids)
            ]
            .index.get_level_values("label")
            .to_list()
        )

        candidate_labels_at_frame = (
            segments_at_frame[
                segments_at_frame["track_id"].isin(self.candidate_track_ids)
            ]
            .index.get_level_values("label")
            .to_list()
        )

        input_vals = finalized_labels_at_frame + candidate_labels_at_frame
        output_vals = [1] * len(finalized_labels_at_frame) + [2] * len(
            candidate_labels_at_frame
        )

        return map_array(block, input_vals, output_vals)

    ############# transition functions #############
    @log_error
    def select_track(self, frame, val, track_id):
        self.track_id = track_id
        segment_labels = np.ones(self.sizeT, dtype=np.uint32) * NOSEL_VALUE
        df = self.df_tracks[self.df_tracks["track_id"] == track_id]
        frames = df.index.get_level_values("frame").values
        labels = df.index.get_level_values("label").values
        segment_labels[frames] = labels

        self.label_edited = np.zeros(len(segment_labels), dtype=bool)
        self.segment_labels = segment_labels
        self.original_segment_labels = segment_labels.copy()
        # used to rewrite track on exit

        row = self.df_divisions[self.df_divisions["parent_track_id"] == track_id]
        print("segment id:", track_id)
        print(segment_labels)
        print(row)
        if len(row) == 1:
            self.frame_childs = list(row.iloc[0][["frame_child1", "frame_child2"]])
            self.label_childs = list(row.iloc[0][["label_child1", "label_child2"]])
        elif len(row) == 0:
            self.frame_childs = []
            self.label_childs = []
        else:
            return
        print(self.frame_childs, self.label_childs)
        self.sel_label_layer.data = [
            d.map_blocks(self.__label_to_selected_label_image, dtype=np.uint8)
            for d in self.label_layer.data
        ]

    @log_error
    def label_redraw_enter_valid(self):
        iT = self.viewer.dims.current_step[0]
        # return True if:
        # - this timeframe is in target_T
        # - segment_labels is not NOSEL_VALUE in either of this, previous, next target_T
        if not iT in self.target_Ts:
            logger.info("this frame is not in target_Ts")
            return False
        previous_iT = self.target_Ts[max(0, self.target_Ts.index(iT) - 1)]
        next_iT = self.target_Ts[
            min(len(self.target_Ts) - 1, self.target_Ts.index(iT) + 1)
        ]
        if (
            self.segment_labels[iT] == NOSEL_VALUE
            and self.segment_labels[previous_iT] == NOSEL_VALUE
            and self.segment_labels[next_iT] == NOSEL_VALUE
            #            not np.any(self.sel_label_layer.data[iT] == 1)
            #            and not np.any(self.sel_label_layer.data[min(iT + 1, self.sizeT)] == 1)
            #            and not np.any(self.sel_label_layer.data[max(iT - 1, 0)] == 1)
        ):
            logger.info("track does not exist in connected timeframe")
            return False
        else:
            logger.info("redraw valid")
            return True

    @log_error
    def check_drawn_label(self):
        return np.any(self.redraw_label_layer.data == 1)

    @log_error
    def label_redraw_finish(self):
        logger.info("label redraw finish")
        iT = self.viewer.dims.current_step[0]
        logger.info("label redraw finish")
        self.sel_label_layer.data[iT] = 0
        self.sel_label_layer.data[iT] = self.redraw_label_layer.data == 1
        self.label_edited[iT] = True
        if self.segment_labels[iT] == NOSEL_VALUE:
            self.segment_labels[iT] = NEW_LABEL_VALUE
        else:
            if ask_draw_label(self.viewer) == "new":
                self.segment_labels[iT] = NEW_LABEL_VALUE

    @log_error
    def switch_track_enter_valid(self):
        iT = self.viewer.dims.current_step[0]
        if not iT in self.target_Ts:
            logger.info("this frame is not in target_Ts")
            return False
        previous_iT = self.target_Ts[max(0, self.target_Ts.index(iT) - 1)]
        next_iT = self.target_Ts[
            min(len(self.target_Ts) - 1, self.target_Ts.index(iT) + 1)
        ]
        if (
            self.segment_labels[iT] == NOSEL_VALUE
            and self.segment_labels[previous_iT] == NOSEL_VALUE
            and self.segment_labels[next_iT] == NOSEL_VALUE
            #            not np.any(self.sel_label_layer.data[iT] == 1)
            #            and not np.any(self.sel_label_layer.data[min(iT + 1, self.sizeT)] == 1)
            #            and not np.any(self.sel_label_layer.data[max(iT - 1, 0)] == 1)
        ):
            logger.info("track does not exist in connected timeframe")
            return False
        else:
            logger.info("switch valid")
            return True

    @log_error
    def switch_track(self, frame, val, track_id):
        direction = choose_direction_by_mbox(self.viewer)

        if not direction:
            return
        elif direction == "forward":
            print("forward ... ")
            df = self.df_tracks[
                (self.df_tracks["track_id"] == track_id)
                & (self.df_tracks.index.get_level_values("frame") >= frame)
            ]
            frames = df.index.get_level_values("frame").values
            labels = df.index.get_level_values("label").values

            self.segment_labels[frame:] = NOSEL_VALUE
            self.segment_labels[frames] = labels
            self.label_edited[frame:] = False
            # FIXME revert layer to original
            row = self.df_divisions[self.df_divisions["parent_track_id"] == track_id]

            if len(row) == 1:
                self.frame_childs = row.iloc[0][["frame_child1", "frame_child2"]]
                self.label_childs = row.iloc[0][["label_child1", "label_child2"]]
            elif len(row) == 0:
                self.frame_childs = []
                self.label_childs = []
            self.termination_annotation = ""

        elif direction == "backward":
            df = self.df_tracks[
                (self.df_tracks["track_id"] == track_id)
                & (self.df_tracks.index.get_level_values("frame") <= frame)
            ]
            frames = df.index.get_level_values("frame").values
            labels = df.index.get_level_values("label").values
            self.segment_labels[:frame] = NOSEL_VALUE
            self.segment_labels[frames] = labels
            self.label_edited[:frame] = False

    @log_error
    def daughter_choose_mode_enter_valid(self):
        logger.info("enter daughter choose")
        iT = self.viewer.dims.current_step[0]
        if not iT in self.target_Ts:
            logger.info("this frame is not in target_Ts")
            return False
        previous_iT = self.target_Ts[max(0, self.target_Ts.index(iT) - 1)]
        if (
            self.segment_labels[iT] == NOSEL_VALUE
            and self.segment_labels[previous_iT] == NOSEL_VALUE
        ):
            logger.info("track does not exist in connected timeframe")
            return False
        logger.info("mark division...")
        self.frame_child_candidate = iT
        self.label_child_candidates = []
        return True

    @log_error
    def on_enter_DAUGHTER_CHOOSE_MODE(self, *_):
        logger.info("candidates count: %i", len(self.label_child_candidates))
        if len(self.label_child_candidates) == 2:
            self.finalize_daughter()
            self.to_LABEL_SELECTED()
        else:
            method = choose_division_by_mbox(self.viewer)
            logger.info("%s selected", method)
            if method == "select":
                self.to_DAUGHTER_SWITCH()
            elif method == "draw":
                self.refresh_redraw_label_layer()
                self.to_DAUGHTER_DRAW()
            else:
                self.to_LABEL_SELECTED()

    @log_error
    def daughter_select(self, frame, val, track_id):
        if frame == self.frame_child_candidate:
            self.label_child_candidates.append(int(val))
        else:
            logger.info("frame not correct")

    @log_error
    def daughter_draw_finish(self):
        self.label_child_candidates.append(self.redraw_label_layer.data == 1)

    @log_error
    def finalize_daughter(self):
        assert len(self.label_child_candidates) == 2
        self.frame_childs = []
        self.label_childs = []
        for j, candidate in enumerate(self.label_child_candidates):
            self.label_childs.append(candidate)
            self.frame_childs.append(self.frame_child_candidate)
        self.segment_labels[self.frame_child_candidate :] = NOSEL_VALUE

    @log_error
    def mark_termination_enter_valid(self):
        iT = self.viewer.dims.current_step[0]
        if not np.any(self.sel_label_layer.data[iT] == 1):
            logger.info("track does not exist in connected timeframe")
            return False
        else:
            logger.info("marking termination valid")
            return True

    @log_error
    def mark_termination(self):
        iT = self.viewer.dims.current_step[0]
        termination_annotation, res = get_annotation_of_track_end(
            self.viewer, self.termination_annotations.get(self.track_id, "")
        )
        if res:
            logger.info("marking termination: {termination_annotation}")
            self.termination_annotation = termination_annotation
            if iT < self.segment_labels.shape[0] - 1:
                self.segment_labels[iT + 1 :] = NOSEL_VALUE
        else:
            logger.info("marking termination cancelled")

    @log_error
    def finalize_track(self):
        track_id = self.track_id
        segment_labels = self.segment_labels

        frame_childs = self.frame_childs.copy()
        label_childs = self.label_childs.copy()

        segment_graph = nx.Graph()
        frame_labels = list(enumerate(segment_labels)) + list(
            zip(frame_childs, label_childs)
        )
        relevant_track_ids = np.unique(
            [
                self.df_tracks.loc[(frame, label), "track_id"]
                for frame, label in frame_labels
                if np.isscalar(label)
                and label != NOSEL_VALUE
                and label != NEW_LABEL_VALUE
            ]
        )

        last_frames = {}
        for relevant_track_id in relevant_track_ids:
            df = self.df_tracks[self.df_tracks["track_id"] == relevant_track_id]
            if len(df) == 0:
                continue
            df = df.sort_index(level="frame")
            last_frames[relevant_track_id] = df.index.get_level_values("frame")[-1]
            if len(df) == 1:
                frame, label = df.index[0]
                segment_graph.add_node((frame, label))
            else:
                for ((frame1, label1), _), ((frame2, label2), _) in zip(
                    df.iloc[:-1].iterrows(), df.iloc[1:].iterrows()
                ):
                    segment_graph.add_edge((frame1, label1), (frame2, label2))

        for frame, label in enumerate(segment_labels):
            if label in (NOSEL_VALUE, NEW_LABEL_VALUE):
                continue
            segment_graph.remove_node((frame, label))
            self.df_tracks.loc[(frame, label), "track_id"] = track_id

        for frame, label in zip(frame_childs, label_childs):
            if not np.isscalar(label):
                continue
            neighbors = segment_graph.neighbors((frame, label))
            ancestors = [n for n in neighbors if n[0] < frame]
            if len(ancestors) == 0:
                continue
            else:
                assert len(ancestors) == 1
                ancestor = ancestors[0]
                segment_graph.remove_edge((frame, label), ancestor)

        # relavel divided tracks
        for subsegment in nx.connected_components(segment_graph):
            frame_labels = sorted(subsegment, key=lambda x: x[0])
            original_track_id = self.df_tracks.loc[frame_labels, "track_id"]
            assert np.all(original_track_id.iloc[0] == original_track_id)
            original_track_id = original_track_id.iloc[0]
            last_frame = last_frames[original_track_id]
            frames, _ = zip(*frame_labels)

            self.df_tracks.loc[frame_labels, "track_id"] = self.new_track_id
            if np.any(frames == last_frame):
                ind = self.df_divisions["parent_track_id"] == original_track_id
                if np.any(ind):
                    assert np.sum(ind) == 1
                    self.df_divisions.loc[ind, "parent_track_id"] = self.new_track_id
            self.new_track_id += 1

        def __draw_label(label_image, frame, label):
            # XXX tenative imprementation, faster if directly edit the zarr?
            __dask_compute = (
                lambda arr: arr.compute() if isinstance(arr, da.Array) else arr
            )
            inds = [__dask_compute(i) for i in np.where(label_image)]
            bboxes = [(np.min(ind), np.max(ind) + 1) for ind in inds]
            subimg = np.array(
                self.label_layer.data[frame, 0, 0, slice(*bboxes[0]), slice(*bboxes[1])]
            )
            subimg[tuple((ind - bbox[0]) for ind, bbox in zip(inds, bboxes))] = label
            self.label_layer.data[
                frame, 0, 0, slice(*bboxes[0]), slice(*bboxes[1])
            ] = subimg
            return bboxes

        for redrawn_frame in np.where(self.label_edited)[0]:
            label = self.segment_labels[redrawn_frame]
            if not label in [NOSEL_VALUE, NEW_LABEL_VALUE]:
                __draw_label(
                    self.label_layer.data[redrawn_frame, 0, 0] == label,
                    redrawn_frame,
                    0,
                )
            else:
                label = self.new_label_value

                # FIXME: rewrite with concat
                self.df_tracks = self.df_tracks.append(
                    pd.Series({"track_id": track_id}, name=(redrawn_frame, label))
                )
                self.new_label_value += 1

            bboxes = __draw_label(
                self.sel_label_layer.data[redrawn_frame, 0, 0] == 1,
                redrawn_frame,
                label,
            )
            # set bounding box
            self.df_tracks.loc[(redrawn_frame, label), "bbox_y0"] = bboxes[0][0]
            self.df_tracks.loc[(redrawn_frame, label), "bbox_y1"] = bboxes[0][1]
            self.df_tracks.loc[(redrawn_frame, label), "bbox_x0"] = bboxes[1][0]
            self.df_tracks.loc[(redrawn_frame, label), "bbox_x1"] = bboxes[1][1]

        ind = self.df_divisions["parent_track_id"] == track_id
        if np.any(ind):
            assert np.sum(ind) == 1
            self.df_divisions = self.df_divisions[~ind]
            self.new_track_id += 1

        if len(frame_childs) > 0 and len(label_childs) > 0:
            assert len(frame_childs) == 2 and len(label_childs) == 2
            division_row = {"parent_track_id": track_id}
            for j, (frame_child, label_child) in enumerate(
                zip(frame_childs, label_childs)
            ):
                division_row[f"frame_child{j+1}"] = frame_child
                if np.isscalar(label_child):
                    # means the daughter was selected
                    division_row[f"label_child{j+1}"] = label_child
                    track_id_child = self.df_tracks.loc[
                        (frame_child, label_child), "track_id"
                    ]
                else:
                    bboxes = __draw_label(
                        label_child[0], frame_child, self.new_label_value
                    )
                    division_row[f"label_child{j+1}"] = self.new_label_value
                    # FIXME: rewrite with concat
                    self.df_tracks = self.df_tracks.append(
                        pd.Series(
                            {
                                "track_id": self.new_track_id,
                                "bbox_y0": bboxes[0][0],
                                "bbox_y1": bboxes[0][1],
                                "bbox_x0": bboxes[1][0],
                                "bbox_x1": bboxes[1][1],
                            },
                            name=(frame_child, self.new_label_value),
                        )
                    )
                    track_id_child = self.new_track_id
                    self.new_track_id += 1
                    self.new_label_value += 1
                if not track_id_child in self.finalized_track_ids:
                    logger.info(f"candidate adding ... {track_id_child}")
                    self.candidate_track_ids.add(track_id_child)
            # FIXME: rewrite with concat
            self.df_divisions = self.df_divisions.append(
                division_row, ignore_index=True
            )

        self.finalized_track_ids.add(track_id)
        self.candidate_track_ids.discard(track_id)
        self.termination_annotations[track_id] = self.termination_annotation

        self.finalized_label_layer.data = self.label_layer.data.map_blocks(
            self.__label_to_finalized_label, dtype=np.uint8
        )

    @log_error
    def save_results(self, zarr_path, label_dataset_name, chunks, persist):
        logger.info("saving validation results...")

        if not label_dataset_name.endswith(".travali"):
            label_dataset_name += ".travali"
        zarr_file = zarr.open(zarr_path, "a")
        if label_dataset_name in zarr_file["labels"].keys():
            if_overwrite = ask_ok_or_not(
                self.viewer, "Validation file already exists. Overwrite?"
            )
            if not if_overwrite:
                logger.warning("label not saved")
                return
        logger.info("saving label ...")

        # to avoid IO from/to the same array, save to a temp array and then rename
        label_group = zarr_file["labels"]
        label_chunks = [chunks[0], *chunks[2:]]
        label_data = self.label_layer.data[:, 0, :, :, :].rechunk(label_chunks)
        ds = label_group.create_dataset(
            f"{label_dataset_name}_tmp",
            shape=label_data.shape,
            dtype=label_data.dtype,
            chunks=label_chunks,
            overwrite=True,
        )
        label_data.to_zarr(ds, overwrite=True)
        if label_dataset_name in label_group.keys():
            del label_group[label_dataset_name]
        label_group.store.rename(ds.name, f"{label_group.name}/{label_dataset_name}")
        label_group[label_dataset_name].attrs["target_Ts"] = list(
            map(int, self.target_Ts)
        )

        logger.info("saving segments...")

        segments_group = zarr_file["df_tracks"]
        if label_dataset_name in segments_group.keys():
            del segments_group[label_dataset_name]
        segments_ds = segments_group.create_dataset(
            label_dataset_name,
            data=self.df_tracks.reset_index()[DF_TRACKS_COLUMNS].astype(int).values,
        )
        segments_ds.attrs["finalized_track_ids"] = list(
            map(int, self.finalized_track_ids)
        )
        segments_ds.attrs["candidate_track_ids"] = list(
            map(int, self.candidate_track_ids)
        )
        segments_ds.attrs["termination_annotations"] = {
            int(k): str(v) for k, v in self.termination_annotations.items()
        }

        logger.info("saving divisions...")

        divisions_group = zarr_file["df_divisions"]
        if label_dataset_name in divisions_group.keys():
            del divisions_group[label_dataset_name]
        divisions_group.create_dataset(
            label_dataset_name,
            data=self.df_divisions.reset_index()[DF_DIVISIONS_COLUMNS]
            .astype(int)
            .values,
        )
        logger.info("reading data ...")
        self.label_layer.data = da.from_zarr(label_group[label_dataset_name])[
            :, np.newaxis, :, :, :
        ]
        if persist:
            self.label_layer.data = self.label_layer.data.persist()
        logger.info("saving validation results finished")


In [62]:
LOGGING_PATH = ".travali/log.txt"
basedir = "/Users/fukai/Downloads/"
zarr_path = path.join(basedir,"aligned_image_small2.zarr")
label_dataset_name = "original"
zarr_file = zarr.open(zarr_path, "a")

In [39]:
images = [da.from_zarr(im) for im in zarr_file["image"].values()]
label_group = zarr_file["labels"]
label_ds = label_group[label_dataset_name]
labelss = [da.from_zarr(l)[:, np.newaxis, :, :, :] for l in label_ds.values()]
data_chunks = images[0].chunks

tracks_ds = zarr_file["df_tracks"][label_dataset_name]
df_tracks_original = pd.DataFrame(tracks_ds, columns=DF_TRACKS_COLUMNS).set_index(
    ["frame", "label"]
)
df_divisions = pd.DataFrame(
    zarr_file["df_divisions"][label_dataset_name],
    columns=DF_DIVISIONS_COLUMNS,
)

if "finalized_track_ids" in tracks_ds.attrs:
    finalized_track_ids = set(tracks_ds.attrs["finalized_track_ids"])
else:
    finalized_track_ids = set()
if "candidate_track_ids" in tracks_ds.attrs:
    candidate_track_ids = set(tracks_ds.attrs["candidate_track_ids"])
else:
    candidate_track_ids = set()
termination_annotations = {
    int(k): str(v)
    for k, v in tracks_ds.attrs.get("termination_annotations", {}).items()
}

target_Ts = sorted(label_ds.attrs["target_Ts"])
target_Ts = sorted(list(map(int, target_Ts)))
assert all(np.array(target_Ts) < labelss[0].shape[0])

new_label_value = df_tracks_original.index.get_level_values("label").max() + 1
new_track_id = df_tracks_original["track_id"].max() + 1

#### only extract information in target_Ts ####
def show_only_target_Ts(block, block_info=None):
    if block_info is None or len(block_info) == 0:
        return None
    location = block_info[0]["array-location"]
    frame = location[0][0]
    if frame in target_Ts:
        return block
    else:
        return np.zeros_like(block)
label2 = [labels.map_blocks(show_only_target_Ts, dtype=labels.dtype) for labels in labelss]

In [157]:
viewer = napari.Viewer()
viewer.add_image(images)
viewer.add_labels(label2, name="labels")

<Labels layer 'labels' at 0x7fde185def50>

In [158]:
def _get_move_in_target_Ts(forward: bool):
    def fn(_event) -> None:
        # XXX dirty implementation but works
        target_Ts2 = np.array(target_Ts)
        logger.info(f"moving {forward}")
        iT = viewer.dims.point[0]
        if forward:
            iTs = target_Ts2[target_Ts2 > iT]
            if len(iTs) > 0:
                viewer.dims.set_point(0, np.min(iTs))
        else:
            iTs = target_Ts2[target_Ts2 < iT]
            if len(iTs) > 0:
                viewer.dims.set_point(0, np.max(iTs))
    return fn

viewer.bind_key("Shift-Right", _get_move_in_target_Ts(True), overwrite=True)
viewer.bind_key("Shift-Left", _get_move_in_target_Ts(False), overwrite=True)


# Load pickup sample positions

In [159]:
pick_up_samples_df2 = pd.read_csv(path.join(basedir,"pickup_sample_positions2.csv"))
reference_frames = pick_up_samples_df2["frame"].unique()
assert len(reference_frames) == 1
reference_frame = reference_frames[0]

In [160]:
labels = labelss[0][reference_frame].compute()

In [161]:
from skimage.measure import regionprops_table
regionprops_df = pd.DataFrame(regionprops_table(labels[0,0], properties=["label", "centroid"]))

In [162]:
pick_up_samples_df2.columns

Index(['index_x', 'position ID', 'tube ID ', 'eject confirmed ',
       'mosaic Y [px]', 'mosaic X [px]', 'mask_value', 'frame', 'index_y',
       'label', 'centroid-0', 'centroid-1', 'area', 'intensity_mean-0',
       'intensity_mean-1', 'intensity_mean-2', 'intensity_mean-3',
       'intensity_max-0', 'intensity_max-1', 'intensity_max-2',
       'intensity_max-3', 'intensity_min-0', 'intensity_min-1',
       'intensity_min-2', 'intensity_min-3', 'frame_y', 'tree_id', 'track_id'],
      dtype='object')

In [163]:
positions_df = pd.merge(
    pick_up_samples_df2[['position ID', 'tube ID ',"label"]],
    regionprops_df,
    on="label",)

In [164]:
data = np.hstack([np.ones((len(positions_df), 1))*reference_frame, 
                  np.zeros((len(positions_df), 2)), 
                  positions_df[["centroid-0", "centroid-1"]].values])
data[:10]
viewer.add_points(data,size=20, face_color="red", name="positions", text=list(positions_df["tube ID "].values))


<Points layer 'positions' at 0x7fe1706c54e0>

In [182]:
viewer.dims.set_point(0, reference_frame)

# Load tracks

In [101]:
df_tracks = df_tracks_original[
    df_tracks_original.index.get_level_values("frame").isin(target_Ts)
].copy()
df_tracks

Unnamed: 0_level_0,Unnamed: 1_level_0,track_id,bbox_y0,bbox_y1,bbox_x0,bbox_x1
frame,label,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,199,198,31,60,6635,6651
0,200,199,61,89,9487,9513
0,201,200,75,106,5081,5113
0,202,201,83,110,9506,9531
0,203,202,136,192,3722,3794
...,...,...,...,...,...,...
1963,12081,345951,13820,13849,308,339
1963,12088,354010,13895,13937,1103,1147
1963,12136,354024,13918,13945,11,35
1963,12141,354026,13993,14034,911,956


In [102]:
df_tracks2 = df_tracks.reset_index()
track_heads = df_tracks2.loc[df_tracks2.groupby("track_id").frame.idxmin()].set_index("track_id")
df_tracks.loc[(24,210)]
track_heads

Unnamed: 0_level_0,frame,label,bbox_y0,bbox_y1,bbox_x0,bbox_x1
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
194,96,235,0,5,3498,3517
196,24,210,0,2,9115,9128
198,0,199,31,60,6635,6651
199,0,200,61,89,9487,9513
200,0,201,75,106,5081,5113
...,...,...,...,...,...,...
354004,1963,12076,13903,13926,1639,1660
354010,1963,12088,13895,13937,1103,1147
354024,1963,12136,13918,13945,11,35
354026,1963,12141,13993,14034,911,956


In [48]:
df_tracks_original.loc[(207)]

Unnamed: 0_level_0,track_id,bbox_y0,bbox_y1,bbox_x0,bbox_x1
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
294,9938,0,2,5113,5127
297,1637,28,59,6629,6647
298,9177,55,87,9095,9125
299,6218,59,87,5031,5061
300,5733,72,97,9191,9221
...,...,...,...,...,...
1861,1225,13904,13936,2502,2536
1862,7549,13906,13932,802,836
1863,1226,13931,13961,2477,2505
1864,12954,13970,13993,2409,2433


In [106]:
logger.info("organizing dataframes")

ng_indices = []
for i in [1, 2]:
    for j, row in tqdm(df_divisions.iterrows()):
        frame, label = row[f"frame_child{i}"], row[f"label_child{i}"]
        try:
            track_id = df_tracks_original.loc[(frame, label)]["track_id"]
        except KeyError:
            df_divisions.loc[j, f"frame_child{i}"] = None
            df_divisions.loc[j, f"label_child{i}"] = None
            continue
        try:
            df_matched = track_heads.loc[track_id]
            df_divisions.loc[j, f"frame_child{i}"] = df_matched["frame"]
            df_divisions.loc[j, f"label_child{i}"] = df_matched["label"]
        except KeyError:
            df_divisions.loc[j, f"frame_child{i}"] = None
            df_divisions.loc[j, f"label_child{i}"] = None

    df_divisions[f"frame_child{i}"] = df_divisions[f"frame_child{i}"].astype(
        pd.Int64Dtype()
    )
    df_divisions[f"label_child{i}"] = df_divisions[f"label_child{i}"].astype(
        pd.Int64Dtype()
    )
df_divisions = df_divisions.dropna()

91018it [00:25, 3525.86it/s]
91018it [00:25, 3577.07it/s]


In [107]:
assert all(df_tracks.index.get_level_values("frame").isin(target_Ts))
for i in [1, 2]:
    assert all(df_divisions[f"frame_child{i}"].isin(target_Ts))

# Running viewer

In [124]:
viewer = napari.Viewer()

In [125]:

#contrast_limits = np.percentile(np.array(images[0][0]).ravel(), (2, 98))

viewer.add_image(images)

label_layer = viewer.add_labels(label2, name="label", cache=False)
sel_label_layer = viewer.add_labels(
    [da.zeros_like(l, dtype=np.uint8) for l in labelss] , 
    name="Selected label", cache=False
)
sel_label_layer.contour = 3
redraw_label_layer = viewer.add_labels(
    np.zeros(labelss[0].shape[-3:], dtype=np.uint8), name="Drawing", cache=False
)
finalized_label_layer = viewer.add_labels(
    [da.zeros_like(l, dtype=np.uint8) for l in labelss] , 
    name="Finalized",
    # color ={1:"red"}, not working
    opacity=1.0,
#    blending="opaque",
    cache=False,
)
finalized_label_layer.contour = 3

In [128]:
viewer_model = ViewerModel(
    viewer,
    target_Ts,
    label_layer,
    redraw_label_layer,
    sel_label_layer,
    finalized_label_layer,
    df_tracks,
    df_divisions,
    new_track_id=new_track_id,
    new_label_value=new_label_value,
    finalized_track_ids=finalized_track_ids,
    candidate_track_ids=candidate_track_ids,
    termination_annotations=termination_annotations,
)
machine = Machine(
    model=viewer_model,
    states=ViewerState,
    transitions=transitions,
    after_state_change="update_layer_status",
    initial=ViewerState.ALL_LABEL,
    ignore_invalid_triggers=True,  # ignore invalid key presses
)
viewer_model.update_layer_status()

AttributeError: 'list' object has no attribute 'shapes'

Traceback (most recent call last):
  File "/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/vispy/app/backends/_qt.py", line 903, in paintGL
    self._vispy_canvas.events.draw(region=None)
  File "/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/vispy/util/event.py", line 453, in __call__
    self._invoke_callback(cb, event)
  File "/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/vispy/util/event.py", line 471, in _invoke_callback
    _handle_exception(self.ignore_callback_errors,
  File "/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/vispy/util/event.py", line 469, in _invoke_callback
    cb(event)
  File "/Users/fukai/.pyenv/versions/miniforge3-4.14.0-2/envs/image_analysis2/lib/python3.10/site-packages/napari/_qt/qt_viewer.py", line 1102, in on_draw
    layer._update_draw(
  File "/Us