In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2


In [None]:
from nips2018.movie import data, parameters, models
from nips2018.movie.analysis import performance
import pandas as pd
import seaborn as sns
import datajoint as dj
import numpy as np
import matplotlib.pyplot as plt
from nips2018.movie.analysis import tuning
from contextlib import contextmanager
import pycircstat as circ
from config import movie_vs_noise_cmap, fix_axis, scan_order, scan_cmap, performance_yticks, performance_ylim, strike
from skimage.transform import resize
from nips2018.movie import oracle

@contextmanager
def silence():
    old_stdout = sys.stdout
    sys.stdout = open('/dev/null', 'w')
    try:
        yield
    finally:
        sys.stdout = old_stdout
        
groups = [21, 22, 23]
group_constr = 'group_id in ({})'.format(','.join([str(e) for e in groups]))


# Difference in preferred orientation

In [None]:
group_ids = [21,22,23]
group_constr = 'group_id in ({})'.format(','.join(tuple(map(str, group_ids))))
base = {'core_hash': '22d11147b37e3947e7d1034cc00d402c', # 12 x 36
     'seed': 2606,
     'train_hash': '624f62a2ef01d39f6703f3491bb9242b', # batchsize=8 stop gradient
     'ro_hash':'bf00321c11e46d68d4a42653a725969d', # 2 and 4 
    }
network_configs0 = dj.AndList([
    base,
    'mod_hash in ("4954311aa3bebb347ebf411ab5198890")',
    'shift_hash in ("64add03e1462b7413b59812d446aee9f")',
    'data_hash in ("5253599d3dceed531841271d6eeba9c5", "6c0290da908317e55c4baf92e379d651")',
    group_constr
])
network_configs1 = dj.AndList([
    base,
    'mod_hash in ("bafd7322c6e97d25b6299b5d6fe8920b")',
    'shift_hash in ("bafd7322c6e97d25b6299b5d6fe8920b")',
    'data_hash in ("5253599d3dceed531841271d6eeba9c5", "6c0290da908317e55c4baf92e379d651")',
    group_constr
])
network_configs = [network_configs0, network_configs1]



In [None]:
tuning.MonetResponse().populate(network_configs)

In [None]:
tuning.MonetOri().populate(network_configs)

In [None]:
tuning.MonetOri() & network_configs & dict(ori_type='ori')

## Orientation Tuning

In [None]:
rel_cell = tuning.Ori.Cell() & dict(stimulus_type="stimulus.Monet2", ori_type='ori', spike_method=5, segmentation_method=3, ori_version=3) & 'selectivity > .2 and r2 > 0.005' & (data.MovieMultiDataset.Member() & group_constr)
df_cell = pd.DataFrame(rel_cell.fetch())
rel_model = tuning.MonetOri.Cell().proj(model_angle='angle', train_data='IF(data_hash = "5253599d3dceed531841271d6eeba9c5", "movies", "noise")') & network_configs
df_model = pd.DataFrame(rel_model.fetch())
df_model['model'] = ['full' if not (r.shift_hash == 'bafd7322c6e97d25b6299b5d6fe8920b' and (r.mod_hash == 'bafd7322c6e97d25b6299b5d6fe8920b')) else strike('shifter/modulator') for _, r in df_model.iterrows()]
df = df_cell.merge(df_model, on=['animal_id', 'session', 'scan_idx', 'unit_id', 'ori_type'])
df = df.drop([e for e in df.columns if 'hash' in e], axis=1)
df[:10]

In [None]:
df['scan'] = ['{animal_id}-{session}-{scan_idx}'.format(**r.to_dict()) for _, r in df.iterrows()]
df[r'$\Delta \phi$'] = circ.cdiff(2 * df.angle, 2 * df.model_angle) / 2


In [None]:
sns.set_context('paper', font_scale=1.2)

sns.set_palette(scan_cmap)
g = sns.factorplot("model",  r'$\Delta \phi$', hue='scan', data=df, kind='lv', 
                   col='train_data', legend=False, order=[strike('shifter/modulator'), "full"])
g.fig.set_dpi(150)
g.axes[0,0].set_yticks([-np.pi/2,-np.pi/4,0,np.pi/4,np.pi/2])
g.axes[0,0].set_yticklabels([r'$-\frac{\pi}{2}$', r'$-\frac{\pi}{4}$', '0', r'$\frac{\pi}{4}$', r'$\frac{\pi}{2}$'])
# g.set_xlabels("")
g.axes[0,0].set_ylabel(r'$\Delta$ preferred orientation')
g.axes[0,0].yaxis.grid(linestyle=':', zorder=-20, lw=1)
g.axes[0,1].yaxis.grid(linestyle=':', zorder=-20, lw=1)
g.axes[0,1].legend(ncol=1)
g.set_titles("{col_name}")
g.fig.set_size_inches((5,5))
sns.despine(trim=True)
g.fig.subplots_adjust(bottom=.2, left=.125)
g.fig.savefig('figures/delta_ori_lv.png', dpi=200)

#  Shifter vs. Modulator

In [None]:
from scipy.stats import levene

def test(d):
    return pd.Series({'p-value':levene(d['full'], d[strike('shifter/modulator')])[1]})

df2 = df.groupby(['train_data', 'scan', 'model'])[r'$\Delta \phi$'].std().unstack('model')
df2['ratio'] = df2.loc[:, strike('shifter/modulator')]/df2.loc[:, 'full']

df3 = df.set_index(['train_data', 'scan','unit_id',  'model'])[r'$\Delta \phi$'].unstack('model').reset_index()
df3 = df3.groupby(['train_data', 'scan']).apply(test)
df2['p-value'] = df3['p-value']
df2

## Direction tuning

In [None]:
tuning.MonetOri() & network_configs & dict(ori_type='dir', ori_version=3)

In [None]:
# rel_cell = tune.Ori.Cell() & dict(stimulus_type="stimulus.Monet2", ori_type='dir', spike_method=5, segmentation_method=3, ori_version=3) & '1/(2/selectivity-1) > .1 and r2 > 0.002' & (data.MovieMultiDataset.Member() & group_constr)
rel_cell = tuning.Ori.Cell() & dict(stimulus_type="stimulus.Monet2", ori_type='dir', spike_method=5, segmentation_method=3, ori_version=3) & 'selectivity > .1 and r2 > 0.002' & (data.MovieMultiDataset.Member() & group_constr)
df_cell = pd.DataFrame(rel_cell.fetch())
rel_model = tuning.MonetOri.Cell().proj(model_angle='angle', train_data='IF(data_hash = "5253599d3dceed531841271d6eeba9c5", "movies", "noise")') & network_configs
df_model = pd.DataFrame(rel_model.fetch())
df_model['model'] = ['full' if not (r.shift_hash == 'bafd7322c6e97d25b6299b5d6fe8920b' and (r.mod_hash == 'bafd7322c6e97d25b6299b5d6fe8920b')) else strike('shifter/modulator') for _, r in df_model.iterrows()]
df = df_cell.merge(df_model, on=['animal_id', 'session', 'scan_idx', 'unit_id', 'ori_type'])
df = df.drop([e for e in df.columns if 'hash' in e], axis=1)
df[:10]

In [None]:
df[r'$\Delta \phi$'] = circ.cdiff(df.angle, df.model_angle) 
df['scan'] = ['{animal_id}-{session}-{scan_idx}'.format(**r.to_dict()) for _, r in df.iterrows()]
len(df)

In [None]:
sns.set_context('paper', font_scale=1.2)
sns.set_palette(scan_cmap)
g = sns.factorplot("model",  r'$\Delta \phi$', hue='scan', data=df,  kind='lv', col='train_data', legend=False, order=[strike('shifter/modulator'), "full"])
g.fig.set_dpi(150)
g.axes[0,0].set_yticks(np.linspace(-1, 1, 5) * np.pi)
g.axes[0,0].set_yticklabels([r'$-\pi$', r'$-\frac{3\pi}{4}$', r'$-\frac{\pi}{2}$',  r'$-\frac{\pi}{4}$', '0',  r'$\frac{\pi}{4}$', r'$\frac{\pi}{2}$', r'$\frac{3\pi}{4}$', r'$\pi$'][::2])
g.axes[0,0].set_ylabel(r'$\Delta$ preferred direction')
g.axes[0,0].yaxis.grid(linestyle=':', zorder=-20, lw=1)
g.axes[0,1].yaxis.grid(linestyle=':', zorder=-20, lw=1)
g.axes[0,1].legend(ncol=1, loc='upper right')
g.set_titles("{col_name}")
g.fig.set_size_inches((5,5))
g.fig.subplots_adjust(bottom=.2, left=.125)

sns.despine(trim=True)
g.fig.savefig('figures/delta_dir_lv.png', dpi=200)

In [None]:
from scipy.stats import levene

def test(d):
    return pd.Series({'p-value':levene(d['full'], d[strike('shifter/modulator')])[1]})

df2 = df.groupby(['train_data', 'scan', 'model'])[r'$\Delta \phi$'].std().unstack('model')
df2['ratio'] = df2.loc[:, strike('shifter/modulator')]/df2.loc[:, 'full']

df3 = df.set_index(['train_data', 'scan','unit_id',  'model'])[r'$\Delta \phi$'].unstack('model').reset_index()
df3 = df3.groupby(['train_data', 'scan']).apply(test)
df2['p-value'] = df3['p-value']
df2

# Plot tuning curves

In [None]:
group_ids = [21, 22, 23]
group_constr = 'group_id in ({})'.format(','.join(tuple(map(str, group_ids))))

network_config = dj.AndList([
    {    'core_hash': '22d11147b37e3947e7d1034cc00d402c', # 12 x 36
         'mod_hash': '4954311aa3bebb347ebf411ab5198890',
         'seed': 2606,
         'shift_hash': '64add03e1462b7413b59812d446aee9f',
         'train_hash': '624f62a2ef01d39f6703f3491bb9242b', # batchsize=8 stop gradient
         'ro_hash':'bf00321c11e46d68d4a42653a725969d', # 2 and 4 
        },
    'data_hash in ("5253599d3dceed531841271d6eeba9c5", "6c0290da908317e55c4baf92e379d651")',
    group_constr
])
    


In [None]:
tuning.MonetCurve() & network_config

In [None]:
constr_cell = tuning.Ori.Cell() & dict(stimulus_type="stimulus.Monet2", ori_type='dir', spike_method=5, segmentation_method=3) \
                              & '1/(2/selectivity - 1) > .2 and r2 > 0.005' & (data.MovieMultiDataset.Member() & group_constr)
rel_curve = tuning.DirCurve() * tuning.DirCurve.Cell() * constr_cell
df_curve = pd.DataFrame(rel_curve.fetch(order_by='r2 DESC'))


In [None]:
rel_model = (tuning.MonetCurve() * tuning.MonetCurve.Cell() * models.Encoder.UnitTestScores()).proj(
            'pearson', 
            model_directions='directions', 
            model_curve='curve', 
            train_data='IF(data_hash = "5253599d3dceed531841271d6eeba9c5", "movies", "noise")')  & network_config
df_model = pd.DataFrame(rel_model.fetch())

df = df_curve.merge(df_model, on=['animal_id', 'session', 'scan_idx', 'unit_id'])


In [None]:
def get_movie_rank(df):
    df1 = df[df.train_data == "movies"].sort_values("unit_id")
    df2 = df[df.train_data == "noise"].sort_values("unit_id")
    rank = np.array(np.argsort(-df1.pearson))
    df1['score_rank'] = rank
    df2['score_rank'] = rank
    tmp =  pd.concat([df1, df2]).drop(['animal_id', 'session', 'scan_idx'], axis=1)
    return tmp

df2 = df.groupby(['animal_id', 'session', 'scan_idx']).apply(get_movie_rank).reset_index()
df2['scan'] = ['{animal_id}-{session}-{scan_idx}'.format(**r.to_dict()) for _, r in df2.iterrows()]


In [None]:
df2[df2.score_rank < N]

In [None]:
N = 12
sns.set_palette(scan_cmap)
sns.set_context('paper', font_scale=2)
def plot_ori(d, dm, c, cm, zscore=True, **kwargs):
    label = kwargs.pop('label')
    kwargs.pop('color')

    c = c.iloc[0]
    cm = cm.iloc[0]
    if zscore:
        c = (c - c.mean())/c.std()
        cm = (cm - cm.mean())/cm.std()
    if label == 'movies':
        plt.plot(d.iloc[0], c, '--', color='darkslategray', label='neuron', zorder=-10, **kwargs)
    plt.plot(dm.iloc[0], cm, label=label, color=sns.xkcd_rgb['cerulean blue'] if label == 'noise' else sns.xkcd_rgb['deep pink'],  **kwargs)
    ax = plt.gca()
    ax.set_xticks([0,np.pi/2,np.pi, 3* np.pi/2,2*np.pi])
    ax.set_xticklabels(['0', r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])

g = sns.FacetGrid(df2[df2.score_rank < N], row='scan', col='score_rank', hue='train_data', margin_titles=False, col_order=np.arange(12))
g.map(plot_ori, "directions",  "model_directions", "curve", "model_curve", lw=3)
g.set_axis_labels("directions", "z-scored mean\nresponse")
g.set_titles(template="", col_template="", row_template="")
g._margin_titles = True
g.set_titles(template="", col_template="", row_template="{row_var} {row_name}")
leg = g.axes[0,0].legend(loc="upper left", ncol=1)
leg.get_frame().set_linewidth(0.0)
sns.despine(trim=True)
# g.fig.tight_layout()
g.fig.subplots_adjust(wspace=.05, hspace=.05)
g.fig.savefig('figures/direction_tuning.png', dpi=200)

# Receptive fields

In [None]:
group_ids = [21, 22, 23]
group_constr = 'group_id in ({})'.format(','.join(tuple(map(str, group_ids))))

network_config = dj.AndList([
    {    'core_hash': '22d11147b37e3947e7d1034cc00d402c', # 12 x 36
         'mod_hash': '4954311aa3bebb347ebf411ab5198890',
         'seed': 2606,
         'shift_hash': '64add03e1462b7413b59812d446aee9f',
         'train_hash': '624f62a2ef01d39f6703f3491bb9242b', # batchsize=8 stop gradient
         'ro_hash':'bf00321c11e46d68d4a42653a725969d', # 2 and 4 
        },
    'data_hash in ("5253599d3dceed531841271d6eeba9c5", "6c0290da908317e55c4baf92e379d651")',
    group_constr
])
    


In [None]:
tuning.STA() & network_config

In [None]:
constr_cell = tuning.NeuroSTAQual()  & dict(stimulus_type="stimulus.Monet2", spike_method=5, segmentation_method=3) \
                              & 'snr > 6' & (data.MovieMultiDataset.Member() & group_constr)
rel_sta_neuron = tuning.NeuroSTA.Map() * constr_cell
df_sta = pd.DataFrame(rel_sta_neuron.fetch(order_by='snr DESC'))


In [None]:
rel_model = (tuning.STA() * tuning.STA.Map() * models.Encoder.UnitTestScores()).proj(
            'pearson', 
            model_map='map', 
            train_data='IF(data_hash = "5253599d3dceed531841271d6eeba9c5", "movies", "noise")')  & network_config
df_model = pd.DataFrame(rel_model.fetch())



In [None]:
tmp1 = pd.DataFrame(df_model[df_model.train_data == "movies"])
tmp2 = pd.DataFrame(df_model[df_model.train_data == "noise"])
df3 = tmp1.merge(tmp2, on=['animal_id', 'session', 'scan_idx','unit_id'], suffixes=('_movies', '_noise'))


In [None]:
df = df_sta.merge(df3, on=['animal_id', 'session', 'scan_idx', 'unit_id'])

def get_rank(df):
    rank = np.array(np.argsort(-df.pearson_movies))
    df['score_rank'] = rank
    return df

df2 = df.groupby(['animal_id', 'session', 'scan_idx']).apply(get_rank).reset_index()
df2['scan'] = ['{animal_id}-{session}-{scan_idx}'.format(**r.to_dict()) for _, r in df2.iterrows()]


In [None]:
df2[:10]

In [None]:
from functools import partial

def preprocess(rf, lag):
    assert len(rf) == 1
    rf = rf.iloc[0]
    rf = rf[..., :lag].mean(axis=-1)
    return rf
    

N = 12
lag = 3
sns.set_context('paper', font_scale=1.0)
def plot_rf(rf_natural, rf_neuron, rf_noise, lag=1, v=65, g=7, **kwargs):
    kwargs.pop('color')
    rf_natural, rf_neuron, rf_noise = map(partial(preprocess, lag=lag), [rf_natural, rf_neuron, rf_noise])

    if not rf_neuron.shape[0] / rf_natural.shape[0] == rf_neuron.shape[1] / rf_natural.shape[1]:
        if rf_neuron.shape == (126, 216):
            rf_neuron = rf_neuron[4:4 + 117, 4:4 + 208]
   
    shape = rf_neuron.shape
    ax = plt.gca()
    tmp1 = resize(rf_natural, shape, preserve_range=True) 
    tmp2 = resize(rf_noise, shape, preserve_range=True) 
    tmp = np.vstack((tmp1, rf_neuron, tmp2))
    ax.matshow(tmp, vmin=-v, vmax=v, **kwargs)
    y, x = shape
    yt = np.linspace(0,3*y,3*g + 1)
    ax.set_xticks(np.linspace(0,x,g+1))
    ax.set_yticks(yt)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.plot([0,x], [y,y], '--', color='darkslategray',lw=.5)
    ax.plot([0,x], [2 * y,2 * y], '--',color='darkslategray',lw=.5)
    ax.axis('tight')
    ax.axis([0,x+1, -1,3*y])
    ax.grid(lw=.5)
    

with sns.axes_style('whitegrid'):
    g = sns.FacetGrid(df2[df2.score_rank < N], row='scan', col='score_rank', margin_titles=False)
    g.map(plot_rf, "model_map_movies", "map", "model_map_noise", lag=lag, cmap='bwr', v=90)
g.set_axis_labels("", "")
g.set_titles(template="", col_template="", row_template="")
# g._margin_titles = True
# g.set_titles(template="", col_template="", row_template="{row_var} {row_name}")

for ax, rowlab in zip(g.axes[:, -1], g.row_names):
    ax.text(1.1, 0.5, 'scan ' + rowlab, horizontalalignment='center',verticalalignment='center', transform=ax.transAxes, rotation=-90)

g.set_ylabels('movies | neuron | noise')
# leg = g.axes[0,0].legend(loc="upper left")
# leg.get_frame().set_linewidth(0.0)
sns.despine(left=True, bottom=True)

g.fig.set_size_inches((14,6))
g.fig.subplots_adjust(left=.05, hspace=.1, wspace=.1, right=0.95)
g.fig.savefig('figures/receptive_fields.png', dpi=200)