# Benchmark Trackmate

In [None]:
!pip install pyimagej

In [None]:
import imagej
import json
import numpy as np
import pandas as pd
import tifffile

from pathlib import Path

from piscis.data import load_datasets
from piscis.metrics import compute_metrics
from piscis.utils import pad_and_stack

In [None]:
# Define path to outputs folder.
outputs_path = Path().absolute().parent / 'outputs'

# Define paths to datasets.
datasets_path = outputs_path / 'datasets'
piscis_datasets_path = datasets_path / 'piscis'

# Define path to benchmarks.
benchmarks_path = outputs_path / 'benchmarks'
benchmarks_path.mkdir(parents=True, exist_ok=True)

# Define path to temporary folder.
tmp_path = Path('tmp')
tmp_path.mkdir(parents=True, exist_ok=True)

### Run and benchmark Trackmate.

In [None]:
# Initialize FIJI environment.
ij = imagej.init('sc.fiji:fiji')

In [None]:
# Define the Groovy script for running Trackmate.
script = """
//@double radius
//@double threshold

System.out = new PrintStream(new OutputStream() {
    public void write(int b) {
        // NO-OP
    }
})

import ij.IJ
import fiji.plugin.trackmate.Model
import fiji.plugin.trackmate.Settings
import fiji.plugin.trackmate.TrackMate
import fiji.plugin.trackmate.detection.LogDetectorFactory

// Assume you want the current image or specify path with IJ.openImage(path)
imp = IJ.openImage('tmp/image.tif')

import fiji.plugin.trackmate.Model
import fiji.plugin.trackmate.Settings
import fiji.plugin.trackmate.TrackMate

import fiji.plugin.trackmate.detection.LogDetectorFactory

// Swap Z and T dimensions if T=1
dims = imp.getDimensions() // default order: XYCZT
if (dims[4] == 1) {
    imp.setDimensions(dims[2], dims[4], dims[3])
}

// Setup settings for TrackMate
settings = new Settings(imp)

settings.detectorFactory = new LogDetectorFactory()
settings.detectorSettings = settings.detectorFactory.getDefaultSettings()
settings.detectorSettings['RADIUS'] = radius
settings.detectorSettings['THRESHOLD'] = threshold

// Run TrackMate for spot detection
model = new Model()
trackmate = new TrackMate(model, settings)

trackmate.process()

// Get spots and their coordinates
spots = model.getSpots().iterable(true)
spotCoordinates = []

spots.each { spot ->
    coord = [spot.getDoublePosition(0).toString(), spot.getDoublePosition(1).toString(), spot.getDoublePosition(2).toString()]
    spotCoordinates.add(coord)
}

spotCoordinates
"""

In [None]:
# List subdatasets within the composite dataset.
dataset = '20230905'
subdataset_paths = Path(piscis_datasets_path / dataset).glob('*.npz')

In [None]:
# Define grid search.
radii = np.linspace(1.0, 3.0, 5)
thresholds = np.linspace(0.02, 0.4, 20)

In [None]:
# Run grid search.
f1s = {}
for subdataset_path in subdataset_paths:

    test_ds = load_datasets(subdataset_path, adjustment='normalize', load_train=False, load_valid=False, load_test=True)['test']
    images = pad_and_stack(test_ds['images'])
    test_ds['images'] = images
    coords = test_ds['coords']

    subdataset_f1s = {}
    f1s[subdataset_path.stem] = subdataset_f1s
    
    for i, (image, c) in enumerate(zip(images, coords)):
        tifffile.imwrite(tmp_path / 'image.tif', image.astype(np.float32))
        subdataset_f1s.setdefault(i, {})
        for j, radius in enumerate(radii):
            subdataset_f1s[i].setdefault(j, {})
            for k, threshold in enumerate(thresholds):
                args = {'radius': float(radius), 'threshold': float(threshold)}
                result = ij.py.run_script('groovy', script, args).getOutput('result')
                c_pred = np.array([[float(str(string)) for string in list(row)[1::-1]] for row in result])
                if (c_pred.size > 0) and (c.size > 0):
                    f1 = compute_metrics(c_pred, c, evaluation_metrics='f1', distance_thresholds=np.linspace(0, 3, 50))['f1']
                else:
                    f1 = 0
                subdataset_f1s[i][j][k] = f1
                with open(benchmarks_path / 'trackmate_f1s.json', 'w') as f:
                    json.dump(f1s, f)