In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import glob

from multiprocessing import Pool

%matplotlib inline

In [None]:
files = glob.glob("/scratch/04101/vvenu/SPARSE_TEST/voljo/grid_vals/grid*/*")

In [None]:
len(files)

In [None]:
def read_json(file):
    
    index = os.path.basename(file).split(".")[0]
    
    with open(file,"r") as f:
        try:
            result = json.load(f)
        except:
            result = None
    
    return result

In [None]:
def read_all(files):
    
    results = []
    
    with Pool() as pool:
        
        for i,result in enumerate(pool.imap_unordered(read_json,files)):
            if i % 100000 == 0:
                print(i)
            results.append(result)
        
    return results

In [None]:
tets = read_json(files[-1])

In [None]:
tets

In [None]:
results = read_all(files)

In [None]:
#convert to plottable values
for x in results:
    
#     if x is None: 
#         results.remove(x)
#         continue
    
    for item in x:
    
        if x[item] == None: x[item] = 0
        if x[item] == False: x[item] = 0
        if x[item] == True: x[item] = 1
        if type(x[item]) == list: x[item] = str(x[item])
        
        if type(x[item]) == dict:
            for it in x[item]:
      
                if x[item][it] == None: x[item][it] = 0
                if x[item][it] == False: x[item][it] = 0
                if x[item][it] == True: x[item][it] = 1
                if type(x[item][it]) == list: x[item][it] = str(x[item[it]])

In [None]:
best_results = [x for x in results if (False not in [x['best'][y] < x['frags'][y] for y in ['voi_sum','nvi_sum','nid']])] 

In [None]:
len(best_results)

In [None]:
best_results = results

In [None]:
for x in best_results:
    for y in x['best']:
        x[y] = x['best'][y]
    del x['best']

In [None]:
for x in best_results:
    del x['frags']

In [None]:
best_results[0]

In [None]:
plt.hist([x['nvi_sum'] for x in best_results], bins=100);

In [None]:
best_results = sorted(best_results, key=lambda x: x["nvi_sum"])

In [None]:
best_results[0]

In [None]:
to_avg = [
 'rand_split',
 'rand_merge',
 'voi_split',
 'voi_merge',
 'nvi_split',
 'nvi_merge',
 'nid',
 'merge_threshold',
 'voi_sum',
 'nvi_sum']

In [None]:
not_to_avg = [
 'roi',
 'raw_file',
 'labels_dataset',
 'labels_mask',
 'pred_file',
 'pred_dataset',
 'downsampling',
 'denoising',
 'normalize_preds',
 'background_mask',
 'min_seed_distance',
 'merge_function',
 'pred_iteration',
 'affs_iteration',
 'sigma',
 'gb',
 'gt_type',
 'EA',
 'lite',
 'LR']
# not_to_avg = [
#  'raw_file',
#  'labels_dataset',
#  'labels_mask',
#  'pred_file',
#  'pred_dataset',
#  'roi',
#  'downsampling',
#  'denoising',
#  'normalize_preds',
#  'stride',
#  'algorithm']

In [None]:
for x in best_results:
    x['pred_iteration'] = int(x['pred_dataset'].split('_')[-1])
    x['affs_iteration'] = int(x['pred_dataset'].split('_')[-5])
    
    x['sigma'] = x['pred_file'].split('/')[-4]
    
    if '0gb' in x['pred_file']:
        x['gb'] = 0
    if '1gb' in x['pred_file']:
        x['gb'] = 1
    if '2gb' in x['pred_file']:
        x['gb'] = 2
    
    if 'arlo' in x['pred_file']:
        x['gt_type'] = "arlo"
    elif 'jan' in x['pred_file']:
        x['gt_type'] = "jan"
    else: x['gt_type'] = "voronoi"
        
    x["EA"] = "no" if "noEA" in x['pred_file'] else "yes"
    x["lite"] = "yes" if "lite" in x['pred_file'] else "no"
    x["LR"] = "yes" if "LR" in x['pred_file'] else "no"

In [None]:
#grouping across ROIs
all_rois = list(set(x['roi'] for x in best_results))
print(all_rois)

#results by roi
results_by_roi = {k:[x for x in best_results if x['roi'] == k] for k in all_rois}
print(list(zip(range(len(results_by_roi)),[len(x) for x in results_by_roi.values()])))

In [None]:
#find intersection of all arg combos among all rois
intersection = []
for roi in results_by_roi:
    intersection.append(
        [''.join([str(result[x]) for x in not_to_avg[1:]]) for result in results_by_roi[roi]]
    )
    
intersection = set.intersection(*map(set,intersection))

print(len(intersection))

In [None]:
#filter results_by_roi using intersection
results_by_roi = {k:[x for x in results_by_roi[k] if ''.join([str(x[j]) for j in not_to_avg[1:]]) in intersection] for k in all_rois}
print(list(zip(range(len(results_by_roi)),[len(x) for x in results_by_roi.values()])))

In [None]:
#plot metric for all rois
metric = 'nvi_sum'
rois_to_plot = all_rois

fig, axes = plt.subplots(1,len(rois_to_plot),figsize=(12,4),sharex=False,sharey=False,squeeze=False)

for i,roi in enumerate(rois_to_plot):
    
    axes[0][i].hist([x[metric] for x in results_by_roi[roi]],bins=100)
    axes[0][i].set_title(f"roi {i}, {metric}")
    
plt.tight_layout()

In [None]:
#AVERAGE across rois.
results_roi_avg = []

for i,result in enumerate(results_by_roi[all_rois[0]]):

    averaged = {}

    for key in not_to_avg[1:]:
        averaged[key] = result[key]

    #good thing order is preserved
    parts = [results_by_roi[x][i] for x in all_rois]

    #print(len(parts))
    assert len(parts) == len(all_rois)

    for key in to_avg:
        vals = [x[key] for x in parts]
        averaged[key] = [np.mean(vals),np.std(vals)]

    results_roi_avg.append(averaged)

In [None]:
results_roi_avg = sorted(results_roi_avg, key=lambda x: x["nvi_sum"])#[0] - x["nvi_sum"][1])

In [None]:
#plot a selection
results_to_plot = [x for x in results_roi_avg \
                    #if x['background_mask']==0
                    if x['nvi_sum'][0] < 0.3 \
                    and x['nid'][0] < 0.2
                    #and 'sigma' not in x['sigma']
                  ]
len(results_to_plot)

In [None]:
results_to_plot[-1]

In [None]:
#plot histograms
fig, axes = plt.subplots(7,2,figsize=(8,28),sharex=False,sharey=False,squeeze=False)

row = 0

for plot_num,plot_name in enumerate(
    [
        'normalize_preds',
        'background_mask',
        'min_seed_distance',
        'merge_function',
        'pred_iteration',
        'affs_iteration',
        'sigma',
        'gb',
        'gt_type',
        'EA',
        'lite',
        'LR',
        'nvi_sum',
        'nid'
    ]):
    
    if plot_name in to_avg:
        data = [x[plot_name][0]  for x in results_to_plot]
        nbins = 100
    else:
        data = [x[plot_name] for x in results_to_plot]
        nbins = 20
    
    axes[row][plot_num % 2].hist(data,bins=nbins)
    axes[row][plot_num % 2].set_title(plot_name)
    axes[row][plot_num % 2].tick_params(axis='x', rotation=40)
    
    if plot_num % 2 == 1:
        row += 1
        
plt.tight_layout()