In [48]:
def extract_points_from_json(json_file, voxel_size):
    """
    Extract points from a JSON file and scale them to match the tomogram's frame.

    Parameters:
        json_file (str): Path to the JSON file.
        voxel_size (tuple): Voxel size in the tomogram (z, y, x) in the same units as the JSON coordinates.

    Returns:
        np.ndarray: Scaled points as a NumPy array.
    """
    with open(json_file, 'r') as f:
        data = json.load(f)
    if "points" not in data:
        return None
    points = []
    for pt in data["points"]:
        loc = pt["location"]
        # Convert [X, Y, Z] → [Z, Y, X] for Napari and scale by voxel size
        points.append([
            loc["z"] / voxel_size[0],
            loc["y"] / voxel_size[1],
            loc["x"] / voxel_size[2]
        ])
    return np.array(points)

In [49]:
import os
import json
import numpy as np
from glob import glob
from tqdm import tqdm
gt_json_dir_path="/Users/yusufberkoruc/Desktop/Master_thesis/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/TS_86_3/Picks"  # Ground truth JSON directory
prediction_json_path="/Users/yusufberkoruc/Desktop/Master_thesis/3d-adapted-model_6/TS_86_3_protein_detections_peak_local_max.json"  # Prediction JSON file

# --- Predictions from JSON ---
prediction_points = None
if prediction_json_path and os.path.exists(prediction_json_path):
    with open(prediction_json_path, 'r') as f:
        prediction_data = json.load(f)
    if isinstance(prediction_data, list):  # Handle list of lists structure
        prediction_points = np.array(prediction_data)  # Directly convert to NumPy array
    print(f"✅ Loaded predictions from {prediction_json_path}")

# --- Ground Truth ---
gt_points_dict = {}
voxel_size = (10.0, 10.0, 10.0)  # Example voxel size in Ångstroms
if gt_json_dir_path and os.path.isdir(gt_json_dir_path):
    gt_json_files = glob(os.path.join(gt_json_dir_path, "*.json"))
    for jf in gt_json_files:
        label = os.path.splitext(os.path.basename(jf))[0]
        points = extract_points_from_json(jf, voxel_size)
        if points is not None and len(points) > 0:
            gt_points_dict[label] = points
    print(f"✅ Loaded {len(gt_points_dict)} ground truth JSON files")
else:
    print(f"⚠️ No valid ground truth JSON directory: {gt_json_dir_path}")


✅ Loaded predictions from /Users/yusufberkoruc/Desktop/Master_thesis/3d-adapted-model_6/TS_86_3_protein_detections_peak_local_max.json
✅ Loaded 6 ground truth JSON files


In [50]:
particle_list = list(gt_points_dict.keys())
particle_list.sort()
print(f"Particle list: {particle_list}")
particle_name = particle_list[3]  # Get the last particle name
particle = gt_points_dict[particle_name]
print(f"Particle name: {particle_name}")

Particle list: ['apo-ferritin', 'beta-amylase', 'beta-galactosidase', 'ribosome', 'thyroglobulin', 'virus-like-particle']
Particle name: ribosome


In [51]:
import numpy as np
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
from skimage.metrics import mean_squared_error
def _calc_sMAPE(n,m):
    if n == 0 and m == 0:
        return 0
    elif (n == 0 and m > 0) or (n > 0 and m == 0):
        return 1
    elif n > 0 and m > 0:
        # from https://arxiv.org/pdf/2108.01234.pdf
        return abs(n-m)/(abs(n) + abs(m))
    else: 
        return 1
    
def _calc_mae(n, m):
    return np.abs(n-m)

def _compute_pairwise_distances(gt, pred):
    pairwise_distances = cdist(gt, pred, metric="euclidean")

    if np.any(pairwise_distances):
        pairwise_distances = np.stack(pairwise_distances)
    return pairwise_distances

def _dev_percentage(n,m):
    if n > 0 and m > 0:
        return np.abs(((n - m)/n))
    elif (n > 0 and m == 0) or (n == 0 and m > 0):
        return 1
    elif n == 0 and m == 0:
        return 0
    else:
        raise Exception("Number of ground truths and/or preds are negative: gts = {n}, preds = {m}.")
        
def metric_coords(gts, preds, match_distance=45):
    """
    gt: [(x,y), (...), ...]
    pred: [(x,y), (...), ...]
    """
    n = len(gts)
    m = len(preds)

    if n == 0 and m == 0:
        return 1, 1, 1, 0, 0, 0
    
    elif (n == 0 and m > 0) or (n > 0 and m == 0):
        return 0, 0, 0, _dev_percentage(n, m), 1, _calc_mae(n, m)
    
    elif n > 0 and m > 0:
        pairwise_distances = _compute_pairwise_distances(gts, preds)
        if np.any(pairwise_distances):
            trivial = not np.any(pairwise_distances < match_distance)
            if trivial:
                true_positives = 0
            else:
                # match the predicted points to labels via linear cost assignment ('hungarian matching')
                max_distance = pairwise_distances.max()
                # the costs for matching: the first term sets a low cost for all pairs that are in
                # the matching distance, the second term sets a lower cost for shorter distances,
                # so that the closest points are matched
                costs = -(pairwise_distances < match_distance).astype(float) - (max_distance - pairwise_distances) / max_distance
                # perform the matching, returns the indices of the matched coordinates
                label_ind, pred_ind = linear_sum_assignment(costs)
                print(f"label_ind: {label_ind}, pred_ind: {pred_ind}")
                return label_ind, pred_ind 
                # check how many of the matches are smaller than the match distance
                # these are the true positives
                match_ok = pairwise_distances[label_ind, pred_ind] < match_distance
                true_positives = np.count_nonzero(match_ok)
                print(f"true_positives: {true_positives}")

            # compute false positives and false negatives
            false_positives = m - true_positives
            false_negatives = n - true_positives

            precision = true_positives / (true_positives + false_positives) if true_positives > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if true_positives > 0 else 0
            # from https://www.v7labs.com/blog/f1-score-guide#:~:text=The%20F1%20score%20is%20calculated,denotes%20a%20better%20quality%20classifier.
            f1 =  (2*precision*recall)/(precision+recall) if (precision > 0 and recall > 0) else 0
            return precision, recall, f1, _dev_percentage(n, m), _calc_sMAPE(n, m), _calc_mae(n, m)  
        else:
            return 0,0,0, _dev_percentage(n, m), _calc_sMAPE(n, m), _calc_mae(n, m)
    else:
        raise Exception(f"Number of ground truths and/or predictions are negative (metric): len(gts) = {n}, len(preds) = {m}.")


In [52]:
print("Ground Truth Points:")
print(len(particle))
print("Prediction Points:")
print(len(prediction_points))
label_ind, pred_ind = metric_coords(particle, prediction_points)

Ground Truth Points:
55
Prediction Points:
289
label_ind: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54], pred_ind: [200 130  76 137  17 146  81 133  10  49 213 281 114  71  82  44  33  14
  92 206   1  45 107  89  56  62  60  41  12 259 178  90   6 128   7  42
 201 217  26 230 157  46 134  86  23  21  79 129  53 109  75  48 208  29
 263]


In [53]:
distance_matrix = []
for i,k in zip(label_ind, pred_ind):
    print(f"Label Index: {i}, Prediction Index: {k}")
    print(f"Ground Truth Point: {particle[i]}")
    print(f"Prediction Point: {prediction_points[k]}")
    gt = particle[i]
    pred = prediction_points[k]
    print(gt,pred)
    d = np.sqrt(np.sum((gt - pred)**2))
    distance_matrix.append(d)
    if d < 45:
        print("Distance is less than 45")
    else:
        print("Distance is greater than 45")
    print()
    print(mean_squared_error(particle[i], prediction_points[k]))

Label Index: 0, Prediction Index: 200
Ground Truth Point: [ 94.5084 590.7258 466.2299]
Prediction Point: [ 73 570 443]
[ 94.5084 590.7258 466.2299] [ 73 570 443]
Distance is less than 45

477.2661034033322
Label Index: 1, Prediction Index: 130
Ground Truth Point: [110.2158 575.8297 450.4116]
Prediction Point: [111 576 447]
[110.2158 575.8297 450.4116] [111 576 447]
Distance is less than 45

4.094328763333388
Label Index: 2, Prediction Index: 76
Ground Truth Point: [ 99.9208 620.7162 446.2021]
Prediction Point: [ 92 590 464]
[ 99.9208 620.7162 446.2021] [ 92 590 464]
Distance is less than 45

440.99641983000174
Label Index: 3, Prediction Index: 137
Ground Truth Point: [106.3849 615.6435 477.7299]
Prediction Point: [108 615 477]
[106.3849 615.6435 477.7299] [108 615 477]
Distance is less than 45

1.1851314233333474
Label Index: 4, Prediction Index: 17
Ground Truth Point: [ 74.7842 521.5513 466.6205]
Prediction Point: [ 77 518 464]
[ 74.7842 521.5513 466.6205] [ 77 518 464]
Distance is le

In [54]:
distance_matrix = np.array(distance_matrix)
np.sort(distance_matrix)

array([ 1.0226585 ,  1.37914036,  1.52073049,  1.59586926,  1.78502923,
        1.82772056,  1.85343429,  1.87620454,  1.88557532,  1.90794718,
        2.25085837,  2.59970542,  2.61979553,  2.68637541,  2.78837997,
        2.8448673 ,  2.90434041,  2.92857564,  3.05187988,  3.12898924,
        3.13873217,  3.13900897,  3.43727343,  3.50470916,  3.60212835,
        3.63905885,  3.86987765,  3.98117208,  4.08129889,  4.21778869,
        4.29408014,  4.34558979,  4.48251866,  4.66479299,  4.73340401,
        4.80278214,  4.93847361,  4.96063653,  5.28608155,  5.29159255,
        5.41918951,  5.55471779,  5.66850505,  5.77603095,  5.78172984,
        6.32970154,  6.42659065,  9.48587102, 15.68336523, 16.47198829,
       22.78069956, 24.74800611, 36.37291931, 37.83911085, 40.08873559])

In [55]:
import numpy as np
from scipy.spatial.distance import cdist

def metric_coords(gts, preds, match_distance=45):
    """
    gt: [(x,y), (...), ...]
    pred: [(x,y), (...), ...]
    """
    n = len(gts)
    m = len(preds)

    if n == 0 and m == 0:
        return 1, 1, 1, 0, 0, 0
    
    elif (n == 0 and m > 0) or (n > 0 and m == 0):
        return 0, 0, 0, _dev_percentage(n, m), 1, _calc_mae(n, m)
    
    elif n > 0 and m > 0:
        pairwise_distances = _compute_pairwise_distances(gts, preds)
        if np.any(pairwise_distances):
            trivial = not np.any(pairwise_distances < match_distance)
            if trivial:
                true_positives = 0
            else:
                # match the predicted points to labels via linear cost assignment ('hungarian matching')
                max_distance = pairwise_distances.max()
                # the costs for matching: the first term sets a low cost for all pairs that are in
                # the matching distance, the second term sets a lower cost for shorter distances,
                # so that the closest points are matched
                costs = -(pairwise_distances < match_distance).astype(float) - (max_distance - pairwise_distances) / max_distance
                # perform the matching, returns the indices of the matched coordinates
                label_ind, pred_ind = linear_sum_assignment(costs)
                #print(f"label_ind: {label_ind}, pred_ind: {pred_ind}")
                #return label_ind, pred_ind 
                # check how many of the matches are smaller than the match distance
                # these are the true positives
                match_ok = pairwise_distances[label_ind, pred_ind] < match_distance
                true_positives = np.count_nonzero(match_ok)
                print(f"true_positives: {true_positives}")

            # compute false positives and false negatives
            false_positives = m - true_positives
            false_negatives = n - true_positives

            precision = true_positives / (true_positives + false_positives) if true_positives > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if true_positives > 0 else 0
            # from https://www.v7labs.com/blog/f1-score-guide#:~:text=The%20F1%20score%20is%20calculated,denotes%20a%20better%20quality%20classifier.
            f1 =  (2*precision*recall)/(precision+recall) if (precision > 0 and recall > 0) else 0
            return precision, recall, f1, _dev_percentage(n, m), _calc_sMAPE(n, m), _calc_mae(n, m)  
        else:
            return 0,0,0, _dev_percentage(n, m), _calc_sMAPE(n, m), _calc_mae(n, m)
    else:
        raise Exception(f"Number of ground truths and/or predictions are negative (metric): len(gts) = {n}, len(preds) = {m}.")

def metric_coords_greedy(gts, preds, match_distance=45):
    """
    Greedy nearest-neighbor matching for metric calculation.
    gt: [(x,y), (...), ...]
    pred: [(x,y), (...), ...]
    """
    n = len(gts)
    m = len(preds)

    if n == 0 and m == 0:
        return 1, 1, 1, 0, 0, 0
    elif (n == 0 and m > 0) or (n > 0 and m == 0):
        return 0, 0, 0, _dev_percentage(n, m), 1, _calc_mae(n, m)
    elif n > 0 and m > 0:
        pairwise_distances = cdist(gts, preds, metric="euclidean")
        matched_gt = set()
        matched_pred = set()
        for _ in range(min(n, m)):
            # Find the minimum distance pair
            min_idx = np.unravel_index(np.argmin(pairwise_distances, axis=None), pairwise_distances.shape)
            i, j = min_idx
            if pairwise_distances[i, j] < match_distance:
                matched_gt.add(i)
                matched_pred.add(j)
                # Set these rows/cols to a large value so they won't be matched again
                pairwise_distances[i, :] = np.inf
                pairwise_distances[:, j] = np.inf
            else:
                break  # No more matches within threshold

        true_positives = len(matched_gt)
        false_positives = m - true_positives
        false_negatives = n - true_positives

        precision = true_positives / (true_positives + false_positives) if true_positives > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if true_positives > 0 else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision > 0 and recall > 0) else 0
        return precision, recall, f1, _dev_percentage(n, m), _calc_sMAPE(n, m), _calc_mae(n, m)
    else:
        raise Exception(f"Number of ground truths and/or predictions are negative (metric): len(gts) = {n}, len(preds) = {m}.")

In [57]:
precision, recall, f1, dev_percentage, calc_sMAPE, calc_mae = metric_coords(particle, prediction_points)
print(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Dev Percentage: {dev_percentage}, sMAPE: {calc_sMAPE}, MAE: {calc_mae}")
precision, recall, f1, dev_percentage, calc_sMAPE, calc_mae  = metric_coords_greedy(particle, prediction_points)
print(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Dev Percentage: {dev_percentage}, sMAPE: {calc_sMAPE}, MAE: {calc_mae}")


true_positives: 55
Precision: 0.1903114186851211, Recall: 1.0, F1: 0.31976744186046513, Dev Percentage: 4.254545454545455, sMAPE: 0.6802325581395349, MAE: 234
Precision: 0.18685121107266436, Recall: 0.9818181818181818, F1: 0.313953488372093, Dev Percentage: 4.254545454545455, sMAPE: 0.6802325581395349, MAE: 234
