In [None]:
dpath = '/data/dynamic-brain-workshop/visual_coding_neuropixels'

In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import itertools as itt
import holoviews as hv
import dask as da
import dask
import xarray as xr
from IPython.core.debugger import set_trace
from dask.diagnostics import ProgressBar
from holoviews.operation.datashader import regrid, datashade
from __future__ import print_function
hv.extension('bokeh', width=100)

In [None]:
manifest_file = os.path.join(dpath,'ephys_manifest.csv')
expt_info_df = pd.read_csv(manifest_file)
expt_info_df.head()

In [None]:
from swdb_2018_neuropixels.ephys_nwb_adapter import NWB_adapter
from allensdk.core.brain_observatory_cache import BrainObservatoryCache
from allensdk.brain_observatory.stimulus_info import BrainObservatoryMonitor
# ds = NWB_adapter(os.path.join(dpath, expt_info_df['nwb_filename'].iloc[6]))
boc = BrainObservatoryCache(manifest_file='/data/allen-brain-observatory/visual-coding-2p/manifest.json')
bom = BrainObservatoryMonitor()
dim_mon = (1200, 1920)

In [None]:
def sub_spike(spk, ts, te):
    return spk[(spk > ts) & (spk < te)]

def fr(spk, ts, te, bins):
    spk_sub = sub_spike(spk, ts, te)
    if bins is None:
        return len(spk_sub) / (te - ts)
    elif isinstance(bins, float):
        bins = np.ceil((te - ts) / bins)
    return np.histogram(spk, bins, range=(ts, te))

def tuning_per_stim(spk, stim):
    return np.mean(stim.apply(
        lambda r: fr(spk, r['start'], r['end'], None), axis=1))

def tuning_per_unit(spk, stim):
    grp_dim = list(set(stim.columns) - set(['start', 'end']))
    return stim.groupby(grp_dim).apply(
        lambda df: tuning_per_stim(spk, df)).rename('tuning').reset_index()

In [None]:
%%time
tuning_list = []
for unit, stim_type in itt.product(
    [r for _, r in ds.unit_df.iterrows()], set(ds.stim_list) - set(['spontaneous'])):
    prob, uid = unit['probe'], unit['unit_id']
    tu = tuning_per_unit(ds.spike_times[prob][uid], ds.stim_tables[stim_type])
    tu['probe'] = prob
    tu['unit_id'] = uid
    tu['stim_type'] = stim_type
    tuning_list.append(tu)
tuning_df = pd.concat(tuning_list, axis=0, ignore_index=True)

In [None]:
tuning_df = pd.read_csv(
    '/home/ec2-user/SageMaker/phildong/swdb_2018_tools/swdb_2018_tools/deencoder/ephys_multi_58_tuning.csv').drop('Unnamed: 0', axis=1)

# get stimulus templates
## get mask

In [None]:
mask = bom.get_mask()
mask = xr.DataArray(mask,
                    dims=['height', 'width'],
                    coords=dict(height=np.arange(mask.shape[0]), width=np.arange(mask.shape[1])))

## get locally sparse noise templates

In [None]:
cont_df = pd.DataFrame(boc.get_experiment_containers())
exps_df = pd.DataFrame(boc.get_ophys_experiments(str(cont_df['id'].iloc[0])))
oexp = boc.get_ophys_experiment_data(exps_df[exps_df['session_type'] == 'three_session_C']['id'].iloc[0])

In [None]:
oexp.get_stimulus_epoch_table()

In [None]:
temps_lsn = oexp.get_stimulus_template('locally_sparse_noise')
temps_lsn = xr.DataArray(
    temps_lsn, dims=['frame', 'height', 'width'],
    coords=dict(
        frame=np.arange(temps_lsn.shape[0]),
        height=np.arange(temps_lsn.shape[1]),
        width=np.arange(temps_lsn.shape[2])))
temps_lsn = xr.apply_ufunc(
    bom.lsn_image_to_screen,
    temps_lsn.chunk(dict(frame=500)),
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height_new', 'width_new']],
    vectorize=True,
    output_dtypes=[temps_lsn.dtype],
    dask = 'parallelized',
    output_sizes=dict(height_new = dim_mon[0], width_new = dim_mon[1])
)
temps_lsn = temps_lsn.where(mask, drop=True)
temps_lsn = temps_lsn.rename(dict(height_new='height', width_new='width'))
temps_lsn = temps_lsn.assign_coords(height=np.arange(temps_lsn.sizes['height']), width=np.arange(temps_lsn.sizes['width']))
temps_lsn = temps_lsn.rename('locally_sparse_noise').to_dataset()

In [None]:
with ProgressBar():
    temps_lsn.to_netcdf('/home/ec2-user/SageMaker/phildong/swdb_2018_tools/swdb_2018_tools/deencoder/temps_lsn.nc', engine='netcdf4')

## get static gratings

In [None]:
tuning_df_grat = tuning_df[tuning_df['stim_type'] == 'static_gratings'].dropna(axis='columns')
grat_list = []
for ph, sfreq, ori in itt.product(tuning_df_grat['phase'].unique(),
                                  tuning_df_grat['spatial_frequency'].unique(),
                                  tuning_df_grat['orientation'].unique()):
    cur_grat = da.delayed(bom.grating_to_screen)(ph, sfreq, ori)
    cur_grat = da.delayed(xr.DataArray)(cur_grat,
                            dims=['height', 'width'],
                            coords=dict(height=np.arange(mask.sizes['height']),
                                        width=np.arange(mask.sizes['width'])))
    cur_grat = cur_grat.assign_coords(phase=ph, spatial_frequency=sfreq, orientation=ori)
    grat_list.append(cur_grat)

In [None]:
with ProgressBar():
    grats, = da.compute(grat_list)

In [None]:
gratings = xr.concat(grats, dim='gratings').rename('gratings')
gratings.to_netcdf('/home/ec2-user/SageMaker/phildong/swdb_2018_tools/swdb_2018_tools/deencoder/temps_sgrat.nc', engine='netcdf4')

In [None]:
hv_gratings = hv.Dataset(gratings.rename('grat').reset_index('gratings').assign_coords(gratings=np.arange(gratings.sizes['gratings'])))

In [None]:
%%output size=30
regrid(hv_gratings.to(hv.Image, kdims=['width', 'height'])).opts(plot=dict(height=gratings.sizes['height'], width=gratings.sizes['width']))

# compute preferred stimulus
## compute static gratings

In [None]:
gratings = xr.open_dataarray('/home/ec2-user/SageMaker/phildong/swdb_2018_tools/swdb_2018_tools/deencoder/temps_sgrat.nc', autoclose=True)
tuning_df_grat = tuning_df[tuning_df['stim_type'] == 'static_gratings'].dropna(axis='columns')
gratings = gratings.set_index(gratings=['phase', 'spatial_frequency', 'orientation'])

In [None]:
def compute_preferred_stimulus(stim, tuning):
    stim_sum = stim.dot(tuning)
    return (stim_sum - stim_sum.min()) / (stim_sum.max() - stim_sum.min())

In [None]:
pstim_list = []
gratings = gratings.stack(spatial=['height', 'width']).astype(np.float32)
tuning_df_grat = tuning_df_grat.sort_values(['phase', 'spatial_frequency', 'orientation'])
tuning_df_grat['tuning'] = tuning_df_grat['tuning'].astype(np.float32)
for uid, unit_df in tuning_df_grat.groupby(['probe', 'unit_id']):
    tuning = unit_df['tuning'].values
    tuning = xr.DataArray(tuning, dims=['g'])
    pstim = da.delayed(xr.apply_ufunc)(
        compute_preferred_stimulus,
        gratings,
        tuning,
        input_core_dims=[['spatial', 'gratings'], ['g']],
        output_core_dims=[['spatial']]
    )
    pstim = pstim.assign_coords(probe=uid[0], unit_id=uid[1])
    pstim_list.append(pstim)
pstims = da.delayed(xr.concat)(pstim_list, dim='unit')
pstims = pstims.unstack('spatial')

In [None]:
%%time
with ProgressBar():
    pstims = pstims.compute()

In [None]:
%%time
pstims.to_netcdf('/home/ec2-user/SageMaker/phildong/swdb_2018_tools/swdb_2018_tools/deencoder/tuning_sgrat.nc')