In [None]:
import sys
from pathlib import Path

import h5py as h5
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr

sys.path.append('../../lib')
from local_paths import analysis_dir
from hier_group import unpack_hier_names

# Parameters

In [None]:
#============================================================================
# analysis type and result path
#============================================================================
rf_fit_group = 'rf_fit/opt/per_split'  # formatted as in the rf_gaussian_fit script
# rf_fit_group = 'rf_fit/opt/across_splits'
results_subdir = 'feat_corr_map-hg-fix-rf_fit'

#============================================================================
# selection criteria
#============================================================================
rf_fit_thres = {'rf_at_fit_peak': 0.04, 'goodness_of_fit': 0.7, 'fit_coverage': 0.5}

# Preamble

In [None]:
results_dir = Path(analysis_dir+results_subdir).expanduser()
assert results_dir.is_dir()

cols_to_save = ['x', 'y', 'r', 'goodness_of_fit', 'rf_fit_weighted_mean', 'a', 'b', 'ang_rad']

# Load data

In [None]:
df = []
has_split = None
for fp in results_dir.glob('*.h5'):
    with h5.File(fp, 'r') as f:
        try:
            assert f['progress_report/rf_fit/all_done'][()]
        except (KeyError, AssertionError):
            continue

    rf_fit_ds = xr.load_dataset(fp, group=rf_fit_group)
    with h5.File(fp, 'r') as f:
        rf_unit_names = f['rf_fit/unit_names'][()].astype(str)
    rf_fit_data = rf_fit_ds['data'].loc[{'unit':rf_unit_names}]

    if has_split is None:
        has_split = 'split' in rf_fit_data.dims
    else:
        assert has_split == ('split' in rf_fit_data.dims)

    # reshape data; ensure only one condition exists
    dims_ = tuple(set(rf_fit_data.dims) - {'unit','feature','split'})
    assert np.prod([rf_fit_data.coords[d].size for d in dims_]) == 1
    if has_split:
        new_dims = dims_ + ('split', 'unit', 'feature')
    else:
        new_dims = dims_ + ('unit', 'feature')
    rf_fit_data = rf_fit_data.transpose(*new_dims)

    # reformat as dataframe
    data = rf_fit_data.values.reshape(-1, rf_fit_data.shape[-1])
    rf_df = pd.DataFrame(data=data, columns=rf_fit_data.coords['feature'].astype(str))
    index = unpack_hier_names(rf_unit_names)
    if has_split:
        rf_df['Split'] = rf_fit_data['split'].broadcast_like(rf_fit_data.isel(feature=0)).values.ravel()
        rf_df[['Level', 'Name']] = np.concatenate([index]*rf_fit_data['split'].size, axis=0)
    else:
        rf_df[['Level', 'Name']] = index
    rf_df['Session'] = fp.stem
    df.append(rf_df)

df = pd.concat(df)
if has_split:
    df = df.set_index(['Session', 'Level', 'Name', 'Split'])
else:
    df = df.set_index(['Session', 'Level', 'Name'])
assert not df.index.has_duplicates
df['r'] = np.sqrt(np.prod(df[['a', 'b']], axis=1))
print(df.shape)

output_sfx = ('-across_splits', '')[has_split]

# Select from all results

In [None]:
def select_rf_fit(rf_df, rf_fit_thres=rf_fit_thres):
    criteria = {}
    criteria['Is valid'] = np.isfinite(rf_df[list('xyab')].values).all(1)
    for k, v in rf_fit_thres.items():
        criteria[k] = rf_df[k] >= v

    print(f'Selecting from {len(rf_df)} entries')
    for k, m in criteria.items():
        print(f'criterion: {k:<20} passed: {m.mean()*100:.1f}% ({m.sum()} of {m.size})')
    m = np.all(list(criteria.values()), axis=0)
    print(f'criterion: {"All":<20} passed: {m.mean()*100:.1f}% ({m.sum()} of {m.size})')

    return m

In [None]:
rf_df = df.copy()
rf_df['Selected'] = select_rf_fit(rf_df)

In [None]:
rf_df.loc[rf_df['Selected'], cols_to_save].to_csv(
    f'summary/rf_fit{output_sfx}.csv.gz')

# Summarize array-level results

In [None]:
arreg = pd.read_csv('../../db/bank_array_regions.csv').astype({'Array ID': str})
arreg['Subject'] = [v[:2] for v in arreg['Session']]
arreg = arreg.groupby(['Subject', 'Array ID']).first()

In [None]:
rf_df = df.reset_index()
rf_df = rf_df[rf_df['Level']=='Array'].copy()
rf_df['Selected'] = select_rf_fit(rf_df)

In [None]:
adf = rf_df[rf_df['Selected']].copy()
adf['Subject'] = [v[:2] for v in adf['Session']]
adf[['Region', 'Hemisphere']] = arreg.loc[list(map(tuple, adf[['Subject', 'Name']].values))][['Region', 'Hemisphere']].values

In [None]:
print('Array-level RF fit, per session')
fig, axs = plt.subplots(1, 3, figsize=(9,2.5))
for x, ax in zip('xyr', axs):
    sns.histplot(
        data=adf, x=x, hue='Hemisphere', hue_order=('L','R'),
        stat='density', element='poly', common_norm=False, fill=False, ax=ax)

In [None]:
print('Array-level estimates, median across sessions')
df_ = adf.groupby(['Subject','Name']).agg({
    'Region': 'first', 'Hemisphere': 'first', 'Selected': 'mean',
    **{k: 'median' for k in cols_to_save}})
assert df_['Selected'].all()  # sanity check
fig, axs = plt.subplots(1, 3, figsize=(9,2.5))
for i, (x, ax) in enumerate(zip('xyr', axs)):
    sns.histplot(data=df_, x=x, hue='Hemisphere', stat='density', element='poly', common_norm=False, fill=False, ax=ax)
df_

In [None]:
print('Array-level estimates, median across arrays')
df_.groupby(['Region', 'Hemisphere']).median()

In [None]:
# save array-level resuilts median across sessions
gb = adf.groupby(['Subject', 'Name'])
df_ = gb[cols_to_save].median()
df_['Count'] = gb['x'].count()
df_[['Region', 'Hemisphere']] = gb[['Region', 'Hemisphere']].first()
df_['Level'] = 'Array'
df_ = df_.reset_index().set_index(['Subject', 'Level', 'Name'])
df_.to_csv(f'summary/rf_fit{output_sfx}-array_level.csv.gz')
df_