# ols_results_delay

In [7]:
# General
import sys
import os.path as op
import copy
from time import time
from collections import OrderedDict as od
from glob import glob
import itertools
import warnings
from importlib import reload
# from cluster_helper.cluster import cluster_view

# Scientific
import numpy as np
import pandas as pd
pd.options.display.max_rows = 100
pd.options.display.max_columns = 999

# Stats
import statsmodels.api as sm
from statsmodels.formula.api import ols
from sklearn.utils.fixes import loguniform
import scipy.stats as stats
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import minmax_scale, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.model_selection import KFold, GridSearchCV, RandomizedSearchCV
from sklearn.metrics import confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity
import patsy

# Plots
warnings.filterwarnings( 'ignore' )
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import matplotlib as mpl
from matplotlib.lines import Line2D
import matplotlib.patches as patches
# from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['grid.linewidth'] = 0.1
mpl.rcParams['grid.alpha'] = 0.75
mpl.rcParams['lines.linewidth'] = 1
mpl.rcParams['lines.markersize'] = 3
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12
mpl.rcParams['xtick.major.width'] = 0.8
mpl.rcParams['ytick.major.width'] = 0.8
colors = ['1f77b4', 'd62728', '2ca02c', 'ff7f0e', '9467bd', 
          '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler('color', colors)
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.formatter.offset_threshold'] = 2
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['axes.labelpad'] = 8
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.grid'] = False
mpl.rcParams['axes.axisbelow'] = True
mpl.rcParams['legend.loc'] = 'upper right'
mpl.rcParams['legend.fontsize'] = 14
mpl.rcParams['legend.frameon'] = False
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['figure.titlesize'] = 16
mpl.rcParams['figure.figsize'] = (10, 4) 
mpl.rcParams['figure.subplot.wspace'] = 0.25 
mpl.rcParams['figure.subplot.hspace'] = 0.25 
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['pdf.fonttype'] = 42

# Personal
sys.path.append('/home1/cjmac/code/general')
sys.path.append('/home1/cjmac/code/projects/manning_replication')
sys.path.append('/home1/cjmac/code/projects')
import data_io as dio
import array_operations as aop
from helper_funcs import *
from eeg_plotting import plot_trace, plot_trace2
from time_cells import spike_sorting, spike_preproc, events_preproc, events_proc, time_bin_analysis, remapping, pop_decoding, time_cell_plots
from goldmine_replay import place_cells

font = {'tick': 12,
        'label': 14,
        'annot': 12,
        'fig': 16}

# Colors
n = 4
c = 2
colors = [sns.color_palette('Blues', n)[c], 
          sns.color_palette('Reds', n)[c], 
          sns.color_palette('Greens', n)[c],
          sns.color_palette('Purples', n)[c],
          sns.color_palette('Oranges', n)[c],
          sns.color_palette('Greys', n)[c],
          sns.color_palette('YlOrBr', n+3)[c],
          'k']
cmap = sns.palettes.blend_palette((colors[0], 
                                   'w',
                                   colors[1]), 501)

colws = od([('1', 6.55),
            ('2-1/2', 3.15),
            ('2-1/3', 2.1),
            ('2-2/3', 4.2),
            ('3', 2.083),
            ('4', 1.525),
            ('5', 1.19),
            ('6', 0.967),
            (1, 2.05),
            (2, 3.125),
            (3, 6.45),
            ('nat1w', 3.50394),
            ('nat2w', 7.20472),
            ('natl', 9.72441)])

data_dir = '/home1/cjmac/projects/time_cells'
proj_dir = '/home1/cjmac/projects/time_cells'

In [8]:
# Get sessions.
sessions = np.unique([op.basename(f).split('-')[0] 
                      for f in glob(op.join(data_dir, 'analysis', 'events', '*-Events.pkl'))])
print('{} subjects, {} sessions'.format(len(np.unique([x.split('_')[0] for x in sessions])), len(sessions)))

1 subjects, 1 sessions


In [9]:
def run_ols_delay_parallel(subj_sess_unit):
    import sys
    import os
    sys.path.append('/home1/cjmac/code/projects')
    from time_cells import time_bin_analysis
    
    proj_dir = '/home1/cjmac/projects/time_cells'
    n_perm = 10000
    
    try:
        mod_pairs, ols_weights = time_bin_analysis.run_ols_delay(subj_sess_unit,
                                                                 n_perm=1000,
                                                                 alpha=0.05,
                                                                 save_output=True,
                                                                 overwrite=False)
    except:
        err = sys.exc_info()
        errf = '/home1/cjmac/logs/TryExceptError-run_ols_delay_parallel-{}'.format(subj_sess_unit)
        os.system('touch {}'.format(errf))
        with open(errf, 'w') as f:
            f.write(str(err) + '\n')

In [10]:
# Get neurons to process.
fpath = op.join(proj_dir, 'analysis', 'unit_to_behav_10k', '{}-Delay1_Delay2-ols_model_pairs.pkl')

'8.4.1'

In [None]:
pop_spikes = pop_decoding.load_pop_spikes()
neurons = [neuron for neuron in pop_spikes.neurons if not op.exists(fpath.format(neuron))]
print('{} neurons to process'.format(len(neurons)))

In [None]:
# Parallel processing
n_ops = len(neurons)
print('Running code for {} operations.\n'.format(n_ops))
with cluster_view(scheduler="sge", queue="RAM.q", num_jobs=np.min((n_ops, 100)), cores_per_job=1) as view:
    output = view.map(run_ols_delay_parallel, neurons)

In [None]:
subj_sess_unit = 'U532_ses0-21-1' # 'U527_ses0-58-2' # 'U518_ses0-73-1'

start_time = time()

mod_pairs, ols_weights = time_bin_analysis.run_ols_delay(subj_sess_unit)

print('Done in {:.1f}s'.format(time() - start_time))

# Load mod_pairs, ols_weights

In [None]:
# Load processed OLS files.
mod_pairs_globstr = op.join(proj_dir, 'analysis', 'unit_to_behav_1000perm', '*Delay1_Delay2-ols_model_pairs.pkl')
ols_weights_globstr = op.join(proj_dir, 'analysis', 'unit_to_behav_1000perm', '*Delay1_Delay2-ols_weights.pkl')
mod_pairs = pd.concat([dio.open_pickle(f) for f in glob(mod_pairs_globstr)]).reset_index(drop=True)
ols_weights = pd.concat([dio.open_pickle(f) for f in glob(ols_weights_globstr)]).reset_index(drop=True)

# Drop rows.
drop_red = ['full-time,gameState:time']
mod_pairs = mod_pairs.query("(red!={})".format(drop_red)).reset_index(drop=True)

# Add columns.
_map = {'full-gameState': 'gameState',
        'full-time': 'time',
        'full-gameState:time': 'gameState:time'}
testvar_cat = pd.CategoricalDtype(['gameState', 'time', 'gameState:time'],
                                  ordered=True)
mod_pairs.insert(4, 'testvar', mod_pairs['red'].apply(lambda x: _map[x]))
mod_pairs['testvar'] = mod_pairs['testvar'].astype(testvar_cat)

print('mod_pairs: {}'.format(mod_pairs.shape))
print('ols_weights: {}'.format(ols_weights.shape))

In [None]:
# FDR correct across all neurons.
alpha = 0.05
sig_col = 'sig'

mod_pairs['sig01'] = ''
mod_pairs['sig_fdr'] = ''
for testvar in mod_pairs['testvar'].unique():
    pvals = mod_pairs.loc[mod_pairs['testvar']==testvar, 'emp_pval']
    sig_fdr = sm.stats.multipletests(pvals, alpha, method='fdr_tsbky')[0]
    pvals_fdr = sm.stats.multipletests(pvals, alpha, method='fdr_tsbky')[1]
    
    mod_pairs.loc[mod_pairs['testvar']==testvar, 'sig01'] = pvals < 0.01
    mod_pairs.loc[mod_pairs['testvar']==testvar, 'pvals_fdr'] = pvals_fdr
    mod_pairs.loc[mod_pairs['testvar']==testvar, 'sig_fdr'] = sig_fdr

sig_cells = od([])
for testvar in mod_pairs['testvar'].unique():
    sig_cells[testvar] = mod_pairs.query("(testvar=='{}') & ({}==True)".format(testvar, sig_col))['subj_sess_unit'].tolist()

In [None]:
testvars = ['gameState', 'time', 'gameState:time']

count_sig = od([])
count_all = od([])
n_cells = mod_pairs['subj_sess_unit'].unique().size
pvals_unc = od([])
for testvar in testvars:
    n_sig = len(sig_cells[testvar])
    binom_p = stats.binom_test(n_sig,
                               n_cells,
                               p=0.05,
                               alternative='greater')
    count_sig[testvar] = n_sig
    count_all[testvar] = n_cells
    pvals_unc[testvar] = binom_p

# Bonferroni-Holm correct
pvals_corr = sm.stats.multipletests(list(pvals_unc.values()), method='holm')[1]

for iTest, testvar in enumerate(pvals_unc):
    print('{:>19}'.format(testvar),
          '{:>3}/{:>3} {:>6.1%} '.format(count_sig[testvar], count_all[testvar], (1. * count_sig[testvar]) / count_all[testvar]),
          '{:.10f}{:>1}'.format(pvals_unc[testvar], '*' if (pvals_unc[testvar] < alpha) else ''),
          '{:.10f}{:>1}'.format(pvals_corr[iTest], '*' if (pvals_corr[iTest] < alpha) else ''))

In [None]:
# How many time cells per subject?
aop.unique([x.split('_')[0] for x in sig_cells['time']])

In [None]:
# How many delay category cells per subject?
aop.unique([x.split('_')[0] for x in sig_cells['gameState']])

In [None]:
mod_pairs.groupby('testvar', observed=True).agg({'sig': count_pct,
                                                 'sig01': count_pct,
                                                 'sig_fdr': count_pct})

In [None]:
(mod_pairs.query("(sig==True)")
          .groupby('testvar', observed=True)
          .agg({'subj_sess_unit': len,
                'z_lr': [mean_sem, median_q]}))

In [None]:
_df = (ols_weights.query("(subj_sess_unit=={}) & (factor=='gameState')".format(sig_cells['gameState']))
                  .groupby("subj_sess_unit")['z_weight']
                  .apply(lambda x: 1 * x))
print(count_pct(_df>0) + ' neurons with a main effect of delay category fire more in Delay1 than Delay2', 
      '(p = {:.6f}, binomial test)'.format(stats.binom_test(np.count_nonzero(_df>0), len(_df), 0.5)))
plt.plot(_df.sort_values().values, marker='.', linewidth=0)

In [None]:
mod_name = 'full'
factor = 'time'
weight_col = 'z_weight'

qry = "(subj_sess_unit=={}) & (model=='{}') & (factor=='{}')"
levels = ols_weights.query(qry.format(sig_cells[factor], mod_name, factor))['level'].drop_duplicates().values

_df = (ols_weights.query(qry.format(sig_cells[factor], mod_name, factor))
                  .groupby(['subj_sess_unit', 'model', 'factor'], sort=False)
                  .agg({weight_col: [lambda x: levels[np.argmax(np.abs(x))], lambda x: np.array(x)[np.argmax(np.abs(x))]]})
                  .reset_index())
_df.columns = ['subj_sess_unit', 'model', 'factor', 'level', weight_col]
_df['level'] = _df['level'].astype(pd.CategoricalDtype(levels, ordered=True)) 

print('{} – {}; P = {:.6f}, binomial test'.format(factor,
                                                  count_pct(_df[weight_col]>0),
                                                  stats.binom_test(np.count_nonzero(_df[weight_col]>0), len(_df), 0.5)),
      end='\n'*2)
(_df.groupby('level', observed=False)
    .agg({'subj_sess_unit' : lambda x: '{:>2}/{:>2} ({:.1%})'.format(len(x), len(_df), len(x)/len(_df)),
          weight_col       : [lambda x: mean_sem(np.abs(x)), lambda x: count_pct(x>0)]}))

In [None]:
# Are time cells over-represented in each third of the delay?
x = [67, 17, 15]
bins = [3, 4, 3]
n = 457
alpha = 0.05

pvals = [stats.binom_test(x[i], n, 0.05 * (bins[i]/10)) for i in range(len(x))]
pvals_corr = sm.stats.multipletests(pvals, alpha, method='holm')[1]

print(pvals, pvals_corr, pvals_corr<alpha)

In [None]:
factors = ['icpt', 'gameState', 'time', 'gameState:time']
ols_weights['factor'] = ols_weights['factor'].astype(pd.CategoricalDtype(factors, ordered=True))

levels = ['icpt', 'gameState_Delay1'] + ['time_{}'.format(iTime) for iTime in range(1, 11)] + ['gameState_Delay1:time_{}'.format(iTime) for iTime in range(1, 11)]
ols_weights['level'] = ols_weights['level'].astype(pd.CategoricalDtype(levels, ordered=True))

In [None]:
def _merge_dfs():
    level = 'level-{}'.format(factor)
    level_weight = level + '-z_weight'
    if level in mod_pairs:
        mod_pairs.drop(columns=level, inplace=True)
    if level_weight in mod_pairs:
        mod_pairs.drop(columns=level_weight, inplace=True)
        
    return (pd.merge(mod_pairs, _df.rename(columns={'model': 'full', 'factor': 'testvar'}),
                     on=['subj_sess_unit', 'full', 'testvar'], how='left')
            .rename(columns={'level': level, 'z_weight': level_weight}))

# How many cells preferentially encode each level
# of a given variable of interest?
mod_name = 'full'
factor = 'time'
merge_mod_pairs = True

qry = "(subj_sess_unit=={}) & (model=='{}') & (factor=='{}')"
levels = ols_weights.query(qry.format(sig_cells[factor], mod_name, factor))['level'].drop_duplicates().values

_df = (ols_weights.query(qry.format(sig_cells[factor], mod_name, factor))
                  .groupby(['subj_sess_unit', 'model', 'factor'], sort=False, observed=True)
                  .agg({'z_weight': [lambda x: levels[np.argmax(np.abs(x))], lambda x: np.array(x)[np.argmax(np.abs(x))]]})
                  .reset_index())
_df.columns = ['subj_sess_unit', 'model', 'factor', 'level', 'z_weight']
_df['level'] = _df['level'].astype(pd.CategoricalDtype(levels, ordered=True))
if merge_mod_pairs:
    mod_pairs = _merge_dfs()
    
print('{} – {}; P = {:.6f}, binomial test'.format(factor,
                                                  count_pct(_df['z_weight']>0),
                                                  stats.binom_test(np.count_nonzero(_df['z_weight']>0), len(_df), 0.5)),
      end='\n'*2)
(_df.groupby('level', observed=True)
    .agg({'subj_sess_unit' : lambda x: '{:>2}/{:>2} ({:.1%})'.format(len(x), len(_df), len(x)/len(_df)),
          'z_weight'       : [lambda x: mean_sem(np.abs(x)), lambda x: count_pct(x>0)]}))

In [None]:
mod_pairs.query("(full=='full') & (testvar=='time') & (sig==True)").sort_values('z_lr', ascending=0).iloc[:30]

In [None]:
subj_sess_unit = 'U532_ses0-10-1'

display(mod_pairs.query("(subj_sess_unit=='{}')".format(subj_sess_unit)))
display(ols_weights.query("(model=='full') & (subj_sess_unit=='{}')".format(subj_sess_unit))
                   .sort_values(['factor', 'level'], ascending=[True, True]))

# Time fields

In [None]:
reload(time_bin_analysis)

In [None]:
# Find time fields for all significant gameState:time cells.
save_output = 1
overwrite = 0
verbose = False
smooth = 1
n_perm = 1000
thresh = 1.96
max_skips = 1

start_time = time()

for subj_sess_unit in sig_cells['time']:
    output = time_bin_analysis.bootstrap_time_fields(subj_sess_unit,
                                                     game_states=['Delay1', 'Delay2'],
                                                     smooth=smooth,
                                                     n_perm=n_perm,
                                                     thresh=thresh,
                                                     max_skips=max_skips,
                                                     save_output=save_output,
                                                     overwrite=overwrite,
                                                     verbose=verbose)
    
print('Done in {:.1f}s'.format(time() - start_time))

In [None]:
# Load time field results.
files = glob(op.join(proj_dir, 'analysis', 'time_fields', '*Delay1_Delay2-smooth1*.pkl'))
# files = glob(op.join(proj_dir, 'analysis', 'time_fields', '*Delay1_Delay2-smooth0*.pkl'))

time_fields = pd.concat([dio.open_pickle(f)['time_fields'] for f in files]).reset_index(drop=True)

print('time_fields:', time_fields.shape)

In [None]:
time_fields.query("(gameState=='Delay1Delay2')").shape

In [None]:
print('{} time cells with a time_field'.format(count_pct(np.isin(sig_cells['time'],
          time_fields.query("(gameState=='Delay1Delay2')")['subj_sess_unit'].unique()))),
      '{} time cells with a positive time_field'.format(count_pct(np.isin(sig_cells['time'],
          time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['subj_sess_unit'].unique()))),
      '{} time cells with a negative time_field'.format(count_pct(np.isin(sig_cells['time'],
          time_fields.query("(gameState=='Delay1Delay2') & (field_type=='neg')")['subj_sess_unit'].unique()))),
      sep='\n')

In [None]:
subj_sess_unit = 'U518_ses0-16-1'

mod_pairs.query("(subj_sess_unit=='{}')".format(subj_sess_unit))

In [None]:
subj_sess_unit = 'U518_ses0-16-1'

time_fields.query("(subj_sess_unit=='{}')".format(subj_sess_unit))

In [None]:
(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")
            .groupby('subj_sess_unit')
            .agg({'peak_z': np.max,
                  'field_peak': list,
                  'field_size': list})
            .sort_values('peak_z'))

In [None]:
# How many positive time fields does each time cell have?
_df = aop.unique(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size(),
                 sort=False).reset_index()
_df.columns = ['fields', 'count']
print('{} time fields/neuron'.format(mean_sem(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size())))
print('{} time fields/neuron'.format(median_q(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size())))
print('{} neurons have only 1 time field'
      .format(np.sum(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size()==1)))
      
plt.close()
sns.barplot(x='fields', y='count', data=_df, color='k')

In [None]:
idx = np.where(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size()==1)[0]
_neurons = time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('subj_sess_unit').size().index[idx].tolist()
_df = (time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos') & (subj_sess_unit=={})".format(_neurons))
                  .groupby('field_size').size()
                  .reset_index()
                  .rename(columns={0: 'count'}))
plt.close()
sns.barplot(x='field_size', y='count', data=_df, color='k')

print(mean_sd(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos') & (subj_sess_unit=={})".format(_neurons))['field_size'] / 2))
print(median_q(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos') & (subj_sess_unit=={})".format(_neurons))['field_size'] / 2))

In [None]:
# Over what duration is each time field firing above the mean?
_df = time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('field_size').size().reset_index()
_df.columns = ['field_size', 'count']
print('{}s field size'.format(mean_sem(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_size'] / 2)))
print('{}s field size'.format(median_q(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_size'] / 2)))

plt.close()
sns.barplot(x='field_size', y='count', data=_df, color='k')

In [None]:
_df = time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')").groupby('field_peak').size().reset_index()
_df.columns = ['field_peak', 'count']
print('{} peak firing index'.format(mean_sem(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_peak'])))
print('{} peak firing index'.format(median_q(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_peak'])))

plt.close()
sns.barplot(x='field_peak', y='count', data=_df, color='#7ccaa5')

In [None]:
_df = time_fields.query("(gameState=='Delay1Delay2') & (field_type=='neg')").groupby('field_peak').size().reset_index()
_df.columns = ['field_peak', 'count']
print('{} peak firing index'.format(mean_sem(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='neg')")['field_peak'])))
print('{} peak firing index'.format(median_q(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='neg')")['field_peak'])))

plt.close()
sns.barplot(x='field_peak', y='count', data=_df, color='#e55749')

In [None]:
time_fields.head()

In [None]:
print('field_peak ~ field_size:', 
      stats.pearsonr(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_peak'],
                     time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_size']))
print('field_peak ~ peak_z:', 
      stats.pearsonr(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_peak'],
                     time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['peak_z']))
print('field_peak ~ mean_z:', 
      stats.pearsonr(time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['field_peak'],
                     time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')")['mean_z']))

formula = 'field_peak ~ field_size * mean_z'
mod = ols(formula, data=time_fields.query("(gameState=='Delay1Delay2') & (field_type=='pos')"))
fit = mod.fit()
print(fit.summary())

In [None]:
timebin = 9
_params = ols_weights.set_index('level')['weight']

delay1_spikes = event_spikes.get_spike_mat(neuron, 'Delay1', column='time_step').values.astype(float)
delay2_spikes = event_spikes.get_spike_mat(neuron, 'Delay2', column='time_step').values.astype(float)
print(((delay1_spikes + delay2_spikes) / 2).mean(),
        delay1_spikes.mean(),
        delay2_spikes.mean())
print(_params['icpt'],
      _params['icpt'] + _params[Xycols['gameState'][0]],
      _params['icpt'] - _params[Xycols['gameState'][0]])
print('Delay1, timebin {}:'.format(timebin+1), delay1_spikes[:, timebin].mean())
print('Delay2, timebin {}:'.format(timebin+1), delay2_spikes[:, timebin].mean())
print('Delay1, timebin {}:'.format(timebin+1),
      _params['icpt'] + _params['gameState_Delay1'] + _params['time_{}'.format(timebin+1)] + _params['gameState_Delay1:time_{}'.format(timebin+1)])
print('Delay1, timebin {}:'.format(timebin+1), 
      _params['icpt'] - _params['gameState_Delay1'] + _params['time_{}'.format(timebin+1)] - _params['gameState_Delay1:time_{}'.format(timebin+1)])

In [None]:
params = ols_mods['full'].fit().params
n_params = 0
for k, v in Xycols.items():
    print(k, len(v))
    if k != 'neuron':
        n_params += len(v)
print('{} parameters total'.format(n_params))

In [None]:
for mod_name, mod in ols_mods.items():
    print(mod_name, mod.fit().llf)

In [None]:
delay1_spikes = event_spikes.get_spike_mat(neuron, 'Delay1', column='time_step').values.astype(float)
delay2_spikes = event_spikes.get_spike_mat(neuron, 'Delay2', column='time_step').values.astype(float)
print(((delay1_spikes + delay2_spikes) / 2).mean(),
        delay1_spikes.mean(),
        delay2_spikes.mean())
print(params['Intercept'],
      params['Intercept'] + params[Xycols['gameState'][0]],
      params['Intercept'] - params[Xycols['gameState'][0]])
print(delay1_spikes[:, 3].mean(),
      params['Intercept'] + params[Xycols['gameState'][0]] + params[Xycols['time'][3]] + params[Xycols['gameState:time'][3]])
print(delay2_spikes[:, 3].mean(),
      params['Intercept'] - params[Xycols['gameState'][0]] + params[Xycols['time'][3]] - params[Xycols['gameState:time'][3]])

In [None]:
{k:v for k, v in Xycols.items() if k!='neuron'}

In [None]:
for mod_name, mod in ols_mods.items():
    print(mod_name, mod.fit().llf)

In [None]:
Xy['gameState_Delay1'].value_counts()

In [None]:
ols_mods['full'].fit().params[Xycols['gameState_x_time']]

In [None]:
print(ols_mods['full'].fit().summary())

In [None]:
5.7303+0.8273+0.3606

In [None]:
for mod, fit in mod_fits.items():
    print(mod, fit.fit().llf)

In [None]:
def get_ols_time_formulas(neuron, full_mod):
    """Define model formulas for single-unit to behavior comparisons.
    
    Parameters
    ----------
    neuron : str
        e.g. '5-2' would be channel 5, unit 2
    """
    # Get the expanded predictor matrix of deviation-coded parameters,
    # and add on the dependent column.
    Xy = pd.concat((pd.Series(full_mod.endog, name=neuron),
                    pd.DataFrame(full_mod.exog, columns=full_mod.exog_names)),
                   axis=1)
    Xy.drop(columns=['Intercept'], inplace=True)
    Xycols_old = od([('neuron', [neuron]),
                     ('gameState', [col for col in Xy.columns if np.all([('gameState' in col),
                                                                         (':' not in col)])]),
                     ('time', [col for col in Xy.columns if np.all([('time_step' in col),
                                                                    (':' not in col)])]),
                     ('gameState_x_time', [col for col in Xy.columns if np.all([('time' in col),
                                                                                ('gameState' in col),
                                                                                (':' in col)])])])
    Xycols_new = od([('neuron', [neuron]),
                     ('gameState', str_replace(Xcols['gameState'], {'C(gameState, Sum)[S.': 'gameState_',
                                                                    ']': ''})),
                     ('time', str_replace(Xcols['time'], {'C(time_step, Sum)[S.': 'time_',
                                                          ']': ''})),
                     ('gameState_x_time', str_replace(Xcols['gameState_x_time'], {'C(gameState, Sum)[S.': 'gameState_',
                                                                                  'C(time_step, Sum)[S.': 'time_',
                                                                                  ']': ''}))])
    for col_type in Xycols_old:
        Xy.rename(columns=pd.Series(index=Xycols_old[col_type], data=Xycols_new[col_type]).to_dict(), inplace=True)
    Xycols = Xycols_new
    
    # Define formulas.
    formulas = od([])
    formulas['full']             = "Q('{}') ~ 1 + {} + {} + {}".format(neuron, ' + '.join(Xycols['gameState']), ' + '.join(Xycols['time']), ' + '.join(Xycols['gameState_x_time']))
    formulas['subgs']       = "Q('{}') ~ 1      + {} + {}".format(neuron,                                  ' + '.join(Xycols['time']), ' + '.join(Xycols['gameState_x_time']))
    formulas['subtime']     = "Q('{}') ~ 1 + {}      + {}".format(neuron, ' + '.join(Xycols['gameState']),                             ' + '.join(Xycols['gameState_x_time']))
    formulas['subgsxtime']  = "Q('{}') ~ 1 + {} + {}     ".format(neuron, ' + '.join(Xycols['gameState']), ' + '.join(Xycols['time'])                                        )
    
    return Xy, Xycols, formulas

In [None]:
Xy, Xycols, formulas = get_ols_time_formulas(neuron, mod_fits['full'])

In [None]:
# Fit the model.
mod_fits = od([])
for mod, formula in formulas.items():
    mod_fits[mod] = ols(formula, data=Xy)#.fit()

In [None]:
for mod, fit in mod_fits.items():
    print(mod, fit.fit().llf)

In [None]:
mod_fits['full'].fit().summary()

In [None]:
mod_fits['full'].fit().summary()

In [None]:
5.7303 + 0.8273, 5.7303 - 0.8273

In [None]:
delay1_spikes = event_spikes.get_spike_mat(neuron, 'Delay1', column='time_step').values.astype(float)
delay2_spikes = event_spikes.get_spike_mat(neuron, 'Delay2', column='time_step').values.astype(float)
delay1_spikes.mean(), delay2_spikes.mean(), ((delay1_spikes + delay2_spikes) / 2).mean()

In [None]:
5.7303 + 0.8273 + 0.3606 + -0.4636, delay1_spikes[:, 0].mean()

In [None]:
np.mean((delay1_spikes + delay2_spikes) / 2)

In [None]:
delay1_spikes[1].mean() / 2

In [None]:
#event_spikes.get_spike_mat(neuron, 'Delay1', column='time_step').values.astype(float)

In [None]:
mod_fits['full'].summary()

In [None]:
mod_fits['red'].summary()

In [None]:
1

In [None]:
neuron = '8-2'
dpi = 300
font = {'tick': 6, 'label': 7, 'fig': 9}
base_color = 'w'
game_states = ['Encoding']
spikes_when_moving = False

for game_state in game_states:    
    fig, ax = time_cell_plots.plot_firing_maze(subj_sess, 
                                               neuron, 
                                               game_state, 
                                               font=font, 
                                               base_color=base_color,
                                               only_show_spikes_when_moving=spikes_when_moving,
                                               nav_lw=0.12,
                                               nav_color='#296eb4',
                                               nav_alpha=0.5,
                                               spike_marker='.',
                                               spike_fill_color='#e10600',
                                               spike_edge_color='#e10600',
                                               spike_alpha=0.75,
                                               spike_markersize=1.5,
                                               spike_mew=0,
                                               dpi=dpi)
    
fig.show()

In [None]:
def run_place_cells_parallel(subj_sess_neuron):
    import sys
    import os
    import os.path as op
    import pandas as pd
    sys.path.append('/home1/cjmac/code/general')
    import data_io as dio
    sys.path.append('/home1/cjmac/code/projects')
    from time_cells import time_bin_analysis
    sys.path.append('/home1/cjmac/code/goldmine_replay')
    import place_cells
    
    save_output = True
    overwrite = False
    game_state = 'Encoding'
    nperm = 1000
    alpha = 0.05
    zthresh = 2
    data_dir = '/data7/goldmine'
    proj_dir = '/home1/cjmac/projects/goldmine_replay'
    
    try:
        subj_sess, *neuron = subj_sess_neuron.split('-')
        neuron = '-'.join(neuron)

        # Load event_spikes.
        event_spikes = time_bin_analysis.load_event_spikes(subj_sess, proj_dir=data_dir)

        obs_weights = place_cells.get_ols_params_place(neuron, event_spikes, game_state=game_state)
        null_weights = pd.concat([place_cells.get_ols_params_place(neuron, event_spikes, game_state=game_state, circshift_frs=True)
                                  for iPerm in range(nperm)])
        place_fits = place_cells.get_ols_sig_place(obs_weights, null_weights, alpha=alpha, zthresh=zthresh)

        if save_output:
            filename = op.join(proj_dir, 'ols_place_cells', '{}-{}.pkl'.format(place_fits.iloc[0]['subj_sess_unit'], place_fits.iloc[0]['gameState']))
            if overwrite or not op.exists(filename):
                dio.save_pickle(place_fits, filename, verbose=False)    
    except:
        err = sys.exc_info()
        errf = '/home1/cjmac/logs/TryExceptError-run_place_cells_parallel-{}'.format(subj_sess_neuron)
        os.system('touch {}'.format(errf))
        with open(errf, 'w') as f:
            f.write(str(err) + '\n')

In [None]:
# Get neurons to process.
fpath = op.join(proj_dir, 'ols_place_cells', '{}-Encoding.pkl')
pop_spikes = pop_decoding.load_pop_spikes()
neurons = [neuron for neuron in pop_spikes.neurons if not op.exists(fpath.format(neuron))]
print('{} neurons to process'.format(len(neurons)))

# Parallel processing
n_ops = len(neurons)
print('Running code for {} operations.\n'.format(n_ops))
with cluster_view(scheduler="sge", queue="RAM.q", num_jobs=np.min((n_ops, 200)), cores_per_job=1) as view:
    output = view.map(run_place_cells_parallel, neurons)

In [None]:
# Load place cells.
files = glob(op.join(proj_dir, 'ols_place_cells', '*-Encoding.pkl'))
place_fits = pd.concat([dio.open_pickle(f) for f in files]).reset_index(drop=True)

# Add columns.
place_fits.insert(0, 'subj', place_fits['subj_sess_unit'].apply(lambda x: x.split('-')[0].split('_')[0]))
place_fits.insert(1, 'subj_sess', place_fits['subj_sess_unit'].apply(lambda x: x.split('-')[0]))
place_fits.insert(2, 'neuron', place_fits['subj_sess_unit'].apply(lambda x: '-'.join(x.split('-')[1:])))
place_fits.insert(3, 'hemroi', place_fits.apply(lambda x: spike_preproc.roi_lookup(x['subj_sess'], x['neuron'].split('-')[0]), axis=1))
roi_map = spike_preproc.roi_mapping(5)
place_fits.insert(4, 'roi_gen', place_fits['hemroi'].apply(lambda x: roi_map[x[1:]]))
place_fits.insert(place_fits.columns.tolist().index('sig')+1, 'sig_pos', (place_fits['sig'] == True) & (place_fits['n_place_fields'] > 0))

# Save indices to all place cells and place cells with a place field outside the base, respectively.
place_all_idx = place_fits.query("(sig_pos==True)").index.values
place_mine_idx = place_all_idx[np.where(place_fits.query("(sig_pos==True)")['place_fields']
                                                  .apply(lambda x: len([place for place in x if place!='Base'])>0))[0]]

print('place_fits:', place_fits.shape)

In [None]:
n_cells = place_fits['subj_sess_unit'].size
n_place_cells = len(place_fits.query("(sig==True)"))
n_pos_place_cells = len(place_fits.query("(sig==True) & (n_place_fields>0)"))

print('{}/{} ({:.1%}) cells are significant for place'.format(n_place_cells, n_cells, n_place_cells/n_cells))
print('{}/{} ({:.1%}) significant cells have 1+ place fields'.format(n_pos_place_cells, n_place_cells, n_pos_place_cells/n_place_cells))
display(place_fits.query("(sig==True)").groupby('n_place_fields').agg({'subj_sess_unit': len}).rename(columns={'subj_sess_unit': 'n'}))

In [None]:
def sum_pct(x):
    _sum = np.sum(x)
    _n = len(x)
    return '{}/{} ({:.1%})'.format(_sum, _n, _sum/_n)
    
display(place_fits.groupby('roi_gen').agg({'sig_pos': sum_pct}))
display(place_fits.groupby('subj').agg({'sig_pos': sum_pct}))

In [None]:
place_fields = aop.unique(np.concatenate(place_fits.query("(sig_pos==True)")['place_fields'].tolist()))

print('All place cells:')
print('{} place fields across {} place cells'.format(place_fields.sum(), len(place_fits.query("(sig_pos==True)"))), end='\n'*2)

print('Subset of place cells with a place cell outside the mine:')
print('{} place fields across {} place cells'.format(place_fields.sum(), len(place_fits.loc[place_mine_idx])))

display(pd.concat([place_fields, place_fields / place_fields.sum()], axis=1).rename(columns={0: 'n', 1: 'prop'}))

In [None]:
full = 'full'
red = 'subplace'

output.query("(mod=='{}')").pivot(index=['subj_sess_unit', 'gameState', 'mod', 'llf'], columns=['level'], values=['weight'])

In [None]:
print(mod_fits['subplace'].summary())

In [None]:
print(mod_fits['subplace'].summary())

In [None]:
print(mod_fits['full'].summary())