<div class="alert alert-block alert-info"> <b>NOTE</b> Please select the kernel <code>Python [conda env:gnn-pytorch]</code> for this notebook. </div>

# 3. Inference - Tracking

In [1]:
import os
import subprocess
import sys
import re

In [2]:
software_dir = 'software'
sys.path.append(software_dir)

In [3]:
data_dir = 'ctc-data'

models = [
    'Fluo-N2DL-HeLa'
]
seg_dirs = ['gt-seg', 'pred-seg']

pattern = re.compile('\d{3}')
data_ids = [f for f in os.listdir(os.path.join(data_dir, seg_dirs[0], models[0])) if pattern.fullmatch(f)]

In [4]:
modality = '2D'
min_size = 10

model_metric_learning_dir = os.path.join(software_dir, 'parameters/Features_Models')
models_pytorch_lightning = {
    'Fluo-N2DL-HeLa': os.path.join(software_dir, 'parameters/Tracking_Models/Fluo-N2DL-HeLa/checkpoints/epoch=312.ckpt')
}

In [5]:
from preprocess_seq2graph_clean import create_csv
from inference_clean import predict
from postprocess_clean import Postprocess

def run_preprocess(input_images, input_segmentation, input_model, min_cell_size, output_csv):
    create_csv(input_images, input_segmentation, input_model, output_csv, min_cell_size)

def run_inference(model_path, num_seq, output_csv):
    predict(model_path, output_csv, num_seq)


# Override the create_save_dir function in Postprocess to support 3 digit numbers
class Postprocess_Fix(Postprocess):
    def create_save_dir(self):
        num_seq = self.dir_result.split('/')[-1][:3]
        save_tra_dir = os.path.join(self.dir_result, f"../{num_seq}_RES")
        self.save_tra_dir = save_tra_dir
        os.makedirs(self.save_tra_dir, exist_ok=True)

def run_postprocess(modality, path_inference_output, path_Seg_result):
    assert modality == '2D' or modality == '3D'

    is_3d = '3d' in modality.lower()
    directed = True
    merge_operation = 'AND'

    pp = Postprocess_Fix(is_3d=is_3d,
                     type_masks='tif', merge_operation=merge_operation,
                     decision_threshold=0.5,
                     path_inference_output=path_inference_output, center_coord=False,
                     directed=directed,
                     path_seg_result=path_Seg_result)
    all_frames_traject, trajectory_same_label, df_trajectory, str_track = pp.create_trajectory()
    pp.fill_mask_labels(debug=False)

  from .autonotebook import tqdm as notebook_tqdm
  if not hasattr(tensorboard, '__version__') or LooseVersion(tensorboard.__version__) < LooseVersion('1.15'):
  if not hasattr(tensorboard, '__version__') or LooseVersion(tensorboard.__version__) < LooseVersion('1.15'):
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (np.object, string),
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (np.bool, bool),
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  types_pb2.DT_STRING: np.object,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  types_pb2.DT_BOOL: np.bool,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  types_pb2.DT

In [6]:
for seg_type in seg_dirs:
    for m in models:
        print('------', m, '------')
        sub_dir = os.path.join(data_dir, seg_type, m)
        for i in data_ids:
            print(i)
            run_preprocess(
                input_images=os.path.join(sub_dir, i),
                input_segmentation=os.path.join(sub_dir, f'{i}_SEG_RES'),
                input_model=os.path.join(model_metric_learning_dir, m, 'all_params.pth'),
                min_cell_size=min_size,
                output_csv=os.path.join(sub_dir, f'{i}_CSV')
            )

            run_inference(
                model_path=models_pytorch_lightning[m],
                num_seq=i,
                output_csv=sub_dir,
            )

            run_postprocess(
                modality=modality,
                path_inference_output=os.path.join(sub_dir, f'{i}_RES_inference'),
                path_Seg_result=os.path.join(sub_dir, f'{i}_SEG_RES')
            )
            print('/n')

------ Fluo-N2DL-HeLa ------
001
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask000.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask001.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask002.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask003.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask004.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask005.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask006.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask007.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask008.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask009.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask010.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask011.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask012.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask013.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa/001_SEG_RES/mask014.tif
start: ctc-data/gt-seg/Fluo-N2DL-HeLa