In [1]:
import sys
import os

# Append the directory of the module to the Python path
sys.path.append("/Users/sherryyang/Projects/super-segger-toolkit/superseggertoolkit")

In [3]:
%matplotlib widget

In [2]:
from cell import Cell
from link_composer import LinkComposer
import cells_extractor 
import visualizer
import link_algorithm
from cell_event import CellEvent, CellDefine, Cell

In [3]:
import networkx as nx
from queue import Queue

In [4]:
import glob as glob

files = glob.glob("/Users/sherryyang/Documents/wiggins-lab/kevins-data/teresa_high_frame_rate/*.tif")

In [5]:
index = 0

In [6]:
mask_tif = "/Users/sherryyang/Documents/wiggins-lab/kevins-data/Archive/0_masks.tif"
composer = LinkComposer.read_tif(mask_tif=mask_tif)
composer.show_mask_error

<bound method LinkComposer.show_mask_error of <link_composer.LinkComposer object at 0x1686cbe50>>

In [7]:
files[index]

'/Users/sherryyang/Documents/wiggins-lab/kevins-data/teresa_high_frame_rate/14_edited_labels.tif'

In [8]:
def cut_graph(G):
    for u, v in G.edges():
        G[u][v]['reliability'] = sys.maxsize

    for node in G.nodes():
        edges_with_weights = G.edges(node, data='weight')
        sorted_edges = sorted(edges_with_weights, key=lambda x: x[2])
        for i, (u, v, weight) in enumerate(sorted_edges):
            G[u][v]['reliability'] = min(i, G[u][v]['reliability'])

    edges_with_rel=  G.edges(data='reliability')
    sorted_edges = sorted(edges_with_rel, key=lambda x: x[2], reverse=True)
    index = 0

    while nx.is_connected(G):
        u,v, realiability_target = sorted_edges[index]
        realiability = realiability_target
        while realiability == realiability_target:
            index += 1
            G.remove_edge(u, v)
            u,v, realiability_target = sorted_edges[index]

    return G

def link_subtree(G):
    all_edges = set(G.edges(data='weight'))
    test_G = nx.Graph()
    edges = set()
    for u, v, weight in all_edges:
        if G.degree(u) == 1 or G.degree(v) == 1:
            test_G.add_edge(u,v, weight = 0)
        else:
            edges.add((u,v, weight))

    edges = list(edges)
    test_G.add_nodes_from(G.nodes())

    min_cost = sys.maxsize
    min_tree = nx.Graph()

    min_cost, min_tree = link_subtree_helper(0, edges, test_G, min_cost, min_tree)
    return min_cost, min_tree



def link_subtree_helper(index, edges, test_G, min_cost, min_tree):
    if index == len(edges):
        cost = sum(weight for _, _, weight in test_G.edges(data='weight'))
        cost += len(list(nx.isolates(test_G))) * 10
        if cost < min_cost:
            min_cost = cost
            min_tree = test_G.copy()
        return min_cost, min_tree

    u, v, weight = edges[index]

    min_cost, min_tree = link_subtree_helper(index + 1, edges, test_G, min_cost, min_tree)
    if (test_G.degree(u) <= 1) or (test_G.degree(v) <= 1):
        test_G.add_edge(u, v, weight=weight)
        min_cost, min_tree = link_subtree_helper(index + 1, edges, test_G, min_cost, min_tree)
        test_G.remove_edge(u, v)

    return min_cost, min_tree

In [9]:
tree_set = set()

for frame in range(1,composer.frame_num):

    dict = composer.cells_frame_dict
    sorted_frame = sorted(dict)

    source_cells = list(dict[sorted_frame[frame-1]])
    target_cells =  list(dict[sorted_frame[frame]])

    source_cells.sort()
    target_cells.sort()

    source_dim = len(source_cells)
    target_dim = len(target_cells)

    sub_G = nx.Graph()
    for i in range(source_dim):
        for j in range(target_dim):
            source_cell = source_cells[i]
            target_cell = target_cells[j]

            intersection_area = source_cell.polygon.intersection(target_cell.polygon).area
            if intersection_area > 0:
                union_area = source_cell.polygon.union(target_cell.polygon).area
                weight = -1 * intersection_area /  union_area  + 1
                sub_G.add_edge(source_cell, target_cell, weight=weight)

    
    queue = Queue()
    queue.put(sub_G)

    while not queue.empty():
        sub_G = queue.get()
        components = nx.connected_components(sub_G)
        subgraphs = []
        for nodes in components:
            subgraph = sub_G.subgraph(nodes).copy()
            if len(subgraph.edges()) < 20:
                min_cost, min_tree = link_subtree(subgraph)
                tree_set.add(min_tree)
            else:
                subgraph = cut_graph(subgraph)
                queue.put(subgraph)


G = composer.make_new_dircted_graph()
for graph in tree_set:
    for u, v in graph.edges():
        if u < v:
            G.add_edge(u,v)
        else:
            G.add_edge(v,u)

In [10]:
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from mpl_interactions import zoom_factory
from IPython.display import display
import matplotlib

plt.close('all')
matplotlib.use('TkAgg')  # Make sure this backend is compatible with your environment

frame = 0  # Initialize frame
max_frame = composer.frame_num - 1

# Assuming visualizer, composer, and G are defined
# Create the initial plot
plt.ioff()
fig, ax = plt.subplots(figsize=(10, 10))
fig.tight_layout()


label_info_1 = visualizer.get_label_info(G)
label_info_2 = visualizer.get_generation_label_info(G)
labels = [label_info_1, label_info_2]
li_index = 0
label_info = labels[li_index]



label_styles = ["regular", "circled", "empty"]
ls_index = 0
label_style = label_styles[ls_index]


image = composer.get_single_frame_phase(frame)
ax = visualizer.subplot_single_frame_phase(ax=ax, G=G, image=image, cells_frame_dict=composer.cells_frame_dict, label_style  = label_style, frame=frame, info=label_info, fontsize=7, figsize=(15,15), representative_point=True)
disconnect_zoom = zoom_factory(ax)

# Update plot function
def update_plot(frame, ax, fig):
    global label_styles, labels, li_index, ls_index

    label_info = labels[li_index]
    label_style = label_styles[ls_index]

    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    ax.clear()
    image = composer.get_single_frame_phase(frame)
    ax =  visualizer.subplot_single_frame_phase(ax=ax, G=G, image=image, cells_frame_dict=composer.cells_frame_dict, label_style  = label_style, frame=frame, info=label_info, fontsize=7, figsize=(15,15), representative_point=True)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    disconnect_zoom = zoom_factory(ax)
    fig.suptitle(f"Frame: {frame}, use keyboard: ⬅️ ➡️ to change frame, use 'c' to circle label, use 1 to change cell tag ", color = "blue")
    fig.canvas.draw_idle()
    return ax

def on_key(event):
    global frame, ax, max_frame, fig, li_index , ls_index
    if event.key in ['right', 'down']:
        frame = min(frame + 1, max_frame) 
        update_plot(frame, ax, fig)
    elif event.key in ['left', 'up']:
        frame  = max(frame - 1, 0)
        update_plot(frame, ax, fig)
    elif event.key == 'c':
        ls_index += 1
        ls_index =  ls_index % len(label_styles)
        update_plot(frame, ax, fig)
    elif event.key == '1':
        li_index += 1
        li_index =  li_index % len(labels)
        update_plot(frame, ax, fig)
    
        

fig.suptitle(f"Frame: {frame}, use keyboard: ⬅️ ➡️, use 'l' to change label, use c to change color/tag", color = "blue")
fig.canvas.mpl_connect('key_press_event', lambda event: on_key(event))

# Show the plot in a separate window
plt.show()


In [11]:
for u, v in G.edges():
    if not u.label ==  v.label:
        print(u.label, v.label)

14 16
14 15
4 8
4 9
1 2
1 3
24 27
24 28
35 41
35 42
34 36
34 35
3 4
3 5
38 43
38 44
13 17
13 18
37 45
37 46
23 32
23 31
33 37
33 38
20 21
20 22
12 14
12 13
7 10
7 11
21 25
21 26
36 40
36 39
2 7
2 6
19 24
19 23
22 29
22 30


In [12]:
visualizer.quick_lineage(G)