### Currently, this notebook is used for:
    - loading the altSfMix experiment data in python format
    - plotting responses

### Set up

In [3]:
import numpy as np
import matplotlib
matplotlib.use('TkAgg') # to avoid GUI/cluster issues...
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf as pltSave
import autoreload

import sys # so that we can import model_responses (in different folder)
sys.path.insert(0, './functions/') # 0 as in 'relative to this directory'
import model_responses

% matplotlib inline 

plt.style.use('https://raw.githubusercontent.com/paul-levy/SF_diversity/master/Analysis/Functions/paul_plt_cluster.mplstyle');

# at CNS
# dataPath = '/arc/2.2/p1/plevy/SF_diversity/sfDiv-OriModel/sfDiv-python/altExp/recordings/';
# savePath = '/arc/2.2/p1/plevy/SF_diversity/sfDiv-OriModel/sfDiv-python/altExp/analysis/';
# personal mac
dataPath = '/Users/paulgerald/work/sfDiversity/sfDiv-OriModel/sfDiv-python/altExp/analysis/structures/';
save_loc = '/Users/paulgerald/work/sfDiversity/sfDiv-OriModel/sfDiv-python/altExp/analysis/figures/';

conDig = 3; # round contrast to the 3rd digit

which_cell = 4; # index into the dataList

dataList = np.load(dataPath + 'dataList.npy').item();

cellStruct = np.load(dataPath + dataList['unitName'][which_cell] + '_sfm.npy').item();

In [4]:
autoreload.reload(model_responses)

<module 'model_responses' from './functions/model_responses.py'>

In [5]:
z = model_responses.GetNormResp(0)

Computing normalization response for m670l14 ...
Saving, it seems. In /Users/paulgerald/work/sfDiversity/sfDiv-OriModel/sfDiv-python/altExp/analysis/structures/


In [10]:
cellStruct = np.load(dataPath + dataList['unitName'][0] + '_sfm.npy').item();

### Organize data
#### determine contrasts, center spatial frequency, dispersions

In [None]:
data = cellStruct['sfm']['exp']['trial'];

all_cons = np.unique(np.round(data['total_con'], conDig));
all_cons = all_cons[~np.isnan(all_cons)];

all_sfs = np.unique(data['cent_sf']);
all_sfs = all_sfs[~np.isnan(all_sfs)];

all_disps = np.unique(data['num_comps']);
all_disps = all_disps[all_disps>0]; # ignore zero...

nCons = len(all_cons);
nSfs = len(all_sfs);
nDisps = len(all_disps);

con_diffs = np.diff(all_cons);
closest_cons = all_cons[con_diffs>0.01];

#### Put into proper structures

In [None]:
respMean = np.nan * np.empty((nDisps, nSfs, nCons));
respVar = np.nan * np.empty((nDisps, nSfs, nCons));

In [None]:
d = 0; con = 0; sf = 0;

In [None]:
respMean = np.nan * np.empty((nDisps, nSfs, nCons));
respVar = np.nan * np.empty((nDisps, nSfs, nCons));

val_con_by_disp = [];

for d in range(nDisps):
    val_con_by_disp.append([]);
    
    for con in range(nCons):
        for sf in range(nSfs):
            
            valid_disp = data['num_comps'] == all_disps[d];
            valid_sf = data['cent_sf'] == all_sfs[sf];
            valid_con = np.round(data['total_con'], conDig) == all_cons[con];
            
            valid_tr = valid_disp & valid_sf & valid_con;
                  
            if np.all(np.unique(valid_tr) == False):
                continue;
                
            respMean[d, sf, con] = np.mean(data['spikeCount'][valid_tr]);
            respVar[d, sf, con] = np.std((data['spikeCount'][valid_tr]));
        
        
        if np.any(~np.isnan(respMean[d, :, con])):
            if ~np.isnan(np.nanmean(respMean[d, :, con])):
                val_con_by_disp[d].append(con);

### Plots

#### Plots by dispersion

In [None]:
fDisp = []; dispAx = [];

for d in range(nDisps):
    
    v_cons = val_con_by_disp[d];
    n_v_cons = len(v_cons);
    
    fCurr, dispCurr = plt.subplots(n_v_cons, 1, figsize=(40, n_v_cons*10));
    fDisp.append(fCurr)
    dispAx.append(dispCurr);
    
    maxResp = np.max(np.max(respMean[d, ~np.isnan(respMean[d, :, :])]));
    
    for c in reversed(range(n_v_cons)):
        c_plt_ind = len(v_cons) - c - 1;
        v_sfs = ~np.isnan(respMean[d, :, v_cons[c]]);
        
        dispAx[d][c_plt_ind].errorbar(all_sfs[v_sfs], respMean[d, v_sfs, v_cons[c]], respVar[d, v_sfs, v_cons[c]]);
        dispAx[d][c_plt_ind].set_xlim((min(all_sfs), max(all_sfs)));
        dispAx[d][c_plt_ind].set_ylim((0, 1.2*maxResp));
        
        dispAx[d][c_plt_ind].set_xscale('log');
#         dispAx[d][c].set_yscale('log');
        dispAx[d][c_plt_ind].set_xlabel('sf (c/deg)'); 
        dispAx[d][c_plt_ind].set_ylabel('resp (sps)');
        dispAx[d][c_plt_ind].set_title('D%d: contrast: %.3f' % (d+1, all_cons[v_cons[c]]));


saveName = "/cell_%d.pdf" % (which_cell+1)
full_save = os.path.dirname(str(save_loc + 'byDisp/'));
if not os.path.exists(full_save):
    os.makedirs(full_save)
pdfSv = pltSave.PdfPages(full_save + saveName);
for f in fDisp:
    pdfSv.savefig(f)
pdfSv.close()


#### Plot just sfMix contrasts

In [None]:
# i.e. highest (up to) 4 contrasts for each dispersion

mixCons = 4;
maxResp = np.max(np.max(np.max(respMean[~np.isnan(respMean)])));

f, sfMixAx = plt.subplots(mixCons, nDisps, figsize=(40, 30));

for d in range(nDisps):
    v_cons = np.array(val_con_by_disp[d]);
    n_v_cons = len(v_cons);
    v_cons = v_cons[np.arange(np.maximum(0, n_v_cons -mixCons), n_v_cons)]; # max(1, .) for when there are fewer contrasts than 4
    n_v_cons = len(v_cons);
    
    for c in reversed(range(n_v_cons)):
        c_plt_ind = n_v_cons - c - 1;
        sfMixAx[c_plt_ind, d].set_title('con:' + str(np.round(all_cons[v_cons[c]], 2)))
        v_sfs = ~np.isnan(respMean[d, :, v_cons[c]]);
        
        sfMixAx[c_plt_ind, d].errorbar(all_sfs[v_sfs], respMean[d, v_sfs, v_cons[c]], respVar[d, v_sfs, v_cons[c]]);
        sfMixAx[c_plt_ind, d].set_xlim((np.min(all_sfs), np.max(all_sfs)));
        sfMixAx[c_plt_ind, d].set_ylim((0, 1.2*maxResp));
        sfMixAx[c_plt_ind, d].set_xscale('log');
        sfMixAx[c_plt_ind, d].set_xlabel('sf (c/deg)');
        sfMixAx[c_plt_ind, d].set_ylabel('resp (sps)');
        
saveName = "/cell_%d.pdf" % (which_cell+1)
full_save = os.path.dirname(str(save_loc + 'sfMixOnly/'));
if not os.path.exists(full_save):
    os.makedirs(full_save)
pdfSv = pltSave.PdfPages(full_save + saveName);
pdfSv.savefig(f) # only one figure here...
pdfSv.close()

#### Plot contrast response functions

In [None]:
respMean[3, :, -4]

In [None]:
respMean[3, v_sfs[0][1], :]

In [None]:
crfAx = []; fCRF = [];
for d in range(nDisps):
    
    # which sfs have at least one contrast presentation?
    v_sfs = np.where(np.sum(~np.isnan(respMean[d, :, :]), axis = 1) > 0);
    n_v_sfs = len(v_sfs[0])
    fCurr, crfCurr = plt.subplots(1, n_v_sfs, figsize=(n_v_sfs*15, 20), sharex = True, sharey = True);
    fCRF.append(fCurr)
    crfAx.append(crfCurr);
    
    for sf in range(n_v_sfs):
        sf_ind = v_sfs[0][sf];
        v_cons = ~np.isnan(respMean[d, sf_ind, :]);
        n_cons = sum(v_cons);
        
        # 0.1 minimum to keep plot axis range OK...should find alternative
        crfAx[d][sf].errorbar(all_cons[v_cons], np.maximum(np.reshape([respMean[d, sf_ind, v_cons]], (n_cons, )), 0.1),
                            np.reshape([respVar[d, sf_ind, v_cons]], (n_cons, )));
        crfAx[d][sf].set_xscale('log');
        crfAx[d][sf].set_yscale('log');
        crfAx[d][sf].set_xlabel('contrast');
        crfAx[d][sf].set_ylabel('resp (sps)');
        crfAx[d][sf].set_title('D%d: sf: %.3f' % (d+1, all_sfs[sf_ind]));

saveName = "/cell_%d.pdf" % (which_cell+1)
full_save = os.path.dirname(str(save_loc + 'CRF/'));
if not os.path.exists(full_save):
    os.makedirs(full_save)
pdfSv = pltSave.PdfPages(full_save + saveName);
for f in fCRF:
    pdfSv.savefig(f)
pdfSv.close()