In [None]:
# install plotting libraries
!pip install pandas seaborn

In [2]:
# import libraries
import json
import glob
import concurrent.futures

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

%matplotlib inline

In [3]:
def load_file(filepath):
    try:
        with open(filepath, 'r') as f:
            dictionary = json.load(f)
        return dictionary
    except:
        print("bad filepath: ", filepath)
        return None

In [4]:
def split_dataset_plot(df, metric, datasets=None, sparsities=None, volume=None, network=None):

    dataset_color_map = {
        'cremi_c': '#C3186C',
        'cremi_b': '#FE6100',
        'cremi_a': '#FFC107',
        'voljo': '#5E94F3',
        'fib25': '#785EF0',
        'epi': '#0FCEAD',
    }

    # filter df
    plot_df = df.copy()
    if datasets is not None:
        plot_df = plot_df[plot_df['dataset'].isin(datasets)]
    if sparsities is not None:
        plot_df = plot_df[plot_df['sparsity'].isin(sparsities)]
    if volume is not None:
        plot_df = plot_df[plot_df['volume'].isin(volume)]
    if network is not None:
        plot_df = plot_df[plot_df['network'].isin(network)]

    # make sub plots, one for each dataset
    datasets = plot_df['dataset'].unique()
    sparsities = sorted(plot_df['sparsity'].unique(), reverse=True)
    sparsity_to_pos = {str(s): i for i, s in enumerate(sparsities)}
    fig = plt.figure(figsize=(1.5*len(datasets) + 2, 4), dpi=300)
    gs = matplotlib.gridspec.GridSpec(1, len(datasets) + 1, width_ratios=[1]*len(datasets) + [0.2])
    axes = []

    for idx, dataset in enumerate(datasets):
        ax = fig.add_subplot(gs[0, idx])
        axes.append(ax)

        dataset_df = plot_df[plot_df['dataset'] == dataset]
        dataset_df = dataset_df.sort_values(by='sparsity')

        # plot metric on x axis, sparsity on y axis (more sparse is lower, more dense is higher)
        y_positions = []
        x_values = []
        x_errors = []

        for i, (index, row) in enumerate(dataset_df.iterrows()):
            # Calculate a deterministic offset based on the index
            offset = (i % 3 - 1) * 0.15  # Alternates between -0.1 and 0.1
            y_positions.append(sparsity_to_pos[str(row['sparsity']).strip().split('\n')[0]] + offset)

        if isinstance(dataset_df[metric].iloc[0], pd.Series):
            x_values = dataset_df[metric]['mean']
            x_errors = dataset_df[metric]['std']
            x_min = dataset_df[metric]['mean'].min() - dataset_df[metric]['std'].max()
            x_max = dataset_df[metric]['mean'].max() + dataset_df[metric]['std'].min()

        else:
            x_values = dataset_df[metric]
            x_errors = None
            x_min = dataset_df[metric].min()
            x_max = dataset_df[metric].max()

        scatter = ax.scatter(x_values, y_positions, 
                           color=dataset_color_map[dataset],
                           label=dataset, s=50, edgecolor='white', linewidth=0.5)
        
        if x_errors is not None:
            ax.errorbar(x_values, y_positions,
                        xerr=x_errors,
                        fmt='none',
                        color=dataset_color_map[dataset],
                        alpha=0.5)

        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray', axis='x')
        ax.set_axisbelow(True)
        ax.spines['top'].set_visible(True)
        ax.spines['right'].set_visible(True)
        
        # axes labels
        if idx == 0:
            ax.set_ylabel('Amount of human annotation')    
            ax.set_yticks(range(len(sparsities)))
            ax.set_yticklabels(list(sparsities))
        else:   
            ax.set_yticks(range(len(sparsities)))
            ax.set_yticklabels(['' for _ in range(len(sparsities))])

        ax.set_xlabel('')

        # set x axis limits
        x_margin = (x_max - x_min) * 0.2 # Add a margin of 10
        ax.set_xlim([x_min - x_margin, x_max + x_margin])

        # hide legend
        #ax.legend().set_visible(False)

    # axes, labels, title, legend, layout
    fig.text(0.45, 0.02, metric, ha='center', va='center')
    legend_ax = fig.add_subplot(gs[0, -1])
    legend_ax.axis('off')
    legend = legend_ax.legend(handles=[plt.scatter([], [], color=color, label=dataset) 
                                     for dataset, color in dataset_color_map.items()],
                             loc='center left')
    plt.tight_layout(rect=[0, 0.05, 0.95, 1])
    return fig, axes


In [5]:
# load result jsons recursively
with concurrent.futures.ThreadPoolExecutor() as executor:
    gt_results = list(executor.map(load_file, glob.glob('**/results_gt*.json', recursive=True)))
    pred_results = list(executor.map(load_file, glob.glob('**/results_pred*.json', recursive=True)))

In [6]:
# flatten pred results
pred_data = []
for result_dict in pred_results:
    # each result file contains all the results for all the segs for a given pred
    for r in result_dict.values():
        result = {}
        # cremi_c_rep1/dense/volume_1.zarr/3Af2M/5000--from--2d_mtlsd_3ch_20000/segmentations_ws/hist_quant_75--0.25--0.0001--1_2_2--xyTrue--minseed10
        experiment_name, round_name, volume, setup_name, iteration_str, seg_prefix, seg_name = r['seg_ds'].split('/')[-7:]
        result['dataset'] = experiment_name.split('_rep_')[0]
        result['rep'] = int(experiment_name.split('_rep_')[1])
        result['sparsity'] = round_name
        result['volume'] = volume.split('.zarr')[0]
        result['network'] = setup_name
        result['iteration'] = int(iteration_str.split('--from--')[0])
        result['prev_preds'] = None if len(iteration_str.split('--from--')) == 1 else iteration_str.split(str(result['iteration'])+'--from--')[1]
        result['seg_method'] = seg_prefix.split('segmentations_')[1]
        result['seg_name'] = seg_name
        
        for metric, value in r['error_map'].items():
            result['errmap_'+metric] = value

        for metric, value in r['error_mask'].items():
            result['errmask_'+metric] = value

        # seg params
        if result['seg_method'] == 'ws':
            continue # include only mws results
            # seg_params = seg_name.split('--')
            # merge_fn, threshold, min_seed = seg_params[0], seg_params[1], seg_params[-1]
            # if len(seg_params) == 6:
            #     noise_eps = seg_params[2]
            #     sigma = seg_params[3]
            # elif len(seg_params) == 5:
            #     if '_' in seg_params[2]:
            #         sigma = seg_params[2]
            #         noise_eps = 0
            #     else:
            #         sigma = "None"
            #         noise_eps = seg_params[2]

            # result['merge_fn'] = merge_fn
            # result['threshold'] = float(threshold)
            # result['min_seed'] = min_seed
            # result['noise_eps'] = float(noise_eps)
            # result['sigma'] = sigma

        if result['seg_method'] == 'mws':
            seg_params = seg_name.split('--')
            bias, rm_debris = seg_params[-2:]
            if len(seg_params) == 3:
                _, bias, rm_debris = seg_params
                if '_' in seg_params[0]:
                    sigma = seg_params[0]
                    noise_eps = 0
                else:
                    sigma = "None"
                    noise_eps = seg_params[0]
            else:
                noise_eps = seg_params[0]
                sigma = seg_params[1]

            result['bias'] = bias
            result['short_bias']  = float(bias.split('_')[0].split('b')[-1])
            result['long_bias']  = float(bias.split('_')[-1])
            result['rm_debris'] = rm_debris
            result['noise_eps'] = float(noise_eps)
            result['sigma'] = sigma

        pred_data.append(result)

In [7]:
# flatten gt results
gt_data = []
for result_dict in gt_results:
    # each result file contains all the results for all the segs for a given pred
    for r in result_dict.values():
        result = {}
        # cremi_c_rep1/dense/volume_1.zarr/3Af2M/5000--from--2d_mtlsd_3ch_20000/segmentations_ws/hist_quant_75--0.25--0.0001--1_2_2--xyTrue--minseed10
        experiment_name, round_name, volume, setup_name, iteration_str, seg_prefix, seg_name = r['seg_ds'].split('/')[-7:]
        result['dataset'] = experiment_name.split('_rep_')[0]
        result['rep'] = int(experiment_name.split('_rep_')[1])
        result['sparsity'] = round_name
        result['volume'] = volume.split('.zarr')[0]
        result['network'] = setup_name
        result['iteration'] = int(iteration_str.split('--from--')[0])
        result['prev_preds'] = None if len(iteration_str.split('--from--')) == 1 else iteration_str.split(str(result['iteration'])+'--from--')[1]
        result['seg_method'] = seg_prefix.split('segmentations_')[1]
        result['seg_name'] = seg_name
        for metric, value in (r['metrics']['voi'] | r['metrics']['skel']).items():
            result[metric] = value

        # more metrics
        result['nvi_sum'] = result['nvi_split'] + result['nvi_merge']
        result['voi_sum'] = result['voi_split'] + result['voi_merge']
        result['total_skel_errors'] = result['n_splits'] + result['n_mergers']
        result['errors_per_skel_length'] = 1000 * result['total_skel_errors'] / result['total_path_length']

        # seg params
        if result['seg_method'] == 'ws':
            continue # include only mws
            # seg_params = seg_name.split('--')
            # merge_fn, threshold, min_seed = seg_params[0], seg_params[1], seg_params[-1]
            # if len(seg_params) == 6:
            #     noise_eps = seg_params[2]
            #     sigma = seg_params[3]
            # elif len(seg_params) == 5:
            #     if '_' in seg_params[2]:
            #         sigma = seg_params[2]
            #         noise_eps = 0
            #     else:
            #         sigma = "None"
            #         noise_eps = seg_params[2]

            # result['merge_fn'] = merge_fn
            # result['threshold'] = float(threshold)
            # result['min_seed'] = min_seed
            # result['noise_eps'] = float(noise_eps)
            # result['sigma'] = sigma

        if result['seg_method'] == 'mws':
            seg_params = seg_name.split('--')
            bias, rm_debris = seg_params[-2:]
            if len(seg_params) == 3:
                _, bias, rm_debris = seg_params
                if '_' in seg_params[0]:
                    sigma = seg_params[0]
                    noise_eps = 0
                else:
                    sigma = "None"
                    noise_eps = seg_params[0]
            else:
                noise_eps = seg_params[0]
                sigma = seg_params[1]

            result['bias'] = bias
            result['short_bias']  = float(bias.split('_')[0].split('b')[-1])
            result['long_bias']  = float(bias.split('_')[-1])
            result['rm_debris'] = rm_debris
            result['noise_eps'] = float(noise_eps)
            result['sigma'] = sigma

        gt_data.append(result)

In [8]:
# make dataframes
# then, merge gt and pred dataframes by grouping dataset, rep, sparsity, volume, iteration, prev_pred, seg_method, and seg_name
# basically we want to merge the gt and pred dataframes by the columns that are the same, and then include the columns that are different
# i.e, one dataframe for all datasets, sparsities, reps, volumes, with gt and pred results.
# if a row does not exist in either df, then we will add as in the other df with NaNs for the columns which dont have a value

gt_df = pd.DataFrame(gt_data)
pred_df = pd.DataFrame(pred_data)
all_results = pd.merge(gt_df, pred_df, on=['dataset', 'sparsity', 'rep', 'network', 'volume', 'iteration', 'prev_preds', 'seg_method', 'seg_name', 'bias', 'short_bias', 'long_bias', 'rm_debris', 'noise_eps', 'sigma'], how='outer')
# nan_df = all_results[all_results.isna().any(axis=1)] # get all rows with nan values

In [9]:
# reduce merged df
group_by = ['dataset', 'rep', 'sparsity', 'volume', 'network',]
#group_by += ['iteration', 'prev_preds']

metric_for_best_gt = 'nvi_sum'
metric_for_best_self = 'errmask_mean'
metrics_to_avg = ['rand_split', 'rand_merge', 'voi_split', 'voi_merge', 'nvi_split', 'nvi_merge', 'nid', 'erl', 'nerl', 'n_mergers', 'n_splits', 'n_non0_mergers', 'nvi_sum', 'voi_sum', 'total_skel_errors', 'errors_per_skel_length', 'errmap_mean', 'errmap_std', 'errmask_mean', 'errmask_std', 'errmask_num_nonzero_voxels', 'errmap_num_nonzero_voxels', 'errmask_nonzero_ratio', 'errmap_nonzero_ratio']

# 'best' reduction: pick best segmentation result for each combination of columns in `group_by`
best_results_gt = all_results.loc[all_results.groupby(group_by)[metric_for_best_gt].idxmin()].reset_index()
best_results_self = all_results.loc[all_results.groupby(group_by)[metric_for_best_self].idxmin()].reset_index()

# 'avg' reduction: average over all segmentation results for each combination of columns in `group_by`, with mean and std dev
avg_results = all_results.groupby(group_by).agg({col: ['mean', 'std'] for col in metrics_to_avg}).reset_index()

In [None]:
split_dataset_plot(best_results_self, 'errmask_mean');

In [None]:
split_dataset_plot(avg_results, 'errmap_mean');

In [12]:
meta_columns = best_results_gt.select_dtypes(exclude=['float']).columns
diff_df = (best_results_gt.select_dtypes(include=['float']) - best_results_self.select_dtypes(include=['float']))/best_results_gt.select_dtypes(include=['float'])
diff_df = pd.concat([best_results_gt[meta_columns], diff_df], axis=1)

In [None]:
split_dataset_plot(diff_df, 'errors_per_skel_length');

In [14]:
# export best segmentations
best_segs_gt_eval = best_results_gt.to_dict('index')
best_segs_self_eval = best_results_self.to_dict('index')

# save to json
with open('round_1_best_segs_gt_eval.json', 'w') as f:
    json.dump(best_segs_gt_eval, f, indent=4)
with open('round_1_best_segs_self_eval.json', 'w') as f:
    json.dump(best_segs_self_eval, f, indent=4)