In [1]:
import os
import params
from utils import *
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

## Cell 1: Load triggers times

In [3]:
exp = params.exp
# rec_before='04_Flicker_BeforeDrugs_35ND10%_1Hz'               #  /!\ CHANGE HERE /!\
rec_before= '07_OptoStim1_25ND10%_1Hz'

#put in this list all the recordings taken after the drug that you want to analyze ['rec1.raw', 'rec2.raw']
recs_after=['16_OptoStim2_TPMPA_15ND50%_1Hz',               #  /!\ CHANGE HERE MAX -- 3 rec_after /!\
            '20_OptoStim3_SR93351_15ND50%_1Hz',
#             '21_OptoStim3_SR93351_5ND50%_1Hz'
           ]

trig_data_before = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,rec_before))))
stim_onsets_before = trig_data_before['indices']/params.fs 

nb_recs_after=len(recs_after)

trig_data_after = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,recs_after[0]))))
stim_onsets_after = [trig_data_after['indices']/params.fs ]

if nb_recs_after>1:
    trig_data_after_2 = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,recs_after[1]))))
    stim_onsets_after.append(trig_data_after_2['indices']/params.fs) 
if nb_recs_after>2:
    trig_data_after_3 = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,recs_after[2]))))
    stim_onsets_after.append(trig_data_after_3['indices']/params.fs) 


## Cell 2: Load spike data

In [4]:
output_directory=params.output_directory
spike_trains=load_obj(os.path.join(output_directory, r'{}_fullexp_neurons_data.pkl'.format(exp)))

cells=list(spike_trains.keys())

spike_times_before=[]
for cell in cells:
    spike_times_before.append(spike_trains[cell][rec_before])
    
spike_times_after=[[],[],[]]
for cell in cells:
    spike_times_after[0].append(spike_trains[cell][recs_after[0]])
if nb_recs_after>1:
    for cell in cells:
        spike_times_after[1].append(spike_trains[cell][recs_after[1]])    
if nb_recs_after>2:
    for cell in cells:
        spike_times_after[2].append(spike_trains[cell][recs_after[2]]) 

## Cell 3: Compute Raster and psths

In [7]:
################
#                /!\ CHANGE BELOW /!\
#                Change frequency, bin_size (if needed)
#                Change name of directory below
#                /!\ 3 changes in total /!\
################

analyse = {}

stimulus_frequency = 1 # Change here
bin_size = 0.050 #s
nb_triggers_by_repetition = 4

if nb_recs_after==1: rec_labels=['before', 'after']
if nb_recs_after==2: rec_labels=['before', 'after', 'after2']
if nb_recs_after==3: rec_labels=['before', 'after', 'after2', 'after3']

for recording in rec_labels:
        
    if recording=='before':
        onsets=stim_onsets_before
        spikes=spike_times_before
    if recording=='after':
        onsets=stim_onsets_after[0]
        spikes=spike_times_after[0]
    if recording=='after2':
        onsets=stim_onsets_after[1]
        spikes=spike_times_after[1]
    if recording=='after3':
        onsets=stim_onsets_after[2]
        spikes=spike_times_after[2]

    
    nb_triggers = len(onsets)

    sequence_first_indices = list(range(0,nb_triggers,nb_triggers_by_repetition))
    duration_repetition = nb_triggers_by_repetition / stimulus_frequency
    nb_bins = int(duration_repetition / bin_size)
#     print(sequence_first_indices, duration_repetition, nb_bins)
    
    for cell_idx, cell_nb in tqdm(enumerate(cells)):
        
        if not cell_nb in analyse.keys(): analyse[cell_nb] = {}
        
        SU_sptimes = spikes[cell_idx]

        # Flashes: Get the repeated sequence times for the specified position
        nb_repetitions = int(nb_triggers/nb_triggers_by_repetition)   
        repeated_sequences_times = []
        for i in sequence_first_indices:
            times = onsets[i:i+nb_triggers_by_repetition+1]
            repeated_sequences_times += [[times[0], times[-1]]]

        # Build the spike trains corresponding to stimulus repetitions
        spike_trains = []
        for i in range(len(repeated_sequences_times)):
#             spike_train = restrict_array(spikes, repeated_sequences_times[i][0], repeated_sequences_times[i][1])
            spike_train = SU_sptimes[(SU_sptimes >= repeated_sequences_times[i][0]) & (SU_sptimes <= repeated_sequences_times[i][1])]
            spike_trains += [spike_train]

        # Align the spike trains
        for i in range(len(spike_trains)):
            spike_trains[i] = spike_trains[i] - repeated_sequences_times[i][0]

        # Compute psth
        binned_spikes = np.empty((nb_repetitions,nb_bins))
        for i in range(nb_repetitions):
            binned_spikes[i,:] = np.histogram(spike_trains[i], bins=nb_bins, range=(0,duration_repetition))[0]

        # Compute sum
        binned_spikes = np.sum(binned_spikes, axis=0) 

        # Transform spike count in firing rate
        binned_spikes = binned_spikes / nb_repetitions /bin_size

#         analyse[cell_nb][recording_name]["repeated_sequences_times"] = repeated_sequences_times
        analyse[cell_nb][recording]={"spike_trains": spike_trains, "psth": binned_spikes}

np.save(os.path.join(output_directory,'OptoStim_control+TPMPA+SR_10^5R'), analyse)        #  /!\ CHANGE HERE /!\

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

## Cell 4: Plotting

In [8]:
################
#                /!\ CHANGE BELOW /!\
#                Change frequency, bin_size (if needed)
#                Change labels
################

fig_directory = os.path.normpath(os.path.join(output_directory,r'OptoStim_control+TPMPA+SR_10^5R_figs'))        #  /!\ CHANGE HERE /!\
if not os.path.isdir(fig_directory): os.makedirs(fig_directory)

for cell in tqdm(cells[:]):
    fig = plt.figure(figsize=(10,10))
    gs = GridSpec(2+3*(nb_recs_after+1)+3, 6, figure=fig)
    for i in range(nb_recs_after+1+2):
        if i==0:
            #draw the stimulus
            ax = fig.add_subplot(gs[0:2,:])
            ON=list(np.zeros(int(nb_bins/4))+1)
            OFF=list(np.zeros(int(nb_bins/4)))
            vec=[0]+ON+OFF+ON+OFF
            ax.plot(vec, '-', color='k', lw=3)
            ax.set_ylabel('Stimulus', fontsize=15)
            ax.set_xticks([])
            ax.set_xlim([0,nb_bins])
            ax.set_title('Cell {}'.format(cell), fontsize=28)
            
        if i==1:
            #raster before
            ax = fig.add_subplot(gs[2:5,:])
            ax.eventplot(analyse[cell]['before']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('Before', fontsize=25)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
                
        if i==2:
            #raster after
            ax = fig.add_subplot(gs[5:8,:])
            ax.eventplot(analyse[cell]['after']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('After', fontsize=20)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
                
        if nb_recs_after>1 and i==3:
            #raster after 2
            ax = fig.add_subplot(gs[8:11,:])
            ax.eventplot(analyse[cell]['after2']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('After 2', fontsize=20)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
                
        if nb_recs_after>2 and i==4:
            #raster after 3
            ax = fig.add_subplot(gs[11:14,:])
            ax.eventplot(analyse[cell]['after3']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('After 3', fontsize=20)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
                
        if i==2+nb_recs_after:
            ax = fig.add_subplot(gs[2+3+nb_recs_after*3:2+3+nb_recs_after*3+3,:])
            ax.plot(analyse[cell]['before']['psth'], label='control_10^5R*')         # 'Control'
            ax.plot(analyse[cell]['after']['psth'], label='TPMPA_10^5R*')       #  'LAP4+ACET_t10' # OptoStim1_10^4R*
            if nb_recs_after>1: ax.plot(analyse[cell]['after2']['psth'], label='SR93351_10^5R*')  
            if nb_recs_after>2: ax.plot(analyse[cell]['after3']['psth'], label='')  
            
            ax.set_xlim([0,nb_bins])
            ax.set_xticks (range(0,nb_bins+1,int(nb_bins/8)))
            ax.set_xticklabels (np.arange(0,4.1,0.5))
            ax.set_xlabel('Seconds')
            ax.set_ylabel('Psths', fontsize=25)
            ax.legend()
            for j in range(0,nb_bins+1,int(nb_bins/4)):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
    
    plt.subplots_adjust(wspace=0, hspace=0)
    fig_file = os.path.join(fig_directory,f'Cell_{cell}.png')
    plt.savefig(fig_file, dpi=fig.dpi)
    plt.close()

  0%|          | 0/189 [00:00<?, ?it/s]

## Cell 5: Plotting PSTH

In [7]:
################
#                /!\ CHANGE BELOW /!\
#                Change name of directory (x2)
#                Change name of each ficgs
#                Change labels
################

data=np.load(os.path.join(output_directory,'OptoStim_control+TPMPA+SR_10^5R.npy'), allow_pickle=True).item()

fig_directory = os.path.normpath(os.path.join(output_directory,r'OptoStim_control+TPMPA+SR_10^5R_PSTH'))        #  /!\ CHANGE HERE /!\
if not os.path.isdir(fig_directory): os.makedirs(fig_directory)

for cell in cells:
    before=data[cell]['before']['psth']
    after2=data[cell]['after2']['psth']
    plt.figure()
    plt.plot(before,label="control_10^5R*")
    plt.plot(after2,label="SR93351_10^5R*")
    plt.legend(loc='upper right')
    plt.savefig(os.path.join(fig_directory,f'Psth_c{cell}_OptoStim_control+SR_10^5R*'))
    plt.close()