# Benchmark Piscis

In [None]:
import json
import numpy as np

from pathlib import Path

from piscis import Piscis
from piscis.core import adjust_parameters
from piscis.data import load_datasets, transform_batch, transform_subdataset
from piscis.losses import dice_loss, smoothf1_loss, weighted_bce_loss
from piscis.metrics import compute_metrics
from piscis.utils import pad_and_stack

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [None]:
# Define path to outputs folder.
outputs_path = Path().absolute().parent / 'outputs'

# Define paths to datasets.
datasets_path = outputs_path / 'datasets'
piscis_datasets_path = datasets_path / 'piscis'
deepblink_datasets_path = datasets_path / 'deepblink'

# Define path to benchmarks.
benchmarks_path = outputs_path / 'benchmarks'
benchmarks_path.mkdir(parents=True, exist_ok=True)

### Set the default threshold parameter.

In [None]:
# List deepBlink datasets.
deepblink_datasets_list = [file.stem for file in deepblink_datasets_path.glob('*.npz')]

In [None]:
# Define search.
thresholds = np.arange(0.5, 9.0, 0.5)

In [None]:
# Run search.
f1s = {}
for deepblink_dataset in deepblink_datasets_list:
    
    # Load the deepBlink valid dataset.
    valid_ds = load_datasets(deepblink_datasets_path / f'{deepblink_dataset}.npz', adjustment=None, load_train=False, load_valid=True, load_test=False)['valid']
    images = valid_ds['images']
    coords = valid_ds['coords']

    # Load the Piscis model.
    model = Piscis(model_name=f'deepblink_{deepblink_dataset}')

    # Run the Piscis model.
    _, y = model.predict(images, threshold=9, intermediates=True)

    dataset_f1s = []
    f1s[deepblink_dataset] = dataset_f1s

    for threshold in thresholds:

        coords_pred = adjust_parameters(y, threshold)
        
        threshold_f1s = []
        
        for c_pred, c in zip(coords_pred, coords):
            if (c_pred.size > 0) and (c.size > 0):
                f1 = compute_metrics(c_pred, c, distance_thresholds=np.linspace(0, 3, 50), return_df=False)['f1']
            else:
                f1 = 0
            threshold_f1s.append(f1)
    
        dataset_f1s.append(np.mean(threshold_f1s))

In [None]:
# Default threshold.
default_threshold = thresholds[np.argmax(np.sum(np.array(list(f1s.values())), axis=0))]
default_threshold

### Run and benchmark Piscis on Piscis dataset.

In [None]:
# Create dictionaries for F1 scores.
f1s = {}
dice_f1s = {}

In [None]:
# List subdatasets within the composite dataset.
dataset = '20230905'
subdataset_paths = list(Path(piscis_datasets_path / dataset).glob('*.npz'))

In [None]:
# Loop through Piscis models trained using different loss functions.
for suffix in ['', '_dice', '_bce', '_focal']:

    # Load the Piscis model.
    model = Piscis(model_name=f'{dataset}{suffix}')

    dataset_f1s = {}
    f1s[dataset] = dataset_f1s

    for subdataset_path in subdataset_paths:
    
        subdataset = subdataset_path.stem
    
        # Load the test dataset.
        test_ds = load_datasets(subdataset_path, load_train=False, load_valid=False, load_test=True)['test']
        images = pad_and_stack(test_ds['images'])
        test_ds['images'] = images
        coords = test_ds['coords']
    
        subdataset_f1s = []
        dataset_f1s[subdataset] = subdataset_f1s
    
        # Run the Piscis models.
        if suffix in ['_dice', '_focal']:
            coords_pred = model.predict(images, threshold=0.5)
        elif suffix == '_bce':
            coords_pred = model.predict(images, threshold=0.95)
        else:
            coords_pred = model.predict(images, threshold=default_threshold)
        
        for c_pred, c in zip(coords_pred, coords):
            if (c_pred.size > 0) and (c.size > 0):
                f1 = compute_metrics(c_pred, c, distance_thresholds=np.linspace(0, 3, 50), return_df=False)['f1']
            else:
                f1 = 0
            subdataset_f1s.append(f1)
    
        with open(benchmarks_path / f'piscis{suffix}_f1s.json', 'w') as f:
            json.dump(f1s, f)

### Run and benchmark Piscis on deepBlink datasets.

In [None]:
with open(benchmarks_path / 'piscis_f1s.json', 'r') as f:
    f1s = json.load(f)

for deepblink_dataset_path in deepblink_datasets_path.glob('*'):

    deepblink_dataset = deepblink_dataset_path.stem
    
    # Load the test dataset.
    test_ds = load_datasets(deepblink_dataset_path, load_train=False, load_valid=False, load_test=True)['test']
    test_ds = transform_subdataset(test_ds, (512, 512), min_spots=1)
    images = test_ds['images'].squeeze()
    coords = test_ds['coords']

    # Load the Piscis model.
    model = Piscis(model_name=f'deepblink_{deepblink_dataset}')

    # Run the Piscis model.
    coords_pred = model.predict(images, threshold=default_threshold)

    deepblink_dataset_f1s = []
    f1s[deepblink_dataset] = deepblink_dataset_f1s

    for c_pred, c in zip(coords_pred, coords):
        if (c_pred.size > 0) and (c.size > 0):
            f1 = compute_metrics(c_pred, c, distance_thresholds=np.linspace(0, 3, 50))['f1']
        else:
            f1 = 0
        deepblink_dataset_f1s.append(f1)
    
    with open(benchmarks_path / 'piscis_f1s.json', 'w') as f:
        json.dump(f1s, f)

### Compare F1 score estimation between SmoothF1 and Dice loss functions.

In [None]:
# Load the Piscis train dataset.
train_ds = load_datasets(piscis_datasets_path / dataset, adjustment=None, load_train=True, load_valid=False, load_test=False)['train']
images = pad_and_stack(train_ds['images'])
train_ds['images'] = images
coords = train_ds['coords']

In [None]:
# Compute -SmoothF1, -Dice, and F1 scores.

# Load the Piscis model.
model = Piscis(model_name=dataset)

# Run the Piscis model.
coords_pred, y = model.predict(images, intermediates=True)
coords_pad_length = max((len(c) for c in coords))

f1s = []
smoothf1s = []

for i in range(len(images)):
    batch = {k: v[i:i + 1] for k, v in train_ds.items()}
    transformed_batch = transform_batch(batch, coords_pad_length)
    deltas = np.moveaxis(y[i, :2].to_numpy(), 0, -1)
    labels = np.moveaxis(y[i, 2:3].to_numpy(), 0, -1)
    smoothf1 = smoothf1_loss(deltas, labels, transformed_batch['deltas'][0], transformed_batch['labels'][0], 1, 3.0)
    smoothf1s.append(-float(smoothf1))
    c_pred = coords_pred[i]
    c = coords[i]
    if (c_pred.size > 0) and (c.size > 0):
        f1 = compute_metrics(c_pred, c, distance_thresholds=np.linspace(0, 3, 50))['f1']
    else:
        f1 = 0
    f1s.append(f1)

with open(benchmarks_path / 'piscis_train_smoothf1.json', 'w') as f:
    json.dump({'f1s': f1s, 'smoothf1s': smoothf1s}, f)

In [None]:
# Compute -Dice and F1 scores.

# Load the Piscis model.
model = Piscis(model_name=f'{dataset}_dice')

# Run the Piscis model.
coords_pred, y = model.predict(images, threshold=0.5, intermediates=True)
coords_pad_length = max((len(c) for c in coords))

f1s = []
dices = []

for i in range(len(images)):
    batch = {k: v[i:i + 1] for k, v in train_ds.items()}
    transformed_batch = transform_batch(batch, dilation_iterations=0, coords_pad_length=coords_pad_length)
    deltas = np.moveaxis(y[i, :2].to_numpy(), 0, -1)
    labels = np.moveaxis(y[i, 2:3].to_numpy(), 0, -1)
    dice = dice_loss(labels, transformed_batch['labels'][0])
    dices.append(-float(dice))
    c_pred = coords_pred[i]
    c = coords[i]
    if (c_pred.size > 0) and (c.size > 0):
        f1 = compute_metrics(c_pred, c, distance_thresholds=np.linspace(0, 3, 50))['f1']
    else:
        f1 = 0
    f1s.append(f1)

with open(benchmarks_path / 'piscis_train_dice.json', 'w') as f:
    json.dump({'f1s': f1s, 'dices': dices}, f)