In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200

name = 'realdata' #"simdata"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [None]:
import torch
import shutil
import numpy as np
import pickle 
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)
import matplotlib.pyplot as plt
import matplotlib
import re
import string
import seaborn as sns
import hiplot as hip
import statsmodels

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder
from nnfabrik.utility.hypersearch import Bayesian
from nnsysident.utility.measures import get_correlations

from nnsysident.tables.experiments import *
from nnsysident.tables.bayesian import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.tables.scoring import OracleScore, OracleScoreTransfer

from nnsysident.datasets.mouse_loaders import static_loader

def find_number(text, c):
    number_list = re.findall(r'%s(\d+)' % c, text)
    if len(number_list) == 0:
        number = None
    elif len(number_list) == 1:
        number = int(number_list[0])
    else:
        raise ValueError('More than one number found..') 
    return number

def get_transfer(transfer_hashes):
    # prepare the Transfer table in a way that all the info about the transferred model is in the DataFrame. Just pd.merge (on transfer_fn and transfer_hash)
    # it then with the model that the transferred model was used for. 
    
    transfer = pd.DataFrame((Transfer & 'transfer_hash in {}'.format(tuple(transfer_hashes))).fetch())
    transfer = pd.concat([transfer, transfer['transfer_config'].apply(pd.Series)], axis = 1).drop('transfer_config', axis = 1)

    restriction = transfer.rename(columns = {'t_model_hash': 'model_hash', 't_dataset_hash': 'dataset_hash', 't_trainer_hash': 'trainer_hash'})            
    restriction = restriction[['model_hash', 'dataset_hash', 'trainer_hash']].to_dict('records')

    tm = pd.DataFrame((TrainedModel * Dataset * Seed & restriction).fetch()).rename(
        columns = {'model_hash': 't_model_hash', 'trainer_hash': 't_trainer_hash', 'dataset_hash': 't_dataset_hash'})               
    tm = tm.sort_values('score', ascending=False).drop_duplicates(['t_model_hash', 't_trainer_hash', 't_dataset_hash'])

    transfer = pd.merge(transfer, tm, how='inner', on=['t_model_hash', 't_trainer_hash', 't_dataset_hash'])
    transfer = pd.concat([transfer, transfer['dataset_config'].apply(pd.Series)], axis = 1).drop('dataset_config', axis = 1)
    transfer.columns = ['t_' + col if col[:2] != 't_' and col[:8] != 'transfer'  else col for col in transfer.columns]
    transfer = transfer.sort_values(['t_multi_match_n', 't_image_n', 't_multi_match_base_seed', 't_image_base_seed'])
    return transfer


def get_transfer_entries(old_experiment_name, overall_best):
    tm = pd.DataFrame((TrainedModel * Dataset * Seed * Experiments.Restrictions & 'experiment_name="{}"'.format(old_experiment_name)).fetch())
    tm = pd.concat([tm, tm['dataset_config'].apply(pd.Series)], axis = 1).drop('dataset_config', axis = 1)

    model_fn = np.unique(tm['model_fn'])
    assert len(model_fn) == 1 ,"Must have exactly 1 model function in experiment"
    model_fn = model_fn[0] 

    # Filter out best model(s) 
    if overall_best is True:
        tm = tm.loc[(tm['multi_match_n'] == tm['multi_match_n'].max()) & (tm['image_n'] == tm['image_n'].max())]
    tm = tm.sort_values('score', ascending=False).drop_duplicates(['multi_match_n', 'image_n', 'multi_match_base_seed', 'image_base_seed']).sort_values(['multi_match_n', 'image_n'])

    # make entries for Trasfer table
    entries = [dict(transfer_fn='nnsysident.models.transfer_functions.core_transfer', 
                     transfer_config = dict(t_model_hash=row.model_hash, t_dataset_hash=row.dataset_hash, t_trainer_hash=row.trainer_hash),
                     transfer_comment=model_fn.split('.')[-1] + ', multi_match_n={}, multi_match_base_seed={}, image_n={}, image_base_seed={}'.format(row.multi_match_n, 
                                                                                                                                    row.multi_match_base_seed, 
                                                                                                                                    row.image_n, 
                                                                                                                                    row.image_base_seed),
                     transfer_fabrikant='kklurz') for loc, row in tm.iterrows()]
    return entries

---