In [1]:
import numpy as np
from skimage import io
import napari
import pandas as pd
import zarr
import os
import pickle
import matplotlib.pyplot as plt
import tifffile as tf
import glob

### Get Annotation data

In [2]:
fix_n5_path = ['/mnt/aperto/fused/fused.n5','/mnt/aperto/Tatz_brain_data/fused/fused.n5']
directory = './tatz_anno_2d_cplm/'
if not os.path.exists(directory):
    os.makedirs(directory)
    print("Directory created:", directory)
save_path = directory
meta_path = './tatz_anno_2d_cplm/tatz_anno_2d_cplm.pkl'
# create Zarr file object
fix_zarr = [zarr.open(store=zarr.N5Store(fix_n5_path[0]), mode='r'), zarr.open(store=zarr.N5Store(fix_n5_path[1]), mode='r')]
n5_setups = list(fix_zarr[0].keys())
voxel_size = (2.0,1.3,1.3)

  fix_zarr = [zarr.open(store=zarr.N5Store(fix_n5_path[0]), mode='r'), zarr.open(store=zarr.N5Store(fix_n5_path[1]), mode='r')]


In [3]:
# make metadata file if it does not exist.
if not os.path.exists(meta_path):
    df = pd.DataFrame(columns=['ID', 'integer_ID', 'instance_counts', 'corner', 'source', 'ref_channel', 'channel', 'crop_size', 'isHard', 'plane_position','model'])
else:
    df = pd.read_pickle(meta_path)


In [4]:
resolusion_0_shape = fix_zarr[0][n5_setups[1]]['timepoint0']['s0'].shape
resolusion_2_shape = fix_zarr[0][n5_setups[3]]['timepoint0']['s2'].shape
scale_size = []
for i in range(len(resolusion_0_shape)):
    scale_size.append(resolusion_2_shape[i]/resolusion_0_shape[i])

In [5]:
# set your parameters here
reference_chan = 3 # Integer or None
segment_chan = 4

# [100,256,256] crop size and FoV [100,768,768] are recommended for the 2D annotation
crop_size = [100,256,256]
FoV = [100,768, 768]
FoV1 = [1, 2137, 1603]
# set True for 2D annotation and set False for 3D annotation
select_plane = True

# processing of parameters
if not all([(j-i)>=0 for i,j in zip(crop_size, FoV)]):
    raise ValueError('FoV should be larger than crop_size')

#### load ori 234 dataframe

In [6]:
with open('./annotation_position_info.pkl', 'rb') as f:
    ori_234_train = pickle.load(f)

  ori_234_train = pickle.load(f)


In [7]:
sel_77 = pd.read_csv('./selection77.csv')

In [8]:
re_anno = sel_77[sel_77['Re-annotation (y or n)'] == 'y'].index

In [9]:
with open('./pos_77.pkl', 'rb') as f:
    pos_id = pickle.load(f)

In [10]:
id_list = []

In [11]:
for i in range(len(pos_id)):
    if i in re_anno:
        id_list.append(pos_id[i])

In [12]:
### img need 3d view
img_3d_vu = [10,13,14,16,17,33,39,44,68,74,76] 

In [13]:
for item in img_3d_vu:
  id_list.append(pos_id[item])

In [14]:
len(id_list)

55

#### load cellpose, stardist, swin label

In [15]:
with open('./cellpose_training_tatz.pkl', 'rb') as f:
    cellpose_training_all = pickle.load(f)
    
with open('./stardist_training_tatz.pkl', 'rb') as f:
    stardist_training_all = pickle.load(f)

cellpose, stardist = cellpose_training_all, stardist_training_all

In [16]:
with open('./swin_training_tatz.pkl', 'rb') as f:
    swin2d = pickle.load(f)

In [17]:
begin_id = 0

In [22]:
for ori_id in id_list[begin_id:]:
    ori_corner_position = [ori_234_train['corner'][int(ori_id)]]
    ori_plane = int(ori_234_train['plane_position'][int(ori_id)])
    segment_chan = int(ori_234_train['channel'][int(ori_id)])
    reference_chan =  int(ori_234_train['ref_channel'][int(ori_id)])
    #check which source it should be int 0 or 1
    brain_source = int(ori_234_train['source'][int(ori_id)].split('/')[-3].split('_')[-2][0])-1
    
    pos = ori_corner_position[0]
    if len(pos) <= 2:
        raise ValueError('The position should have length 3')
    elif len(pos) == 3:
        isHard = 0
    elif len(pos) == 4:
        isHard = pos[-1]
        pos = pos[:-1]
    else:
        raise ValueError('You have a wrong position format')
        
    print(f"The index {ori_id} with the position {pos}")
    idx = int(ori_id)

    z_index = int((pos[0]+ori_plane)*scale_size[0])
    # prepare to make border lines
    a = [1, pos[1]*scale_size[1], pos[2]*scale_size[2]]
    b = [1, (pos[1]+256)*scale_size[1], (pos[2]+256)*scale_size[2]]
    top_border_corner = tuple(a)
    bottom_border_corner = tuple(b)
   
    #### full brain plane
    zarr_array = fix_zarr[brain_source][n5_setups[segment_chan]]['timepoint0']['s2']
    selected_plane_1 = zarr_array[z_index, :, :]
    zarr_array = fix_zarr[brain_source][n5_setups[reference_chan]]['timepoint0']['s2']
    selected_plane_2 = zarr_array[z_index, :, :]
    #blue_image = np.full_like(zarr_array[0, :, :], fill_value=255)
    
    viewer1 = napari.Viewer()
    
    @viewer1.bind_key('q')
    def close_viewer(viewer):
        print("Closing viewer...")
        viewer.close()

     #set hide hotkey
    @viewer1.bind_key('h')
    def toggle_layer_visibility(viewer):
        layer = viewer.layers.selection.active
        if layer is not None:
            layer.visible = not layer.visible
    
    #data3 = np.zeros(shape, dtype=np.uint8)
    # Initialize the first Napari viewer and add the first Z-plane
   
    viewer1.add_image(selected_plane_1, name=f'Z-plane {ori_plane}', colormap='gray',opacity = 0.5)

    viewer1.add_image(selected_plane_2, name=f'Z-plane {ori_plane}', colormap='gray',opacity = 0.5)
    #viewer1.add_image(blue_image, name='Cover Layer')
    viewer1.add_shapes([[bottom_border_corner[1],bottom_border_corner[2]],[top_border_corner[1],bottom_border_corner[2]]],
                      edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer1.add_shapes([[top_border_corner[1],bottom_border_corner[2]],[top_border_corner[1],top_border_corner[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer1.add_shapes([[bottom_border_corner[1],bottom_border_corner[2]],[bottom_border_corner[1],top_border_corner[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer1.add_shapes([[bottom_border_corner[1],top_border_corner[2]],[top_border_corner[1],top_border_corner[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line') 
 
    # find out any duplication between the current data and the metadata
    # if it is duplicated, ask 
    flag = False
    if df['corner'].isin([pos]).any():
        for k in df['integer_ID'][df['corner'].isin([pos])].to_list():
            if ((df.loc[k,'source'] == fix_n5_path[brain_source]) and 
                (df.loc[k,'ref_channel'] == reference_chan) and 
                (df.loc[k,'channel'] == segment_chan) and 
                (df.loc[k,'crop_size'] == crop_size) and
                (df.loc[k,'select_plane'] == select_plane)):
                flag = True
                idx = k
    if flag:
        ans = input("Do you want to re-analyze the data? y or n")
       
        if ans != 'y':
            continue


  
    # set file path to be saved for both image and mask
    prefix = str(idx)
    while len(prefix) < 4:
        prefix = '0' + prefix
    img_path = os.path.join(save_path, prefix+'_img.tif')
    #mask_path = os.path.join(save_path, prefix+'_mask.tif')
    mask_path = [os.path.join(save_path, prefix+'_nw_mask.tif'),os.path.join(save_path, prefix+'_cp_mask.tif'), 
                 os.path.join(save_path, prefix+'_sd_mask.tif'), os.path.join(save_path, prefix+'_sw_mask.tif')]
    # get the image of a channel to be segmented
    FoV_stack = []
    img = fix_zarr[brain_source][n5_setups[segment_chan]]['timepoint0']['s0']

    # set the corner of FoV in napari
    top_corner = tuple(i-(k-j)//2 for i,j,k in zip(pos, crop_size, FoV))
    bottom_corner = tuple(i+j+(k-j)//2 for i,j,k in zip(pos, crop_size, FoV))
    top_corner = tuple(j if j>=i else i for i,j in zip([0,0,0],top_corner))
    bottom_corner = tuple(j if j<=i else i for i,j in zip(img.shape,bottom_corner))
    
    # prepare to make border lines
    top_border_corner = tuple((k-j)//2 for j,k in zip(crop_size, FoV))
    bottom_border_corner = tuple(j+(k-j)//2 for j,k in zip(crop_size, FoV))
    
    FoV_segment = img[tuple(slice(i,j) for i,j in zip(top_corner, bottom_corner))]
    # get the image of a reference of channel
    if reference_chan is not None:
        img = fix_zarr[brain_source][n5_setups[reference_chan]]['timepoint0']['s0']
        FoV_reference = img[tuple(slice(i,j) for i,j in zip(top_corner, bottom_corner))]
        FoV_stack.append(FoV_reference)
    
    FoV_stack.append(FoV_segment)
    FoV_stack = np.stack(FoV_stack)

    ##### FoV is a 2channel(reference and signal) Field of view

    ## add label layer data
    shape = (100, 768, 768)
    data = np.zeros(shape, dtype=np.uint16)
    
    ### add psesudo labels
    # cellpose
    cp_pseudo = np.zeros(shape, dtype=np.uint16)
    cp_pseudo[ori_plane,256:512,256:512] = cellpose[ori_id]


    # stardist
    sd_pseudo = np.zeros(shape, dtype=np.uint16)
    sd_pseudo[ori_plane,256:512,256:512] = stardist[ori_id]


    # swin

    sw_pseudo =  np.zeros(shape, dtype=np.uint16)
    sw_pseudo[ori_plane,256:512,256:512] = swin2d[ori_id]

    viewer = napari.Viewer()

      ### choose model


    @viewer.bind_key('q')
    def close_viewer(viewer):
        print("Closing viewer...")
        viewer.close()

     #set hide hotkey
    @viewer.bind_key('h')
    def toggle_layer_visibility(viewer):
        layer = viewer.layers.selection.active
        if layer is not None:
            layer.visible = not layer.visible

    viewer.add_image(FoV_stack, channel_axis=0, scale=voxel_size, contrast_limits=[0,65535])
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]]],
                      edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[top_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[bottom_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line') 

    re_anno_label = viewer.add_labels(data,  name=f'Label({ori_plane})', scale=voxel_size)
    re_anno_label.opacity = 1.0
    re_anno_label.brush_size = 1



    ### cellpose
    cp_labels = viewer.add_labels(cp_pseudo,  name=f'cellpose({ori_plane})', scale=voxel_size)
    cp_labels.opacity = 1.0
    cp_labels.brush_size = 1


    ### stardist
    sd_labels = viewer.add_labels(sd_pseudo,  name=f'stardist({ori_plane})', scale=voxel_size)
    sd_labels.opacity = 1.0
    sd_labels.brush_size = 1


    ### swin
    sw_labels = viewer.add_labels(sw_pseudo,  name=f'swin2d({ori_plane})', scale=voxel_size)
    sw_labels.opacity = 1.0
    sw_labels.brush_size = 1

    viewer.camera.zoom = 1.5
    viewer.dims.current_step = (ori_plane,400,400)
    viewer.show(block=True)

    sub_area_slicer = tuple(slice(i,j) for i,j in zip(top_border_corner,bottom_border_corner))
    ######
    # subarea shape
    # save images and segmentation.
    img = np.swapaxes(FoV_stack[(slice(0,None),)+sub_area_slicer],0,1)
    #label layer
    re_anno_res = re_anno_label.data[sub_area_slicer]
    ### cellpose
    cp_anno_res = cp_labels.data[sub_area_slicer]
    ### stardist
    sd_anno_res = sd_labels.data[sub_area_slicer]

    ### swin
    sw_anno_res = sw_labels.data[sub_area_slicer]
    
    
    if select_plane:
        print('selected and save')
        ##########
        # modify this to original label
        plane_pos = ori_plane
        img = img[plane_pos,...]
        re_anno_res = re_anno_res[plane_pos,...]
        #cp,sd,sw
        cp_anno_res = cp_anno_res[plane_pos,...]
        sd_anno_res = sd_anno_res[plane_pos,...]
        sw_anno_res = sw_anno_res[plane_pos,...]
        

        io.imsave(img_path, img, plugin='tifffile', imagej=True, metadata={'axes': 'CYX'})
        ### save masks
        io.imsave(mask_path[0], re_anno_res, plugin='tifffile')
        io.imsave(mask_path[1], cp_anno_res, plugin='tifffile')
        io.imsave(mask_path[2], sd_anno_res, plugin='tifffile')
        io.imsave(mask_path[3], sw_anno_res, plugin='tifffile')
        
    else:
        print('#############')
        print('not selected and save')
        io.imsave(img_path, img, plugin='tifffile', imagej=True, metadata={'axes': 'ZCYX'})

    
    model = input("which model you will re-annotate?(cp, sd, sw, n)?")

    # update the metadata
    df.loc[idx,'ID'] = prefix
    df.loc[idx,'integer_ID'] = idx
    count = (np.unique(re_anno_res)).size - 1
    df.loc[idx,'instance_counts'] = count
    df.at[idx,'corner'] = pos
    df.loc[idx, 'source'] = fix_n5_path[brain_source]
    df.loc[idx, 'ref_channel'] = reference_chan
    df.loc[idx, 'channel'] = segment_chan
    df.at[idx, 'crop_size'] = crop_size
    df.loc[idx, 'select_plane'] = select_plane
    df.loc[idx, 'isHard'] = isHard
    df.loc[idx, 'model'] = model
    if select_plane:
        df.loc[idx, 'plane_position'] = int(plane_pos)
    else:
        df.loc[idx, 'plane_position'] = -1


The index 0126 with the position [93, 4958, 4866]


  io.imsave(mask_path[0], re_anno_res, plugin='tifffile')
  io.imsave(mask_path[1], cp_anno_res, plugin='tifffile')
  io.imsave(mask_path[2], sd_anno_res, plugin='tifffile')
  io.imsave(mask_path[3], sw_anno_res, plugin='tifffile')


selected and save


which model you will re-annotate?(cp, sd, sw, n)? sw


The index 0191 with the position [586, 6230, 3173]
Closing viewer...


  self.window.show(block=block)


AttributeError: 'Window' object has no attribute '_qt_window'

Closing viewer...


In [13]:
df.to_pickle(meta_path)

In [14]:
# if you don't want to run all again, set the beginning id as begin
curent_id = id_list.index(ori_id)
begin_id = curent_id
print(f'current image id {curent_id}, next begin id {begin_id}')

current image id 46, next begin id 46


In [20]:
# (optional)if you want to start all over again, reset begin_id as 0
# begin_id = 0

In [96]:
ori_plane = 50

In [97]:
shape = (100, 768, 768)
cp_pseudo = np.zeros(shape, dtype=np.uint16)
cp_pseudo[ori_plane,256:512,256:512] = cellpose['0210']

In [101]:
prefix = '0021'

In [21]:
mask_path 

['./tatz_anno_2d_cplm/0126_nw_mask.tif',
 './tatz_anno_2d_cplm/0126_cp_mask.tif',
 './tatz_anno_2d_cplm/0126_sd_mask.tif',
 './tatz_anno_2d_cplm/0126_sw_mask.tif']