In [None]:
import numpy as np
import os

from skimage import io
from tqdm import tqdm
import napari
import pandas as pd
from skimage import exposure

import zarr
import dask.array as da

import pywt

from os import listdir
from os.path import isfile, join
import json

import dask.array as da

In [None]:
dataset_folder = '/mnt/ampa02_data01/tmurakami/model_training/crops'
# files = [f for f in listdir(data_path) if isfile(join(data_path, f))]

datacard_path = "/mnt/ampa02_data01/tmurakami/model_training/tatz_datacard02.json"
# open datacard
with open(datacard_path) as f:
    datacard = json.load(f)
    
# metadata path
meta_path = '/mnt/ampa02_data01/tmurakami/model_training/info.pkl'
df = pd.read_pickle(meta_path)

my_save_path = ''

save_path = dataset_folder

test_img_prefix = [i['img'][:4] for i in datacard['datasets']['test']]
train_img_prefix = [i['img'][:4] for i in datacard['datasets']['train']]

In [None]:
# df.shape[0]
for i in range(df.shape[0]):
    dat = df.iloc[i,:]

    # image path
    fix_n5_path = dat['source'] # /mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'

    # create Zarr file object
    fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
    n5_setups = list(fix_zarr.keys())
    voxel_size = (2.0,1.3,1.3)

    # set your parameters here
    reference_chan = dat['ref_channel'] # Integer or None
    segment_chan = dat['channel']

    crop_size = dat['crop_size']
    pos = dat['corner']
    prefix = dat['ID']
    
    if prefix in train_img_prefix:
        continue
    plane_position = int(dat['plane_position'])

    FoV = [100,768,768]

    mask_path = os.path.join(save_path, prefix+'_mask.tif')

    img = da.from_zarr(fix_zarr[n5_setups[segment_chan]]['timepoint0']['s0'])
    img_ref = da.from_zarr(fix_zarr[n5_setups[reference_chan]]['timepoint0']['s0'])

    # set the corner of FoV in napari
    top_corner = tuple(i-(k-j)//2 for i,j,k in zip(pos, crop_size, FoV))
    bottom_corner = tuple(i+j+(k-j)//2 for i,j,k in zip(pos, crop_size, FoV))
    top_corner = tuple(j if j>=i else i for i,j in zip([0,0,0],top_corner))
    bottom_corner = tuple(j if j<=i else i for i,j in zip(img.shape,bottom_corner))
    
    # prepare to make border lines
    top_border_corner = tuple((k-j)//2 for j,k in zip(crop_size, FoV))
    bottom_border_corner = tuple(j+(k-j)//2 for j,k in zip(crop_size, FoV))

    FoV_segment = img[tuple(slice(i,j) for i,j in zip(top_corner, bottom_corner))]
    FoV_reference = img_ref[tuple(slice(i,j) for i,j in zip(top_corner, bottom_corner))]


    # make labeling data
    labels = np.zeros_like(FoV_reference)
    label_plane = io.imread(mask_path)
    label_shape = label_plane.shape
    label_top_corner = [(i-j) // 2 for i,j in zip(FoV[1:],crop_size[1:])]
    slicer = (plane_position,) + tuple(slice(i,i+j,None) for i,j in zip(label_top_corner, label_shape))
    labels[slicer] = label_plane

    # napari browsing
    viewer = napari.Viewer()
    viewer.add_image(FoV_reference, scale=voxel_size, contrast_limits=[0,10000],blending='additive', visible=False)
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]]],
                      edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[top_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],bottom_border_corner[2]*voxel_size[2]],[bottom_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')
    viewer.add_shapes([[bottom_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]],[top_border_corner[1]*voxel_size[1],top_border_corner[2]*voxel_size[2]]],
                  edge_width=2,edge_color='white',ndim=2,shape_type='line')

    viewer.add_labels(labels, scale=voxel_size,blending='additive')
    viewer.add_image(FoV_segment, scale=voxel_size, contrast_limits=[0,10000],blending='additive')

    # viewer.dims.set_point(axis=0, value=plane_position)
    viewer.dims.set_current_step(axis=0, value=plane_position)
    viewer.show(block=True)

    print(i)

In [None]:
img = []
img_norm = []
mask = []
for train_data in datacard['datasets']['train']: 
    for key in train_data.keys():
        source = os.path.join(dataset_folder, train_data[key])
        if os.path.isfile(source):
            if key == 'img':
                img.append(io.imread(source)[1])
            elif key == 'img_norm':
                img_norm.append(io.imread(source)[1])
            elif key == 'label':
                mask.append(io.imread(source))

In [None]:
const = -25
img = np.stack(img[:const])
img_norm = np.stack(img_norm[:const])
mask = np.stack(mask[:const])

In [None]:
viewer = napari.Viewer()
viewer.add_image(img, name='original', contrast_limits=[100,20000], blending='additive')
viewer.add_image(img_norm, name='image01', contrast_limits=[0,2], blending='additive')
viewer.add_labels(mask)

In [None]:
# load information
info_path = '/mnt/ampa02_data01/tmurakami/model_training/info.pkl'

In [None]:
pd.read_pickle(info_path).iloc[68,:]

In [None]:
datacard['datasets']['train'][14]