In [11]:
import plotly.figure_factory as ff
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import math

In [33]:
all_df = pd.read_csv('/Users/rushil/brain_extraction/results/quantitative/rim_comparison/rim_metrics_r1_scan.csv')

exclude = ['6109-317_20150302_0647_ct', '6142-308_20150610_0707_ct', '6193-324_20150924_1431_ct', '6257-335_20160118_1150_ct',
                     '6418-193_20161228_1248_ct', '6470-296_20170602_0607_ct', '6480-154_20170622_0937_ct']

exclude_prefixes = ("6046", "6084", "6096", "6246", "6315", "6342", "6499")

all_df = all_df[~all_df["stem"].isin(exclude)]
all_df = all_df[~all_df["stem"].str.startswith(exclude_prefixes)]

TRIM = 0.02
TARGET_BINS = 20
def trim(x,p=TRIM):
    x = np.asarray(x, float); x = x[~np.isnan(x)]
    if x.size==0: return x
    lo,hi = np.quantile(x,[p,1-p]); return x[(x>=lo)&(x<=hi)]
def fd_bin_size(x):
    if x.size<2: return 1.0
    q25,q75 = np.percentile(x,[25,75]); iqr = q75-q25
    return max(2*iqr/(x.size**(1/3)), 1e-6)

methods = sorted(all_df['method'].unique(),
                 key=lambda m: np.nanmedian(all_df.loc[all_df['method']==m,'p99']))  # sort by p99
rows, cols = math.ceil(len(methods)/2), 2
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[
    f"{m}"
    for m in methods
], vertical_spacing=0.08)

for i,m in enumerate(methods):
    r, c = i//cols+1, i%cols+1
    v95 = trim(all_df.loc[all_df['method']==m,'p95'].values)
    v99 = trim(all_df.loc[all_df['method']==m,'p99'].values)
    pooled = np.r_[v95, v99]; 
    if pooled.size==0: continue
    binsz = fd_bin_size(pooled)
    f = ff.create_distplot([v95, v99], ['p95','p99'],
                           bin_size=binsz, colors=["#ff0909","#6700f8"],
                           curve_type='normal', show_hist=True, show_curve=True, show_rug=False)
    for tr in f.data:
        tr.showlegend = (tr.type=='scatter' and i==0)
        if tr.type=='scatter': tr.line.width=2
        else: tr.opacity=0.35
        fig.add_trace(tr, row=r, col=c)

    lo, hi = np.quantile(pooled,[0.01,0.99])
    fig.update_xaxes(range=[lo,hi], row=r, col=c)
    # medians
    for val,color in [(np.median(v95),"#ff0909"), (np.median(v99),"#6700f8")]:
        fig.add_vline(x=val, line_width=1, line_dash='dot', line_color=color, row=r, col=c)

# share y per column
fig.update_yaxes(matches='y', row=1, col=1)
fig.update_yaxes(matches='y', row=1, col=2)

fig.update_layout(
    barmode='overlay', legend=dict(orientation='h', y=1.03, x=0, title=None),
    height=320*rows, width=1000, margin=dict(l=60,r=20,t=60,b=40)
)
fig.update_yaxes(title_text='Density', row=1, col=1)
fig.update_xaxes(title_text='Value',  row=rows, col=1)
fig.show()
