In [None]:
import os
import params
from utils import *
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

## Cell 1: Load triggers times

In [None]:
exp = params.exp
rec_before='05_OptoStim1_20ND50%_1Hz.raw'             #CHANGE THE NAME OF THE RECORDINGS IF LAP4!!!
rec_after='08_OptoStim2_18betaG_20ND10%_1Hz.raw'

trig_data_before = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,rec_before))))
stim_onsets_before = trig_data_before['indices']/params.fs 

trig_data_after = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,rec_after))))
stim_onsets_after = trig_data_after['indices']/params.fs 

## Cell 2: Load spike data

In [None]:
output_directory=params.output_directory
spike_trains=load_obj(os.path.join(output_directory, r'{}_fullexp_neurons_data.pkl'.format(exp)))

cells=list(spike_trains.keys())

spike_times_before=[]
for cell in cells:
    spike_times_before.append(spike_trains[cell][rec_before])
spike_times_after=[]
for cell in cells:
    spike_times_after.append(spike_trains[cell][rec_after])

## Cell 3: Compute Raster and psths

In [None]:
analyse = {}

stimulus_frequency = 1 # Change here
bin_size = 0.050 #s
nb_triggers_by_repetition = 4
 
for recording in ['before', 'after']:
    if recording=='before':
        onsets=stim_onsets_before
        spikes=spike_times_before
    if recording=='after':
        onsets=stim_onsets_after
        spikes=spike_times_after

    
    nb_triggers = len(onsets)

    sequence_first_indices = list(range(0,nb_triggers,nb_triggers_by_repetition))
    duration_repetition = nb_triggers_by_repetition / stimulus_frequency
    nb_bins = int(duration_repetition / bin_size)
#     print(sequence_first_indices, duration_repetition, nb_bins)
    
    for cell_idx, cell_nb in tqdm(enumerate(cells)):
        
        if not cell_nb in analyse.keys(): analyse[cell_nb] = {}
        
        SU_sptimes = spikes[cell_idx]

        # Flashes: Get the repeated sequence times for the specified position
        nb_repetitions = int(nb_triggers/nb_triggers_by_repetition)   
        repeated_sequences_times = []
        for i in sequence_first_indices:
            times = onsets[i:i+nb_triggers_by_repetition+1]
            repeated_sequences_times += [[times[0], times[-1]]]

        # Build the spike trains corresponding to stimulus repetitions
        spike_trains = []
        for i in range(len(repeated_sequences_times)):
#             spike_train = restrict_array(spikes, repeated_sequences_times[i][0], repeated_sequences_times[i][1])
            spike_train = SU_sptimes[(SU_sptimes >= repeated_sequences_times[i][0]) & (SU_sptimes <= repeated_sequences_times[i][1])]
            spike_trains += [spike_train]

        # Align the spike trains
        for i in range(len(spike_trains)):
            spike_trains[i] = spike_trains[i] - repeated_sequences_times[i][0]

        # Compute psth
        binned_spikes = np.empty((nb_repetitions,nb_bins))
        for i in range(nb_repetitions):
            binned_spikes[i,:] = np.histogram(spike_trains[i], bins=nb_bins, range=(0,duration_repetition))[0]

        # Compute sum
        binned_spikes = np.sum(binned_spikes, axis=0) 

        # Transform spike count in firing rate
        binned_spikes = binned_spikes / nb_repetitions /bin_size

#         analyse[cell_nb][recording_name]["repeated_sequences_times"] = repeated_sequences_times
        analyse[cell_nb][recording]={"spike_trains": spike_trains, "psth": binned_spikes}

np.save(os.path.join(output_directory,'Opto_flash_data'), analyse)    #CHANGE THE NAME HERE FOR THE SAVE!!

## Cell 4: Plotting

In [None]:
fig_directory = os.path.normpath(os.path.join(output_directory,r'Opto_flash_figs'))
if not os.path.isdir(fig_directory): os.makedirs(fig_directory)

for cell in cells[:]:
    fig = plt.figure(figsize=(10,10))
    gs = GridSpec(11, 6, figure=fig)
    for i in range(4):
        if i==0:
            ax = fig.add_subplot(gs[0:2,:])
            ON=list(np.zeros(int(nb_bins/4))+1)
            OFF=list(np.zeros(int(nb_bins/4)))
            vec=[0]+ON+OFF+ON+OFF
            ax.plot(vec, '-', color='k', lw=3)
            ax.set_ylabel('Stimulus', fontsize=25)
            ax.set_xticks([])
            ax.set_xlim([0,nb_bins])
            ax.set_title('Cell {}'.format(cell), fontsize=28)
        if i==1:
            ax = fig.add_subplot(gs[2:5,:])
            ax.eventplot(analyse[cell]['before']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('Before', fontsize=25)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
        if i==2:
            ax = fig.add_subplot(gs[5:8,:])
            ax.eventplot(analyse[cell]['after']['spike_trains'])
            ax.set_xlim([0,4])
            ax.set_ylabel('After', fontsize=25)
            ax.set_xticks([])
            for j in range(0,4+1,1):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
        if i==3:
            ax = fig.add_subplot(gs[8:11,:])
            ax.plot(analyse[cell]['before']['psth'], label='Control')
            ax.plot(analyse[cell]['after']['psth'], label='18BG')
            ax.set_xlim([0,nb_bins])
            ax.set_xticks(range(0,nb_bins+1,int(nb_bins/8)),np.arange(0,4.1,0.5))
            ax.set_xlabel('Seconds')
            ax.set_ylabel('Psths', fontsize=25)
            ax.legend()
            for j in range(0,nb_bins+1,int(nb_bins/4)):
                ax.axvline(j,ymin=0, ymax=1, ls='--',color='grey')
    
    plt.subplots_adjust(wspace=0, hspace=0)
    fig_file = os.path.join(fig_directory,f'Cell_{cell}.png')
    plt.savefig(fig_file, dpi=fig.dpi)
    plt.close()