In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "simdata"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [None]:
import torch
import string
import numpy as np
import pickle 
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)
import matplotlib.pyplot as plt
import matplotlib
import re
import seaborn as sns

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder
from nnfabrik.utility.hypersearch import Bayesian

from mlutils.measures import corr

from nnsysident.tables.scoring import OracleScore, OracleScoreTransfer
from nnsysident.tables.experiments import *
from nnsysident.tables.bayesian import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

In [None]:
def get_transfer(old_experiment_name):
    # prepare the Transfer table in a way that all the info about the transferred model is in the DataFrame. Just pd.merge (on transfer_fn and transfer_hash)
    # it then with the model that the transferred model was used for. 
    transfer = pd.DataFrame(Transfer.fetch())
    transfer = pd.concat([transfer, transfer['transfer_config'].apply(pd.Series)], axis = 1).drop('transfer_config', axis = 1)

    tm = pd.DataFrame((TrainedModel * Dataset * Seed * Experiments.Restrictions & 'experiment_name = "{}"'.format(old_experiment_name)).fetch()).rename(
        columns = {'model_hash': 't_model_hash', 'trainer_hash': 't_trainer_hash', 'dataset_hash': 't_dataset_hash'})
    tm = tm.sort_values('score', ascending=False).drop_duplicates(['t_model_hash', 't_trainer_hash', 't_dataset_hash'])

    transfer = pd.merge(transfer, tm, how='inner', on=['t_model_hash', 't_trainer_hash', 't_dataset_hash'])
    transfer = pd.concat([transfer, transfer['dataset_config'].apply(pd.Series)], axis = 1).drop('dataset_config', axis = 1)
    transfer.columns = ['t_' + col if col[:2] != 't_' and col[:8] != 'transfer'  else col for col in transfer.columns]
    transfer = transfer.sort_values(['t_neuron_n', 't_image_n', 't_neuron_base_seed', 't_image_base_seed'])
    return transfer

def baseline(data, tier):
    """ Function to estimate the highest possible correlation based on the ground truth"""
    # get dataset
    for loc, row in data_.iterrows():
        dataset_config = row['dataset_config']
        dataset_config.update(seed=1)
        dataloaders = builder.get_data(row['dataset_fn'], dataset_config)
        dataset = dataloaders['train']['0-0-3-0'].dataset
        dataset.transforms = []
        break
    # Extract data
    idx = dataset.trial_info.tiers == tier
    gts = np.array([gt for gt in dataset.neurons.ground_truths]).T[idx]
    resps = np.array([datapoint.responses for datapoint in dataset])[idx]
    # Compute correlation and return
    return np.mean(corr(resps, gts, axis=0))

# Triple plot for sim data

### Get data

In [None]:
# Direct
data = pd.DataFrame()
for experiment_name in ['SIM, Direct, se2d_spatialxfeaturelinear, 0-0-3', 'SIM, Direct, se2d_pointpooled, 0-0-3', 'SIM, Direct, se2d_fullgaussian2d, 0-0-3']:
    data_ = pd.DataFrame((TrainedModel * Dataset * Model * Trainer * Seed * OracleScore * Experiments.Restrictions & 'experiment_name="{}"'.format(experiment_name)).fetch())
    data = pd.concat([data, data_])
    
data = pd.concat([data, data['dataset_config'].apply(pd.Series)], axis = 1).drop('dataset_config', axis = 1)
data = pd.concat([data, data['model_config'].apply(pd.Series)], axis = 1).drop('model_config', axis = 1)
data['Readout'] = [row.model_fn.split('.')[-1][5:] for loc, row in data.iterrows()]
data = data.rename(columns = {'neuron_n': '# neurons', 'image_n': "# images"})
direct_data = data.copy()
direct_data['Condition'] = "direct"
base_line = baseline(data=data, tier='test')

# Filter out best performing models over model seeds
direct_data = direct_data.sort_values('score', ascending=False).drop_duplicates(['Readout',
                                                                   '# neurons',
                                                                   '# images', 
                                                                   'neuron_base_seed',
                                                                   'image_base_seed']).sort_values(['Readout', '# neurons', '# images'])


# Full readout data
data = pd.DataFrame()
for experiment_name in ["SIM, core_transfer (sameNI), se2d_fullgaussian2d, 0-0-3 -> 0-0-3, readout full I",
                        "SIM, core_transfer (sameNI), se2d_pointpooled, 0-0-3 -> 0-0-3, readout full I",
                        "SIM, core_transfer (sameNI), se2d_spatialxfeaturelinear, 0-0-3 -> 0-0-3, readout full I"]:
    data_ = pd.DataFrame((TrainedModelTransfer * Dataset * Model * Trainer * Seed * OracleScoreTransfer * Transfer.proj() * ExperimentsTransfer.Restrictions & 'experiment_name="{}"'.format(experiment_name)).fetch())
    transfer = get_transfer(old_experiment_name='SIM, Direct, {}, 0-0-3'.format(experiment_name.split(', ')[2]))
    data_ = pd.merge(data_, transfer, how='inner', on=['transfer_hash', 'transfer_fn'])
    data = pd.concat([data, data_])
data['Readout'] = [row.model_fn.split('.')[-1][5:] for loc, row in data.iterrows()]
data = data.rename(columns = {'t_neuron_n': '# neurons', 't_image_n': "# images"})
full_readout_data = data.copy()
full_readout_data['Condition'] = "diff-core/best-readout"

# Full core data
data = pd.DataFrame()
for experiment_name in ["SIM, core_transfer (best), se2d_fullgaussian2d, 0-0-3 -> 0-0-3",
                         "SIM, core_transfer (best), se2d_pointpooled, 0-0-3 -> 0-0-3",
                         "SIM, core_transfer (best), se2d_spatialxfeaturelinear, 0-0-3 -> 0-0-3"]:
    data_ = pd.DataFrame((TrainedModelTransfer * Dataset * Model * Trainer * Seed * OracleScoreTransfer * Transfer.proj() * ExperimentsTransfer.Restrictions & 'experiment_name="{}"'.format(experiment_name)).fetch())
    transfer = get_transfer(old_experiment_name='SIM, Direct, {}, 0-0-3'.format(experiment_name.split(', ')[2]))
    data_ = pd.merge(data_, transfer, how='inner', on=['transfer_hash', 'transfer_fn'])
    data = pd.concat([data, data_])
data['Readout'] = [row.model_fn.split('.')[-1][5:] for loc, row in data.iterrows()]
data = pd.concat([data, data['dataset_config'].apply(pd.Series)], axis = 1).drop('dataset_config', axis = 1)
data = data.rename(columns = {'neuron_n': '# neurons', 'image_n': "# images"})
full_core_data = data.copy()
full_core_data['Condition'] = "best-core/diff-readout"

data = pd.concat([direct_data, full_readout_data, full_core_data])
data.replace({'Readout': {'spatialxfeaturelinear':'Factorized readout', 'fullgaussian2d':'Gaussian readout', 'pointpooled':'Point readout'}}, inplace=True)



### Plot

In [None]:
title = 'Simulated data triple plot'
scoring_measure = "fraction_oracle" 
sns.set_context("paper")        
col_order = ['Factorized readout', 'Gaussian readout', 'Point readout']
palette = [col for i, col in enumerate(sns.color_palette('bright')) if i in (5,6,8) ]

paper_rc = {'lines.linewidth': 2, 'lines.markersize': 10}  
with sns.plotting_context('paper', rc=paper_rc, font_scale=2.5), sns.color_palette('bright'), sns.axes_style('ticks'):        
    g = sns.relplot(x="# images", 
                    y=scoring_measure,
                    hue="Condition", 
                    col="Readout",
                    col_order=col_order,
                    kind="line",
                    marker="o",
                    data=data,
                   palette=palette,
                   aspect=0.88)          
    g.map(plt.axhline, y=base_line, c='k', ls='--', zorder=0, label= "ground truth")  
       
    g._legend.remove()
    g.add_legend()
    g.axes[0,0].set_ylabel(scoring_measure.replace('_', ' '))
    g.axes[0,0].set_xlabel('# images')
    g.axes[0,1].set_xlabel('# images')
    g.axes[0,2].set_xlabel('# images')
    g._legend.texts[0].set_text("")
    g._legend.set_bbox_to_anchor((.57,.46,.1,.1))
    
    for label in g._legend.texts:
        label.set_size(25)
    
    for i, ax in enumerate(g.axes.flatten()): 
        ax.set_title(ax.get_title()[10:])
        ax.text(-0.1, 1.09, string.ascii_uppercase[i], transform=ax.transAxes, 
        size=25, weight='bold')    
        ax.yaxis.grid(True)
        ax.set_xscale('log')
        ax.set_xticks(np.unique(data['# images']))
        ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
        
        ax.set_title(ax.get_title(), fontsize=30)
        
        if i > 0:
            ax.set(xlabel="")
        else:
            ax.set_xlabel("# images", fontsize=30)
            ax.set_ylabel("Fraction oracle", fontsize=30)
            
    sns.despine(trim=True)
    plt.tight_layout()
#     g.fig.savefig('figures/' + title + '.pdf', dpi=150)