# Glm and SplineLNP model tables

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
run -im djd.main -- --dbname=dj_lisa --r

For remote access to work, make sure to first open an SSH tunnel with MySQL
port forwarding. Run the `djdtunnel` script in a separate terminal, with
optional `--user` argument if your local and remote user names differ.
Or, open the tunnel manually with:
  ssh -NL 3306:huxley.neuro.bzm:3306 -p 1021 USERNAME@tunnel.bio.lmu.de
Connecting execute@localhost:3306
Connected to database 'dj_lisa' as 'execute@10.153.172.3'
For remote file access to work, make sure to first mount the filesystem at tunnel.bio.lmu.de:1021 via SSH


In [None]:
from djd import glm # for plottin

# Populate new parameter sets

Uncomment and create new parameter set dictionary with individual parameter set ID:

In [None]:
# glm_run1 = {
#     'glm_paramset':1,
#     'glm_distr':'softplus',
#     'glm_alpha':1.0, # the weighting between L1 penalty and L2 penalty term of the loss function
#     'glm_lambda':0.00015, # regularization parameter of penalty term; default: 0.1
#     'glm_solver':'batch-gradient', # optimization method
#     'glm_lr':0.7, # learning rate for gradient descent
#     'glm_max_iter':1000,# maximum number of iterations for the solver
#     'glm_tol':1e-6, # convergence threshold or stopping criteria
#     'glm_seed':0, # seed of the random number generator used to initialize the solution 
#     'glm_norm_y':'True',
#     'glm_nlag':8,
#     'glm_shift':1,
# }

# GlmParams().populate(glm_run1)

GlmParams()

In [None]:
# splineLNP_run1 = {
#     'spl_paramset':1,
#     'spl_nonlin':'softplus',
#     'spl_alpha':1,
#     'spl_lambda':3,
#     'spl_lr':1e-2,
#     'spl_max_iter':2000,
#     'spl_dt':0.033,
#     'spl_spat_df':6,
#     'spl_temp_df':6,
#     'spl_psh_filt':'True',
#     'spl_verb':200,
#     'spl_metric':'corrcoef',
#     'spl_norm_y':'False',
#     'spl_nlag':8,
#     'spl_shift':1,
# }

# SplineLNPParams().populate(splineLNP_run1)

SplineLNPParams()

# Populate model tables

Especially the `Glm()` table might take a while to be populated for new parameter sets depending on the number of maximum iterations.

To test, maybe try only for one (good) unit using `(SplineLNP & {'m': 'Ntsr1Cre_2019_0008', 's':3, 'e':7, 'u':14, 'glm_paramset':1}).populate()`.

In [None]:
#Glm.populate()

In [None]:
#SplineLNP.populate()

# Evaluate models

In [None]:
bins = np.arange(start=-1, stop=1, step=0.05)

# Glm

In [None]:
#np.mean(Glm.GlmEval().fetch('glm_r_train'))

plt.hist(Glm.GlmEval().fetch('glm_r_train'), bins=bins, label='training set', alpha=0.5);
plt.vlines(x=np.mean(Glm.GlmEval().fetch('glm_r_train')), ymin=0, ymax=50, linestyles='dashed', colors='tab:blue')
plt.hist(Glm.GlmEval().fetch('glm_r_test'), bins=bins, label='test set', alpha=0.5);
plt.vlines(x=np.mean(Glm.GlmEval().fetch('glm_r_test')), ymin=0, ymax=50, linestyles='dashed', colors='tab:orange')
sns.despine()
plt.legend()
plt.xlabel('correlation coefficient')
plt.ylabel('number of units')
plt.title(('Glm performance per unit for paramset {:02d}').format((dj.U('glm_paramset') & Glm.GlmEval()).fetch1('glm_paramset')));

print('Mean performance training set: ', np.round(np.mean(Glm.GlmEval().fetch('glm_r_train')), 3))
print('Mean performance test set: ', np.round(np.mean(Glm.GlmEval().fetch('glm_r_test')), 3))

# SplineLNP

### without post-spike history filter

In [None]:
SplineLNPParams() & {'spl_paramset':2}

In [None]:
all_perf_test = (SplineLNP.SplineLNPEval()&{'spl_paramset':2}).fetch('spl_r_test')
all_perf_train = (SplineLNP.SplineLNPEval()&{'spl_paramset':2}).fetch('spl_r_train')

In [None]:
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_train'), bins=bins, label='training set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_train')), ymin=0, ymax=50, linestyles='dashed', colors='tab:blue')
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_test'), bins=bins, label='test set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_test')), ymin=0, ymax=50, linestyles='dashed', colors='tab:orange')
sns.despine()
plt.legend()
plt.xlabel('correlation coefficient')
plt.ylabel('number of units')
plt.title(('SplineLNP performance per unit for paramset {:02d}').format((dj.U('spl_paramset') & SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch1('spl_paramset')));

print('Mean performance training set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_train')), 3))
print('Mean performance test set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':2, 'spl_psh_filt':'False'}).fetch('spl_r_test')), 3))

### with post-spike history filter

In [None]:
SplineLNPParams() & {'spl_paramset':1}

In [None]:
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_train'), bins=bins, label='training set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_train')), ymin=0, ymax=50, linestyles='dashed', colors='tab:blue')
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_test'), bins=bins, label='test set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_test')), ymin=0, ymax=50, linestyles='dashed', colors='tab:orange')
sns.despine()
plt.legend()
plt.xlabel('correlation coefficient')
plt.ylabel('number of units')
plt.title(('SplineLNP performance per unit for paramset {:02d}').format((dj.U('spl_paramset') & SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch1('spl_paramset')));

print('Mean performance training set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_train')), 3))
print('Mean performance test set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':1, 'spl_psh_filt':'True'}).fetch('spl_r_test')), 3))

### with post-spike filter but higher regularization

In [None]:
SplineLNPParams() & {'spl_paramset':3}

In [None]:
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_train'), bins=bins, label='training set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_train')), ymin=0, ymax=50, linestyles='dashed', colors='tab:blue')
plt.hist((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_test'), bins=bins, label='test set', alpha=0.5);
plt.vlines(x=np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_test')), ymin=0, ymax=50, linestyles='dashed', colors='tab:orange')
sns.despine()
plt.legend()
plt.xlabel('correlation coefficient')
plt.ylabel('number of units')
plt.title(('SplineLNP performance per unit for paramset {:02d}').format((dj.U('spl_paramset') & SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch1('spl_paramset')));

print('Mean performance training set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_train')), 3))
print('Mean performance test set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':3, 'spl_psh_filt':'True'}).fetch('spl_r_test')), 3))

## Plot best units

### Glm

In [None]:
glm_mean_test = np.round(np.mean(Glm.GlmEval().fetch('glm_r_test')), 3)
glm_std_test = np.round(np.std(Glm.GlmEval().fetch('glm_r_test')), 3)

In [None]:
glm_keys = (Glm*Glm.GlmEval() & ('glm_r_test > {:.3f}'.format(glm_mean_test+2*glm_std_test)) & {'m':'Ntsr1Cre_2019_0008'} ).fetch(dj.key)
print('Number of units that are better than mean correlation + 2 times std: ', len(glm_keys))

In [None]:
for k in glm_keys:
    glm._plot_RF(k)

### SplineLNP

In [None]:
SplineLNPParams()

In [None]:
paramset = 3

In [None]:
print('Mean performance training set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':paramset}).fetch('spl_r_train')), 3))
print('Mean performance test set: ', np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':paramset}).fetch('spl_r_test')), 3))

In [None]:
spl_mean_test = np.round(np.mean((SplineLNP.SplineLNPEval()&{'spl_paramset':paramset}).fetch('spl_r_test')), 3)
spl_std_test = np.round(np.std((SplineLNP.SplineLNPEval()&{'spl_paramset':paramset}).fetch('spl_r_test')), 3)

In [None]:
spl_keys = (SplineLNP*SplineLNP.SplineLNPEval() & {'spl_paramset':paramset} & ('spl_r_test > {:.3f}'.format(spl_mean_test+1.5*spl_std_test)) & {'m':'Ntsr1Cre_2019_0008'} ).fetch(dj.key)
print('Number of units that are better than mean correlation + 1.5 times std: ', len(spl_keys))

In [None]:
for k in spl_keys:
    glm._plot_RF(k)

## Plots for all units

In [None]:
# keys_glm = Glm().fetch(dj.key)
# for key in keys_glm:
#     glm._plot_RF(key)

In [None]:
# keys_spl = SplineLNP().fetch(dj.key)
# for key in keys_spl:
#     glm._plot_RF(key)