In [None]:
from matplotlib.pyplot import *
from numpy import *
import params
from utils import *
import itertools
from tqdm import tqdm

def compute_tuning(ch_raster,base_fire, seq_len, seq_sep, n_repeats=4):
    ###########################################################
    # computing tuning
    merged = list(itertools.chain(*ch_raster))   #all the spike times of all the 32 gratings. In this way when I bin I am
                                                 #binning per each of the 8 angles the responses to all the 4 repetitions of
                                                 #that angle

    nbins = 8*10*20                 #totoal nb of bins  (1600)
    binsize = seq_sep*8*1000//nbins #bin size in ms     (100)
    binsec = 1000//binsize          #nb bins per second  (10)
    base_fire=base_fire*(seq_sep*8/nbins)*n_repeats
    
    bins =  np.linspace(0,seq_sep*8,nbins+1)
    counts, bins = np.histogram(merged,bins=bins)   #binning the spike times of all the repetitions at once
    counts=counts-base_fire
    maxcount = np.amax(counts)

    #for plotting purposes, counts has 1600 bins, 10 each second of the 160 seconds. But some of this bins are fake because
    #the seq_sep (20 secs for the slow gratings) added in ch_raster is longer than the actual seq_len (12 secs for slow grating),
    #in which the stimulus was presented. So the last 8 secs after each angle have to have 80 empty.
    
        #--------------------------
    TuneSum = np.zeros(9)
    VxS=0
    VyS=0

    for a in np.arange(8):
        #################################################
        #per each angle I select the bins that go from 2 secs after the grating onset to the grating offset. Why?
        sel_bins = np.copy(counts[ int(seq_len*1000/6)//binsize + int(seq_sep*binsec*a): int(seq_len*binsec + seq_sep*binsec*a)])  
        #################################################

        TuneSum[a] = np.sum(sel_bins)    #per each angle these are all the spikes that the cell fired during the 4 repetitions
                                         #of that angle from 2 to 12 seconds
        #print(TuneSum[a])
        VxS+= np.cos(np.pi*a*45/180)*TuneSum[a]
        VyS+= np.sin(np.pi*a*45/180)*TuneSum[a]
#             VxM+= np.cos(np.pi*a/180)*TuneMax[a]
#             VyM+= np.sin(np.pi*a/180)*TuneMax[a]
        if a==0:
            TuneSum[a+8] = np.sum(sel_bins)

############################    
    if sum(TuneSum)==0:
        DG_data = ({'IDX':0,'Tuning':TuneSum,'atune':0,'Rtune':0})
        return np.zeros(9), 0, 0, 0, counts, maxcount, bins, DG_data
############################
    VxS=VxS/np.amax(TuneSum) 
    VyS=VyS/np.amax(TuneSum) 

    TuneSum=TuneSum/np.amax(TuneSum)
    atune = np.arctan2(VyS,VxS)
    R = np.sqrt(VyS**2+VxS**2)
   
    angle = int(np.round(atune/np.pi*4 ))
        
    IDX = (TuneSum[:-1][angle]-TuneSum[:-1][int((angle+4)%8)])/(TuneSum[:-1][angle]+ TuneSum[:-1][int((angle+4)%8)])
    if IDX<-0.2:
        angle2=angle+1
        IDX = (TuneSum[:-1][angle2]-TuneSum[:-1][int((angle2+4)%8)])/(TuneSum[:-1][angle2]+ TuneSum[:-1][int((angle2+4)%8)])
        angle=angle2
    if IDX<-0.2:
        angle2=angle-2
        IDX = (TuneSum[:-1][angle2]-TuneSum[:-1][int((angle2+4)%8)])/(TuneSum[:-1][angle2]+ TuneSum[:-1][int((angle2+4)%8)])
        if IDX<-0.2: angle=angle+1
        IDX = (TuneSum[:-1][angle]-TuneSum[:-1][int((angle+4)%8)])/(TuneSum[:-1][angle]+ TuneSum[:-1][int((angle+4)%8)])
    
    DG_data = ({'IDX':IDX,'Tuning':TuneSum,'atune':atune,'Rtune':R, 'rasters': ch_raster})

    ###########################################################
    return TuneSum, atune, R, IDX, counts, maxcount, bins, DG_data


## Cell 1: Load triggers times and load spike data in a list of arrays

In [3]:
exp = params.exp

#select DG recording
recording_names = params.recording_names

print(*['{} : {}'.format(i,recording_name) for i, recording_name in enumerate(recording_names)], sep="\n")
rec = recording_names[int(input("\nSelect DG recording : "))]

#Load triggers
trig_data = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,rec))))
#convert them in seconds
stim_onsets = trig_data['indices']/params.fs  


output_directory=params.output_directory
spike_trains=load_obj(os.path.join(output_directory, r'{}_fullexp_neurons_data.pkl'.format(exp)))

#Load spikes of the selected recording
cells=list(spike_trains.keys())
spike_times=[]
for cell in cells:
    spike_times.append(spike_trains[cell][rec])

0 : 00_Checkerboard_30ND50%_20pix30checks_30Hz
1 : 01_Checkerboard_30ND50%_16pix40checks_30Hz
2 : 02_DG_30ND50%_2sT_50Hz
3 : 03_Chirp_20reps_30ND50%_50Hz
4 : 04_Flicker_BeforeDrugs_30ND50%_1Hz
5 : 05_VDH_Synchro+MultiSpots(bright)_N8_Z(-35)_30ND50%_40Hz
6 : 06_VDH_Synchro_N10_Z(-35)_30ND50%_40Hz
7 : 07_Flicker_LAP4+ACET_t+10_30ND50%_1Hz
8 : 08_HoloStim1_LAP4+ACET_N8_Z(-35)
9 : 09_HoloStim1_LAP4+ACET_N15_Z(-30)
10 : 10_OptoStim1_LAP4+ACET_15ND50%_1Hz
11 : 11_OptoStim1_LAP4+ACET_5ND50%_1Hz
12 : 12_HoloStim2_GRF_t30_N15_Z(-30)
13 : 13_OptoStim2_GRF_t35_15ND50%_1Hz
14 : 14_OptoStim2_GRF_t40_5ND50%_1Hz
15 : 15_HoloStim3_SR95531_t30_N15_Z(-30)
16 : 16_OptoStim3_SR95531_t35_15ND50%_1Hz
17 : 17_OptoStim3_SR95531_t40_5ND50%_1Hz
18 : 18_HoloStim3_18BG_t5_N15_Z(-30)
19 : 19_OptoStim3_18BG_t10_15ND50%_1Hz
20 : 20_OptoStim3_18BG_t15_5ND50%_1Hz
21 : 21_HoloStim3_18BG_t20_N15_Z(-30)

Select DG recording : 2


## Cell2: Make DG rasters

intitialization of the random test. Uncomment only if you want to perform significance test

In [None]:
# # Make Random sequences with no repeats on the real ones for signficance assessing
# DG_seq = [0, 1, 2, 3, 4, 5, 6, 7, 4, 1, 5, 2, 0, 3, 7, 6, 1, 4, 0, 3, 2, 5, 6, 7, 5, 2, 3, 6, 1, 4, 7, 0]
# count=0
# seq_base = list(np.arange(8))*4
# Ran_seq_test=[]
# for i in np.arange(100000):
#     a= random.shuffle(seq_base)
#     if np.sum(np.array(seq_base)==np.array(DG_seq))<1:
#         Ran_seq_test.append(np.copy(seq_base))
#         count+=1
#         print(count)

stimulus is composed of 8 moving gratings each moving at a different angle and it is displayed at 50Hz. These gratings can be displayed at different speeds and the time each single grating stays on the screen depends on the speed. If the grating is slow, it sweeps for 12 seconds. Each of the 8 gratings is repeated 4 times

In [None]:
fig_directory = os.path.normpath(os.path.join(output_directory,r'DG_figs'))
if not os.path.isdir(fig_directory): os.makedirs(fig_directory)

#select grating's speed
T = 2  # T=0 = fast, T=1 medium,  T=2 slow

################
if T==0: 
    seq_len=3.96
    seq_sep=9
    trigsinrep = int(50 * seq_len)
    ttext = 'FAST'
if T==1: 
    seq_len=6
    seq_sep=10
    trigsinrep = 50 * seq_len
    ttext = 'MEDIUM'
if T==2: 
    seq_len=12
    seq_sep=20
    trigsinrep = 50 * seq_len
    ttext = 'SLOW'
    
################

save_data = True
#do_test = False

PLOT = True     #whether to plot or not a figure for cell 
SKIP = False    #whether to run the code for all the cells or a single one
clus = 118      # if SKIP: for which cell the code should be run

#-------------------------------
# Testing
#N_shuffles = 1000
#seq_base = list(np.arange(8))*4

#-------------------------------

# DG angle sequence order
DG_seq = [0, 1, 2, 3, 4, 5, 6, 7, 4, 1, 5, 2, 0, 3, 7, 6, 1, 4, 0, 3, 2, 5, 6, 7, 5, 2, 3, 6, 1, 4, 7, 0]
DG_seq = (np.ones(32)*7-DG_seq).astype('int')  #(angles go counterclockwise in the stim)
    
n_angles=8
n_repeats=4

Test_set = {}
Tune_data = {}
DG_set={}
Nspikes_set = {}  #a dict that per each cells has the tot nb of spikes that the total stimulus evoked

if SKIP:
    i0=np.where(clus==np.array(cells))[0][0]
    iz=i0+1
else:
    i0=0
    iz=len(cells)
#-------------------------------
for i in tqdm(np.arange(i0,iz)):
    
    clus=cells[i]
    dg_sptimes = spike_times[i]
#################################################
    base_fire = 0                                     #what is this?? How to calculate it??
#################################################

    #--------------------
    # Get start times and make rasters
    nb_rep = len(stim_onsets)//trigsinrep    # nb_angles*n_repeats   (32)
    dg_rep_starts = []
    for n in np.arange(nb_rep):
        dg_rep_starts.append(stim_onsets[trigsinrep*n])  #the times at which each of the 32 gratings starts sweeping
    dg_count=np.zeros([n_angles],dtype='int')
    ch_raster=[]
    for rep in np.arange(4):
        ch_raster.append([])
    for n in np.arange(8,nb_rep):  #number between 8 and 32, why excluding first 8 gratings? (1 of the 4 repetitions)
        
        #given a grating, rep_sptimes are the times of the spikes it evoked
        if n == nb_rep-1:
            rep_sptimes = dg_sptimes[(dg_rep_starts[n]<dg_sptimes)&(dg_sptimes<dg_rep_starts[n]+seq_len)]
        else:
            rep_sptimes = dg_sptimes[(dg_rep_starts[n]<dg_sptimes)&(dg_sptimes<dg_rep_starts[n+1])]
            
        ch_raster[dg_count[DG_seq[n]]] = np.append(ch_raster[dg_count[DG_seq[n]]],rep_sptimes-dg_rep_starts[n]+DG_seq[n]*seq_sep)
        dg_count[DG_seq[n]]+=1
        #ch_raster is a list of 4 lists, one per repetition. Each list containes the times at which each orientation angle
        #evoked spikes. The single grating spike trains are artificially spaced by 20 seconds (seq_sep) for plotting purposes
    
    if not(list(itertools.chain(*ch_raster))):continue  # checking that ch_raster is not empty
    #--------------------------------------------------------------                    
    #--------------------------------------------------------------   
    
    Nspikes = len(ch_raster[0])+len(ch_raster[1])+len(ch_raster[2])+len(ch_raster[3])
    Nspikes_set.update({clus:Nspikes})

    TuneSum, atune,R,IDX, counts, maxcount, bins , DG_data =compute_tuning(ch_raster,base_fire, seq_len, seq_sep)
    Tune_data.update({clus:[TuneSum,R,atune]})
    DG_set.update({clus:DG_data})
#     #--------------------------------------------------------------                    
#     #--------------------------------------------------------------
#     #-----------------------
#     # TEST SHUFFLES
#     #--------------------
#     shift=0
#     if do_test:
#         Rt = np.zeros(N_shuffles)
#         It = np.zeros(N_shuffles)
#         for sh in np.arange(N_shuffles):
#             dg_count=np.zeros([n_angles],dtype='int')
#             # Get start times and make rasters
#             ch_raster_test=[]
#             for rep in np.arange(4):
#                 ch_raster_test.append([])
#             for n in np.arange(nb_rep):
#                 if n == nb_rep-1:
#                     rep_sptimes = dg_sptimes[(dg_rep_starts[n]<dg_sptimes)&(dg_sptimes<dg_rep_starts[n]+seq_len)]
#                 else:
#                     rep_sptimes = dg_sptimes[(dg_rep_starts[n]<dg_sptimes)&(dg_sptimes<dg_rep_starts[n+1])]
#                 ch_raster_test[dg_count[Ran_seq_test[sh+shift][n]]] = np.append(ch_raster_test[dg_count[Ran_seq_test[sh+shift][n]]],rep_sptimes-dg_rep_starts[n]+Ran_seq_test[sh+shift][n]*seq_sep)
#                 dg_count[Ran_seq_test[sh+shift][n]]+=1

#             #--------------------------------------------------------------     
#             Tt,at,Rt[sh],It[sh],ct,mt,bt,DG_st,DG_dt = compute_tuning(ch_raster_test,base_fire,save_DG,Nspikes)
#             #--------------------------------------------------------------                    
#         test_set= dict({'Rt':Rt,'R':R,'It':It,'I':IDX})
#         Test_set.update({str(clus):test_set})
#     #------------------------
    # save rasters
#     if save_rasters:
#         saveF = rootF + r'\pckls_dg_times\c{}_dg_times'.format(clus)
#         np.save(saveF, ch_raster)
#     #------------------------
    #PLOT
    if PLOT:
        #if Nspikes<10: continue
        
        #--------plot the rasters-------------------
        fig = plt.figure(figsize=(12, 8))
        gs = fig.add_gridspec(5, 8,
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.3, hspace=0.7)
        
        ax = fig.add_subplot(gs[0:2, 0:8])
        ax.eventplot(ch_raster[:],color='k',lw=1,linelengths=0.95)
        for a in np.arange(8):
            ax.axvline(a*seq_sep,color='gray',lw=2)
            ax.axvline(a*seq_sep+seq_len,color='gray',lw=2)
            ax.axvline(a*seq_sep+seq_len/6,color='gray',ls='--',lw=1.5)
            
            
        ax.set_xlim([-seq_sep/2,seq_sep*8])
        ax.set_ylim([-0.5,3.5+2+4+2])
        ax.set_yticks(np.arange(4))
        ax.set_ylabel('Repetition               Counts       ',size=10)
        ax.set_xlabel('Time (s) {8 angles}',size=10)
        #fig.suptitle(ttext+'    cluster '+str(clus) + '      '+'% spikes: ' +str(round(len(dg_sptimes)/len(sp_times)*100,1))+'    Nspikes '+str(Nspikes))
        plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
        ax.text(5,12,'0                      45                      90                    135                    180                   225                    270                   315')
        ax.axhline(3.5+2,color='k',lw=0.5)  # base_firing
        
        #--------------------------plot the histograms------------------------
        counts= counts/maxcount*4+3.5+2
        ax.hist(bins[:-1], bins,histtype='step',lw=1.5,color='darkblue',weights = counts)
                
        #--------------------------plot the polar plot left--------------
        
        ax = fig.add_subplot(gs[2:5, 1:4],polar=True)
         
        theta = np.linspace(0, 2 * np.pi, 9)
        # Arrange the grid into number of sales equal parts in degrees
        lines, labels = plt.thetagrids(range(0, 360, int(360/8)),np.arange(0,360,45))

        # Plot actual sales graph
        ax.plot(theta, TuneSum)
        ax.fill(theta, TuneSum, 'b', alpha=0.1)
#         ax.plot(theta, TuneMax,'orange')
    
        ax.plot([atune,atune],[0,R],'b-')
        ax.plot([atune],[R],'bo')
        
        ax.set_yticks([0,0.25,0.5,0.75,1])
        ax.set_yticklabels([])
        ax.set_ylim([0,1])
        
        ax.text(np.pi*1/5,1.3,'IDX = '+str(np.round(IDX,1)),size=18)
        ax.text(np.pi*1/8,1.25,'R = '+str(np.round(R,1)),size=18)
        
        #---------------------------plot the polar plot right (same as left but not limited between 0 and 1)------
        ax = fig.add_subplot(gs[2:5, 5:8],polar=True)
        
        ax.plot([atune,atune],[0,R],'b-')
        ax.plot([atune],[R],'bo')
        ax.plot(theta, TuneSum)
        ax.fill(theta, TuneSum, 'b', alpha=0.1)
        
        ax.set_yticks([0,0.5,1,1.5,2])
        ax.set_yticklabels([0,'',1,'',2])
                
        #-----------------------------------------------------------------------------------
        fsave = os.path.join(fig_directory, 'DG_resp_exp{}_c{}'.format(exp, clus)) 

        fig.savefig(fsave+'.png',format='png',dpi=90)
        close(fig)
        #--------------------------------------------------------------
        #--------------------------------------------------------------

# save data
if save_data:
    savef = os.path.join(output_directory, 'DG_data_exp{}'.format(exp)) 
    save_obj(DG_set,savef)

print('Done!')

In [None]:
DG_set