In [None]:
import numpy as np
import pandas as pd
import os
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import roc_auc_score
import matplotlib as mpl
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import functools
import math as math
import seaborn as sns
from tqdm import tqdm
from scipy import stats as scpstats
import random
from scipy.signal import resample
from shapely.geometry import LineString
from matplotlib.ticker import FixedLocator, FixedFormatter,AutoMinorLocator
from matplotlib.patches import Rectangle
from scipy.interpolate import interp1d
import scipy.integrate as integrate
from sklearn.metrics import r2_score
from scipy.optimize import curve_fit
from shapely.geometry import LineString
NTIMES=101
NPOSES=202

winlen_post, winlen_pre = 0, 0.2
def Convert(string):
    strs=string.replace('[','').replace(']','').replace('\n','').replace('  ',' ')
    li = list(strs.split(" "))
    res = [eval(i) for i in li]
    return res

def extract_flow_velo(V,r):
    NTIMES=101
    string= "R"+ str(r)+"_v"+str(V)
    T=Vitesses_dico[string] #Rechercher clé dans le dico
    X0=np.linspace(0,6/V,NTIMES) #axs[j].plot(np.linspace(0,6/V,NTIMES),T[:,NPOSES//2],label=str(V))
    X=X0-4/V #mise à zéro
    idx=np.min(np.argwhere(X > 0))
    return(X[:idx+1],T.loc[:idx] )

def df_to_numpy(df):
    df=df.astype(float)
    df=df.T
    df=df.to_numpy()
    return(df)

def FiringRateBis(spiketrain, inf, sup):
    xx = np.arange(inf, sup, 0.001)
    yy = np.zeros_like(xx)
    for spike in spiketrain:
        yy += scpstats.norm.pdf(xx, loc=spike, scale=0.05) #bandwidth de 50 ms
    return(xx,yy)


def extract_clusters(df_final,clus):
    T=pd.DataFrame()
    for file in clusters.loc[clusters['ID']==clus,'file']:
        T=pd.concat([T, df_final.loc[df_final['file']==file]], ignore_index=True)
    return(T)

def jitter_clust(CLUST):
    jitters=[-0.6,0,0.6,0.6]
    CLUST2=CLUST.copy()
    for idx,r in enumerate([7,15,30]):
        CLUST2.loc[CLUST2["R"]==r, "v"] = CLUST2["v"].apply(lambda x: x + jitters[idx])
    return(CLUST2)
    
durat=[0.1,0.2,0.4,0.8]


def format_axes(ax,xlabel,ylabel, reffig, loclegend, pos='right',color='black'):
    ax.text(reffig[0], reffig[1], reffig[2], fontsize=reffig[3], fontweight="bold", transform=ax.transAxes)
    ax.spines[pos].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlabel(xlabel, fontsize=8)
    ax.set_ylabel(ylabel, fontsize=8, color=color)
    ax.tick_params(axis='x', labelsize=8)
    ax.tick_params(axis='y', labelsize=8)


def compare_clusters_lineplot(metric, clusters,Xval, Hval, r, ax,color ,selectR=True, regplot=True, alpha=0.1):
    IZ_colors=['tab:blue','tab:orange','tab:green']
    if selectR:
        clust=clusters.loc[clusters[Hval]==r]
    else:
        clust= clusters
    if regplot:    
        sns.regplot(x=Xval, y=metric, data=clust,ax=ax, color=color, scatter_kws={'s':0}) #,ci=None
          
        sns.scatterplot(x=Xval, y=metric, data=jitter_clust(clust),ax=ax, marker='o',  s=4)  #lienplot hue="file",alpha=alpha, 
    

    ax.legend([],[], frameon=False)
    try:
        slope,intercept = np.polyfit(clust[Xval],clust[metric],1)
        
    except:
        try:
            clust=clust.dropna()
            slope,intercept = np.polyfit(clust[Xval],clust[metric],1)
        except:
            slope,intercept =0,0
    return(slope,intercept)


def plot_raws(stimulus, raw , axs, color):
    R=raw['401 Memory']
    colors=[color, color, color, color]
    piston = df_to_numpy(stimulus['3 IN 4     '])
    puff= df_to_numpy(stimulus['2 IN 3     '])     
    pistontime = df_to_numpy(stimulus['Time'])
    first_line = LineString(np.column_stack((pistontime, piston)))
    second_line = LineString(np.column_stack((pistontime, np.repeat(25,len(pistontime)))))
    intersection = first_line.intersection(second_line)

    L=[]
    for k in list(intersection.geoms):
        l=k.wkt
        l=l[7:]
        l=l[:-4]
        if float(l) >1:
            L+=[float(l)]
    fullL=sorted(L)
    durat=[]
    D=[]
    for i in range(len(fullL)-4):
        if i%2==0:
            duration=(fullL[i+1]-fullL[i])
            dur=np.trunc(duration*10)/10
            D+=[dur]
    order=D
    comp=[0.1,0.2,0.4,0.8]
    j=0
    stim=[int(np.log2(k*10)) for k in order]
    for i in range(len(fullL)-4):
        if i%2==0:
            ID=stim[j]
            borneinf=fullL[i]-winlen_pre
            bornesup=fullL[i+1]+winlen_post
            Rcopy=R[int(borneinf*10000):int(bornesup*10000)]
            Rcopy=(Rcopy+abs(min(Rcopy)))
            if abs(max(Rcopy))>5:
                Rcopy=Rcopy/abs(max(Rcopy)) -0.05
                Rlinsp=np.linspace(-winlen_pre-comp[stim[j]],winlen_post,len(Rcopy))#np.linspace(0,bornesup-borneinf, len(Rcopy))
                axs[ID].plot(Rlinsp,Rcopy,color=color,linewidth=0.5,alpha=0.5)
            j+=1


def plot_spikes(stimulus, SPK , axs, color):
    colors=[color, color, color, color]
    piston = df_to_numpy(stimulus['3 IN 4     '])
    puff= df_to_numpy(stimulus['2 IN 3     '])     
    pistontime = df_to_numpy(stimulus['Time'])
    first_line = LineString(np.column_stack((pistontime, piston)))
    second_line = LineString(np.column_stack((pistontime, np.repeat(25,len(pistontime)))))
    intersection = first_line.intersection(second_line)
    L=[]
    for k in list(intersection.geoms):
        l=k.wkt
        l=l[7:]
        l=l[:-4]
        if float(l) >1:
            L+=[float(l)]
    fullL=sorted(L)
    durat=[]
    D=[]
    for i in range(len(fullL)-4):
        if i%2==0:
            duration=(fullL[i+1]-fullL[i])
            dur=np.trunc(duration*10)/10
            D+=[dur]
    order=D
    j=0
    stim=[int(np.log2(k*10)) for k in order]
    for i in range(len(fullL)-4):
        if i%2==0:
            ID=stim[j]
            SPKcopy=SPK[SPK<fullL[i+1]+winlen_post]
            SPKcopy=SPKcopy[SPKcopy>fullL[i]-winlen_pre]-fullL[i] -order[j]
            axs[ID].eventplot(SPKcopy,color=colors[ID],linewidths=0.5)
            j+=1
            
def plot_traces(ax, stimulus, clust, label,color,showvelo,idxclus):
    ax.set_axis_off()
    Iax1 = ax.inset_axes([0, 0.75, 1, 0.25])
    Iax2 = ax.inset_axes([0, 0.5, 1, 0.25])
    Iax3 = ax.inset_axes([0, 0.25, 1, 0.25])
    Iax4 = ax.inset_axes([0, 0.00, 1, 0.25])
    axs=[Iax1,Iax2,Iax3,Iax4]

    v=[40,20,10,5]
    col=['tab:grey','tab:grey','tab:grey','tab:grey']

    for (xax,v,col) in zip(axs, v, col):
        xax.plot(extract_flow_velo(v,15)[0],extract_flow_velo(v,15)[1]*1e8/8,color=col)
        if showvelo==True:
            xax.text(0,0.2,'v='+str(v)+'cm/s',fontsize=8, transform=xax.transAxes )


    for xax in axs:
        xax.set_ylim([0,2])
        xax.set_xlim([-0.9,0.1])
        xax.spines['left'].set_visible(False)
        xax.spines['right'].set_visible(False)
        xax.spines['top'].set_visible(False)
        xax.tick_params(left = False, labelleft = False)

    axs[0].text(0.02, 0.95, label, fontsize=10, fontweight="bold", transform=ax.transAxes)
    axs[3].set_xlabel('time to collision', fontsize=8)
    axs[3].tick_params(axis='x', labelsize=8)
    axs[0].text(0.42,0.95,'Cluster '+str(idxclus),fontsize=10, fontweight="bold", transform=axs[0].transAxes,color=color, bbox=dict(facecolor='none', edgecolor=color))
    plot_raws(stimulus, clust, axs,color)
    
    
def transform_radius(R):
    r=15
    if R==7:
        r=7.5
    elif R==30:
        r=25
    else:
        r=15
    return(int(r))


    
def plot_timing_max_FR(subsetT, FRs, X, ax, offset, coloroffset, eps=0.01):
    TIMES_MAX_FR=[]
    for k in range(len(subsetT)):
        TIMES_MAX_FR+=[X[FRs[k,:].argmax()]]
    MEAN=np.mean(TIMES_MAX_FR)
    STD=np.std(TIMES_MAX_FR)/np.sqrt(len(subsetT))

    ax.vlines(MEAN, offset-eps,  offset+eps,color=coloroffset)
    ax.vlines(MEAN-STD, offset-eps,  offset+eps,color=coloroffset)
    ax.vlines(MEAN+STD, offset-eps,  offset+eps,color=coloroffset)
    ax.hlines(offset, MEAN-STD, MEAN+STD,color=coloroffset)    
    

def plot_meanFR2(CLUS,v,r,V,ax,color,smalloffset):
    subsetT=CLUS.loc[(CLUS['v']==v) & (CLUS['R']==r)].reset_index() 
    FRs=np.zeros((len(subsetT),int(600+(4/v)*1000)))
    for k in range(0,len(subsetT)):
        try:
            spikes=Convert(subsetT['spikes'].loc[k])
            B=FiringRateBis(spikes,-4/v,0.6)
            FRs[k,:]=B[1]/max(B[1]) #*max(V) #plt.plot(B[0],B[1],'grey') #plt.eventplot(spikes,lineoffsets=k) #, color=IZ_colors[color]
        except:
            pass #plt.eventplot([],lineoffsets=k) #, color=IZ_colors[color]
    X=np.linspace(-4/v, 0.6, int(600+(4/v)*1000)) 
    
    plot_timing_max_FR(subsetT, FRs,X, ax, 0.05+smalloffset, color, eps=0.01)

def eventplot_FRs(CLUS,v,r,V,ax=False,color=None,smalloffset=0):
    subsetT=CLUS.loc[(CLUS['v']==v) & (CLUS['R']==r)].reset_index() 
    FRs=np.zeros((len(subsetT),int(600+(4/v)*1000)))
    for k in range(0,len(subsetT)):
        try:
            spikes=Convert(subsetT['spikes'].loc[k])
            B=FiringRateBis(spikes,-4/v,0.6)
            FRs[k,:]=B[1]/max(B[1]) #*max(V) #plt.plot(B[0],B[1],'grey') #plt.eventplot(spikes,lineoffsets=k) #, color=IZ_colors[color]
            #plt.eventplot(spikes,lineoffsets=k)
        except:
            pass #plt.eventplot([],lineoffsets=k) #, color=IZ_colors[color]
    X=np.linspace(-4/v, 0.6, int(600+(4/v)*1000)) 
    
    return(X,FRs)               

In [None]:
# Construct a dictionnary of the linear momentum curves
Vitesses_dico ={}
Vitesses_dico["R7.5_v5"]  =pd.read_csv("Volumes/2_Volume_R7.5_V5.csv",header=None)
Vitesses_dico["R7.5_v10"] =pd.read_csv("Volumes/2_Volume_R7.5_V10.csv",header=None)
Vitesses_dico["R7.5_v20"] =pd.read_csv("Volumes/2_Volume_R7.5_V20.csv",header=None)
Vitesses_dico["R7.5_v40"] =pd.read_csv("Volumes/2_Volume_R7.5_V40.csv",header=None)
Vitesses_dico["R15_v5"] =pd.read_csv("Volumes/2_Volume_R15_V5.csv",header=None)
Vitesses_dico["R15_v10"]=pd.read_csv("Volumes/2_Volume_R15_V10.csv",header=None)
Vitesses_dico["R15_v20"]=pd.read_csv("Volumes/2_Volume_R15_V20.csv",header=None)
Vitesses_dico["R15_v40"]=pd.read_csv("Volumes/2_Volume_R15_V40.csv",header=None)
Vitesses_dico["R25_v5"] =pd.read_csv("Volumes/2_Volume_R25_V5.csv",header=None)
Vitesses_dico["R25_v10"]=pd.read_csv("Volumes/2_Volume_R25_V10.csv",header=None)
Vitesses_dico["R25_v20"]=pd.read_csv("Volumes/2_Volume_R25_v20.csv",header=None)
Vitesses_dico["R25_v40"]=pd.read_csv("Volumes/2_Volume_R25_v40.csv",header=None)

df_final=pd.read_csv("data_final.csv")
df_final=df_final.drop('Unnamed: 0',axis=1)
df_final=df_final.reset_index(drop=True)


clusters=pd.read_csv('ACF_spikes_copy_2.csv')
T1=extract_clusters(df_final,1)
T2=extract_clusters(df_final,2)
T3=extract_clusters(df_final,3)
T1.loc[(T1["R"]==30) & (T1["v"]==5)  & (T1["file"]=="2023_04_11_0014"),'timingFR']=np.nan #correction car pas réponse, donc timingFR tard

#Select only data for v=5,10,20 or 40 (sometimes we tried v=25 or v=15 or v=50cm/s)
T1= T1[T1['v'].isin([5, 10,20,40])]
T2= T2[T2['v'].isin([5, 10,20,40])]
T3= T3[T3['v'].isin([5, 10,20,40])]


# 1. Plot D and v

In [None]:
clust1=np.loadtxt("raw_texts_figures//2023_03_29_0007_2143.txt")
clust2=np.loadtxt("raw_texts_figures//2023_03_02_0036.txt")
clust3=np.loadtxt("raw_texts_figures//2023_03_17_0014_3412.txt")

stimulus1 = pd.read_csv("raw_texts_figures//valves_2023_03_29_0007.txt", sep='\t')
stimulus2 = pd.read_csv("raw_texts_figures//valves_2023_03_02_0036.txt", sep='\t')
stimulus3 = pd.read_csv("raw_texts_figures//valves_2023_03_17_0014.txt", sep='\t')

raw1 = pd.read_csv("raw_texts_figures//2023_03_29_0007_raw.txt", sep='\t')
raw2 = pd.read_csv("raw_texts_figures//2023_03_02_0036_raw.txt", sep='\t')
raw3 = pd.read_csv("raw_texts_figures//2023_03_17_0014_raw.txt", sep='\t')

In [None]:
os.pardir

In [None]:
fd+'\\2-Behavioral_Data\\1_Influence_D_v\\'

In [None]:
fd=os.path.abspath(os.path.join(os.getcwd(), os.pardir))
fd=os.path.abspath(os.path.join(fd, os.pardir))
fd

df_final = pd.read_excel(fd+'\\2-Behavioral_Data\\1_Influence_D_v\\Treadmill_Velocities.xlsx')

for k in range(len(df_final)): #create the column ID (ex: 75-203412 means v=20-5-30-40-10-20 and D=7.5mm)
    if k%6==0:
        arr=list(map(int,np.floor(df_final['v'].iloc[k:k+6].to_numpy()/10).tolist()))
        T=str(df_final['R'].iloc[k])+'-'+''.join(str(a) for a in arr)
        for i in range(6):
            df_final.loc[k+i,'ID']=T
            
df_final=df_final.loc[(df_final['keep'] == 1) ]
condition=df_final[(df_final['tps_reaction_norm']<1) & (df_final['safe'] ==0)]
df_final.loc[condition.index,'tps_reaction_norm']=1 #if safe=0 but t<1 (in a few cases, there is a misalignement and the cricket is touched before the theoretical end of the stimulus), we set the reaction time to 1
df_final_new = df_final[df_final['frames_reac'] <= 600] #threshold to remove data when the cricket does not react
df_final_new['treac2'] = df_final_new['tps_reaction_norm']
#df_final_new.to_csv("Treadmill_R3.csv") #Export data for analysis in R


#df_final : data without processing
df_final_N=df_final.loc[df_final['species']=='nemobius']
df_final_A=df_final.loc[df_final['species']=='acheta']

#df_final_new : data without the siutations when the cricket does not move at all (more relevant to compare reaction times)
df_final_new_N=df_final_new.loc[df_final_new['species']=='nemobius']
df_final_new_A=df_final_new.loc[df_final_new['species']=='acheta']



def grouped_plot_2_BEH(metric, clusters,ax ,color, r,logistic):
    clusters=clusters.loc[clusters['R']==r]
    
    
    if logistic==False:
        sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
        sns.scatterplot(x="v", y=metric, data=jitter_clust(clusters),ax=ax, marker='o', s=4)  #lienplot hue="file",alpha=alpha, 
    else:
        sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
    
    ax.legend([],[], frameon=False)


def plot_species(df_final, df_final_new, axs, args=['a. (','b. ('], shownumber=True):
    label1=args[0]
    label2=args[1]
    ax00, ax01 = axs
    grouped_plot_2_BEH('safe', df_final, ax00, 'tab:blue', 75, logistic=True )
    grouped_plot_2_BEH('safe', df_final, ax00, 'tab:orange', 150, logistic=True )
    grouped_plot_2_BEH('safe', df_final, ax00, 'tab:green', 250, logistic=True )
    sns.pointplot(x="v", y='safe',  hue="R", data=df_final.loc[df_final['R'] !=50], native_scale=True, markers='x',palette=['tab:blue','tab:orange','tab:green'],ax=ax00, join=False, errorbar=None) #, flierprops={"marker": "x"})
    ax00.legend([],[], frameon=False)
    if shownumber:
        format_axes(ax00, 'Sphere velocity v (cm/s)', 'Escape success', [0.02,0.95,label1+' ('+str(int(len(df_final.loc[df_final['R'] !=50])/5)) +' insects, ' +str(len(df_final.loc[df_final['R'] !=50]))+' tests)',10], None )
    else:
        format_axes(ax00, 'Sphere velocity v (cm/s)', 'Escape success', [0.02,0.95,label1 ,10], None )

    
    grouped_plot_2_BEH('treac2', df_final_new, ax01, 'tab:blue', 75, logistic=False )
    grouped_plot_2_BEH('treac2', df_final_new, ax01, 'tab:orange', 150, logistic=False )
    grouped_plot_2_BEH('treac2', df_final_new, ax01, 'tab:green', 250, logistic=False )
    
    
    ax01.set_ylim([0,2])
    ax01.set_xlim([2.5,42.5])
    if shownumber:
        format_axes(ax01, 'Sphere velocity v (cm/s)', 'Normalized reaction time', [0.02,0.95,label2+str(int(len(df_final.loc[df_final['R'] !=50])/5)) +' insects, '+str(len(df_final_new))+' tests)',10], None )
    else:
        format_axes(ax01, 'Sphere velocity v (cm/s)', 'Normalized reaction time', [0.02,0.95,label2,10], None )

        
    #handles, labels = ax01.get_legend_handles_labels()
    #ax01.legend(handles[:3], labels[:3], title='D', frameon=False, loc=4)
    
    ax01.plot(0,0,label='D=7.5 mm',color='tab:blue')
    ax01.plot(0,0,label='D=15 mm',color='tab:orange')
    ax01.plot(0,0,label='D=25 mm',color='tab:green')
    ax01.legend(frameon=False, loc=4, prop={'size': 7})
    
    
    ax01.text(0.01,0.05,'before collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)
    ax01.text(0.01,0.65,' after collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)
    ax01.hlines(1, 2.5, 42.5, color='black',linestyle='dashed')

def plot_traces_behav(ax, stimulus, clust, label,color,showvelo):
    ax.set_axis_off()
    Iax1 = ax.inset_axes([0, 0.75, 1, 0.25])
    Iax2 = ax.inset_axes([0, 0.5, 1, 0.25])
    Iax3 = ax.inset_axes([0, 0.25, 1, 0.25])
    Iax4 = ax.inset_axes([0, 0.00, 1, 0.25])
    axs=[Iax1,Iax2,Iax3,Iax4]

    v=[40,20,10,5]
    col=['tab:grey','tab:grey','tab:grey','tab:grey']

    for (xax,v,col) in zip(axs, v, col):
        xax.plot(extract_flow_velo(v,15)[0],extract_flow_velo(v,15)[1]*1e8/8,color=col)
        if showvelo==True:
            xax.text(0,0.2,'v='+str(v)+'cm/s',fontsize=8, transform=xax.transAxes )


    for xax in axs:
        xax.set_ylim([0,2])
        xax.set_xlim([-0.9,0.1])
        xax.spines['left'].set_visible(False)
        xax.spines['right'].set_visible(False)
        xax.spines['top'].set_visible(False)
        xax.tick_params(left = False, labelleft = False)

    axs[0].text(0.02, 0.95, label, fontsize=10, fontweight="bold", transform=ax.transAxes)
    axs[3].set_xlabel('time to collision', fontsize=8)
    axs[3].tick_params(axis='x', labelsize=8)
    axs[0].text(0.42,0.95,'Behavior ',fontsize=10, fontweight="bold", transform=axs[0].transAxes,color=color, bbox=dict(facecolor='none', edgecolor=color))

    for (xax,velo) in zip([Iax1,Iax2,Iax3,Iax4], [40,20,10,5]):
        DFTEST=df_final_new_A.loc[(df_final_new_A["R"]==250) & (df_final_new_A["v"]==velo)]
        IDTEST= DFTEST[["ID"]].iloc[13][0] 
        treac=DFTEST.loc[DFTEST['ID']==IDTEST, "tps_reaction"].values[0] - DFTEST.loc[DFTEST['ID']==IDTEST, "tps_total"].values[0]
        #print(treac)
        xax.vlines(treac,0,xax.get_ylim()[1], color=color, linestyle='dashed')


In [None]:
fig,ax=plt.subplots()
plot_traces_behav(ax, None, None,'a4.','black',True)

In [None]:
fig,ax = plt.subplots(figsize=(16,12)) 
ax.axis('off')

ax06=plt.subplot2grid((3,4),(0,0))
ax07=plt.subplot2grid((3,4),(0,1))
ax08=plt.subplot2grid((3,4),(0,2))
ax00=plt.subplot2grid((3,4),(1,0))
ax01=plt.subplot2grid((3,4),(1,1))
ax02=plt.subplot2grid((3,4),(1,2))
ax03=plt.subplot2grid((3,4),(2,0))
ax04=plt.subplot2grid((3,4),(2,1))
ax05=plt.subplot2grid((3,4),(2,2))

ax09=plt.subplot2grid((3,4),(0,3))
ax10=plt.subplot2grid((3,4),(1,3))
ax11=plt.subplot2grid((3,4),(2,3))



plot_traces_behav(ax09, None, None,'a4.','black',True)
plot_species(df_final_A,df_final_new_A, [ax10, ax11], ['b4.', 'c4.'])

In [None]:
fig,ax = plt.subplots(figsize=(12,12)) 
ax.axis('off')

ax06=plt.subplot2grid((3,4),(0,0))
ax07=plt.subplot2grid((3,4),(0,1))
ax08=plt.subplot2grid((3,4),(0,2))
ax00=plt.subplot2grid((3,4),(1,0))
ax01=plt.subplot2grid((3,4),(1,1))
ax02=plt.subplot2grid((3,4),(1,2))
ax03=plt.subplot2grid((3,4),(2,0))
ax04=plt.subplot2grid((3,4),(2,1))
ax05=plt.subplot2grid((3,4),(2,2))

ax09=plt.subplot2grid((3,4),(0,3))
ax10=plt.subplot2grid((3,4),(1,3))
ax11=plt.subplot2grid((3,4),(2,3))

Xval="v"
Hval="R"
vals=[7,15,30]
clusts=[T1,T2,T3]
axs=[ax00,ax01,ax02,
    ax03,ax04,ax05] #

for idx,CLUST in enumerate(clusts):

    compare_clusters_lineplot("maxFR", CLUST,Xval,Hval,vals[0],axs[idx],'tab:blue')  
    compare_clusters_lineplot("maxFR", CLUST,Xval,Hval,vals[1],axs[idx],'tab:orange')
    compare_clusters_lineplot("maxFR", CLUST,Xval,Hval,vals[2],axs[idx],'tab:green')  

    compare_clusters_lineplot("timingFR", CLUST,Xval,Hval,vals[0],axs[idx+3],'tab:blue')  
    compare_clusters_lineplot("timingFR", CLUST,Xval,Hval,vals[1],axs[idx+3],'tab:orange')
    compare_clusters_lineplot("timingFR", CLUST,Xval,Hval,vals[2],axs[idx+3],'tab:green')
    
        
format_axes(ax00, 'Sphere velocity v (cm/s)', 'Max FR (s-1)', [0.02,0.95,'b1.',10], None )
format_axes(ax01, 'Sphere velocity v (cm/s)', 'Max FR (s-1)', [0.02,0.95,'b2.',10], None )
format_axes(ax02, 'Sphere velocity v (cm/s)', 'Max FR (s-1)', [0.02,0.95,'b3.',10], None )

ax00.set_ylim([0,20])
ax01.set_ylim([0,100])
ax02.set_ylim([0,160])
ax05.set_ylim([-0.3,0.1])

minX=0
maxX=41
ax03.hlines(0,minX,maxX, color='black', linestyle='dotted')
ax04.hlines(0,minX,maxX, color='black', linestyle='dotted')
ax05.hlines(0,minX,maxX, color='black', linestyle='dotted')

ax03.text(0.05,0.05,'before collision' ,rotation=90, transform=ax03.transAxes, fontsize=8)
ax03.text(0.05,0.65,' after collision' ,rotation=90, transform=ax03.transAxes, fontsize=8)
ax04.text(0.05,0.05,'before collision' ,rotation=90, transform=ax04.transAxes, fontsize=8)
ax04.text(0.05,0.55,' after collision' ,rotation=90, transform=ax04.transAxes, fontsize=8)

format_axes(ax03, 'Sphere velocity v (cm/s)', 'Peak timing (s)', [0.02,0.95,'c1.',10], None )
format_axes(ax04, 'Sphere velocity v (cm/s)', 'Peak timing (s)', [0.02,0.95,'c2.',10], None )
format_axes(ax05, 'Sphere velocity v (cm/s)', 'Peak timing (s)', [0.02,0.95,'c3.',10], None )



plot_traces(ax06, stimulus1, raw1,'a1.','tab:red',True,1)
plot_traces(ax07, stimulus2, raw2,'a2.','tab:purple',False,2)
plot_traces(ax08, stimulus3, raw3,'a3.','tab:pink',False,3)


ax05.plot(0,0,label='D=7.5mm',color='tab:blue')
ax05.plot(0,0,label='D=15mm',color='tab:orange')
ax05.plot(0,0,label='D=25mm',color='tab:green')
ax05.legend(frameon=False, loc=4, prop={'size': 7})


ax00.set_xlim([4,41])
ax01.set_xlim([4,41])
ax02.set_xlim([4,41])

ax03.set_xlim([4,41])
ax04.set_xlim([4,41])
ax05.set_xlim([4,41])

ax03.set_ylim([-0.2,0.2])
ax04.set_ylim([-0.5,0.5])


plot_traces_behav(ax09, None, None,'d1.','black',True)
plot_species(df_final_A,df_final_new_A, [ax10, ax11], ['d2.', 'd3.'], False)


fig.tight_layout()

plt.savefig("Plot_D_v_Beh_Neu.png", bbox_inches='tight', pad_inches=0) 


# 2. Behaviour + Electrophysiology

In [None]:
# Import the file of the behavioural data
df_final = pd.concat(pd.read_excel('Treadmill_clean.xlsx', sheet_name=None), ignore_index=True)
df_final=df_final.loc[(df_final['keep'] == 1) ]
df_final['treac']=df_final['frames_reac']*0.002
df_final_new = df_final[df_final['frames_reac'] <= 400]
df_final_new = df_final_new[df_final_new['v']!= 100]
df_final_new['dreac']=df_final_new['v']*df_final_new['treac']-4
df_final_new['treac2'] = df_final_new['treac'] - 4/df_final_new['v']
df_final_new['R_div_v']=df_final_new['R']*0.01/df_final_new['v']
df_final = pd.concat(pd.read_excel('Treadmill_clean.xlsx', sheet_name=None), ignore_index=True)
df_final=df_final.loc[(df_final['keep'] == 1) ]
df_final['temps_reac']=df_final['frames_reac']*0.002

## 2.A. Means + SD

In [None]:
_,ax = plt.subplots(figsize=(14,10)) 
ax00=plt.subplot2grid((3,4),(0,0))
ax01=plt.subplot2grid((3,4),(0,1))
ax02=plt.subplot2grid((3,4),(0,2))
ax03=plt.subplot2grid((3,4),(0,3))
ax04=plt.subplot2grid((3,4),(1,0))
ax05=plt.subplot2grid((3,4),(1,1))
ax06=plt.subplot2grid((3,4),(1,2))
ax07=plt.subplot2grid((3,4),(1,3))
ax08=plt.subplot2grid((3,4),(2,0))
ax09=plt.subplot2grid((3,4),(2,1))
ax10=plt.subplot2grid((3,4),(2,2))
ax11=plt.subplot2grid((3,4),(2,3))


axs=[ax00,ax01,ax02,ax03,
     ax04,ax05,ax06,ax07,
     ax08,ax09,ax10,ax11]
cols=['tab:blue', 'tab:orange', 'tab:green', 'tab:red',
     'tab:blue', 'tab:orange', 'tab:green', 'tab:red',
     'tab:blue', 'tab:orange', 'tab:green', 'tab:red']

IDX=0

for r in [7,15,30]:
    for v in [5,10,20,40]:
        if r==30:
            r_behav=25
        elif r==7:
            r_behav=7.5
        else:
            r_behav=15

        TIME,V= extract_flow_velo(v,r_behav)
        
        plot_meanFR2(T1,v,r,V,axs[IDX],'tab:red',0)
        plot_meanFR2(T2,v,r,V,axs[IDX],'tab:purple',0.025)
        plot_meanFR2(T3,v,r,V,axs[IDX],'tab:pink',0.05)
        
        T=df_final.loc[(df_final['v']==v) & (df_final['R']==r_behav*10) & (df_final['temps_reac']<1), 'temps_reac']
        #axs[IDX].plot(TIME,0.05*V/V.max(axis=0)[0],color='black') #cols[IDX]
        
        axs[IDX].plot(TIME,0.05*1e7*V,color='black')
        
        MEANT=np.mean(T)
        STDT=np.std(T) /np.sqrt(len(T))
        offset=0.125
        axs[IDX].vlines(-4/v+ MEANT,offset-0.01,offset+0.01,color='grey')
        axs[IDX].vlines(-4/v+ MEANT-STDT,offset-0.01,offset+0.01,color='grey')
        axs[IDX].vlines(-4/v+ MEANT+STDT,offset-0.01,offset+0.01,color='grey')
        axs[IDX].hlines(offset, -4/v+ MEANT - STDT ,-4/v+ MEANT +STDT,color='grey')
        
        
        axs[IDX].set_xlim([-4/v,0.1])
        axs[IDX].set_xlabel('Time (s) (R='+str(r_behav)+'V='+str(v)+')')
        
        
        
        axs[IDX].plot(0,0, color='tab:red', label="Clust 1")
        axs[IDX].plot(0,0, color='tab:purple', label="Clust 2")
        axs[IDX].plot(0,0, color='tab:pink', label="Clust 3")
        axs[IDX].plot(0,0, color='grey', label="Behavior resp.")
        
        axs[IDX].set_ylim([0,0.16])
        #axs[IDX].set_aspect("equal", adjustable="datalim")
        IDX+=1
#axs[0].legend(loc='upper left')

axs[0].text(-0.8,0.050,'Clust 1',fontsize=8,color='tab:red')
axs[0].text(-0.8,0.075,'Clust 2',fontsize=8,color='tab:purple')
axs[0].text(-0.8,0.100,'Clust 3',fontsize=8,color='tab:pink')
axs[0].text(-0.8,0.125,'Behav',fontsize=8,color='grey')

axs[4].text(-0.8,0.050,'Clust 1',fontsize=8,color='tab:red')
axs[4].text(-0.8,0.075,'Clust 2',fontsize=8,color='tab:purple')
axs[4].text(-0.8,0.100,'Clust 3',fontsize=8,color='tab:pink')
axs[4].text(-0.8,0.125,'Behav',fontsize=8,color='grey')

axs[8].text(-0.8,0.050,'Clust 1',fontsize=8,color='tab:red')
axs[8].text(-0.8,0.075,'Clust 2',fontsize=8,color='tab:purple')
axs[8].text(-0.8,0.100,'Clust 3',fontsize=8,color='tab:pink')
axs[8].text(-0.8,0.125,'Behav',fontsize=8,color='grey')

dico_lettres=['a1.','b1.','c1.','d1.',
             'a2.','b2.','c2.','d2.']
for idx,ax in enumerate([ax00,ax01,ax02,ax03,ax04,ax05,ax06,ax07]):
    format_axes(ax, '', '', [0.02,0.95,dico_lettres[idx],10], None )
    ax.spines['left'].set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

dico_lettres_2=['a3.','b3.','c3.','d3.']
for idx,ax in enumerate([ax08, ax09, ax10, ax11]):
    format_axes(ax, 'Time (s)', '', [0.02,0.95,dico_lettres_2[idx],10], None )
    ax.spines['left'].set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    


## 2.B Figure Eventplot Mean

This figure shows the extent of the jitter associated with the timing of the firing rate

In [None]:
_,ax = plt.subplots(figsize=(14,10)) 
ax00=plt.subplot2grid((3,4),(0,0))
ax01=plt.subplot2grid((3,4),(0,1))
ax02=plt.subplot2grid((3,4),(0,2))
ax03=plt.subplot2grid((3,4),(0,3))
ax04=plt.subplot2grid((3,4),(1,0))
ax05=plt.subplot2grid((3,4),(1,1))
ax06=plt.subplot2grid((3,4),(1,2))
ax07=plt.subplot2grid((3,4),(1,3))
ax08=plt.subplot2grid((3,4),(2,0))
ax09=plt.subplot2grid((3,4),(2,1))
ax10=plt.subplot2grid((3,4),(2,2))
ax11=plt.subplot2grid((3,4),(2,3))

axs=[ax00,ax01,ax02,ax03,
     ax04,ax05,ax06,ax07,
     ax08,ax09,ax10,ax11]
cols=['tab:blue', 'tab:orange', 'tab:green', 'tab:red',
     'tab:blue', 'tab:orange', 'tab:green', 'tab:red',
     'tab:blue', 'tab:orange', 'tab:green', 'tab:red']

IDX=0


for r in [7,15,30]:
    for v in [5,10,20,40]:
        if r==30:
            r_behav=25
        elif r==7:
            r_behav=7.5
        else:
            r_behav=15
        
        TIME,V= extract_flow_velo(v,r_behav)
        
        clusts=[T1,T2,T3]
        colors=['tab:red', 'tab:purple', 'tab:pink']
        idx=0
        for clust, col in zip(clusts,colors):
            X,FRs=eventplot_FRs(clust,v,r,V)
            filters=FRs.argmax(axis=1) #filters=filters[filters != 0]
            
            tpeaks= X[filters[filters !=0]]

            for tk in range(len(filters)):
                try:
                    axs[IDX].eventplot([tpeaks[tk]], lineoffsets=idx/10, linelengths=0.1, color=col)
                except:
                    axs[IDX].eventplot([-10], lineoffsets=idx/10, linelengths=0.1, color=col)
                idx+=1
        
        axs[IDX].plot(TIME,0.05*1e8*V,color='black')
        
       
        offset=0.125
        axs[IDX].set_xlim([-0.8,0.1])
        axs[IDX].set_xlabel('Time (s) (R='+str(r_behav)+'V='+str(v)+')')
        
        IDX+=1

dico_lettres=['a1.','b1.','c1.','d1.',
             'a2.','b2.','c2.','d2.']
for idx,ax in enumerate([ax00,ax01,ax02,ax03,ax04,ax05,ax06,ax07]):
    format_axes(ax, '', '', [0.02,0.95,dico_lettres[idx],10], None )
    ax.spines['left'].set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

dico_lettres_2=['a3.','b3.','c3.','d3.']
for idx,ax in enumerate([ax08, ax09, ax10, ax11]):
    format_axes(ax, 'Time (s)', '', [0.02,0.95,dico_lettres_2[idx],10], None )
    ax.spines['left'].set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    
