In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import functools as ft
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import statsmodels.api as sm
import seaborn as sns
from shapely.geometry import LineString
from matplotlib.patches import Circle



def grouped_plot(metric, clusters,ax ):

    IZ_colors=['tab:blue','tab:orange','tab:green']

    sns.boxplot(
        x="v",       # x variable name
        y=metric,       # y variable name
        hue="R",  # group variable name
        data=clusters,     # dataframe to plot
        palette=IZ_colors,ax=ax
    )

    sns.stripplot(
        x="v", 
        y=metric, 
        hue="R", 
        data=clusters, dodge=True, alpha=0.2,ax=ax, palette=['black','black','black']
    )
    ax.legend([],[], frameon=False)
    # remove extra legend handles
    
def intersect(line1,line2):    
    firstline=LineString(line1)
    secondline = LineString(line2)   
    try:
        intersectionA = firstline.intersection(secondline)
        vth=intersectionA.xy[1][0]
    except:
        vth=np.nan
    return(vth)

def format_axes(ax,xlabel,ylabel, reffig, loclegend):
    ax.text(reffig[0], reffig[1], reffig[2], fontsize=reffig[3], fontweight="bold", transform=ax.transAxes)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlabel(xlabel, fontsize=8)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.tick_params(axis='x', labelsize=8)
    ax.tick_params(axis='y', labelsize=8)
    #ax.legend(fontsize=8,loc=loclegend, frameon=False)
    
def jitter_clust(CLUST):
    jitters=[-0.6,0,0.6,0.6]
    CLUST2=CLUST.copy()
    for idx,r in enumerate([75,150,250]):
        CLUST2.loc[CLUST2["R"]==r, "v"] = CLUST2["v"].apply(lambda x: x + jitters[idx])
    return(CLUST2)

def grouped_plot_2(metric, clusters,ax ,color, r,logistic):
    clusters=clusters.loc[clusters['R']==r]
    
    
    if logistic==False:
        sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
        sns.scatterplot(x="v", y=metric, data=jitter_clust(clusters),ax=ax, marker='o', s=4)  #lienplot hue="file",alpha=alpha, 
    else:
        sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
    
    ax.legend([],[], frameon=False)
    
def plot_species(df_final, df_final_new, args=['a. (','b. (']):
    
    label1=args[0]
    label2=args[1]

    fig,ax = plt.subplots(figsize=(12,4))
    ax00=plt.subplot2grid((1,2),(0,0))
    ax01=plt.subplot2grid((1,2),(0,1))


    grouped_plot_2('safe', df_final, ax00, 'tab:blue', 75, logistic=True )
    grouped_plot_2('safe', df_final, ax00, 'tab:orange', 150, logistic=True )
    grouped_plot_2('safe', df_final, ax00, 'tab:green', 250, logistic=True )

    #sns.pointplot(x="v", y='safe',  hue="R", data=df_final, native_scale=True, markers='x',palette=['tab:blue','tab:orange','tab:green'],ax=ax00, join=False, errorbar=None) #, flierprops={"marker": "x"})
    
    sns.pointplot(x="v", y='safe',  hue="R", data=df_final.loc[df_final['R'] !=50], native_scale=True, markers='x',palette=['tab:blue','tab:orange','tab:green'],ax=ax00, join=False, errorbar=None) #, flierprops={"marker": "x"})
    

    
    ax00.legend([],[], frameon=False)
    format_axes(ax00, 'Sphere velocity v (cm/s)', 'Proportion of crickets that responded before collision', [0.02,0.95,label1+str(int(len(df_final.loc[df_final['R'] !=50])/5)) +' insects, ' +str(len(df_final.loc[df_final['R'] !=50]))+' tests)',10], None )


    grouped_plot_2('treac2', df_final_new, ax01, 'tab:blue', 75, logistic=False )
    grouped_plot_2('treac2', df_final_new, ax01, 'tab:orange', 150, logistic=False )
    grouped_plot_2('treac2', df_final_new, ax01, 'tab:green', 250, logistic=False )


    ax01.set_ylim([0,2])
    ax01.set_xlim([2.5,42.5])
    format_axes(ax01, 'Sphere velocity v (cm/s)', 'Timing of cricket reaction (s)', [0.02,0.95,label2+str(int(len(df_final.loc[df_final['R'] !=50])/5)) +' insects, '+str(len(df_final_new))+' tests)',10], None )
    #handles, labels = ax01.get_legend_handles_labels()
    #ax01.legend(handles[:3], labels[:3], title='D', frameon=False, loc=4)
    
    ax01.plot(0,0,label='D=7.5mm',color='tab:blue')
    ax01.plot(0,0,label='D=15mm',color='tab:orange')
    ax01.plot(0,0,label='D=25mm',color='tab:green')
    ax01.legend(frameon=False, loc=4, prop={'size': 7})
    
    
    ax01.text(0.01,0.05,'before collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)
    ax01.text(0.01,0.65,' after collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)
    ax01.hlines(1, 2.5, 42.5, color='black',linestyle='dashed')
    
    
def grouped_plot_3(metric, clusters,ax ,color, r,logistic,order=1):
    clusters=clusters.loc[clusters['R']==r]
    
    
    if logistic==False:
        if order<2:
            sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic, order=order) 
            sns.scatterplot(x="v", y=metric, data=jitter_clust(clusters),ax=ax, marker='o', s=4)  #lienplot hue="file",alpha=alpha, 
        else:
            sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic, order=order) 

    else:
        sns.regplot(x="v", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
    
    ax.legend([],[], frameon=False)
    
    
def grouped_plot_R1(metric, clusters,ax ,color, v,logistic):
    clusters=clusters.loc[clusters['v']==v]
    if logistic==False:
        sns.regplot(x="R", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
        sns.scatterplot(x="R", y=metric, hue="v", data=jitter_clust(clusters),ax=ax, marker='o', palette=[color], s=4)  #lienplot hue="file",alpha=alpha, 
    else:
        sns.regplot(x="R", y=metric, data=clusters,ax=ax, color=color, scatter_kws={'s':0},logistic=logistic) 
        
    ax.legend([],[], frameon=False) 

# 1. Treadmill Data

## 1.A. Various velocities

In [None]:
df_final = pd.read_excel('Treadmill_Velocities.xlsx')

for k in range(len(df_final)): #create the column ID (ex: 75-203412 means v=20-5-30-40-10-20 and D=7.5mm)
    if k%6==0:
        arr=list(map(int,np.floor(df_final['v'].iloc[k:k+6].to_numpy()/10).tolist()))
        T=str(df_final['R'].iloc[k])+'-'+''.join(str(a) for a in arr)
        for i in range(6):
            df_final.loc[k+i,'ID']=T
            
df_final=df_final.loc[(df_final['keep'] == 1) ]
condition=df_final[(df_final['tps_reaction_norm']<1) & (df_final['safe'] ==0)]
df_final.loc[condition.index,'tps_reaction_norm']=1 #if safe=0 but t<1 (in a few cases, there is a misalignement and the cricket is touched before the theoretical end of the stimulus), we set the reaction time to 1
df_final_new = df_final[df_final['frames_reac'] <= 600] #threshold to remove data when the cricket does not react
df_final_new['treac2'] = df_final_new['tps_reaction_norm']
#df_final_new.to_csv("Treadmill_R3.csv") #Export data for analysis in R


#df_final : data without processing
df_final_N=df_final.loc[df_final['species']=='nemobius']
df_final_A=df_final.loc[df_final['species']=='acheta']

#df_final_new : data without the siutations when the cricket does not move at all (more relevant to compare reaction times)
df_final_new_N=df_final_new.loc[df_final_new['species']=='nemobius']
df_final_new_A=df_final_new.loc[df_final_new['species']=='acheta']

In [None]:
plot_species(df_final_A,df_final_new_A, ['a. (', 'b. ('])

plot_species(df_final_N,df_final_new_N, ['b1. (', 'b2. ('])

 Histogram reaction times

plt.hist(df_final_new_N['frames_reac'],alpha=0.5)
plt.hist(df_final_new_A['frames_reac'],alpha=0.5)

## 1.B. Various diameters

In [None]:
compare_sizes = pd.concat(pd.read_excel('Treadmill_Sizes.xlsx', sheet_name=None), ignore_index=True)
compare_sizes=compare_sizes.loc[(compare_sizes['keep'] == 1) ]
compare_sizes['treac']=compare_sizes['frames_reac']*0.002

compare_sizes_new = compare_sizes #df_final[df_final['frames_reac'] <= 400]
compare_sizes_new = compare_sizes_new[compare_sizes_new['v']!= 100]
compare_sizes_new['dreac']=compare_sizes_new['v']*compare_sizes_new['treac']-4
compare_sizes_new['treac2'] = compare_sizes_new['treac']/ (4/compare_sizes_new['v'])

compare_sizes_new['R'] = compare_sizes_new['R']/10

compare_sizes_new=compare_sizes_new.loc[compare_sizes_new['frames_reac'] < 600]

In [None]:
fig,ax = plt.subplots(figsize=(12,4))
ax00=plt.subplot2grid((1,2),(0,0))
ax01=plt.subplot2grid((1,2),(0,1))


grouped_plot_R1('safe', compare_sizes_new, ax00, 'tab:olive', 10, logistic=True )
grouped_plot_R1('safe', compare_sizes_new, ax00, 'tab:grey', 40, logistic=True )
sns.pointplot(x="R", y='safe',  hue="v", data=compare_sizes_new, native_scale=True, markers='x',palette=['tab:olive','tab:grey'],ax=ax00, join=False, errorbar=None) #, flierprops={"marker": "x"})

ax00.legend([],[], frameon=False)
format_axes(ax00, 'Sphere diameter D (mm)', 'Proportion of crickets that responded before collision', [0.02,0.95,'a.',10], None )

grouped_plot_R1('treac2', compare_sizes_new, ax01, 'tab:olive', 10, logistic=False )
grouped_plot_R1('treac2', compare_sizes_new, ax01, 'tab:grey', 40, logistic=False )

format_axes(ax01, 'Sphere diameter D (mm)', 'Timing of cricket reaction (s)', [0.02,0.95,'b.',10], None )
ax01.hlines(1, 6.5, 25.5, color='black',linestyle='dashed')
ax01.set_ylim([0,2])
ax01.set_xlim([6.5,25.5])
ax01.text(0.01,0.05,'before collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)
ax01.text(0.01,0.65,' after collision' ,rotation=90, transform=ax01.transAxes, fontsize=8)   
ax01.plot(0,0,label='v=10cm/s',color='tab:olive')
ax01.plot(0,0,label='v=40cm/s',color='tab:grey')
ax01.legend(frameon=False, loc=4, prop={'size': 7})
    