# setup
## import

In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import itertools as itt
import holoviews as hv
import dask as da
import dask
import xarray as xr
import numba as nb
import warnings as wrn
import collections
import warnings
from bokeh import models
from bokeh.io import export_svgs
from IPython.core.debugger import set_trace
from dask.diagnostics import ProgressBar
from holoviews.operation.datashader import regrid, datashade, dynspread
from skimage.transform import resize
from __future__ import print_function
from holoviews.core.options import Options
hv.extension('bokeh', width=100)
warnings.simplefilter('ignore')

## functions def

In [2]:
def sub_spike(spk, ts, te):
    return spk[(spk > ts) & (spk < te)]

def fr(spk, ts, te, bins):
    spk_sub = sub_spike(spk, ts, te)
    if bins is None:
        return len(spk_sub) / (te - ts)
    elif isinstance(bins, float):
        bins = np.ceil((te - ts) / bins)
    return np.histogram(spk, bins, range=(ts, te))

def construct_bins(stim, grp_dim, wnd, align=['start', 'end']):
    stim['wnd_start'] = stim[align[0]] + wnd[0]
    stim['wnd_end'] = stim[align[1]] + wnd[1]
    stim = stim.sort_values('wnd_start')
    sarr, earr = stim['wnd_start'].values, stim['wnd_end'].values
    if not ((sarr < earr).all() and ((earr[:-1] < sarr[1:]).all())):
        wrn.warn("Overlapping windows with size {}".format(wnd))
    sdf = stim[grp_dim + ['wnd_start']].rename(dict(wnd_start='t'), axis='columns')
    edf = stim[grp_dim + ['wnd_end']].rename(dict(wnd_end='t'), axis='columns')
    sdf['edge'] = 'start'
    edf['edge'] = 'end'
    bin_df = pd.concat([sdf, edf], axis='rows')
    bin_df = bin_df.sort_values('t').reset_index(drop=True)
    return stim, bin_df

def mean_fr(spk, bin_df):
    fr, bins = np.histogram(spk, bin_df['t'].values)
    bin_df['fr'] = np.append(fr, np.nan)
    return bin_df

def fr_vec(spk, times, bins):
    fr0, tm0 = fr(spk, times[0, 0], times[0, 1], bins)
    result = np.zeros((times.shape[0], len(fr0)))
    result[0, :] = fr0
    for id_times in range(1, times.shape[0]):
        cur_fr, _ = fr(spk, times[id_times, 0], times[id_times, 1], bins)
        result[id_times, :] = cur_fr
    return result

@nb.jit(nopython=True, nogil=True)
def fr_vec_nb(spk, times, bins):

    def fr(spk, ts, te, bins):
        spk_sub = spk[(spk > ts) & (spk < te)]
        return np.histogram(spk, bins, range=(ts, te))
    
    fr0, tm0 = fr(spk, times[0, 0], times[0, 1], bins)
    result = np.zeros((times.shape[0], len(fr0)))
    result[0, :] = fr0
    for id_times in range(1, times.shape[0]):
        cur_fr, _ = fr(spk, times[id_times, 0], times[id_times, 1], bins)
        result[id_times, :] = cur_fr
    return result

@nb.jit(nopython=True, nogil=True)
def norm_sum(x):
    return x / np.nansum(x)

def normalize(x):
    eps = np.finfo(np.float32).eps
    return (x - np.nanmin(x)) / (np.nanmax(x) - np.nanmin(x) + eps)

In [3]:
def flatten(l):
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, basestring):
            for sub in flatten(el):
                yield sub
        else:
            yield el

def _get_figures_core(objs):
    if isinstance(objs, list):
        objs = [_get_figures_core(plot) for plot in objs]
    elif isinstance(objs, (models.Column, models.Row)):
        objs = [_get_figures_core(child) for child in objs.children
                if not isinstance(child, (models.ToolbarBox,
                                          models.WidgetBox))]
    return objs

def _get_figures(objs):
    try:
        return list(flatten(_get_figures_core(objs)))
    except TypeError:
        return [_get_figures_core(objs)]

def _save_to_svg(hv_obj, save):
    bokeh_obj = hv.renderer('bokeh').get_plot(hv_obj).state
    figures = _get_figures(bokeh_obj)
    for i, figure in enumerate(figures):
        figure.output_backend = 'svg'

        if len(figures) != 1:
            if not os.path.exists(save):
                os.makedirs(save)
            tidied_title = figure.title.text
            save_fp = os.path.join(
                save, '{0}_{1}'.format(tidied_title, i))
        else:
            save_fp = save

        if not save_fp.endswith('svg'):
            save_fp = '{0}.{1}'.format(save_fp, 'svg')

        export_svgs(figure, save_fp)


## load monitor information

In [4]:
from allensdk.brain_observatory.stimulus_info import BrainObservatoryMonitor
dim_mon = (300, 480)
bom = BrainObservatoryMonitor()

## load ephys experiments

In [5]:
from swdb_2018_neuropixels.ephys_nwb_adapter import NWB_adapter
dpath = '/data/dynamic-brain-workshop/visual_coding_neuropixels'
manifest_file = os.path.join(dpath,'ephys_manifest.csv')
expt_info_df = pd.read_csv(manifest_file)
exp_name = expt_info_df['nwb_filename'].iloc[5].rstrip(".nwb")

In [32]:
ds = NWB_adapter(os.path.join(dpath, exp_name + ".nwb"))

## load ophys experiments

In [None]:
from allensdk.core.brain_observatory_cache import BrainObservatoryCache
boc = BrainObservatoryCache(manifest_file='/data/allen-brain-observatory/visual-coding-2p/manifest.json')

# get neuronal responses
## try different baseline time window

In [None]:
stim_list = ['static_gratings', 'natural_scenes']
stim_dfs = []
bin_dfs = []
# wnd_list = [(-0.05, 0), (-0.2, 0)]
wnd_list = [(-0.2, 0)]
for wnd in wnd_list:
    stim_df_list = []
    bin_df_list = []
    for stim in stim_list:
        stim_df = ds.stim_tables[stim].copy()
        grp_dim = list(set(stim_df.columns) - set(['start', 'end']))
        stim_df = stim_df.dropna()
        stim_df = stim_df.reset_index().rename(dict(index='trial'), axis='columns')
        stim_df['stim_type'] = stim
        stim_df['occur'] = 1
        stim_df['occur'] = stim_df.groupby(grp_dim)['occur'].cumsum()
        grp_dim = list(set(stim_df.columns) - set(['start', 'end']))
        stim_df_wnd, bin_df = construct_bins(
            stim_df, grp_dim, wnd, ('start', 'start'))
        stim_df_wnd['window'] = str(wnd)
        bin_df['window'] = str(wnd)
        stim_df_wnd['wnd_size'] = wnd[1] - wnd[0]
        bin_df['wnd_size'] = wnd[1] - wnd[0]
        stim_df_list.append(stim_df_wnd)
        bin_df_list.append(bin_df)
    stim_df = pd.concat(stim_df_list, axis='rows')
    bin_df = pd.concat(bin_df_list, axis='rows')
    bin_df = bin_df.sort_values('t').reset_index(drop=True)
    stim_dfs.append(stim_df)
    bin_dfs.append(bin_df)

In [None]:
%%time
bases = []
for bin_df, stim_df in zip(bin_dfs, stim_dfs):
    base_list = []
    for i, unit in ds.unit_df.iterrows():
        prob, uid = unit['probe'], unit['unit_id']
        spk = ds.spike_times[prob][uid]
        fr_df = mean_fr(spk, bin_df.copy())
        fr_df['probe'] = prob
        fr_df['unit_id'] = int(uid)
        base_list.append(fr_df)
        if i % 200 == 0:
            print(i)
    base_df = pd.concat(base_list, axis=0, ignore_index=True)
    base_df = base_df.rename(dict(t='wnd_start'), axis='columns')
    tu_df = stim_df.merge(
        base_df,
        on=['frame', 'occur', 'orientation', 'phase', 'trial',
            'spatial_frequency', 'stim_type', 'wnd_start', 'window', 'wnd_size'],
        validate='one_to_many',
        copy=False)
    bases.append(tu_df)

In [None]:
base_df = pd.concat(bases, axis='rows', ignore_index=True)
base_df['fr'] = base_df['fr'] / base_df['wnd_size']

In [None]:
base_df.to_hdf(os.path.join(exp_name, "baseline.h5"), 'baseline')

## average across trial

In [None]:
%%time
# base_df = pd.read_hdf(os.path.join(exp_name, "baseline.h5"), 'baseline')
base_df_grt = base_df[base_df['stim_type'] == 'static_gratings']
base_df_ns = base_df[base_df['stim_type'] == 'natural_scenes']
base_df_grt_mean = base_df_grt.groupby(['window', 'probe', 'unit_id']).mean()
base_df_grt_mean = base_df_grt_mean['fr'].reset_index()
base_df_ns_mean = base_df_ns.groupby(['window', 'probe', 'unit_id']).mean()
base_df_ns_mean = base_df_ns_mean['fr'].reset_index()

In [None]:
hv_base_grt = hv.Dataset(
    base_df_grt_mean,
    kdims=['window', 'unit_id', 'probe'], vdims=['fr'])
hv_base_ns = hv.Dataset(
    base_df_ns_mean,
    kdims=['window', 'unit_id', 'probe'], vdims=['fr'])

In [None]:
%%opts HeatMap [height=500, width=800, tools=['hover'], colorbar=True] {+framewise}
hv_base_grt.to(hv.HeatMap, kdims=['unit_id', 'probe'])

In [None]:
%%opts HeatMap [height=500, width=800, tools=['hover'], colorbar=True] {+framewise}
hv_base_ns.to(hv.HeatMap, kdims=['unit_id', 'probe'])

## get tuning

In [None]:
stim_list = ['static_gratings', 'natural_scenes']
stim_df_list = []
bin_df_list = []
for stim in stim_list:
    stim_df = ds.stim_tables[stim].copy()
    grp_dim = list(set(stim_df.columns) - set(['start', 'end']))
    stim_df = stim_df.dropna()
    stim_df = stim_df.reset_index().rename(dict(index='trial'), axis='columns')
    stim_df['stim_type'] = stim
    stim_df['occur'] = 1
    stim_df['occur'] = stim_df.groupby(grp_dim)['occur'].cumsum()
    grp_dim = list(set(stim_df.columns) - set(['start', 'end']))
    eps = np.finfo(np.float32).eps
    stim_df_wnd, bin_df = construct_bins(
        stim_df, grp_dim, (0, -eps), ('start', 'end'))
    stim_df_list.append(stim_df_wnd)
    bin_df_list.append(bin_df)
stim_df = pd.concat(stim_df_list, axis='rows')
bin_df = pd.concat(bin_df_list, axis='rows')
bin_df = bin_df.sort_values('t').reset_index(drop=True)

In [None]:
%%time
tu_list = []
for i, unit in ds.unit_df.iterrows():
    prob, uid = unit['probe'], unit['unit_id']
    spk = ds.spike_times[prob][uid]
    fr_df = mean_fr(spk, bin_df.copy())
    fr_df['probe'] = prob
    fr_df['unit_id'] = int(uid)
    tu_list.append(fr_df)
    if i % 200 == 0:
        print(i)
tu_df_pt = pd.concat(tu_list, axis=0, ignore_index=True)
tu_df_pt = tu_df_pt.rename(dict(t='wnd_start'), axis='columns')
tu_df = stim_df.merge(
    tu_df_pt,
    on=['frame', 'occur', 'orientation', 'phase', 'trial',
        'spatial_frequency', 'stim_type', 'wnd_start'],
    validate='one_to_many',
    copy=False)
tu_df['fr'] = tu_df['fr'] / 0.25

In [None]:
base_grt = base_df_grt_mean[base_df_grt_mean['window'] == '(-0.2, 0)']
base_ns = base_df_ns_mean[base_df_ns_mean['window'] == '(-0.2, 0)']
tu_grt = tu_df[tu_df['stim_type'] == 'static_gratings']
tu_ns = tu_df[tu_df['stim_type'] == 'natural_scenes']

In [None]:
grt_dim = ['occur', 'fr', 'probe', 'unit_id', 'trial', 'orientation', 'phase', 'spatial_frequency']
ns_dim = ['occur', 'fr', 'probe', 'unit_id', 'trial', 'frame']
base_grt = base_grt.drop(set(base_grt.columns) - set(grt_dim), axis=1)
tu_grt = tu_grt.drop(set(tu_grt.columns) - set(grt_dim), axis=1)
base_ns = base_ns.drop(set(base_ns.columns) - set(ns_dim), axis=1)
tu_ns = tu_ns.drop(set(tu_ns.columns) - set(ns_dim), axis=1)

In [None]:
%%time
tu_grt_sub = base_grt.merge(tu_grt, on=['probe', 'unit_id'])
tu_ns_sub = base_ns.merge(tu_ns, on=['probe', 'unit_id'])
tu_grt_sub['fr'] = tu_grt_sub['fr_y'] - tu_grt_sub['fr_x']
tu_ns_sub['fr'] = tu_ns_sub['fr_y'] - tu_ns_sub['fr_x']
tu_grt_sub = tu_grt_sub.drop(['fr_x', 'fr_y'], axis=1)
tu_ns_sub = tu_ns_sub.drop(['fr_x', 'fr_y',], axis=1)
tu_grt_sub['gratings'] = tu_grt_sub['orientation'].astype(str) + '_' + tu_grt_sub['phase'].astype(str) + '_' + tu_grt_sub['spatial_frequency'].astype(str)

In [None]:
tu_grt_sub.to_hdf(os.path.join(exp_name, "tuning_grt.h5"), "tuning")
tu_ns_sub.to_hdf(os.path.join(exp_name, "tuning_ns.h5"), "tuning")

## visualize

In [None]:
tu_grt_sub = pd.read_hdf(os.path.join(exp_name, 'tuning_grt.h5'), "tuning")
tu_ns_sub = pd.read_hdf(os.path.join(exp_name, 'tuning_ns.h5'), "tuning")

In [None]:
%%opts HeatMap [height=600, width=800, colorbar=True, tools=['hover']] (cmap='RdBu') {+framewise}
hv_tu_ns = hv.Dataset(tu_ns_sub, kdims=['probe', 'unit_id', 'frame', 'occur'], vdims=['fr'])
hv_tu_ns.select(unit_id=slice(0,100), probe='probeA').to(hv.HeatMap, kdims=['frame', 'occur']).redim.range(fr=(-25,25))

In [None]:
%%opts HeatMap [height=600, width=800, colorbar=True, tools=['hover']] (cmap='RdBu') {+framewise}
n_trial = tu_ns_sub['trial'].max()
tu_ns_sub_plt = tu_ns_sub[tu_ns_sub['trial'].isin(np.arange(0, n_trial, 20))]
hv_tu_ns = hv.Dataset(tu_ns_sub_plt, kdims=['probe', 'unit_id', 'trial'], vdims=['fr'])
hv_tu_ns.select(probe='probeA', unit_id=slice(0, 100)).to(hv.HeatMap, kdims=['trial', 'unit_id']).redim.range(fr=(-25,25))

In [None]:
%%opts HeatMap [height=600, width=800, colorbar=True, tools=['hover']] (cmap='RdBu') {+framewise}
hv_tu_gr = hv.Dataset(tu_grt_sub, kdims=['probe', 'unit_id', 'gratings', 'occur'], vdims=['fr'])
hv_tu_gr.select(unit_id=slice(0,100), probe='probeA').to(hv.HeatMap, kdims=['gratings', 'occur']).redim.range(fr=(-25,25))

In [None]:
%%opts HeatMap [height=600, width=800, colorbar=True, tools=['hover']] (cmap='RdBu') {+framewise}
n_trial = tu_grt_sub['trial'].max()
tu_grt_sub_plt = tu_grt_sub[tu_grt_sub['trial'].isin(np.arange(0, n_trial, 20))]
hv_tu_gr = hv.Dataset(tu_grt_sub_plt, kdims=['probe', 'unit_id', 'trial'], vdims=['fr'])
hv_tu_gr.select(probe='probeA', unit_id=slice(0, 100)).to(hv.HeatMap, kdims=['trial', 'unit_id']).redim.range(fr=(-25,25))

## get firing activities

In [None]:
stim_tab_ns = ds.stim_tables['natural_scenes']
stim_tab_ns['occur'] = 1
stim_tab_ns['occur'] = stim_tab_ns.groupby('frame')['occur'].cumsum()
firing_list = []
for _, unit in ds.unit_df.iterrows():
    prob, uid = unit['probe'], unit['unit_id']
    spk = ds.spike_times[prob][uid]
    frs = da.delayed(fr_vec_nb)(spk, stim_tab_ns[['start', 'end']].values, 50)
    frs = da.delayed(xr.DataArray)(
        frs, dims=['trial', 'time_bin'],
        coords=dict(
            trial=np.arange(len(stim_tab_ns)),
            time_bin=np.arange(50)))
    frm_coords = da.delayed(xr.DataArray)(
        stim_tab_ns['frame'].values, dims=['trial'],
        coords=dict(trial=np.arange(len(stim_tab_ns))))
    occ_coords = da.delayed(xr.DataArray)(
        stim_tab_ns['occur'].values, dims=['trial'],
        coords=dict(trial=np.arange(len(stim_tab_ns))))
    frs = frs.assign_coords(
        probe=prob, unit_id=int(uid), frame=frm_coords, occur=occ_coords)
    firing_list.append(frs)

In [None]:
%%time
with ProgressBar():
    firing_list, = da.compute(firing_list)
firing = xr.concat(firing_list, dim='unit')
firing = firing / 0.005
firing = firing.assign_coords(unit=np.arange(firing.sizes['unit'])).rename('firing')

In [None]:
%%opts Image [colorbar=True]
hv_firing = hv.Dataset(firing, kdims=['trial', 'time_bin', 'unit'])
firing_plot = regrid(hv_firing.to(hv.Image, kdims=['time_bin', 'unit']))
firing_plot.opts(plot=dict(height=500, width=500))

In [None]:
%%time
firing.to_netcdf(os.path.join(exp_name, "firing.nc"), engine='netcdf4')

# get stimulus templates
## get mask

In [None]:
mask = resize(bom.get_mask(), dim_mon, mode='reflect')
mask = xr.DataArray(
    mask,
    dims=['height', 'width'],
    coords=dict(
        height=np.arange(mask.shape[0]),
        width=np.arange(mask.shape[1])))

## get natuaral image

In [None]:
cont_df = pd.DataFrame(boc.get_experiment_containers())
exps_df = pd.DataFrame(boc.get_ophys_experiments(str(cont_df['id'].iloc[0])))
oexp = boc.get_ophys_experiment_data(
    exps_df[exps_df['session_type'] == 'three_session_B']['id'].iloc[0])
ns = oexp.get_stimulus_template('natural_scenes')

In [None]:
temp_ns = xr.DataArray(
    ns, dims=['frame', 'height', 'width'],
    coords=dict(
        frame=np.arange(ns.shape[0]),
        height=np.arange(ns.shape[1]),
        width=np.arange(ns.shape[2])))

In [None]:
temp_ns = xr.apply_ufunc(
    lambda im: resize(bom.natural_scene_image_to_screen(im), dim_mon, mode='reflect'),
    temp_ns,
    input_core_dims=[['height', 'width']], 
    output_core_dims=[['height_new', 'width_new']],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[temp_ns.dtype]
)
temp_ns = temp_ns.rename(dict(height_new='height', width_new='width'))
temp_ns = temp_ns.assign_coords(
    height=np.arange(dim_mon[0]), width=np.arange(dim_mon[1]))
temp_ns = temp_ns.where(mask, drop=True).rename('natural_scenes')

In [None]:
hv_ns = hv.Dataset(temp_ns, kdims=['frame', 'height', 'width'])
ns_plot = hv_ns.to(hv.Image, kdims=['width', 'height'])
regrid(ns_plot).opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
from imageio import imwrite
ns = ns.astype(np.uint8)
for id_im, img in enumerate(ns):
    imwrite(os.path.join("./stim_temp/natural", str(id_im) + ".png"), img)

In [None]:
temp_ns.to_netcdf("./stim_temp/ns_ds4x.nc")

## get locally sparse noise templates

In [None]:
cont_df = pd.DataFrame(boc.get_experiment_containers())
exps_df = pd.DataFrame(boc.get_ophys_experiments(str(cont_df['id'].iloc[0])))
oexp = boc.get_ophys_experiment_data(
    exps_df[exps_df['session_type'] == 'three_session_C']['id'].iloc[0])

In [None]:
temps_lsn = oexp.get_stimulus_template('locally_sparse_noise')
temps_lsn = xr.DataArray(
    temps_lsn, dims=['frame', 'height', 'width'],
    coords=dict(
        frame=np.arange(temps_lsn.shape[0]),
        height=np.arange(temps_lsn.shape[1]),
        width=np.arange(temps_lsn.shape[2])))
temps_lsn = xr.apply_ufunc(
    lambda im: resize(bom.lsn_image_to_screen(im), dim_mon, mode='reflect'),
    temps_lsn.chunk(dict(frame=500)),
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height_new', 'width_new']],
    vectorize=True,
    output_dtypes=[temps_lsn.dtype],
    dask = 'parallelized',
    output_sizes=dict(height_new = dim_mon[0], width_new = dim_mon[1])
)
temps_lsn = temps_lsn.rename(dict(height_new='height', width_new='width'))
temps_lsn = temps_lsn.assign_coords(
    height=np.arange(mask.sizes['height']), width=np.arange(mask.sizes['width']))
temps_lsn = temps_lsn.where(mask, drop=True)
temps_lsn = temps_lsn.rename('locally_sparse_noise')

In [None]:
%%time
with ProgressBar():
    temps_lsn = temps_lsn.compute()

In [None]:
hv_temps_lsn = hv.Dataset(temps_lsn, kdims=['height', 'width', 'frame'])
temps_lsn_plot = regrid(hv_temps_lsn.to(hv.Image, kdims=['width', 'height']))
temps_lsn_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%time
temps_lsn.to_netcdf('stim_temp/lsn_ds4x.nc', engine='netcdf4')

## get static gratings

In [None]:
tuning_df = ds.stim_tables['static_gratings']
tuning_df_grat = tuning_df[tuning_df['stim_type'] == 'static_gratings'].dropna(axis='columns')
grat_list = []
for ph, sfreq, ori in itt.product(tuning_df_grat['phase'].unique(),
                                  tuning_df_grat['spatial_frequency'].unique(),
                                  tuning_df_grat['orientation'].unique()):
    cur_grat = da.delayed(bom.grating_to_screen)(ph, sfreq, ori)
    cur_grat = da.delayed(resize)(cur_grat, dim_mon, mode='reflect')
    cur_grat = da.delayed(xr.DataArray)(cur_grat,
                            dims=['height', 'width'],
                            coords=dict(height=np.arange(mask.sizes['height']),
                                        width=np.arange(mask.sizes['width'])))
    cur_grat = cur_grat.assign_coords(phase=ph, spatial_frequency=sfreq, orientation=ori)
    grat_list.append(cur_grat)

In [None]:
%%time
with ProgressBar():
    grats, = da.compute(grat_list)
gratings = xr.concat(grats, dim='gratings')
gratings = gratings.assign_coords(gratings=np.arange(gratings.sizes['gratings']))
gratings = gratings.where(mask, drop=True).rename('static_gratings')

In [None]:
hv_gratings = hv.Dataset(gratings)
gratings_plot = regrid(hv_gratings.to(hv.Image, kdims=['width', 'height']))
gratings_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%time
gratings.to_netcdf('stim_temp/sgrat_ds4x.nc', engine='netcdf4')

# compute preferred stimulus
## compute static gratings

In [None]:
gratings = xr.open_dataarray('stim_temp/sgrat_ds4x.nc', autoclose=True)
tuning_df_grat = pd.read_hdf(os.path.join(exp_name, "tuning_grt.h5"), "tuning")
gratings = gratings.set_index(gratings=['phase', 'spatial_frequency', 'orientation']).load()

In [None]:
%%time
tuning_df_grat = tuning_df_grat.groupby(
    ['probe', 'unit_id', 'orientation', 'phase', 'spatial_frequency']).mean()
tuning_df_grat = tuning_df_grat.reset_index().drop('occur', axis=1)

In [None]:
sgrat_tu_list = []
tuning_df_grat = tuning_df_grat.sort_values(['phase', 'spatial_frequency', 'orientation'])
for uid, unit_df in tuning_df_grat.groupby(['probe', 'unit_id']):
    tuning = unit_df['fr'].values
    tuning = xr.DataArray(tuning, dims=['gratings'])
    tuning = tuning.assign_coords(gratings = gratings.coords['gratings'])
    sgrat_tu = da.delayed(gratings).dot(tuning, 'gratings')
    sgrat_tu = sgrat_tu.assign_coords(probe=uid[0], unit_id=uid[1])
    sgrat_tu_list.append(sgrat_tu)
sgrat_tu = da.delayed(xr.concat)(sgrat_tu_list, dim='unit')
sgrat_tu = sgrat_tu.assign_coords(unit=da.delayed(np.arange)(sgrat_tu.sizes['unit']))

In [None]:
%%time
with ProgressBar():
    sgrat_tu = sgrat_tu.compute()
    sgrat_tu = xr.apply_ufunc(
        normalize,
        sgrat_tu.chunk(dict(unit=100)),
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[sgrat_tu.dtype]
    )
    sgrat_tu = sgrat_tu.compute().rename('static_gratings_tuning')

In [None]:
%%opts Image [colorbar=True]
hv_sgrat_tu = hv.Dataset(sgrat_tu)
sgrat_tu_plot = regrid(hv_sgrat_tu.to(hv.Image, kdims=['width', 'height']))
sgrat_tu_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%time
sgrat_tu.to_netcdf(os.path.join(exp_name, "pref_sgrat.nc"))

## compute preferred natural images

In [None]:
ns = xr.open_dataarray('stim_temp/ns_ds4x.nc', autoclose=True).load()
tuning_df_ns = pd.read_hdf(os.path.join(exp_name, "tuning_ns.h5"), "tuning")
tuning_df_ns = tuning_df_ns[tuning_df_ns['frame'] >= 0]

In [None]:
%%time
tuning_df_ns = tuning_df_ns.groupby(
    ['probe', 'unit_id', 'frame']).mean()
tuning_df_ns = tuning_df_ns.reset_index().drop('occur', axis=1)

In [None]:
ns_tu_list = []
for uid, unit_df in tuning_df_ns.groupby(['probe', 'unit_id']):
    tuning = unit_df['fr'].values
    tuning = xr.DataArray(
        tuning, dims=['frame'],
        coords=dict(frame=unit_df['frame'].values))
    ns_tu = da.delayed(ns).dot(tuning, 'frame')
    ns_tu = ns_tu.assign_coords(probe=uid[0], unit_id=uid[1])
    ns_tu_list.append(ns_tu)
ns_tu = da.delayed(xr.concat)(ns_tu_list, dim='unit')
ns_tu = ns_tu.assign_coords(
    unit=da.delayed(np.arange)(ns_tu.sizes['unit']))

In [None]:
%%time
with ProgressBar():
    ns_tu = ns_tu.compute()
    ns_tu = xr.apply_ufunc(
        normalize,
        ns_tu.chunk(dict(unit=100)),
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[ns_tu.dtype]
    )
    ns_tu = ns_tu.compute().rename('natural_images_tuning')

In [None]:
hv_ns_tu = hv.Dataset(ns_tu, kdims=['height', 'width', 'unit'])
ns_tu_plot = regrid(hv_ns_tu.to(hv.Image, kdims=['width', 'height']))
ns_tu_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%time
ns_tu.to_netcdf(os.path.join(exp_name, "pref_ns.nc"))

# get neuronal images
## get natural neuronal images

In [None]:
sgrat_tu = xr.open_dataarray(
    os.path.join(exp_name, "pref_sgrat.nc"),
    autoclose=True, chunks=dict(height=100, width=100))
ns_tu = xr.open_dataarray(
    os.path.join(exp_name, "pref_ns.nc"),
    autoclose=True, chunks=dict(height=100, width=100))
sgrat_tu = sgrat_tu.set_index(unit=['probe', 'unit_id']).fillna(0)
ns_tu = ns_tu.set_index(unit=['probe', 'unit_id']).fillna(0)
tuning_ns = pd.read_hdf(os.path.join(exp_name, "tuning_ns.h5"), "tuning")
firing = tuning_ns.set_index(['probe', 'unit_id', 'frame', 'occur']).to_xarray()['fr']
firing = firing.stack(unit=['probe', 'unit_id']).dropna('unit')

In [None]:
firing_norm = xr.apply_ufunc(
    norm_sum, firing,
    input_core_dims=[['unit']], output_core_dims=[['unit']],
    dask='parallelized', output_dtypes=[firing.dtype])
nis_from_gr = firing_norm.dot(sgrat_tu, 'unit')
nis_from_ns = firing_norm.dot(ns_tu, 'unit')

In [None]:
%%time
with ProgressBar():
    nis_from_gr = nis_from_gr.compute()
    nis_from_ns = nis_from_ns.compute()

In [None]:
nis_from_gr = nis_from_gr.where(mask)
nis_from_ns = nis_from_ns.where(mask)

In [None]:
%%opts Image {+framewise}
hv_nis_gr = hv.Dataset(nis_from_gr.rename("nis"))
nis_gr_plot = regrid(hv_nis_gr.to(hv.Image, kdims=['width', 'height']))
nis_gr_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%opts Image {+framewise}
hv_nis_ns = hv.Dataset(nis_from_ns.rename("nis"))
nis_ns_plot = regrid(hv_nis_ns.to(hv.Image, kdims=['width', 'height']))
nis_ns_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
nis_from_gr.to_netcdf(os.path.join(exp_name, "nis_from_gr.nc"), engine='netcdf4')
nis_from_ns.to_netcdf(os.path.join(exp_name, "nis_from_ns.nc"), engine='netcdf4')

## get grating neuronal images

In [None]:
sgrat_tu = xr.open_dataarray(
    os.path.join(exp_name, "pref_sgrat.nc"),
    autoclose=True, chunks=dict(height=100, width=100))
ns_tu = xr.open_dataarray(
    os.path.join(exp_name, "pref_ns.nc"),
    autoclose=True, chunks=dict(height=100, width=100))
sgrat_tu = sgrat_tu.set_index(unit=['probe', 'unit_id']).fillna(0)
ns_tu = ns_tu.set_index(unit=['probe', 'unit_id']).fillna(0)
tuning_gr = pd.read_hdf(os.path.join(exp_name, "tuning_grt.h5"), "tuning")
firing = tuning_gr.set_index(
    ['probe', 'unit_id', 'orientation', 'phase', 'spatial_frequency', 'occur'])
firing = firing.to_xarray()['fr']
firing = firing.stack(
    unit=['probe', 'unit_id'],
    gratings=['orientation', 'phase', 'spatial_frequency'])
firing = firing.dropna('unit', 'all').dropna('gratings', 'all')

In [None]:
firing_norm = xr.apply_ufunc(
    norm_sum, firing,
    input_core_dims=[['unit']], output_core_dims=[['unit']],
    dask='parallelized', output_dtypes=[firing.dtype])
grs_from_gr = firing_norm.dot(sgrat_tu, 'unit')
grs_from_ns = firing_norm.dot(ns_tu, 'unit')

In [None]:
%%time
with ProgressBar():
    grs_from_gr = grs_from_gr.compute()
    grs_from_ns = grs_from_ns.compute()

In [None]:
grs_from_gr = grs_from_gr.where(mask).unstack('gratings')
grs_from_ns = grs_from_ns.where(mask).unstack('gratings')

In [None]:
%%opts Image {+framewise}
hv_grs_gr = hv.Dataset(grs_from_gr.rename("grs"))
grs_gr_plot = regrid(hv_grs_gr.to(hv.Image, kdims=['width', 'height']))
grs_gr_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
%%opts Image {+framewise}
hv_grs_ns = hv.Dataset(grs_from_ns.rename("grs"))
grs_ns_plot = regrid(hv_grs_ns.to(hv.Image, kdims=['width', 'height']))
grs_ns_plot.opts(plot=dict(height=dim_mon[0], width=dim_mon[1]))

In [None]:
grs_from_gr.to_netcdf(os.path.join(exp_name, "grs_from_gr.nc"), engine='netcdf4')
grs_from_ns.to_netcdf(os.path.join(exp_name, "grs_from_ns.nc"), engine='netcdf4')

# compare neuronal images
## aggregate across `occur`

In [None]:
%%time
chk = dict(height=50, width=50)
grs_from_gr = xr.open_dataarray(os.path.join(exp_name, "grs_from_gr.nc"), chunks=chk).persist()
nis_from_gr = xr.open_dataarray(os.path.join(exp_name, "nis_from_gr.nc"), chunks=chk).persist()
grs_from_ns = xr.open_dataarray(os.path.join(exp_name, "grs_from_ns.nc"), chunks=chk).persist()
nis_from_ns = xr.open_dataarray(os.path.join(exp_name, "nis_from_ns.nc"), chunks=chk).persist()

In [None]:
%%time
with np.errstate(invalid='ignore'), da.config.set(scheduler='threads'), ProgressBar():
    grs_from_gr_mean = grs_from_gr.mean('occur').compute()
    nis_from_gr_mean = nis_from_gr.mean('occur').compute()
    grs_from_ns_mean = grs_from_ns.mean('occur').compute()
    nis_from_ns_mean = nis_from_ns.mean('occur').compute()

In [None]:
%%time
grs_from_gr_mean.to_netcdf(os.path.join(exp_name, "grs_from_gr_mean.nc"))
nis_from_gr_mean.to_netcdf(os.path.join(exp_name, "nis_from_gr_mean.nc"))
grs_from_ns_mean.to_netcdf(os.path.join(exp_name, "grs_from_ns_mean.nc"))
nis_from_ns_mean.to_netcdf(os.path.join(exp_name, "nis_from_ns_mean.nc"))

## visulalize

In [None]:
opts = dict(plot=dict(height=dim_mon[0], width=dim_mon[1]))
hv_grs_from_gr_mean = hv.Dataset(grs_from_gr_mean.rename("grs")).to(hv.Image, kdims=['width', 'height'])
hv_nis_from_gr_mean = hv.Dataset(nis_from_gr_mean.rename("nis")).to(hv.Image, kdims=['width', 'height'])
regrid(hv_grs_from_gr_mean).opts(**opts) + regrid(hv_nis_from_gr_mean).opts(**opts)

In [None]:
opts = dict(plot=dict(height=dim_mon[0], width=dim_mon[1]))
hv_grs_from_ns_mean = hv.Dataset(grs_from_ns_mean.rename("grs")).to(hv.Image, kdims=['width', 'height'])
hv_nis_from_ns_mean = hv.Dataset(nis_from_ns_mean.rename("nis")).to(hv.Image, kdims=['width', 'height'])
regrid(hv_grs_from_ns_mean).opts(**opts) + regrid(hv_nis_from_ns_mean).opts(**opts)

## export

In [None]:
grs_from_gr_mean = xr.open_dataarray(os.path.join(exp_name, "grs_from_gr_mean.nc"), autoclose=True)
grs_from_ns_mean = xr.open_dataarray(os.path.join(exp_name, "grs_from_ns_mean.nc"), autoclose=True)
nis_from_gr_mean = xr.open_dataarray(os.path.join(exp_name, "nis_from_gr_mean.nc"), autoclose=True)
nis_from_ns_mean = xr.open_dataarray(os.path.join(exp_name, "nis_from_ns_mean.nc"), autoclose=True)

In [None]:
grs_from_gr_mean = grs_from_gr_mean.stack(gratings=['phase', 'orientation', 'spatial_frequency']).transpose('gratings', 'height', 'width')
grs_from_ns_mean = grs_from_ns_mean.stack(gratings=['phase', 'orientation', 'spatial_frequency']).transpose('gratings', 'height', 'width')
nis_from_gr_mean = nis_from_gr_mean.transpose('frame', 'height', 'width')
nis_from_ns_mean = nis_from_ns_mean.transpose('frame', 'height', 'width')
grs_from_gr_mean = xr.apply_ufunc(
    normalize, grs_from_gr_mean, input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']], vectorize=True) * 255
grs_from_ns_mean = xr.apply_ufunc(
    normalize, grs_from_ns_mean, input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']], vectorize=True) * 255
nis_from_gr_mean = xr.apply_ufunc(
    normalize, nis_from_gr_mean, input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']], vectorize=True) * 255
nis_from_ns_mean = xr.apply_ufunc(
    normalize, nis_from_ns_mean, input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']], vectorize=True) *255
grs_from_gr_mean = grs_from_gr_mean.fillna(0).astype(np.uint8)
grs_from_ns_mean = grs_from_ns_mean.fillna(0).astype(np.uint8)
nis_from_gr_mean = nis_from_gr_mean.fillna(0).astype(np.uint8)
nis_from_ns_mean = nis_from_ns_mean.fillna(0).astype(np.uint8)

In [None]:
from imageio import imwrite
for id_im, im in enumerate(grs_from_gr_mean.values):
    try:
        imwrite(os.path.join(exp_name, "neuronal_images", "from_grt", "gratings", str(id_im) + '.png'), im)
    except OSError:
        os.makedirs(os.path.join(os.path.join(exp_name, "neuronal_images", "from_grt", "gratings")))
        imwrite(os.path.join(exp_name, "neuronal_images", "from_grt", "gratings", str(id_im) + '.png'), im)
for id_im, im in enumerate(grs_from_ns_mean.values):
    try:
        imwrite(os.path.join(exp_name, "neuronal_images", "from_ns", "gratings", str(id_im) + '.png'), im)    
    except OSError:
        os.makedirs(os.path.join(exp_name, "neuronal_images", "from_ns", "gratings"))
        imwrite(os.path.join(exp_name, "neuronal_images", "from_ns", "gratings", str(id_im) + '.png'), im)    
for id_im, im in enumerate(nis_from_gr_mean.values):
    try:
        imwrite(os.path.join(exp_name, "neuronal_images", "from_grt", "natural_scenes", str(id_im) + '.png'), im)
    except OSError:
        os.makedirs(os.path.join(exp_name, "neuronal_images", "from_grt", "natural_scenes"))
        imwrite(os.path.join(exp_name, "neuronal_images", "from_grt", "natural_scenes", str(id_im) + '.png'), im)
for id_im, im in enumerate(nis_from_ns_mean.values):
    try:
        imwrite(os.path.join(exp_name, "neuronal_images", "from_ns", "natural_scenes", str(id_im) + '.png'), im)
    except OSError:
        os.makedirs(os.path.join(exp_name, "neuronal_images", "from_ns", "natural_scenes"))
        imwrite(os.path.join(exp_name, "neuronal_images", "from_ns", "natural_scenes", str(id_im) + '.png'), im)

# analysis of fring rates
## subtract baseline from firing

In [None]:
baseline = pd.read_hdf(os.path.join(exp_name, "baseline.h5"))
baseline = baseline.query(
    "window == '(-0.2, 0)' and stim_type == 'natural_scenes'")
baseline['unit_id'] = baseline['unit_id'].astype(int)
base_fr = baseline.set_index(
    ['trial', 'probe', 'unit_id'])['fr']
base_fr = base_fr.to_xarray()
base_tr = baseline[['trial', 'frame', 'occur']].drop_duplicates()
coords_occ = xr.DataArray(
    base_tr['occur'].values, dims=['trial'],
    coords=dict(trial=base_tr['trial'].values))
coords_fr = xr.DataArray(
    base_tr['frame'].values, dims=['trial'],
    coords=dict(trial=base_tr['trial'].values))
base_fr = base_fr.assign_coords(frame=coords_fr, occur=coords_occ)

In [None]:
stim_fr = xr.open_dataarray(os.path.join(exp_name, "firing.nc"))
stim_fr = stim_fr.set_index(
    unit=['unit_id', 'probe'])
stim_fr = stim_fr.unstack('unit')

In [None]:
%%time
sub_fr = stim_fr - base_fr.mean('trial')
sub_fr = sub_fr.stack(unit=['unit_id', 'probe']).dropna('unit')
sub_fr = sub_fr.reset_index('unit').assign_coords(unit=np.arange(sub_fr.sizes['unit']))

In [None]:
%%time
sub_fr.to_netcdf(os.path.join(exp_name, "sub_firing.nc"))

## PCA of raster

In [6]:
sub_fr = xr.open_dataarray(os.path.join(exp_name, "sub_firing.nc"))

In [7]:
sub_fr = sub_fr.set_index(unit=['unit_id', 'probe']).unstack('unit')

In [8]:
sub_fr_stk = sub_fr.stack(
    features=['unit_id', 'probe', 'time_bin']).dropna('features')
# sub_fr_stk = sub_fr_stk.stack(samples=['frame', 'occur'])
sub_fr_stk = sub_fr_stk.rename(dict(trial='samples'))

In [9]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scl = StandardScaler()

In [10]:
%%time
sub_fr_stk_norm = xr.apply_ufunc(
    scl.fit_transform, sub_fr_stk,
    input_core_dims=[['samples', 'features']],
    output_core_dims=[['samples', 'features']])
sub_fr_stk_norm = sub_fr_stk_norm.transpose('samples', 'features')

CPU times: user 3.91 s, sys: 3.66 s, total: 7.58 s
Wall time: 7.57 s


In [11]:
%%time
pca = PCA(n_components=100)
pca.fit(sub_fr_stk_norm.values)
sub_fr_stk_trans = xr.apply_ufunc(
    pca.transform, sub_fr_stk_norm,
    input_core_dims=[['samples', 'features']],
    output_core_dims=[['samples', 'PC']])
sub_fr_stk_trans = sub_fr_stk_trans.assign_coords(
    PC='PC' + np.char.array(np.arange(sub_fr_stk_trans.sizes['PC'])))

CPU times: user 1min 54s, sys: 17.9 s, total: 2min 12s
Wall time: 21.7 s


In [12]:
hvobj = hv.Bars(pca.explained_variance_ratio_, 'components', 'var_explained', label="PCA: Variance Explained")
hvobj = hvobj.opts(plot=dict(height=500, width=1200, xrotation=90))
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "pca_variance_explained"))
hvobj



In [13]:
pd_fr_trans = sub_fr_stk_trans.rename(
    dict(samples='trial')).to_pandas().reset_index()
pre_df = sub_fr.set_index(mul=['frame', 'occur']).coords['mul'].to_pandas()
pre_df = pre_df.reset_index().reset_index().rename(dict(index='trial'), axis='columns')
pre_df = pre_df.drop(0, axis='columns')
pd_fr_trans = pd_fr_trans.merge(pre_df, on='trial', validate='one_to_one')

In [14]:
pd_fr_trans_sub = pd_fr_trans[(pd_fr_trans['trial'] < 6000) | (pd_fr_trans['trial'] > 0)]
pd_fr_trans_sub = pd_fr_trans_sub[pd_fr_trans_sub['frame'].isin(np.arange(-1, 50))]

In [15]:
%%capture
opts = dict(
    plot=dict(height=350, width=400, color_index=2, colorbar=True, tools=['hover']),
    style=dict(cmap='Category20b', size=2.5))
hvobj = (
    hv.Scatter(pd_fr_trans_sub, kdims=['PC0'], vdims=['PC1', 'frame'], label="colored by frame").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC0'], vdims=['PC1', 'occur'], label="colored by occurance").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC0'], vdims=['PC1', 'trial'], label="colored by trial").opts(**opts))
hvobj = hvobj.relabel('Principal Components 0 vs 1')
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "pc01"))



In [16]:
hvobj

In [17]:
%%capture
opts = dict(
    plot=dict(height=350, width=400, color_index=2, colorbar=True, tools=['hover']),
    style=dict(cmap='Category20b', size=2.5))
hvobj = (
    hv.Scatter(pd_fr_trans_sub, kdims=['PC1'], vdims=['PC2', 'frame'], label="colored by frame").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC1'], vdims=['PC2', 'occur'], label="colored by occurance").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC1'], vdims=['PC2', 'trial'], label="colored by trial").opts(**opts))
hvobj = hvobj.relabel('Principal Components 1 vs 2')
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "pc12"))



In [None]:
hvobj

In [None]:
%%capture
opts = dict(
    plot=dict(height=350, width=400, color_index=2, colorbar=True, tools=['hover']),
    style=dict(cmap='Category20b', size=2.5))
hvobj = (
    hv.Scatter(pd_fr_trans_sub, kdims=['PC2'], vdims=['PC3', 'frame'], label="colored by frame").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC2'], vdims=['PC3', 'occur'], label="colored by occurance").opts(**opts) + \
    hv.Scatter(pd_fr_trans_sub, kdims=['PC2'], vdims=['PC3', 'trial'], label="colored by trial").opts(**opts))
hvobj = hvobj.relabel('Principal Components 2 vs 3')
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "pc23"))

In [None]:
hvobj

In [None]:
PC = xr.DataArray(pca.components_,
                   dims=['PC', 'features'],
                   coords=dict(PC=sub_fr_stk_trans.coords['PC'], features=sub_fr_stk.coords['features']))
PC = PC.unstack('features').stack(unit=['probe', 'unit_id']).dropna('unit', 'all')
# PC = PC.unstack('features')

In [None]:
PC_plt = PC.rename('PCs').assign_coords(unit=np.arange(PC.sizes['unit'])).sel(unit=np.arange(0, PC.sizes['unit'], 2), PC=['PC0', 'PC1', 'PC2', 'PC3'])
# PC_plt = PC.rename('PCs').sel(PC=['PC0', 'PC1', 'PC2', 'PC3'])

In [None]:
opts = dict(
    plot=dict(height=500, width=900, colorbar=True),
    style=dict(cmap='RdBu_r'))
hv_PC = hv.Dataset(PC_plt, kdims=['unit', 'time_bin', 'PC']).to(hv.Image, kdims=['unit', 'time_bin']).opts(**opts)
regrid(hv_PC).opts(**opts)

In [None]:
fr_pc = sub_fr_stk.groupby_bins('samples', 4).mean('samples')
fr_pc = fr_pc.unstack('features').stack(unit=['probe', 'unit_id']).dropna('unit', 'all')
fr_pc = fr_pc.assign_coords(unit=np.arange(fr_pc.sizes['unit']))

In [29]:
opts = dict(
    plot=dict(height=400, width=500, colorbar=True, symmetric=True),
    style=dict(cmap="RdBu_r"),
    norm=(dict(axiswise=True)))
hv_fr_pc = hv.Dataset(fr_pc.rename('fr_pc')).to(
    hv.Image, kdims=['time_bin', 'unit']).opts(**opts).layout('samples_bins').redim.range(fr_pc=(-3, 3))
hv_pc = hv.Dataset(PC_plt.sel(PC=['PC0', 'PC1', 'PC2', 'PC3']).rename('pc')).to(
    hv.Image, kdims=['time_bin', 'unit']).opts(**opts).layout('PC').redim.range(pc=(-0.015, 0.015))
hvobj = (hv_fr_pc.opts(**opts) + hv_pc.opts(**opts)).cols(1)
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "fr_pc"))

In [30]:
hvobj

In [None]:
%%opts Curve [width=500, height=300, show_legend=True]
PC_plt_crv = PC_plt.mean('unit')
hv_PC_plt_crv = hv.Dataset(PC_plt_crv.rename('pc_crv'))
hv_PC_plt_crv.to(hv.Curve, kdims=['time_bin']).overlay('PC')

In [None]:
%%opts Scatter [color_index=2, height=500, width=800] (size=4)
hv_pre = hv.Dataset(pre_df, kdims=['trial'], vdims=['frame', 'occur'])
hv_pre.to(hv.Scatter)

In [31]:
sub_fr_mean = sub_fr.mean('time_bin')
sub_fr_mean = sub_fr_mean.stack(unit=['probe', 'unit_id']).dropna('unit', 'all')
sub_fr_mean = sub_fr_mean.assign_coords(unit=np.arange(sub_fr_mean.sizes['unit']))
sub_fr_mean = sub_fr_mean.rename('sub_fr')

In [None]:
opts = dict(
    plot=dict(colorbar=True, symmetric=True, height=500, width=1000),
    style=dict(cmap="RdBu_r"))
hv_sub_fr = hv.Image(sub_fr_mean, kdims=['trial', 'unit'])
hvobj = hv_sub_fr.opts(**opts).redim.range(sub_fr=(-10, 10))
_save_to_svg(hvobj, os.path.join(exp_name, "imgs", "raw_raster"))

In [None]:
regrid(hvobj).opts(**opts)

## PCA of tuning for natural images

In [None]:
tuning_ns = pd.read_hdf(os.path.join(exp_name, "tuning_ns.h5"), "tuning")
tu_ns = tuning_ns.set_index(['probe', 'unit_id', 'frame', 'occur']).to_xarray()['fr']
tu_ns = tu_ns.stack(unit=['probe', 'unit_id']).dropna('unit')

In [None]:
tu_ns_flat = tu_ns.stack(trial=['frame', 'occur']).transpose('trial', 'unit')
tu_ns_mean = tu_ns.mean('occur').transpose('frame', 'unit')

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scl = StandardScaler()
tu_ns_flat_norm = xr.apply_ufunc(
    scl.fit_transform, tu_ns_flat, input_core_dims=[['trial', 'unit']],
    output_core_dims=[['trial', 'unit']])
pca = PCA(n_components=20)
pca.fit(tu_ns_flat_norm.values)
tu_ns_flat_trans = xr.apply_ufunc(
    pca.transform, tu_ns_flat_norm, input_core_dims=[['trial', 'unit']],
    output_core_dims=[['trial', 'PC']])
tu_ns_flat_trans = tu_ns_flat_trans.assign_coords(
    PC='PC' + np.char.array(np.arange(tu_ns_flat_trans.sizes['PC'])))

In [None]:
%%opts Bars [height=500, width=800]
hv.Bars(pca.explained_variance_ratio_, 'components', 'var_explained')

In [None]:
pd_ns_trans = tu_ns_flat_trans.isel(PC=[0,1]).to_pandas().reset_index()

In [None]:
%%opts Scatter [height=600, width=800, color_index=2, colorbar=True] (cmap='Category20b', size=10)
hv.Scatter(pd_ns_trans[pd_ns_trans['frame'].isin(np.arange(-1, 4))], kdims=['PC0'], vdims=['PC1', 'frame'])