In [None]:
import sys
from pathlib import Path

import h5py as h5
import numpy as np
import pandas as pd

sys.path.append('../lib')
from storage import get_storage_functions
from local_paths import preproc_dir, analysis_dir

# Set parameters

In [None]:
#============================================================================
# session
#============================================================================
sess_name = 'sess_name'


#============================================================================
# response windows
#============================================================================
t_pre  = 500
t_post = 500
t_win  = 200
t_step =  50


#============================================================================
# temporal resolution
#============================================================================
# for summarizing trial duration and iti
dur_res = 100
iti_res =  50
min_dur = 200  # exclude these before finding most-common (mode) trial dur


#============================================================================
# paths
#============================================================================
proc_dir = preproc_dir

output_dir = analysis_dir + 'trial_level_psth'

# Check prereqs and params

In [None]:
preproc_main_path = Path(proc_dir) / (sess_name + '-main.nwb')
print('Loading session from', preproc_main_path)
preproc_main_path = preproc_main_path.expanduser()
assert preproc_main_path.is_file()

proc_path = Path(proc_dir) / (sess_name + '-proc.h5')
print('Loading shared processing from', proc_path)
proc_path = proc_path.expanduser()
assert proc_path.is_file()

rasters_path = Path(proc_dir) / (sess_name + '-rasters.nwb')
print('Loading rasters from', rasters_path)
rasters_path = rasters_path.expanduser()
assert rasters_path.is_file()

output_dir = Path(output_dir)
assert output_dir.expanduser().is_dir()
output_path = output_dir / (sess_name + '.h5')
print('Saving results to', output_path)
output_path = output_path.expanduser()

In [None]:
with h5.File(proc_path, 'r') as f:
    unit_names = f['unit_selection/simple'][()].astype(str)

In [None]:
analysis_name = 'trial_level_psth'

if output_path.is_file():
    with h5.File(output_path, 'r') as f:
        try:
            if f[f'progress_report/{analysis_name}/all_done'][()].item():
                raise RuntimeError(f'{sess_name} has already been processed')
        except KeyError:
            pass

# Save config

In [None]:
save_results, add_attr_to_dset, check_equals_saved, link_dsets, copy_group = \
    get_storage_functions(output_path)

In [None]:
group = analysis_name + '/config/time_windows/'
save_results(group+'t_pre', t_pre)
save_results(group+'t_post', t_post)
save_results(group+'t_win', t_win)
save_results(group+'t_step', t_step)
add_attr_to_dset(group, attrs=dict(unit='ms'))

# Select trials and fixations

In [None]:
with h5.File(preproc_main_path, 'r') as f:
    pres_iim = f['stimulus/presentation/presentations/data'][()]
    pres_tid = f['intervals/presentations/trial_id'][()]
    pres_t0s = f['intervals/presentations/start_time'][()]
    pres_t1s = f['intervals/presentations/stop_time'][()]

durs = (pres_t1s - pres_t0s) * 1e3
durs = np.round(durs / dur_res) * dur_res
dur = pd.Series(durs[durs > min_dur]).mode().values.mean().item()
dur = round(dur / dur_res) * dur_res
print('Mode trial duration:\t', dur, 'ms')

itis = (pres_t0s[1:] - pres_t1s[:-1]) * 1e3
itis = np.round(itis / iti_res) * iti_res
iti = pd.Series(itis).mode().values.mean().item()
iti = np.round(iti / iti_res) * iti_res
print('Mode ITI:\t\t', iti, 'ms')  # for the record only

# these are trial ID's
itr_sel = pres_tid[np.nonzero(durs == dur)[0]]
print('Selected', len(itr_sel), 'complete trials')

# these are trial indices, relative to pres_*
tr_sel = pd.DataFrame(data={'Trial': pres_tid, 'Index': np.arange(pres_tid.size)})\
    .set_index('Trial').loc[itr_sel, 'Index'].values

In [None]:
save_results(analysis_name+'/mode_dur', dur)
save_results(analysis_name+'/mode_iti', iti)
save_results(analysis_name+'/trial_selection', itr_sel)

# Get image onset-aligned responses

In [None]:
with h5.File(rasters_path, 'r') as f:
    all_unit_names = list(f['processing/ecephys/unit_names/unit_name'][()].astype(str))
    sel_ = np.array([all_unit_names.index(n) for n in unit_names])
    rasters = f['processing/ecephys/rasters/data'][()][:,sel_]
rasters.shape, rasters.dtype

In [None]:
ts = np.arange(-t_pre, dur+t_post, t_step)
hwin = t_win//2

resps = np.empty((itr_sel.size, ts.size, unit_names.size), dtype=np.float32)

for i, t0 in enumerate(pres_t0s[tr_sel]*1e3):
    for j, t in enumerate(np.round(t0 - hwin + ts).astype(int)):
        resps[i,j] = rasters[t:t+t_win,:].mean(0)

In [None]:
save_results(analysis_name+'/mean_responses', resps.mean(0)*1e3, attrs=dict(
    dims=np.array(['time', 'unit'], dtype=bytes),
    time=ts, unit=unit_names.astype(bytes), n_trial=len(resps)))

# Wrap up

In [None]:
save_results(f'progress_report/{analysis_name}/all_done', True)

In [None]:
%load_ext watermark
%watermark
%watermark -vm --iversions -rbg

# Plots

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(ts, resps.mean((0,-1)) * 1e3)

yl = plt.gca().get_ylim()
plt.fill_betweenx(yl, 0, dur, ec='none', fc='whitesmoke', zorder=-1)
plt.fill_betweenx(yl, dur, dur+iti, ec='none', fc=(1, .9, .9), zorder=-1)
plt.fill_betweenx(yl, -iti, 0, ec='none', fc=(1, .9, .9), zorder=-1)

plt.xlabel('Time rel. image onset, ms')
plt.ylabel('Grand mean firing rate, spikes/s');