<div class="alert alert-block alert-info"> <b>NOTE</b> Please select the kernel <code>Python [conda env:gnn-pytorch]</code> for this notebook. </div>

# 3. Inference - Tracking

In [1]:
import os
import subprocess
import sys
import re

In [None]:
software_dir = 'software'
sys.path.append(software_dir)

In [2]:
data_dir = 'data'

models = [
    'Fluo-N2DL-HeLa',
    'Fluo-N2DH-SIM+'
]

pattern = re.compile('\d{3}')
data_ids = [f for f in os.listdir(os.path.join(data_dir, models[0])) if pattern.fullmatch(f)]

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
modality = '2D'
min_size = '10'

model_metric_learning_dir = os.path.join(software_dir, 'parameters/Features_Models')
models_pytorch_lightning = {
    'Fluo-N2DL-HeLa': os.path.join(software_dir, 'parameters/Tracking_Models/Fluo-N2DL-HeLa/checkpoints/epoch=312.ckpt'),
    'Fluo-N2DH-SIM+': os.path.join(software_dir, 'parameters/Tracking_Models/Fluo-N2DH-SIM+/checkpoints/epoch=132.ckpt')
}

In [18]:
from preprocess_seq2graph_clean import create_csv
from inference_clean import predict
from postprocess_clean import Postprocess

def run_preprocess(input_images, input_segmentation, input_model, min_cell_size, output_csv):
    create_csv(input_images, input_segmentation, input_model, output_csv, min_cell_size)

def run_inference(model_path, num_seq, output_csv):
    predict(model_path, output_csv, num_seq)


# Override the create_save_dir function in Postprocess to support 3 digit numbers
class Postprocess_Fix(Postprocess):
    def create_save_dir(self):
        num_seq = self.dir_result.split('/')[-1][:3]
        save_tra_dir = os.path.join(self.dir_result, f"../{num_seq}_RES")
        self.save_tra_dir = save_tra_dir
        os.makedirs(self.save_tra_dir, exist_ok=True)
    
def run_postprocess(modality, path_inference_output, path_Seg_result):
    assert modality == '2D' or modality == '3D'

    is_3d = '3d' in modality.lower()
    directed = True
    merge_operation = 'AND'

    pp = Postprocess_Fix(is_3d=is_3d,
                     type_masks='tif', merge_operation=merge_operation,
                     decision_threshold=0.5,
                     path_inference_output=path_inference_output, center_coord=False,
                     directed=directed,
                     path_seg_result=path_Seg_result)
    all_frames_traject, trajectory_same_label, df_trajectory, str_track = pp.create_trajectory()
    pp.fill_mask_labels(debug=False)

In [19]:
for m in models:
    print('------', m, '------')
    for i in data_ids:
        print(i)
        # p = subprocess.run(
        #     ['python', os.path.join(software_dir, 'preprocess_seq2graph_clean.py'),
        #      '-cs', min_size,
        #      '-ii', os.path.join(data_dir, m, i),
        #      '-iseg', os.path.join(data_dir, m, f'{i}_SEG_RES'),
        #      '-im', os.path.join(model_metric_learning_dir, m, 'all_params.pth'),
        #      '-oc', os.path.join(data_dir, m, f'{i}_CSV')
        #     ],
        #     stdout=subprocess.PIPE,
        #     stderr=subprocess.PIPE)
        # print(p.stdout.decode("utf-8").strip())
        # print(p.stderr.decode("utf-8").strip())
        # print('\n')

        # run_preprocess(
        #     input_images=os.path.join(data_dir, m, i),
        #     input_segmentation=os.path.join(data_dir, m, f'{i}_SEG_RES'),
        #     input_model=os.path.join(model_metric_learning_dir, m, 'all_params.pth'),
        #     min_cell_size=min_size,
        #     output_csv=os.path.join(data_dir, m, f'{i}_CSV')
        # )

        # p = subprocess.run(
        #     ['python', os.path.join(software_dir, 'inference_clean.py'),
        #      '-mp', models_pytorch_lightning[m],
        #      '-ns', i,
        #      '-oc', os.path.join(data_dir, m)
        #     ],
        #     stdout=subprocess.PIPE,
        #     stderr=subprocess.PIPE)
        # print(p.stdout.decode("utf-8").strip())
        # print(p.stderr.decode("utf-8").strip())
        # print('\n')

        # run_inference(
        #     model_path=models_pytorch_lightning[m],
        #     num_seq=i,
        #     output_csv=os.path.join(data_dir, m)
        # )

        # p = subprocess.run(
        #     ['python', os.path.join(inference_dir, 'postprocess_clean.py'),
        #      '-modality', modality,
        #      '-iseg', os.path.join(data_dir, m, f'{i}_SEG_RES'),
        #      '-oi', os.path.join(data_dir, m, f'{i}_RES_inference')
        #     ],
        #     stdout=subprocess.PIPE,
        #     stderr=subprocess.PIPE)
        # print(p.stdout.decode("utf-8").strip())
        # print(p.stderr.decode("utf-8").strip())
        # print('\n')

        run_postprocess(
            modality=modality,
            path_inference_output=os.path.join(data_dir, m, f'{i}_RES_inference'),
            path_Seg_result=os.path.join(data_dir, m, f'{i}_SEG_RES')
        )
        print('/n')

------ Fluo-N2DL-HeLa ------
011
Load data/Fluo-N2DL-HeLa/011_RES_inference/pytorch_geometric_data.pt
Load data/Fluo-N2DL-HeLa/011_RES_inference/all_data_df.csv
Load data/Fluo-N2DL-HeLa/011_RES_inference/raw_output.pt
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask000.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask001.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask002.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask003.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask004.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask005.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask006.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask007.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask008.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask009.tif
Saving file: data/Fluo-N2DL-HeLa/011_SEG_RES/../011_RES/mask010.tif
Saving file: data/Fluo-N2DL-HeLa/0