In [None]:
import ast
import re
import sys
from collections import OrderedDict
from pathlib import Path

import h5py as h5
import numpy as np
import pandas as pd

sys.path.append('../lib')
from local_paths import preproc_dir
from storage import get_storage_functions

# Set parameters

In [None]:
#============================================================================
# session
#============================================================================
sess_name = 'sess_name'

#============================================================================
# target
#============================================================================
unit_dset = ''  # dset contaiing unit names; use . to indicate attr
dsets = []  # a list of dsets (as an hdf5 key)
unit_axes = []  # an int axis per dset
dsets_to_copy = None  # None or a list of dsets

#============================================================================
# grouping
#============================================================================
stat = 'mean'  # np function name
hiers = ['Unit', 'Channel', 'Bank', 'Array']
save_unique_only = True  # if False, save all groups even if non-unique

#============================================================================
# paths
#============================================================================
annot_path = '../db/bank_array_regions.csv'
input_dir = preproc_dir
input_suffix = ''
output_dir = preproc_dir
output_suffix = ''

# Check prereqs and params

In [None]:
input_path = Path(input_dir) / (sess_name + input_suffix + '.h5')
print('Loading input from', input_path)
input_path = input_path.expanduser()
assert input_path.is_file()

print('Loading recording array annotations from', annot_path)
annot_path = Path(annot_path).expanduser()
assert annot_path.is_file()
adf = pd.read_csv(annot_path).set_index('Session').loc[[sess_name]].set_index('Bank')

output_dir = Path(output_dir)
assert output_dir.expanduser().is_dir()
output_path = output_dir / (sess_name + output_suffix + '.h5')
print('Saving results to', output_path)
output_path = output_path.expanduser()

In [None]:
def maybe_str2seq(v):
    if isinstance(v, str):
        v = ast.literal_eval(v)
        assert isinstance(v, list) or isinstance(v, tuple)
    return v

dsets = maybe_str2seq(dsets)
unit_axes = maybe_str2seq(unit_axes)
if dsets_to_copy is None:
    dsets_to_copy = []
else:
    dsets_to_copy = maybe_str2seq(dsets_to_copy)

assert len(dsets) == len(unit_axes)
assert all(isinstance(a, int) for a in unit_axes)
assert set(hiers) <= {'Unit', 'Channel', 'Bank', 'Array'}
hiers = sorted(hiers, key=['Unit', 'Channel', 'Bank', 'Array'].index)

stat_fun = np.__dict__[stat]

In [None]:
save_results, add_attr_to_dset, check_equals_saved, link_dsets, copy_group = \
    get_storage_functions(output_path)

# Main

In [None]:
with h5.File(input_path, 'r') as fi, h5.File(output_path, 'a') as fo:
    if '.' in unit_dset:
        i = unit_dset.rfind('.')
        unit_dset_in = unit_dset[:i]
        unit_dset_attr = unit_dset[i+1:]
        unit_names = fi[unit_dset_in].attrs[unit_dset_attr].astype(str)
        unit_dset_out = 'unit_names'
    else:
        unit_names = fi[unit_dset][()].astype(str)
        unit_dset_in = unit_dset_out = unit_dset
        unit_dset_attr = None
    n_unit = len(unit_names)
    unit_df = pd.DataFrame(data={'Name': unit_names})
    unit_df['Channel'] = [int(re.search('\d+', v).group()) for v in unit_names]
    unit_df['Bank'] = (unit_df['Channel']-1) // 32
    unit_df['Array'] = adf.loc[unit_df['Bank'].values, 'Array ID'].values

    unique_groups = OrderedDict()
    name2ig = OrderedDict()
    for hier in hiers:
        if hier == 'Unit': continue
        for name, idc in unit_df.groupby(hier).groups.items():
            if len(idc) < 2: continue
            k = tuple(sorted(idc))
            try:
                name2ig[(hier, name)] = unique_groups[k]
            except KeyError:
                name2ig[(hier, name)] = unique_groups[k] = len(unique_groups)
    ig2name0 = OrderedDict()
    for name, ig in name2ig.items():
        if not ig in ig2name0:
            ig2name0[ig] = name
    unique_groups_name = np.array(['/'.join(map(str, v)) for v in ig2name0.values()])
    all_groups_name = np.array(['/'.join(map(str, v)) for v in name2ig.keys()])

    for d in dsets_to_copy:
        if d not in fo:
            fi.copy(fi[d], fo, d)
        else:
            check_equals_saved(fi[d][()], fo[d][()], d)

    names_ = None
    for d, a in zip(dsets, unit_axes):
        vals = fi[d][()]
        assert vals.shape[a] == n_unit
        vals = np.swapaxes(vals, a, 0)

        gvals = np.empty_like(vals, shape=(len(unique_groups),*vals.shape[1:]))
        for idc, i in unique_groups.items():
            gvals[i] = stat_fun(vals[list(idc)], axis=0)

        names = []
        new_vals = []
        if 'Unit' in hiers:
            names.append([f'Unit/{v}' for v in unit_names])
            new_vals.append(vals)
        if name2ig:
            if save_unique_only:
                names.append(unique_groups_name)
                new_vals.append(gvals)
            else:
                names.append(all_groups_name)
                new_vals.append(gvals[list(name2ig.values())])

        names = np.concatenate(names)
        if names_ is None:
            names_ = names
        else:
            assert np.array_equal(names_, names)
        new_vals = np.concatenate(new_vals, axis=0)
        new_vals = np.swapaxes(new_vals, 0, a)
        if d in fo and 'hier_grouped' in fo[d].attrs:
            check_equals_saved(new_vals, fo[d][()], d)
        else:
            attrs = {k: v for k, v in fi[d].attrs.items()}
            attrs['hier_grouped'] = True
            if d == unit_dset_in:
                attrs[unit_dset_attr] = names.astype(bytes)
            save_results(d, new_vals, overwrite=True)
            for k, v in attrs.items():
                fo[d].attrs[k] = v

save_results(unit_dset_out, names.astype(bytes), overwrite=True)
add_attr_to_dset(unit_dset_out, {
    'all_groups_name': all_groups_name.astype(bytes),
    'all_groups_uid': np.array(list(name2ig.values()))})
save_results('hier_group/orig_unit_names', unit_names.astype(bytes))
save_results('hier_group/hiers', np.array(hiers).astype(bytes))
save_results('hier_group/stat', stat)
save_results('hier_group/save_unique_only', save_unique_only)
save_results('hier_group/groups/name', np.array(list(name2ig.keys())).astype(bytes))
save_results('hier_group/groups/uid', np.array(list(name2ig.values())))
for i, k in enumerate(unique_groups):
    save_results(f'hier_group/groups/unit_indices/{i}', np.array(k))

# Wrap up

In [None]:
save_results('progress_report/hier_group/all_done', True)

In [None]:
%load_ext watermark
%watermark -vm --iversions -rbg