# Create, save, and populate new parameter sets for `SplineLNPParams()`

In this jupyter notebook, new parameter sets for grid search is created and saved. It's also shown how to load and populate these parameter sets in the parameter table for the spline-based LNP models.

## Setup

In [1]:
run -im djd.main -- --dbname=dj_hmovmodels --r

For remote access to work, make sure to first open an SSH tunnel with MySQL
port forwarding. Run the `djdtunnel` script in a separate terminal, with
optional `--user` argument if your local and remote user names differ.
Or, open the tunnel manually with:
  ssh -NL 3306:huxley.neuro.bzm:3306 -p 1021 USERNAME@tunnel.bio.lmu.de
Connecting execute@localhost:3306
Connected to database 'dj_hmovmodels' as 'execute@10.153.172.3'
For remote file access to work, make sure to first mount the filesystem at tunnel.bio.lmu.de:1021 via SSHFS with `hux -r`


In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Automatically reload modules to get code changes without restarting kernel
# NOTE: Does not work for DJD table modules
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

## Check what is already populated

In [7]:
#SplineLNPParams().drop()

`dj_hmovmodels`.`spline_l_n_p_params` (48 tuples)
`dj_hmovmodels`.`__spline_l_n_p` (282 tuples)
`dj_hmovmodels`.`__spline_l_n_p__eval` (282 tuples)


Proceed? [yes, No]:  yes


Tables dropped.  Restart kernel.


In [3]:
SplineLNPParams()

spl_paramset  parameter set ID,spl_distr  nonlinearity in LNP,spl_alpha  weighting betw. L2 and L1 penalty (alpha=1 only uses L1),spl_lambda  regularization parameter of penalty term,spl_lr  initial learning rate for the JAX optimizer,spl_max_iter  maximum number of iterations for the solver,spl_dt  inverse of the sampling rate,spl_spat_df  degrees of freedom num of basis functions for spatial domain,spl_temp_df  degrees of freedom num of basis functions for temp component,spl_pshf  fit post-spike history filter,spl_pshf_len  length of the post-spike history filter,spl_pshf_df  number of basis functions for post-spike history filter,spl_verb  when verbose=n progress will be printed in every n steps,"spl_metric  'None', 'mse', 'r2', or 'corrcoef'",spl_norm_y  normalize observed responses,spl_nlag  number of time steps of the kernel,spl_shift  shift kernel to not predict itself,spl_spat_scaling  scaling factor for spatial resolution of movie,spl_opto  fit optogenetics filter,spl_opto_len  length of the opto filter (number of time steps),spl_opto_df  number of basis functions for opto filter,spl_run  fit running filter,spl_run_len  length of the running filter (number of time steps),spl_run_df  number of basis functions for running filter,spl_eye  fit eye filter,spl_eye_len  length of the eye filter (number of time steps),spl_eye_df  number of basis functions for eye filter,spl_eye_lpfilt  lowpass filter eye data,spl_eye_cutoff  cutoff freq. for lowpass filter (set 0 if lpfilt=False)
,,,,,,,,,,,,,,,,,,,,,,,,,,,,


## Generate parameter sets for grid search

Some parameters stay fixed (e.g. learning rate, nlag, ...), and some are varied in the grid search. 

Varied parameters are the following:
* regularization constant `spl_lambda`
* number of spline basis functions in the spatial dimension (frame height) `spl_spat_df`
* opto information as model input `spl_opto`
* running information as model input `spl_run`

In [8]:
# Define fix parameters
lr = 0.01
max_iter = 4000
nlag = 8
shift = 1
temp_df = 7
distr = 'softplus'
alpha = 1.0
dt = 0.033
pshf = 'True'
pshf_len = 8
pshf_df = 6
verbose = 200
metric = 'corrcoef'
norm_y = 'False'
spat_scaling = 0.06
eye = 'False'
eye_len = 0
eye_df = 0
eye_lpfilt = 'False'
eye_cutoff = 0

# Define ranges for grid search
lambda_grid = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2]
spat_df_grid = [6,8]
opto_grid = ['False', 'True']
run_grid = ['False', 'True']

# Loop over conditions
paramdicts_grid = []
parameterset_idx = 100
for opto in opto_grid:
    for run in run_grid:
        for spat_df in spat_df_grid:
            for lamdba_param in lambda_grid:
                # Behavior and opto configuration
                if opto is 'True':
                    opto_len = 8
                    opto_df = 7
                elif opto is 'False':
                    opto_len = 0
                    opto_df = 0
                if run is 'True':
                    run_len = 8
                    run_df = 7
                elif run is 'False':
                    run_len = 0
                    run_df = 0
                # generate parameter dict
                param_dict = SplineLNPParams().generate_paramset(
                    paramseti=parameterset_idx, 
                    # Grid
                    spl_lambda=lamdba_param, 
                    spat_df=spat_df,
                    # Behavior config
                    opto=opto, 
                    opto_len=opto_len, 
                    opto_df=opto_df, 
                    run=run,
                    run_len=run_len, 
                    run_df=run_df, 
                    # Fixed params
                    lr=lr, 
                    max_iter=max_iter, 
                    nlag=nlag, 
                    shift=shift, 
                    temp_df=temp_df,
                    distr=distr, 
                    alpha=alpha, 
                    dt=dt, 
                    pshf=pshf, 
                    pshf_len=pshf_len,
                    pshf_df=pshf_df, 
                    verbose=verbose, 
                    metric=metric, 
                    norm_y=norm_y,
                    spat_scaling=spat_scaling, 
                    eye=eye, 
                    eye_len=eye_len, 
                    eye_df=eye_df,
                    eye_lpfilt=eye_lpfilt, 
                    eye_cutoff=eye_cutoff,
                    )
                paramdicts_grid.append(param_dict)
                parameterset_idx+=1

Check number of generated parameter sets:

In [9]:
6*2*2*2

48

In [10]:
len(paramdicts_grid)

48

Check first and last parameter set:

In [11]:
paramdicts_grid[0]

{'spl_paramset': 100,
 'spl_distr': 'softplus',
 'spl_alpha': 1.0,
 'spl_lambda': 0.2,
 'spl_lr': 0.01,
 'spl_max_iter': 4000,
 'spl_dt': 0.033,
 'spl_spat_df': 6,
 'spl_temp_df': 7,
 'spl_pshf': 'True',
 'spl_pshf_len': 8,
 'spl_pshf_df': 6,
 'spl_verb': 200,
 'spl_metric': 'corrcoef',
 'spl_norm_y': 'False',
 'spl_nlag': 8,
 'spl_shift': 1,
 'spl_spat_scaling': 0.06,
 'spl_opto': 'False',
 'spl_opto_len': 0,
 'spl_opto_df': 0,
 'spl_run': 'False',
 'spl_run_len': 0,
 'spl_run_df': 0,
 'spl_eye': 'False',
 'spl_eye_len': 0,
 'spl_eye_df': 0,
 'spl_eye_lpfilt': 'False',
 'spl_eye_cutoff': 0}

In [12]:
paramdicts_grid[-1]

{'spl_paramset': 147,
 'spl_distr': 'softplus',
 'spl_alpha': 1.0,
 'spl_lambda': 1.2,
 'spl_lr': 0.01,
 'spl_max_iter': 4000,
 'spl_dt': 0.033,
 'spl_spat_df': 8,
 'spl_temp_df': 7,
 'spl_pshf': 'True',
 'spl_pshf_len': 8,
 'spl_pshf_df': 6,
 'spl_verb': 200,
 'spl_metric': 'corrcoef',
 'spl_norm_y': 'False',
 'spl_nlag': 8,
 'spl_shift': 1,
 'spl_spat_scaling': 0.06,
 'spl_opto': 'True',
 'spl_opto_len': 8,
 'spl_opto_df': 7,
 'spl_run': 'True',
 'spl_run_len': 8,
 'spl_run_df': 7,
 'spl_eye': 'False',
 'spl_eye_len': 0,
 'spl_eye_df': 0,
 'spl_eye_lpfilt': 'False',
 'spl_eye_cutoff': 0}

## Save as .json

In [13]:
import json

date_str = (np.datetime_as_string(np.datetime64('now'))).replace(":", "").replace("T","_")
filename = 'SplineLNPParams_grid_search_{:s}.json'.format(date_str)
#path = './'
path = '/mnt/hux/mudata/djstore/hmov_paramsets'
full_path = os.path.join(path, filename)
with open(full_path, 'w') as file_out:
    json.dump(paramdicts_grid, file_out, indent=2)
print('JSON file saved {:s}'.format(full_path))

JSON file saved /mnt/hux/mudata/djstore/hmov_paramsets/SplineLNPParam_grid_search_2021-03-18_165333.json


## Populate `SplineLNPParams()`

In [14]:
SplineLNPParams().populate_saved_paramset('SplineLNPParams_grid_search_2021-03-18_165333.json')

Parameter set dictionary loaded: SplineLNPParam_grid_search_2021-03-18_165333.json contains 48 parametersets.


In [15]:
SplineLNPParams()

spl_paramset  parameter set ID,spl_distr  nonlinearity in LNP,spl_alpha  weighting betw. L2 and L1 penalty (alpha=1 only uses L1),spl_lambda  regularization parameter of penalty term,spl_lr  initial learning rate for the JAX optimizer,spl_max_iter  maximum number of iterations for the solver,spl_dt  inverse of the sampling rate,spl_spat_df  degrees of freedom num of basis functions for spatial domain,spl_temp_df  degrees of freedom num of basis functions for temp component,spl_pshf  fit post-spike history filter,spl_pshf_len  length of the post-spike history filter,spl_pshf_df  number of basis functions for post-spike history filter,spl_verb  when verbose=n progress will be printed in every n steps,"spl_metric  'None', 'mse', 'r2', or 'corrcoef'",spl_norm_y  normalize observed responses,spl_nlag  number of time steps of the kernel,spl_shift  shift kernel to not predict itself,spl_spat_scaling  scaling factor for spatial resolution of movie,spl_opto  fit optogenetics filter,spl_opto_len  length of the opto filter (number of time steps),spl_opto_df  number of basis functions for opto filter,spl_run  fit running filter,spl_run_len  length of the running filter (number of time steps),spl_run_df  number of basis functions for running filter,spl_eye  fit eye filter,spl_eye_len  length of the eye filter (number of time steps),spl_eye_df  number of basis functions for eye filter,spl_eye_lpfilt  lowpass filter eye data,spl_eye_cutoff  cutoff freq. for lowpass filter (set 0 if lpfilt=False)
100,softplus,1.0,0.2,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
101,softplus,1.0,0.4,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
102,softplus,1.0,0.6,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
103,softplus,1.0,0.8,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
104,softplus,1.0,1.0,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
105,softplus,1.0,1.2,0.01,4000,0.033,6,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
106,softplus,1.0,0.2,0.01,4000,0.033,8,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
107,softplus,1.0,0.4,0.01,4000,0.033,8,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
108,softplus,1.0,0.6,0.01,4000,0.033,8,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
109,softplus,1.0,0.8,0.01,4000,0.033,8,7,True,8,6,200,corrcoef,False,8,1,0.06,False,0,0,False,0,0,False,0,0,False,0
