In [None]:
%load_ext autoreload
%autoreload 2

In [7]:
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from fitting.utils import get_expt_dir
from fitting.utils import get_best_model_version
from fitting.utils import get_lab_example

results_dir = 'behavenet_figs/'
save_outputs = True

# 1. Plot train/val losses as a function of epochs

In [8]:
hparams = {
    'data_dir': '/Volumes/paninski-locker/data/',  # might be a different path than what is saved in hparams
    'tt_save_path': '/Volumes/paninski-locker/analysis/behavenet/',
    'experiment_name': 'test_pt',
    'model_class': 'ae'}

# 2. Plot test losses as a function of latents (TODO)

### load results

In [9]:
## Get normalizing values
from fitting.utils import get_best_model_and_data
from behavenet.models import AE

normalizing_info={}

labs = ['musall','steinmetz','steinmetz-face','datta']
labs_capped = ['Musall','Steinmetz','Steinmetz-face','Datta']
for lab in labs:
    
    normalizing_info[lab]={}
    
    model_type='conv'
        
    hparams = {
        'data_dir': '/Volumes/paninski-locker/data/',  # might be a different path than what is saved in hparams
        'tt_save_path': '/Volumes/paninski-locker/analysis/behavenet/',
        'experiment_name': 'test_pt',
        'model_class': 'ae'}

    hparams['model_type']=model_type
    get_lab_example(hparams, lab)

    model_cae, data_generator = get_best_model_and_data(hparams, AE, version='best')

    # Get means
    sum_pix = 0
    dims=0
    for trial_idx in data_generator.batch_indxs[0]['test']:
        images = data_generator.datasets[0][trial_idx]['images'].cpu().detach().numpy()
        sum_pix += np.sum(images)
        dims += np.prod(images.shape)
    normalizing_info[lab]['mean'] = sum_pix/dims
    
    # Get variances
    sum_var = 0
    dims=0
    for trial_idx in data_generator.batch_indxs[0]['test']:
        images = data_generator.datasets[0][trial_idx]['images'].cpu().detach().numpy()
        sum_var += np.sum((images-normalizing_info[lab]['mean'])**2)
        dims += np.prod(images.shape)
    normalizing_info[lab]['var'] = sum_var/dims

Loading model defined in /Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/conv/16_latents/test_tube_data/test_pt/version_0/meta_tags.pkl
Loading model defined in /Volumes/paninski-locker/analysis/behavenet/steinmetz/2-probe/mouse-01/session-01/ae/conv/12_latents/test_tube_data/test_pt/version_0/meta_tags.pkl
Loading model defined in /Volumes/paninski-locker/analysis/behavenet/steinmetz/2-probe-face/mouse-01/session-01/ae/conv/12_latents/test_tube_data/test_pt/version_0/meta_tags.pkl
Loading model defined in /Volumes/paninski-locker/analysis/behavenet/datta/inscopix/15566/2018-11-27/ae/conv/08_latents/test_tube_data/test_pt/version_0/meta_tags.pkl


In [11]:
def get_test_data(lab,model_types, n_latents, experiment_name):

    hparams['experiment_name'] = experiment_name

    metrics_df = []

    get_lab_example(hparams, lab)
    for model_type in model_types:
        hparams['model_type'] = model_type
        for n_ae_latents in n_latents:
            try:
                hparams['n_ae_latents'] = n_ae_latents
                expt_dir = get_expt_dir(hparams)
                model_version = get_best_model_version(expt_dir)
                print(expt_dir)
                print(model_version)
                metric_file = os.path.join(expt_dir, model_version[0], 'metrics.csv')
                metrics = pd.read_csv(metric_file)
                for i, row in metrics.iterrows():
                    metrics_df.append(pd.DataFrame({
                        'epoch': row['epoch'],
                        'loss': row['val_loss'],
                        'n_latents': n_ae_latents,
                        'dtype': 'val',
                        'model_type': model_type,
                        'data': lab
                    }, index=[0]))
                    metrics_df.append(pd.DataFrame({
                        'epoch': row['epoch'],
                        'loss': row['tr_loss'],
                        'n_latents': n_ae_latents,
                        'dtype': 'train',
                        'model_type': model_type,
                        'data': lab
                    }, index=[0]))
                    if 'test_loss' in row:
                        metrics_df.append(pd.DataFrame({
                            'epoch': row['epoch'],
                            'loss': row['test_loss'], #/normalizing_info[lab]['var'],
                            'n_latents': n_ae_latents,
                            'dtype': 'test',
                            'model_type': model_type,
                            'data': lab
                        }, index=[0]))
            except:
                pass
    metrics_df = pd.concat(metrics_df)  
    
    data_queried = metrics_df[metrics_df.dtype=='test']
    
    return data_queried

In [12]:
experiment_name = 'latent_search'
model_types = ['conv', 'linear']
n_latents = [4, 8, 16, 32]

labs = ['musall','steinmetz','steinmetz-face','datta']
dataset_names = ['WFCI','NP','NP-zoom','Kinect']
data_queried={}
for lab in labs:
    data_queried[lab] = get_test_data(lab, model_types, n_latents, experiment_name)
    

/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/conv/04_latents/test_tube_data/latent_search
['version_0']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/conv/08_latents/test_tube_data/latent_search
['version_0']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/conv/16_latents/test_tube_data/latent_search
['version_1']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/conv/32_latents/test_tube_data/latent_search
['version_0']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/linear/04_latents/test_tube_data/latent_search
['version_0']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/linear/08_latents/test_tube_data/latent_search
['version_0']
/Volumes/paninski-locker/analysis/behavenet/musall/vistrained/mSM30/10-Oct-2017/ae/linear/16_latents/test_tube_data/latent_search
['version_0']


In [22]:
ee = data_queried['steinmetz'][data_queried['steinmetz'].n_latents==4]
ee = ee[ee.model_type=='linear']

In [26]:
np.nanstd(ee['loss'])

0.0006047642699404232

### plot conv vs ae test losses as a function of latents

In [None]:
## Get example figure for every data type
from fitting.utils import get_best_model_and_data
from behavenet.models import AE

example_images={}

labs = ['musall','steinmetz','steinmetz-face','datta']
labs_capped = ['Musall','Steinmetz','Steinmetz-face','Datta']
for lab in labs:
    
    example_images[lab]={}
    
    for model_type in ['conv','linear']:
        
        hparams = {
            'data_dir': '/Volumes/paninski-locker/data/',  # might be a different path than what is saved in hparams
            'tt_save_path': '/Volumes/paninski-locker/analysis/behavenet/',
            'experiment_name': 'test_pt',
            'model_class': 'ae'}
        
        hparams['model_type']=model_type
        get_lab_example(hparams, lab)
        
        model_cae, data_generator = get_best_model_and_data(hparams, AE, version='best')
        trial_idx = data_generator.batch_indxs[0]['test'][2]
        
        recon_output, _ = model_cae(data_generator.datasets[0][trial_idx]['images'])
        example_images[lab][model_type]=recon_output.cpu().detach().numpy()
        example_images[lab]['orig'] = data_generator.datasets[0][trial_idx]['images'].cpu().detach().numpy()
        
        if hparams['lab']=='musall':
            example_images[lab]['orig'] = np.transpose(example_images[lab]['orig'],(0,1,3,2))
            example_images[lab][model_type] = np.transpose(example_images[lab][model_type],(0,1,3,2))
        example_images[lab]['orig'] = np.concatenate([example_images[lab]['orig'][:,j] for j in range(example_images[lab]['orig'].shape[1])],axis=2)
        example_images[lab][model_type] = np.concatenate([example_images[lab][model_type][:,j] for j in range(example_images[lab][model_type].shape[1])],axis=2)
        
        if lab=='datta':
            example_images[lab]['mask'] = data_generator.datasets[0][trial_idx]['masks'].cpu().detach().numpy()

In [None]:
normalizing_info['steinmetz-face']

In [None]:
.0004/.01

In [None]:
data_generator.datasets[0][trial_idx]['images'].shape

In [None]:
import matplotlib
FONT_SIZE = 22
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams.update({'font.size': FONT_SIZE})
#matplotlib.rcParams['font.family'] = 'sans-serif'
#matplotlib.rcParams['font.sans-serif'] = ['Myriad Pro']
plt.rcParams["font.family"] = "Times New Roman"



plt.rc('font', size=FONT_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_SIZE)  # fontsize of the figure title

In [None]:
### Make figure 1

y_axis = 'loss' # 'test_loss' | 'test_r2'
plot_type = 'line'
hue = 'model_type'
x_axis = 'n_latents' # 'layers' | 'layer_size' | 'pred' | 'lags'


fig, axes = plt.subplots(3,4,figsize=(24,14))

for i_ax in range(4):
    sns.relplot(x=x_axis, y=y_axis, hue=hue, kind=plot_type, data=data_queried[labs[i_ax]],ax=axes[0][i_ax])
    axes[0][i_ax].set_title(dataset_names[i_ax])
    axes[0][i_ax].set_xlabel('Latent dimension')
    axes[0][i_ax].set_ylabel('')
    
axes[0][0].legend(['Conv AE','Linear AE'])

axes[0][0].set_ylabel('MSE')
i_frame=20
for i_ax in range(4):
    axes[1][i_ax].imshow(example_images[labs[i_ax]]['orig'][i_frame],vmin=0,vmax=1,cmap='gray')
    if labs[i_ax]=='datta':
        axes[1][i_ax].imshow(example_images[labs[i_ax]]['orig'][i_frame]*example_images[labs[i_ax]]['mask'][i_frame][0],vmin=0,vmax=1,cmap='gray',interpolation='none')
    
    [s.set_visible(False) for s in axes[1][i_ax].spines.values()]    
    axes[1][i_ax].set_xticks([])
    axes[1][i_ax].set_yticks([])
    
for i_ax in range(4):
    axes[2][i_ax].imshow(example_images[labs[i_ax]]['conv'][i_frame],vmin=0,vmax=1,cmap='gray',interpolation='none')
    for spine in plt.gca().spines.values():
        spine.set_visible(False)
        
    [s.set_visible(False) for s in axes[2][i_ax].spines.values()]
    axes[2][i_ax].set_xticks([])
    axes[2][i_ax].set_yticks([])
    
# for i_ax in range(4):
#     axes[3][i_ax].imshow(example_images[labs[i_ax]]['linear'][120],vmin=0,vmax=1,cmap='gray')
#     axes[3][i_ax].set_xticks([])
#     axes[3][i_ax].set_yticks([])

axes[1][0].set_ylabel('Original \n Frame',rotation=0, labelpad=50)
axes[2][0].set_ylabel('Conv AE \n Reconstructed \n Frame',rotation=0, labelpad=60)
fig.tight_layout()
fig.savefig('behavenet_figs/reconstructions.pdf', transparent=True)

In [None]:
# # data_queried = metrics[~pd.notna(metrics.test_loss)]

# y_axis = 'loss' # 'test_loss' | 'test_r2'
# plot_type = 'line'
# hue = 'model_type'
# x_axis = 'n_latents' # 'layers' | 'layer_size' | 'pred' | 'lags'

# splt = sns.relplot(
#     x=x_axis, y=y_axis, hue=hue, kind=plot_type, data=data_queried)
# for i, ax in enumerate(splt.axes):
#     ax[0].set_yscale('log')
#     if i == 0:
#         ax[0].set_ylabel('MSE per pixel')

In [None]:
print(normalizing_info['musall']['var'])
print(normalizing_info['steinmetz']['var'])
print(normalizing_info['steinmetz-face']['var'])

# 3. Make reconstruction movies

In [None]:
from analyses.ae_movies import make_ae_reconstruction_movie
from fitting.utils import get_lab_example
from data.data_generator import ConcatSessionsGenerator

include_linear = True
version = 'best'
hparams = {
    'data_dir': '/Volumes/paninski-locker/data/',  # might be a different path than what is saved in hparams
    'tt_save_path': '/Volumes/paninski-locker/analysis/behavenet/',
    'experiment_name': 'test_pt',
    'lin_experiment_name': 'test_pt',
    'model_class': 'ae',
    'model_type': 'conv',
    'lib': 'pt'}

labs = ['datta'] #['steinmetz', 'steinmetz-face', 'musall', 'datta']

for lab in labs:
    get_lab_example(hparams, lab)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'], hparams, 
        signals=['images'], transforms=[None], load_kwargs=[None],
        device='cpu', as_numpy=False, batch_load=True, rng_seed=0)
    trial = data_generator.batch_indxs[0]['test'][4]

    save_file = os.path.join(
        results_dir, str('%s_%02i_dim_recon_ae' % (lab, hparams['n_ae_latents'])))

    make_ae_reconstruction_movie(
        hparams, version=version, save_file=save_file, include_linear=include_linear,
        trial=trial)