This notebook runs segmentation using the tensorstore package, and the associated changes made to the neurotorch array and predictor classes.

In [None]:
import tracemalloc
from datetime import datetime
import torch
from matplotlib import pyplot as plt
import os
import numpy as np
import ac_segmentation
import ac_segmentation.neurotorch.datasets.dataset
from ac_segmentation.neurotorch.datasets.dataset import open_ZarrTensor, create_EmptyTensor
import ac_segmentation.neurotorch.core.predictor
import ac_segmentation.neurotorch.nets.RSUNet

Predictor = ac_segmentation.neurotorch.core.predictor.Predictor
Vector = ac_segmentation.neurotorch.datasets.datatypes.Vector
BoundingBox = ac_segmentation.neurotorch.datasets.datatypes.BoundingBox
TSArray = ac_segmentation.neurotorch.datasets.dataset.TSArray
Array = ac_segmentation.neurotorch.datasets.dataset.Array
RSUNet = ac_segmentation.neurotorch.nets.RSUNet.RSUNet

In [None]:
##Open input tensor
in_arr = open_ZarrTensor("/ACdata/Users/kevin/exaspim_ome_zarr/output_exa4/test.zarr/tile_x_0002_y_0001_z_0000_ch_488/0/", bytes_limit= 100_000_000)
in_arr = in_arr[0,0,4000:4250,4250:4500,4250:4500]#.transpose()

##Create output tensor
out_arr = create_EmptyTensor('/ACdata/Users/connorl/Image_Files/TS_Array_tile_x_0002_y_0001_z_0000_ch_488.zarr', in_arr.shape, dtype = 'float32', fill_value=-np.inf)

In [None]:
start = datetime.now()

checkpt_file = "/allen/programs/celltypes/workgroups/mousecelltypes/MachineLearning/Olga/forConnor/new_model/best.ckpt"
predictor = Predictor(RSUNet(), checkpt_file, gpu_device=None)

###Create input and output array objects
inarr = TSArray(in_arr, iteration_size=BoundingBox(Vector(0, 0, 0),Vector(64, 64, 64)), stride=Vector(32, 32, 32))
outarr = TSArray(out_arr, iteration_size=BoundingBox(Vector(0, 0, 0),Vector(64, 64, 64)), stride=Vector(32, 32, 32), prob_map=True)

###Optional masking
##Li Thresholding
#inarr.set_Li_Threshold(1000)

##Set Mask from Array or Tensorstore
#mask = np.ones(inarr.tensor.shape)
#inarr.setMask(mask)

###Optional intensity adjustment
inarr.set_Rescale_Perc((90,100))

###Run segmentation
torch.set_num_threads(46) 
predictor.run(inarr, outarr, batch_size=100,  mini_batch_size=20, max_pix=40000)

end = datetime.now()
print(end-start)