In [None]:
%pylab inline
import os
import sys
import numpy as np
import importlib

p = !pwd
p = os.path.dirname(p[0])
if p not in sys.path:
    sys.path.append(p)
    
from cnn_sys_ident.data import Dataset, MonkeyDataset
from cnn_sys_ident.cnnsysid import ConvNet

### Load Data

To train with all data, use Dataset.get_clean_data(). To use specific types of images for training and testing, run the commented code and secify in Dataset.add_train_test_types() the types for training and/or testing as a list. For example: Dataset.add_train_test_types(data_dict, types_train=['conv1','conv2'], types_test=['conv4']) for training with types conv1 and conv2 and testing on conv4

In [None]:
# data_dict = Dataset.load_data()
# data_dict = Dataset.manage_repeats(data_dict)
# data_dict = Dataset.preprocess_nans(data_dict)
# data_dict = Dataset.add_train_test_types(data_dict, types_train='all', types_test='all')

# With a wrapper function
data_dict = Dataset.get_clean_data()

In [None]:
data = MonkeyDataset(data_dict, seed=1000, train_frac=0.8 ,subsample=2, crop = 30)

### Define the Model

In [None]:
model = ConvNet(data, log_dir='monkey', log_hash='cnn', obs_noise_model='poisson')

In [None]:
print('Log dir: %s' % model.log_hash)
_, test_responses = data.test_av()
_, val_responses, real_val_resps = data.val()
_, tr_responses, real_tr_resps = data.train()

val_array = data.nanarray(real_val_resps,val_responses)
tr_array = data.nanarray(real_tr_resps,tr_responses)
print('Average variances | validation set: %f | test set: %f' % (np.nanmean(np.nanvar(val_array, axis=0)), np.nanmean(np.nanvar(test_responses, axis=0))))

### Build the Model

Let's create a 3 layer CNN with filters of sizes 13x13x32, 3x3x32, and 3x3x32.
Each convolutional layer will have stride 1, and will be padded according to 'valid', 'same', and 'same'
Additionally, we will impose a smoothness 2d penalty of 3e-4 in the filters of the first layer and a 2.5e-4 L1 penalty on the 2nd and 3rd layer. 
The readout will have a sparsity L1 penalty of 2e-4 and finally, the output nonlinearity will have a smoothing penalty of 0 in this case (check paper for details)

In [None]:
model.build(filter_sizes=[13, 3, 3],
          out_channels=[32, 32, 32],
          strides=[1, 1, 1],
          paddings=['VALID', 'SAME', 'SAME'],
          smooth_weights=[0.0003, 0, 0],
          sparse_weights=[0.0, 0.00025, 0.00025],
          readout_sparse_weight= 0.0002,
          output_nonlin_smooth_weight = 0)

### Train the model

In [None]:
learning_rate=0.001
for lr_decay in range(3):
    training = model.train(max_iter=10000, val_steps=100, save_steps=10000, early_stopping_steps=10, batch_size=256, learning_rate=learning_rate)
    for (i, (logl, readout_sparse, conv_sparse, smooth, total_loss, pred)) in training:
        result = model.eval()
        print('Step %d | Loss: %s | %s: %s | L1 readout: %s | L1 conv: %s | L2 conv: %s | Var(test): %.4f | Var(val): %.4f' % \
              (i, total_loss, model.obs_noise_model, logl, readout_sparse, conv_sparse, smooth, np.mean(np.var(pred, axis=0)), np.mean(np.var(result[-1], axis=0))))
           
    learning_rate /= 3
    print('Reducing learning rate to %f' % learning_rate)
print('Done fitting')

### Test Performance of the Model

In [None]:
model.performance_test()
eve = model.eve.mean()
print('Explainable variance explained on test set: {}'.format(eve))

In [None]:
model.performance_val()
eve_val = model.eve_val.mean()
print('Explainable variance explained on validation set: {}'.format(eve_val))

In [None]:
avg_correlation_valset = model.evaluate_avg_corr_val()
print('Mean single trial correlation on validation set: {}'.format(avg_correlation_valset))