In [None]:
import livecell_tracker.sample_data
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from napari.layers import Shapes


In [None]:
dic_dataset_path = Path("../datasets/test_data_STAV-A549/DIC_data")
mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
dic_dataset, mask_dataset = livecell_tracker.sample_data.tutorial_three_image_sys(dic_dataset_path, mask_dataset_path)


In [None]:
from livecell_tracker.core.io_sc import prep_scs_from_mask_dataset
single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
from livecell_tracker.segment.ou_utils import create_ou_input_from_sc
from livecell_tracker.segment.utils import find_contours_opencv
from livecell_tracker.core.datasets import SingleImageDataset

class ScSegOperator:
    def __init__(
        self,
        sc: SingleCellStatic,
        viewer,
        shape_layer: Shapes=None,
        face_color=(0, 0, 1, 1),
    ):
        self.sc = sc
        self.viewer = viewer
        self.shape_layer = shape_layer
        self.face_color = face_color
        if not (self.shape_layer is None):
            self.setup_edit_contour_shape_layer()

    def create_sc_layer(self, name=None, contour_sample_num=100):
        if name is None:
            name = f"sc_{self.sc.id}"
        shape_vec = self.sc.get_napari_shape_contour_vec(contour_sample_num=contour_sample_num)
        properties = {"sc": [self.sc]}
        print("shape vec", shape_vec)
        shape_layer = self.viewer.add_shapes(
            [shape_vec],
            properties=properties,
            face_color=[self.face_color],
            shape_type="polygon",
            name=name,
        )
        self.shape_layer = shape_layer
        self.setup_edit_contour_shape_layer()

    def update_shape_layer_by_sc(self):
        shape_vec = self.sc.get_napari_shape_contour_vec(contour_sample_num=100)
        self.shape_layer.data = [shape_vec]

    def correct_segment(self, model, create_ou_input_kwargs = {
            "padding_pixels": 50,
            "dtype": float,
            "remove_bg": False,
            "one_object": True,
            "scale": 0,
        }):
        import torch
        from torchvision import transforms
        #  padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale

        input_transforms = transforms.Compose(
                [
                    transforms.Resize(size=(412, 412)),
                ]
        )
        ou_input = create_ou_input_from_sc(self.sc, **create_ou_input_kwargs)
        original_shape = ou_input.shape

        ou_input = input_transforms(torch.tensor([ou_input]))
        ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
        ou_input = ou_input.float().cuda()

        back_transforms = transforms.Compose(
            [
                transforms.Resize(size=(original_shape[0], original_shape[1])),
            ]
        )
        output = model(ou_input)
        output = back_transforms(output)
        return output

    def replace_sc_mask(self, mask, padding_pixels=0, refresh=True):
        self.sc.mask_dataset = SingleImageDataset(mask)
        contours = find_contours_opencv(mask)
        assert len(contours) == 1
        self.sc.contour = contours[0] + self.sc.bbox[:2] - padding_pixels
        self.sc.update_bbox()
        if refresh:
            self.update_shape_layer_by_sc()

    def replace_sc_contour(self, contour, padding_pixels=0, refresh=True):
        self.sc.contour = contour + self.sc.bbox[:2] - padding_pixels
        self.sc.update_bbox()
        if refresh:
            self.update_shape_layer_by_sc()
            
    def setup_edit_contour_shape_layer(self):
        return
        # TODO 
        from copy import deepcopy
        # Callback to check if shape_layer has more than one shape and remove the last one
        self.saved_data = deepcopy(self.shape_layer.data)
        def _shape_data_changed(event):
            print("_shape_data_changed fired")
            print("len of shape_layer.data:", len(self.shape_layer.data))
            if len(self.shape_layer.data) > 1:
                # self.shape_layer.events.data.disconnect(self._shape_data_changed)  # disconnect the callback
                print("[_shape_data_changed] len of saved_data:", len(self.saved_data))
                self.shape_layer.data = deepcopy(self.saved_data)
                # self.shape_layer.events.data.connect(self._shape_data_changed)
            elif len(self.shape_layer.data) == 1:
                self.saved_data = deepcopy(self.shape_layer.data)
        # If the shape_layer already exists, connect the callback
        if self.shape_layer is not None:
            self.shape_layer.events.data.connect(_shape_data_changed)




In [None]:
import napari
viewer = napari.view_image(dic_dataset.to_dask(), name="dic_image", cache=True)

In [None]:
sample_sc = single_cells[1]
sample_sc_seg_operator = ScSegOperator(sample_sc, viewer=viewer)
sample_sc_seg_operator.create_sc_layer(contour_sample_num=20)

In [None]:
from livecell_tracker.model_zoo.segmentation.sc_correction import CorrectSegNet
ckpt = r"./notebook_results/csn_models/model_v11_epoch=3282-test_loss=2.3688.ckpt"

model = CorrectSegNet.load_from_checkpoint(ckpt)
model = model.cuda()
model = model.eval()

In [None]:
padding_pixels = 40
create_ou_input_kwargs = {
            "padding_pixels": padding_pixels,
            "dtype": float,
            "remove_bg": False,
            "one_object": True,
            "scale": 0,
        }

output = sample_sc_seg_operator.correct_segment(model, create_ou_input_kwargs=create_ou_input_kwargs)
plt.imshow(output[0].cpu().detach().numpy()[0])

In [None]:
mask = output[0].cpu().detach().numpy()[0] > 0.5
sample_sc_seg_operator.replace_sc_mask(mask, padding_pixels=padding_pixels)

In [None]:
from livecell_tracker.segment.utils import find_contours_opencv
from livecell_tracker.core.datasets import SingleImageDataset

padding_pixels = 50
new_sc = sample_sc_seg_operator.sc.copy()
new_mask = output[0].cpu().detach().numpy()[0] > 0.5
new_sc.mask_dataset = SingleImageDataset(new_mask)
new_sc.contour = np.array(find_contours_opencv(new_mask)[0]) + new_sc.bbox[:2] - padding_pixels
new_sc_seg_operator = ScSegOperator(new_sc, viewer=viewer)
new_sc_seg_operator.create_sc_layer(contour_sample_num=20)