# Create, save, and populate new parameter sets for `SplineLNPParams()`

In this jupyter notebook, new parameter sets for grid search is created and saved. It's also shown how to load and populate these parameter sets in the parameter table for the spline-based LNP models.

## Setup

In [1]:
run -im djd.main -- --dbname=dj_hmov --r

For remote access to work, make sure to first open an SSH tunnel with MySQL
port forwarding. Run the `djdtunnel` script in a separate terminal, with
optional `--user` argument if your local and remote user names differ.
Or, open the tunnel manually with:
  ssh -NL 3306:huxley.neuro.bzm:3306 -p 1021 USERNAME@tunnel.bio.lmu.de
Connecting execute@localhost:3306
Connected to database 'dj_hmov' as 'execute@10.153.172.3'
For remote file access to work, make sure to first mount the filesystem at tunnel.bio.lmu.de:1021 via SSHFS with `hux -r`


In [2]:
# Automatically reload modules to get code changes without restarting kernel
# NOTE: Does not work for DJD table modules
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

## Check what is already populated

In [None]:
SplineLNPParams()

You might want to drop what is populated:

In [7]:
#SplineLNPParams().drop()

`dj_hmovmodels`.`spline_l_n_p_params` (48 tuples)
`dj_hmovmodels`.`__spline_l_n_p` (282 tuples)
`dj_hmovmodels`.`__spline_l_n_p__eval` (282 tuples)


Proceed? [yes, No]:  yes


Tables dropped.  Restart kernel.


In [3]:
SplineLNPParams()

spl_paramset  parameter set ID,spl_distr  nonlinearity in LNP,spl_alpha  weighting betw. L2 and L1 penalty (alpha=1 only uses L1),spl_lambda  regularization parameter of penalty term,spl_lr  initial learning rate for the JAX optimizer,spl_max_iter  maximum number of iterations for the solver,spl_dt  inverse of the sampling rate,spl_spat_df  degrees of freedom num of basis functions for spatial domain,spl_temp_df  degrees of freedom num of basis functions for temp component,spl_pshf  fit post-spike history filter,spl_pshf_len  length of the post-spike history filter,spl_pshf_df  number of basis functions for post-spike history filter,spl_verb  when verbose=n progress will be printed in every n steps,"spl_metric  'None', 'mse', 'r2', or 'corrcoef'",spl_norm_y  normalize observed responses,spl_nlag  number of time steps of the kernel,spl_shift  shift kernel to not predict itself,spl_spat_scaling  scaling factor for spatial resolution of movie,spl_opto  fit optogenetics filter,spl_opto_len  length of the opto filter (number of time steps),spl_opto_df  number of basis functions for opto filter,spl_run  fit running filter,spl_run_len  length of the running filter (number of time steps),spl_run_df  number of basis functions for running filter,spl_eye  fit eye filter,spl_eye_len  length of the eye filter (number of time steps),spl_eye_df  number of basis functions for eye filter,spl_eye_lpfilt  lowpass filter eye data,spl_eye_cutoff  cutoff freq. for lowpass filter (set 0 if lpfilt=False)
,,,,,,,,,,,,,,,,,,,,,,,,,,,,


## Generate parameter sets

Some parameters stay fixed (e.g. learning rate, nlag, ...), and some are varied in the grid search. 

Varied parameters are the following:
* number of spline basis functions in the spatial dimension (frame height): `spl_spat_df`
* model configurations:
    * post-spike history filter: `spl_pshf`
    * opto stimuluation as model input: `spl_opto`
    * running speed as model input: `spl_run`
    * pupil size as model input: `spl_eye`

In [11]:
# Start index for new paramsets
parameterset_idx = 500

The attribute `spl_paramset` will just start with parameterset_idx and count up for all following parameter combinations that we create.

Make sure that the parameterset_idx is unique and not already populated.

In [12]:
# Define fixed parameters
stim = 'hmov'
distr = 'softplus'
alpha = 1.0
lr = 0.1
spl_lambda = 1.4
max_iter = 1500
spat_df = 6
temp_df = 7
verbose = 200
metric='corrcoef'
norm_y = 'False'
nlag = 20
shift = 1
spat_scaling=0.06
data_fs = 60
eye_smooth = 'True'

# Define ranges for grid search
spat_df_grid = [4,6,8]
pshf_grid = ['True', 'False']
opto_grid = ['True', 'False']
run_grid = ['True', 'False']
eye_grid = ['True', 'False']

# Loop over conditions
paramdicts_grid = []
for pshf in pshf_grid:
    for opto in opto_grid:
        for run in run_grid:
            for eye in eye_grid:
                for spat_df in spat_df_grid:
                    # Behavior and opto configuration
                    # post-spike
                    if pshf is 'True':
                        pshf_len = 20
                        pshf_df = 10
                    elif pshf is 'False':
                        pshf_len = 0
                        pshf_df = 0
                    # opto
                    if opto is 'True':
                        opto_len = 20
                        opto_df = 10
                    elif opto is 'False':
                        opto_len = 0
                        opto_df = 0
                    # run
                    if run is 'True':
                        run_len = 20
                        run_df = 10
                    elif run is 'False':
                        run_len = 0
                        run_df = 0
                    # eye
                    if eye is 'True':
                        eye_len = 20
                        eye_df = 10
                    elif eye is 'False':
                        eye_len = 0
                        eye_df = 0
                    # generate parameter dict
                    param_dict = SplineLNPParams().generate_paramset(paramseti=parameterset_idx, 
                                                                        # Grid 
                                                                        spat_df=spat_df,
                                                                        # Behavior config
                                                                        pshf=pshf, 
                                                                        pshf_len=pshf_len,
                                                                        pshf_df=pshf_df, 
                                                                        opto=opto,
                                                                        opto_len=opto_len, 
                                                                        opto_df=opto_df,
                                                                        run=run,
                                                                        run_len=run_len, 
                                                                        run_df=run_df,
                                                                        eye=eye,
                                                                        eye_len=eye_len, 
                                                                        eye_df=eye_df, 
                                                                        # Fixed params
                                                                        stim=stim,
                                                                        lr=lr, 
                                                                        spl_lambda=spl_lambda,
                                                                        max_iter=max_iter, 
                                                                        nlag=nlag, 
                                                                        shift=shift, 
                                                                        temp_df=temp_df,
                                                                        distr=distr, 
                                                                        alpha=alpha,
                                                                        verbose=verbose, 
                                                                        metric=metric, 
                                                                        norm_y=norm_y,
                                                                        spat_scaling=spat_scaling,
                                                                        data_fs=data_fs,
                                                                        eye_smooth=eye_smooth
                                                                        )
                    paramdicts_grid.append(param_dict)
                    parameterset_idx+=1

Check number of generated parameter sets:

In [13]:
6*2*2*2

48

In [14]:
len(paramdicts_grid)

48

Check first and last parameter set:

In [15]:
paramdicts_grid[0]

{'spl_paramset': 500,
 'spl_stim': 'hmov',
 'spl_distr': 'softplus',
 'spl_alpha': 1.0,
 'spl_lambda': 1.4,
 'spl_lr': 0.1,
 'spl_max_iter': 1500,
 'spl_spat_df': 4,
 'spl_temp_df': 7,
 'spl_pshf': 'True',
 'spl_pshf_len': 20,
 'spl_pshf_df': 10,
 'spl_verb': 200,
 'spl_metric': 'corrcoef',
 'spl_norm_y': 'False',
 'spl_nlag': 20,
 'spl_shift': 1,
 'spl_spat_scaling': 0.06,
 'spl_data_fs': 60,
 'spl_opto': 'True',
 'spl_opto_len': 20,
 'spl_opto_df': 10,
 'spl_run': 'True',
 'spl_run_len': 20,
 'spl_run_df': 10,
 'spl_eye': 'True',
 'spl_eye_len': 20,
 'spl_eye_df': 10,
 'spl_eye_smooth': 'True'}

In [16]:
paramdicts_grid[-1]

{'spl_paramset': 547,
 'spl_stim': 'hmov',
 'spl_distr': 'softplus',
 'spl_alpha': 1.0,
 'spl_lambda': 1.4,
 'spl_lr': 0.1,
 'spl_max_iter': 1500,
 'spl_spat_df': 8,
 'spl_temp_df': 7,
 'spl_pshf': 'False',
 'spl_pshf_len': 0,
 'spl_pshf_df': 0,
 'spl_verb': 200,
 'spl_metric': 'corrcoef',
 'spl_norm_y': 'False',
 'spl_nlag': 20,
 'spl_shift': 1,
 'spl_spat_scaling': 0.06,
 'spl_data_fs': 60,
 'spl_opto': 'False',
 'spl_opto_len': 0,
 'spl_opto_df': 0,
 'spl_run': 'False',
 'spl_run_len': 0,
 'spl_run_df': 0,
 'spl_eye': 'False',
 'spl_eye_len': 0,
 'spl_eye_df': 0,
 'spl_eye_smooth': 'True'}

## Save as .json

In [9]:
import json

date_str = (np.datetime_as_string(np.datetime64('now'))).replace(":", "").replace("T","_")
filename = 'SplineLNPParams_grid_search_{:s}.json'.format(date_str)
#path = './'
path = '/mnt/hux/mudata/djstore/hmov_paramsets'
full_path = os.path.join(path, filename)
with open(full_path, 'w') as file_out:
    json.dump(paramdicts_grid, file_out, indent=2)
print('JSON file saved {:s}'.format(full_path))

JSON file saved /mnt/hux/mudata/djstore/hmov_paramsets/SplineLNPParams_grid_search_2021-06-23_102743.json


## Populate `SplineLNPParams()`

Insert name of the .json file you want to load.

Here we just want to populate the paramsets we just generated so we can use the `filename`.

Check the directory `/mnt/hux/mudata/djstore/hmov_paramsets/` for other paramsets. Note that the attributes of the parameter table might have changed so that old saved paramsets might have to be adjusted in their attributes to be loaded.

In [10]:
filename

'SplineLNPParams_grid_search_2021-06-23_102743.json'

In [None]:
SplineLNPParams().populate_saved_paramset(filename)

# Check if population was successful

In [None]:
SplineLNPParams()