# Benchmark RS-FISH

In [None]:
!pip install zarr

In [None]:
import json
import numpy as np
import pandas as pd
import subprocess
import zarr

from pathlib import Path

from piscis.data import load_datasets
from piscis.metrics import compute_metrics
from piscis.utils import pad_and_stack

In [None]:
# Define path to outputs folder.
outputs_path = Path().absolute().parent / 'outputs'

# Define paths to datasets.
datasets_path = outputs_path / 'datasets'
piscis_datasets_path = datasets_path / 'piscis'

# Define path to benchmarks.
benchmarks_path = outputs_path / 'benchmarks'
benchmarks_path.mkdir(parents=True, exist_ok=True)

# Define path to temporary folder.
tmp_path = Path('tmp')
tmp_path.mkdir(parents=True, exist_ok=True)

### Clone and compile the RS-FISH repository.

In [None]:
!git clone https://github.com/PreibischLab/RS-FISH-Spark tmp/RS-FISH-Spark

In [None]:
!mvn clean package -f tmp/RS-FISH-Spark/pom.xml
!cp tmp/RS-FISH-Spark/target/RS-Fish-jar-with-dependencies.jar ../outputs/rs-fish.jar

### Generate N5 dataset. 

In [None]:
# List subdatasets within the composite dataset.
dataset = '20230905'
subdataset_paths = tuple(Path(piscis_datasets_path / dataset).glob('*.npz'))

In [None]:
test_ds = load_datasets(piscis_datasets_path / '20230905', adjustment='normalize', load_train=False, load_valid=False, load_test=True)['test']
images = pad_and_stack(test_ds['images'])
test_ds['images'] = images
coords = test_ds['coords']

In [None]:
# Save images as a N5 dataset.
store = zarr.N5Store(tmp_path / '20230905.n5')
root = zarr.group(store=store)
compressor = zarr.GZip(level=-1)
for subdataset_path in subdataset_paths:

    subgroup = root.create_group(subdataset_path.stem, overwrite=True)
    
    test_ds = load_datasets(subdataset_path, adjustment='normalize', load_train=False, load_valid=False, load_test=True)['test']
    images = pad_and_stack(test_ds['images'])
    test_ds['images'] = images
    coords = test_ds['coords']

    for i, image in enumerate(images):
        subgroup.create_dataset(i, data=image, compressor=compressor, overwrite=True)

### Run and benchmark RS-FISH.

In [None]:
# Define the base command for running RS-FISH.
base_command = [
    'java',
    '-cp',
    str(outputs_path / 'rs-fish.jar'),
    '-Xmx20G',
    '-Dspark.driver.extraJavaOptions=-Dlog4j.configuration=file:/path/to/log4j.properties',
    '-Dspark.master=local[8]',
    'net.preibisch.rsfish.spark.SparkRSFISH',
    f"--image={str(tmp_path / f'{dataset}.n5')}",
    '--minIntensity=0',
    '--maxIntensity=1',
    '--anisotropy=1.0',
    f"--output={str(tmp_path / 'output.csv')}",
]

In [None]:
# Define grid search.
sigmas = np.linspace(1.0, 3.0, 5)
thresholds = np.linspace(0.002, 0.04, 20)

In [None]:
# Run grid search.
f1s = {}
for subdataset_path in subdataset_paths:
    
    coords = load_datasets(subdataset_path, adjustment='normalize', load_train=False, load_valid=False, load_test=True)['test']['coords']

    subdataset_f1s = {}
    f1s[subdataset_path.stem] = subdataset_f1s
    
    for i, c in enumerate(coords):
        subdataset_f1s.setdefault(i, {})
        for j, sigma in enumerate(sigmas):
            subdataset_f1s[i].setdefault(j, {})
            for k, threshold in enumerate(thresholds):
                command = base_command + [f'--dataset={subdataset_path.stem}/{i}', f'--sigma={sigma}', f'--threshold={threshold}']
                with open(tmp_path / 'output_log.txt', 'w') as output_log:
                    result = subprocess.run(command, stdout=output_log, stderr=output_log)
                try:
                    csv = pd.read_csv(tmp_path / 'output.csv')
                    c_pred = np.stack((csv['y'], csv['x']), axis=-1)
                    if (c_pred.size > 0) and (c.size > 0):
                        f1 = compute_metrics(c_pred, c, evaluation_metrics='f1', distance_thresholds=np.linspace(0, 3, 50))['f1']
                    else:
                        f1 = 0
                except pd.errors.EmptyDataError:
                    f1 = 0
                subdataset_f1s[i][j][k] = f1
                with open(benchmarks_path / 'rs_fish_f1s.json', 'w') as f:
                    json.dump(f1s, f)