### In this script, DBN is run on the all the sessions
### In this script, DBN is run with 1s time bin, 3 time lag 
### In this script, the animal tracking is done with only one camera - camera 2 (middle) 
### only focus on the strategy switching session (among Ginger/Kanga/Dodson/Dannon)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn
import scipy
import scipy.stats as st
from sklearn.neighbors import KernelDensity
import string
import warnings
import pickle

import os
import glob
import random
from time import time

from pgmpy.models import BayesianModel
from pgmpy.models import DynamicBayesianNetwork as DBN
from pgmpy.estimators import BayesianEstimator
from pgmpy.estimators import HillClimbSearch,BicScore
from pgmpy.base import DAG
import networkx as nx


### function - get body part location for each pair of cameras

In [None]:
from ana_functions.body_part_locs_eachpair import body_part_locs_eachpair
from ana_functions.body_part_locs_singlecam import body_part_locs_singlecam

### function - align the two cameras

In [None]:
from ana_functions.camera_align import camera_align       

### function - merge the two pairs of cameras

In [None]:
from ana_functions.camera_merge import camera_merge

### function - find social gaze time point

In [None]:
from ana_functions.find_socialgaze_timepoint import find_socialgaze_timepoint
from ana_functions.find_socialgaze_timepoint_singlecam import find_socialgaze_timepoint_singlecam
from ana_functions.find_socialgaze_timepoint_singlecam_wholebody import find_socialgaze_timepoint_singlecam_wholebody

### function - define time point of behavioral events

In [None]:
from ana_functions.bhv_events_timepoint import bhv_events_timepoint
from ana_functions.bhv_events_timepoint_singlecam import bhv_events_timepoint_singlecam

### function - plot behavioral events

In [None]:
from ana_functions.plot_bhv_events import plot_bhv_events
from ana_functions.plot_bhv_events_levertube import plot_bhv_events_levertube
from ana_functions.draw_self_loop import draw_self_loop
import matplotlib.patches as mpatches 
from matplotlib.collections import PatchCollection

### function - plot inter-pull interval

In [None]:
from ana_functions.plot_interpull_interval import plot_interpull_interval

### function - make demo videos with skeleton and inportant vectors

In [None]:
from ana_functions.tracking_video_singlecam_demo import tracking_video_singlecam_demo
from ana_functions.tracking_video_singlecam_wholebody_demo import tracking_video_singlecam_wholebody_demo

### function - interval between all behavioral events

In [None]:
from ana_functions.bhv_events_interval import bhv_events_interval

### function - train the dynamic bayesian network - multi time lag (3 lags)

In [None]:
from ana_functions.train_DBN_multiLag import train_DBN_multiLag
from ana_functions.train_DBN_multiLag import train_DBN_multiLag_create_df_only
from ana_functions.train_DBN_multiLag import train_DBN_multiLag_training_only
from ana_functions.train_DBN_multiLag import graph_to_matrix
from ana_functions.train_DBN_multiLag import get_weighted_dags
from ana_functions.train_DBN_multiLag import get_significant_edges
from ana_functions.train_DBN_multiLag import threshold_edges
from ana_functions.train_DBN_multiLag import Modulation_Index
from ana_functions.EfficientTimeShuffling import EfficientShuffle
from ana_functions.AicScore import AicScore

## Analyze each session

### prepare the basic behavioral data (especially the time stamps for each bhv events)

In [None]:
# instead of using gaze angle threshold, use the target rectagon to deside gaze info
# ...need to update
sqr_thres_tubelever = 75 # draw the square around tube and lever
sqr_thres_face = 1.15 # a ratio for defining face boundary
sqr_thres_body = 4 # how many times to enlongate the face box boundry to the body


# get the fps of the analyzed video
fps = 30

# frame number of the demo video
# nframes = 0.5*30 # second*30fps
nframes = 45*30 # second*30fps

# re-analyze the video or not
reanalyze_video = 0
redo_anystep = 0

# session list options
do_bestsession = 1 # only analyze the best (five) sessions for each conditions during the training phase
if do_bestsession:
    savefile_sufix = '_bestsessions_StraSwitch'
else:
    savefile_sufix = '_StraSwitch'
    
# all the videos (no misaligned ones)
# aligned with the audio
# get the session start time from "videosound_bhv_sync.py/.ipynb"
# currently the session_start_time will be manually typed in. It can be updated after a better method is used

# dodson ginger
if 1:
    if not do_bestsession:
        dates_list = [
            
                     ]
        session_start_times = [ 
            
                              ] # in second
    elif do_bestsession:
        dates_list = [
                      "20240924","20240926","20241001","20241003","20241007",
                     ]
        session_start_times = [ 
                             0.00, 43.0, 20.0, 0.00, 0.00,
                              ] # in second
            
    animal1_fixedorder = ['dodson']
    animal2_fixedorder = ['ginger']

    animal1_filename = "Dodson"
    animal2_filename = "Ginger"
     
# ginger kanga
if 0:
    if not do_bestsession:
        dates_list = [
                      
                   ]
        session_start_times = [ 
                                
                              ] # in second 
    elif do_bestsession:       
        dates_list = [
                      "20240923","20240925","20240930","20241002","20241004",
                   ]
        session_start_times = [ 
                                 19.0, 0.00, 26.8, 35.0, 15.4,
                              ] # in second 
    
    animal1_fixedorder = ['ginger']
    animal2_fixedorder = ['kanga']

    animal1_filename = "Ginger"
    animal2_filename = "Kanga"

    
# dannon kanga
if 0:
    if not do_bestsession:
        dates_list = [
                    
                   ]
        session_start_times = [ 
                              
                              ] # in second 
    elif do_bestsession: 
        dates_list = [
                      "20240926", "20241001", "20241003", "20241007",
                   ]
        session_start_times = [ 
                                   0.00,  37.0, 0.00, 0.00,
                              ] # in second 
    
    animal1_fixedorder = ['dannon']
    animal2_fixedorder = ['kanga']

    animal1_filename = "Dannon"
    animal2_filename = "Kanga"


#    
# dates_list = ["20221128"]
# session_start_times = [1.00] # in second
ndates = np.shape(dates_list)[0]

session_start_frames = session_start_times * fps # fps is 30Hz

totalsess_time = 600

# video tracking results info
animalnames_videotrack = ['dodson','scorch'] # does not really mean dodson and scorch, instead, indicate animal1 and animal2
bodypartnames_videotrack = ['rightTuft','whiteBlaze','leftTuft','rightEye','leftEye','mouth']


# which camera to analyzed
cameraID = 'camera-2'
cameraID_short = 'cam2'


# location of levers and tubes for camera 2
# get this information using DLC animal tracking GUI, the results are stored: 
# /home/ws523/marmoset_tracking_DLCv2/marmoset_tracking_with_lever_tube-weikang-2023-04-13/labeled-data/
considerlevertube = 1
considertubeonly = 0
# # camera 1
# lever_locs_camI = {'dodson':np.array([645,600]),'scorch':np.array([425,435])}
# tube_locs_camI  = {'dodson':np.array([1350,630]),'scorch':np.array([555,345])}
# # camera 2
lever_locs_camI = {'dodson':np.array([1335,715]),'scorch':np.array([550,715])}
tube_locs_camI  = {'dodson':np.array([1550,515]),'scorch':np.array([350,515])}
# # lever_locs_camI = {'dodson':np.array([1335,715]),'scorch':np.array([550,715])}
# # tube_locs_camI  = {'dodson':np.array([1650,490]),'scorch':np.array([250,490])}
# # camera 3
# lever_locs_camI = {'dodson':np.array([1580,440]),'scorch':np.array([1296,540])}
# tube_locs_camI  = {'dodson':np.array([1470,375]),'scorch':np.array([805,475])}


if np.shape(session_start_times)[0] != np.shape(dates_list)[0]:
    exit()

    
# define bhv events summarizing variables     
tasktypes_all_dates = np.zeros((ndates,1))
coopthres_all_dates = np.zeros((ndates,1))

succ_rate_all_dates = np.zeros((ndates,1))
interpullintv_all_dates = np.zeros((ndates,1))
trialnum_all_dates = np.zeros((ndates,1))

owgaze1_num_all_dates = np.zeros((ndates,1))
owgaze2_num_all_dates = np.zeros((ndates,1))
mtgaze1_num_all_dates = np.zeros((ndates,1))
mtgaze2_num_all_dates = np.zeros((ndates,1))
pull1_num_all_dates = np.zeros((ndates,1))
pull2_num_all_dates = np.zeros((ndates,1))

bhv_intv_all_dates = dict.fromkeys(dates_list, [])

sess_videotimes_all_dates = np.zeros((ndates,1))

# where to save the summarizing data
data_saved_folder = '/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/3d_recontruction_analysis_self_and_coop_task_data_saved/'

# save the session start time
data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
if not os.path.exists(data_saved_subfolder):
    os.makedirs(data_saved_subfolder)
#
with open(data_saved_subfolder+'sessstart_time_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
    pickle.dump(session_start_times, f)
with open(data_saved_subfolder+'dates_list_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
    pickle.dump(dates_list, f)


In [None]:
# basic behavior analysis (define time stamps for each bhv events, etc)

try:
    if redo_anystep:
        dummy
    
    # load saved data
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    
    with open(data_saved_subfolder+'/owgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        owgaze1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/owgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        owgaze2_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/mtgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        mtgaze1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/mtgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        mtgaze2_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/pull1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        pull1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/pull2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        pull2_num_all_dates = pickle.load(f)

    with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        tasktypes_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        coopthres_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/succ_rate_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        succ_rate_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/interpullintv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        interpullintv_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/trialnum_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        trialnum_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/bhv_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        bhv_intv_all_dates = pickle.load(f)
        
    with open(data_saved_subfolder+'/sess_videotimes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        sess_videotimes_all_dates = pickle.load(f)    

    print('all data from all dates are loaded')

except:

    print('analyze all dates')

    for idate in np.arange(0,ndates,1):
        date_tgt = dates_list[idate]
        session_start_time = session_start_times[idate]

        # folder and file path
        camera12_analyzed_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/test_video_cooperative_task_3d/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_camera12/"
        camera23_analyzed_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/test_video_cooperative_task_3d/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_camera23/"
        
        try:
            singlecam_ana_type = "DLC_dlcrnetms5_marmoset_tracking_with_middle_cameraSep1shuffle1_150000"
            try: 
                bodyparts_camI_camIJ = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"
            except:
                bodyparts_camI_camIJ = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"        
        except:
            singlecam_ana_type = "DLC_dlcrnetms5_marmoset_tracking_with_middle_camera_withHeadchamberFeb28shuffle1_167500"
            try: 
                bodyparts_camI_camIJ = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"
            except:
                bodyparts_camI_camIJ = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"        
        
        
        min_length = np.min(list(bodyparts_locs_camI.values())[0].shape[0])
        
        sess_videotimes_all_dates[idate] = min_length/fps
        
        # load behavioral results
        try:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
        except:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])

        # get animal info from the session information
        animal1 = session_info['lever1_animal'][0].lower()
        animal2 = session_info['lever2_animal'][0].lower()

        
        # get task type and cooperation threshold
        try:
            coop_thres = session_info["pulltime_thres"][0]
            tasktype = session_info["task_type"][0]
        except:
            coop_thres = 0
            tasktype = 1
        tasktypes_all_dates[idate] = tasktype
        coopthres_all_dates[idate] = coop_thres   

        # clean up the trial_record
        warnings.filterwarnings('ignore')
        trial_record_clean = pd.DataFrame(columns=trial_record.columns)
        for itrial in np.arange(0,np.max(trial_record['trial_number']),1):
            # trial_record_clean.loc[itrial] = trial_record[trial_record['trial_number']==itrial+1].iloc[[0]]
            trial_record_clean = trial_record_clean.append(trial_record[trial_record['trial_number']==itrial+1].iloc[[0]])
        trial_record_clean = trial_record_clean.reset_index(drop = True)

        # change bhv_data time to the absolute time
        time_points_new = pd.DataFrame(np.zeros(np.shape(bhv_data)[0]),columns=["time_points_new"])
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number']),1):
            ind = bhv_data["trial_number"]==itrial+1
            new_time_itrial = bhv_data[ind]["time_points"] + trial_record_clean["trial_starttime"].iloc[itrial]
            time_points_new["time_points_new"][ind] = new_time_itrial
        bhv_data["time_points"] = time_points_new["time_points_new"]
        bhv_data = bhv_data[bhv_data["time_points"] != 0]


        # analyze behavior results
        # succ_rate_all_dates[idate] = np.sum(trial_record_clean["rewarded"]>0)/np.shape(trial_record_clean)[0]
        succ_rate_all_dates[idate] = np.sum((bhv_data['behavior_events']==3)|(bhv_data['behavior_events']==4))/np.sum((bhv_data['behavior_events']==1)|(bhv_data['behavior_events']==2))
        trialnum_all_dates[idate] = np.shape(trial_record_clean)[0]
        #
        pullid = np.array(bhv_data[(bhv_data['behavior_events']==1) | (bhv_data['behavior_events']==2)]["behavior_events"])
        pulltime = np.array(bhv_data[(bhv_data['behavior_events']==1) | (bhv_data['behavior_events']==2)]["time_points"])
        pullid_diff = np.abs(pullid[1:] - pullid[0:-1])
        pulltime_diff = pulltime[1:] - pulltime[0:-1]
        interpull_intv = pulltime_diff[pullid_diff==1]
        interpull_intv = interpull_intv[interpull_intv<10]
        mean_interpull_intv = np.nanmean(interpull_intv)
        std_interpull_intv = np.nanstd(interpull_intv)
        #
        interpullintv_all_dates[idate] = mean_interpull_intv
        # 
        if np.isin(animal1,animal1_fixedorder):
            pull1_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==1) 
            pull2_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==2)
        else:
            pull1_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==2) 
            pull2_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==1)
        
        # load behavioral event results
        try:
            # dummy
            print('load social gaze with '+cameraID+' only of '+date_tgt)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'rb') as f:
                output_look_ornot = pickle.load(f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'rb') as f:
                output_allvectors = pickle.load(f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'rb') as f:
                output_allangles = pickle.load(f)  
        except:   
            print('analyze social gaze with '+cameraID+' only of '+date_tgt)
            # get social gaze information 
            output_look_ornot, output_allvectors, output_allangles = find_socialgaze_timepoint_singlecam_wholebody(bodyparts_locs_camI,lever_locs_camI,tube_locs_camI,
                                                                                                                   considerlevertube,considertubeonly,sqr_thres_tubelever,
                                                                                                                   sqr_thres_face,sqr_thres_body)
            # save data
            current_dir = data_saved_folder+'/bhv_events_singlecam_wholebody/'+animal1_fixedorder[0]+animal2_fixedorder[0]
            add_date_dir = os.path.join(current_dir,cameraID+'/'+date_tgt)
            if not os.path.exists(add_date_dir):
                os.makedirs(add_date_dir)
            #
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'wb') as f:
                pickle.dump(output_look_ornot, f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'wb') as f:
                pickle.dump(output_allvectors, f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'wb') as f:
                pickle.dump(output_allangles, f)
  

        look_at_other_or_not_merge = output_look_ornot['look_at_other_or_not_merge']
        look_at_tube_or_not_merge = output_look_ornot['look_at_tube_or_not_merge']
        look_at_lever_or_not_merge = output_look_ornot['look_at_lever_or_not_merge']
        # change the unit to second
        session_start_time = session_start_times[idate]
        look_at_other_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_other_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_lever_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_lever_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_tube_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_tube_or_not_merge['dodson'])[0],1)/fps - session_start_time 

        # find time point of behavioral events
        output_time_points_socialgaze ,output_time_points_levertube = bhv_events_timepoint_singlecam(bhv_data,look_at_other_or_not_merge,look_at_lever_or_not_merge,look_at_tube_or_not_merge)
        time_point_pull1 = output_time_points_socialgaze['time_point_pull1']
        time_point_pull2 = output_time_points_socialgaze['time_point_pull2']
        oneway_gaze1 = output_time_points_socialgaze['oneway_gaze1']
        oneway_gaze2 = output_time_points_socialgaze['oneway_gaze2']
        mutual_gaze1 = output_time_points_socialgaze['mutual_gaze1']
        mutual_gaze2 = output_time_points_socialgaze['mutual_gaze2']
            
                
        # # plot behavioral events
        if np.isin(animal1,animal1_fixedorder):
                plot_bhv_events(date_tgt,animal1, animal2, session_start_time, 600, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
        else:
                plot_bhv_events(date_tgt,animal2, animal1, session_start_time, 600, time_point_pull2, time_point_pull1, oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)
        #
        # save behavioral events plot
        if 0:
            current_dir = data_saved_folder+'/bhv_events_singlecam_wholebody/'+animal1_fixedorder[0]+animal2_fixedorder[0]
            add_date_dir = os.path.join(current_dir,cameraID+'/'+date_tgt)
            if not os.path.exists(add_date_dir):
                os.makedirs(add_date_dir)
            plt.savefig(data_saved_folder+"/bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/'+date_tgt+"_"+cameraID_short+".pdf")

        #
        if np.isin(animal1,animal1_fixedorder):
            owgaze1_num_all_dates[idate] = np.shape(oneway_gaze1)[0]#/(min_length/fps)
            owgaze2_num_all_dates[idate] = np.shape(oneway_gaze2)[0]#/(min_length/fps)
            mtgaze1_num_all_dates[idate] = np.shape(mutual_gaze1)[0]#/(min_length/fps)
            mtgaze2_num_all_dates[idate] = np.shape(mutual_gaze2)[0]#/(min_length/fps)
        else:
            owgaze1_num_all_dates[idate] = np.shape(oneway_gaze2)[0]#/(min_length/fps)
            owgaze2_num_all_dates[idate] = np.shape(oneway_gaze1)[0]#/(min_length/fps)
            mtgaze1_num_all_dates[idate] = np.shape(mutual_gaze2)[0]#/(min_length/fps)
            mtgaze2_num_all_dates[idate] = np.shape(mutual_gaze1)[0]#/(min_length/fps)

        # analyze the events interval, especially for the pull to other and other to pull interval
        # could be used for define time bin for DBN
        if 1:
            _,_,_,pullTOother_itv, otherTOpull_itv = bhv_events_interval(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, 
                                                                         oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            #
            pull_other_pool_itv = np.concatenate((pullTOother_itv,otherTOpull_itv))
            bhv_intv_all_dates[date_tgt] = {'pull_to_other':pullTOother_itv,'other_to_pull':otherTOpull_itv,
                            'pull_other_pooled': pull_other_pool_itv}
        
        # plot the tracking demo video
        if 0: 
            tracking_video_singlecam_wholebody_demo(bodyparts_locs_camI,output_look_ornot,output_allvectors,output_allangles,
                                              lever_locs_camI,tube_locs_camI,time_point_pull1,time_point_pull2,
                                              animalnames_videotrack,bodypartnames_videotrack,date_tgt,
                                              animal1_filename,animal2_filename,session_start_time,fps,nframes,cameraID,
                                              video_file_original,sqr_thres_tubelever,sqr_thres_face,sqr_thres_body)         
        

    # save data
    if 1:
        
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
                
        # with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
        #     pickle.dump(DBN_input_data_alltypes, f)

        with open(data_saved_subfolder+'/owgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(owgaze1_num_all_dates, f)
        with open(data_saved_subfolder+'/owgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(owgaze2_num_all_dates, f)
        with open(data_saved_subfolder+'/mtgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(mtgaze1_num_all_dates, f)
        with open(data_saved_subfolder+'/mtgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(mtgaze2_num_all_dates, f)
        with open(data_saved_subfolder+'/pull1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(pull1_num_all_dates, f)
        with open(data_saved_subfolder+'/pull2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(pull2_num_all_dates, f)

        with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(tasktypes_all_dates, f)
        with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(coopthres_all_dates, f)
        with open(data_saved_subfolder+'/succ_rate_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(succ_rate_all_dates, f)
        with open(data_saved_subfolder+'/interpullintv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(interpullintv_all_dates, f)
        with open(data_saved_subfolder+'/trialnum_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(trialnum_all_dates, f)
        with open(data_saved_subfolder+'/bhv_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(bhv_intv_all_dates, f)
    
        with open(data_saved_subfolder+'/sess_videotimes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(sess_videotimes_all_dates, f)
    

In [None]:
tasktypes_all_dates

#### redefine the tasktype and cooperation threshold to merge them together

In [None]:
# 100: self; 3: 3s coop; 2: 2s coop; 1.5: 1.5s coop; 1: 1s coop; -1: no-vision

tasktypes_all_dates[tasktypes_all_dates==5] = -1 # change the task type code for no-vision
coopthres_forsort = (tasktypes_all_dates-1)*coopthres_all_dates/2
coopthres_forsort[coopthres_forsort==0] = 100 # get the cooperation threshold for sorting

### plot behavioral events interval to get a sense about time bin
#### only focus on pull_to_other_bhv_interval and other_bhv_to_pull_interval

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 5))
#
# sort the data based on task type and dates
sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
dates_list_sorted = np.array(dates_list)[sorting_df.index]
ndates_sorted = np.shape(dates_list_sorted)[0]

pull_other_intv_forplots = {}
pull_other_intv_mean = np.zeros((1,ndates_sorted))[0]
pull_other_intv_ii = []
for ii in np.arange(0,ndates_sorted,1):
    pull_other_intv_ii = pd.Series(bhv_intv_all_dates[dates_list_sorted[ii]]['pull_other_pooled'])
    # remove the interval that is too large
    pull_other_intv_ii[pull_other_intv_ii>(np.nanmean(pull_other_intv_ii)+2*np.nanstd(pull_other_intv_ii))]= np.nan    
    # pull_other_intv_ii[pull_other_intv_ii>10]= np.nan
    pull_other_intv_forplots[ii] = pull_other_intv_ii
    pull_other_intv_mean[ii] = np.nanmean(pull_other_intv_ii)
    
    
#
pull_other_intv_forplots = pd.DataFrame(pull_other_intv_forplots)

#
# plot
pull_other_intv_forplots.plot(kind = 'box',ax=ax1, positions=np.arange(0,ndates_sorted,1))
# plt.boxplot(pull_other_intv_forplots)
plt.plot(np.arange(0,ndates_sorted,1),pull_other_intv_mean,'r*',markersize=10)
#
ax1.set_ylabel("bhv event interval(around pulls)",fontsize=13)
ax1.set_ylim([-2,16])
#
plt.xticks(np.arange(0,ndates_sorted,1),dates_list_sorted, rotation=90,fontsize=10)
plt.yticks(fontsize=10)
#
tasktypes = ['MC']
taskswitches = np.where(np.array(sorting_df['coopthres'])[1:]-np.array(sorting_df['coopthres'])[:-1]!=0)[0]+0.5
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.plot([taskswitch,taskswitch],[-2,15],'k--')
taskswitches = np.concatenate(([0],taskswitches))
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.text(taskswitch+0.25,-1,tasktypes[itaskswitch],fontsize=10)
ax1.text(taskswitch-0,15,'mean Inteval = '+str(np.nanmean(pull_other_intv_forplots)),fontsize=10)

print(pull_other_intv_mean)
print(np.nanmean(pull_other_intv_forplots))

savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"bhvInterval_hist_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.jpg')

### plot some other basis behavioral measures
#### successful rate

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 5))
#
# sort the data based on task type and dates
sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
dates_list_sorted = np.array(dates_list)[sorting_df.index]
ndates_sorted = np.shape(dates_list_sorted)[0]


ax1.plot(np.arange(0,ndates_sorted,1),succ_rate_all_dates[sorting_df.index],'o',markersize=10)
#
ax1.set_ylabel("successful rate",fontsize=13)
ax1.set_ylim([-0.1,1.1])
ax1.set_xlim([-0.5,ndates_sorted-0.5])
#
plt.xticks(np.arange(0,ndates_sorted,1),dates_list_sorted, rotation=90,fontsize=10)
plt.yticks(fontsize=10)
#
tasktypes = ['MC']
taskswitches = np.where(np.array(sorting_df['coopthres'])[1:]-np.array(sorting_df['coopthres'])[:-1]!=0)[0]+0.5
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.plot([taskswitch,taskswitch],[-0.1,1.1],'k--')
taskswitches = np.concatenate(([0],taskswitches))
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.text(taskswitch+0.25,-0.05,tasktypes[itaskswitch],fontsize=10)
    
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"successfulrate_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.jpg')


#### animal pull numbers

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 5))
#
# sort the data based on task type and dates
sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
dates_list_sorted = np.array(dates_list)[sorting_df.index]
ndates_sorted = np.shape(dates_list_sorted)[0]

pullmean_num_all_dates = (pull1_num_all_dates+pull2_num_all_dates)/2
ax1.plot(np.arange(0,ndates_sorted,1),pull1_num_all_dates[sorting_df.index],'bv',markersize=5,label='animal1 pull #')
ax1.plot(np.arange(0,ndates_sorted,1),pull2_num_all_dates[sorting_df.index],'rv',markersize=5,label='animal2 pull #')
ax1.plot(np.arange(0,ndates_sorted,1),pullmean_num_all_dates[sorting_df.index],'kv',markersize=8,label='mean pull #')
ax1.legend()


#
ax1.set_ylabel("pull numbers",fontsize=13)
ax1.set_ylim([-20,240])
ax1.set_xlim([-0.5,ndates_sorted-0.5])

#
plt.xticks(np.arange(0,ndates_sorted,1),dates_list_sorted, rotation=90,fontsize=10)
plt.yticks(fontsize=10)
#
tasktypes = ['MC',]
taskswitches = np.where(np.array(sorting_df['coopthres'])[1:]-np.array(sorting_df['coopthres'])[:-1]!=0)[0]+0.5
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.plot([taskswitch,taskswitch],[-20,240],'k--')
taskswitches = np.concatenate(([0],taskswitches))
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.text(taskswitch+0.25,-10,tasktypes[itaskswitch],fontsize=10)
    
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"pullnumbers_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.jpg')


#### gaze number

In [None]:

gaze1_num_all_dates = owgaze1_num_all_dates + mtgaze1_num_all_dates
gaze2_num_all_dates = owgaze2_num_all_dates + mtgaze2_num_all_dates
gazemean_num_all_dates = (gaze1_num_all_dates+gaze2_num_all_dates)/2

print(np.nanmax(gaze1_num_all_dates))
print(np.nanmax(gaze2_num_all_dates))

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 5))
#
# sort the data based on task type and dates
sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
dates_list_sorted = np.array(dates_list)[sorting_df.index]
ndates_sorted = np.shape(dates_list_sorted)[0]



ax1.plot(np.arange(0,ndates_sorted,1),gaze1_num_all_dates[sorting_df.index],'b^',markersize=5,label='animal1 gaze #')
ax1.plot(np.arange(0,ndates_sorted,1),gaze2_num_all_dates[sorting_df.index],'r^',markersize=5,label='animal2 gaze #')
ax1.plot(np.arange(0,ndates_sorted,1),gazemean_num_all_dates[sorting_df.index],'k^',markersize=8,label='mean gaze #')
ax1.legend()


#
ax1.set_ylabel("social gaze number",fontsize=13)
ax1.set_ylim([-20,1500])
ax1.set_xlim([-0.5,ndates_sorted-0.5])

#
plt.xticks(np.arange(0,ndates_sorted,1),dates_list_sorted, rotation=90,fontsize=10)
plt.yticks(fontsize=10)
#
tasktypes = ['MC']
taskswitches = np.where(np.array(sorting_df['coopthres'])[1:]-np.array(sorting_df['coopthres'])[:-1]!=0)[0]+0.5
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.plot([taskswitch,taskswitch],[-20,1500],'k--')
taskswitches = np.concatenate(([0],taskswitches))
for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
    taskswitch = taskswitches[itaskswitch]
    ax1.text(taskswitch+0.25,-10,tasktypes[itaskswitch],fontsize=10)
    
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"gazenumbers_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.jpg')


In [None]:
gaze_numbers = (owgaze1_num_all_dates+owgaze2_num_all_dates+mtgaze1_num_all_dates+mtgaze2_num_all_dates)/30
gaze_pull_ratios = (owgaze1_num_all_dates+owgaze2_num_all_dates+mtgaze1_num_all_dates+mtgaze2_num_all_dates)/(pull1_num_all_dates+pull2_num_all_dates)/30

fig, ax1 = plt.subplots(figsize=(10, 5))

grouptypes = ['MC']

gaze_numbers_groups = [
                       np.transpose(gaze_numbers[np.transpose(coopthres_forsort==1)[0]])[0],
                       ]

gaze_numbers_plot = plt.boxplot(gaze_numbers_groups)

plt.xticks(np.arange(1, len(grouptypes)+1, 1), grouptypes, fontsize = 12);
ax1.set_ylim([-30/30,5400/30])
ax1.set_ylabel("average social gaze numbers",fontsize=13)

savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"averaged_gazenumbers_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pdf')


## plot the gaze numbers for all individuals 

In [None]:
if 1:

    animal1_fixedorders = ['dodson',       'ginger_withK', 'dannon']
    animal2_fixedorders = ['ginger_withD', 'kanga_withG',  'kanga_withD']
    
    animal1_filenames = ['dodson', 'ginger', 'dannon']
    animal2_filenames = ['ginger', 'kanga',  'kanga']
    
    nanimalpairs = np.shape(animal1_fixedorders)[0]

    grouptypes = ['MC',]
    coopthres_IDs = [ 1, ]
    
    ngrouptypes = np.shape(grouptypes)[0]

    gazenum_foreachgroup_foreachAni = pd.DataFrame(columns=['dates','condition','act_animal','gazenumber','pullnumber'])
    #

    #
    for igrouptype in np.arange(0,ngrouptypes,1):

        grouptype = grouptypes[igrouptype]
        coopthres_ID = coopthres_IDs[igrouptype]

        #
        for ianimalpair in np.arange(0,nanimalpairs,1):
            animal1 = animal1_fixedorders[ianimalpair]
            animal2 = animal2_fixedorders[ianimalpair]

            animal1_filename = animal1_filenames[ianimalpair]
            animal2_filename = animal2_filenames[ianimalpair]

            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_filename+animal2_filename+'/'
            with open(data_saved_subfolder+'/owgaze1_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                owgaze1_num_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/owgaze2_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                owgaze2_num_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/mtgaze1_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                mtgaze1_num_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/mtgaze2_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                mtgaze2_num_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/pull1_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                pull1_num_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/pull2_num_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                pull2_num_all_dates = pickle.load(f)

            with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                tasktypes_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                coopthres_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/succ_rate_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                succ_rate_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/interpullintv_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                interpullintv_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/trialnum_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                trialnum_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/bhv_intv_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                bhv_intv_all_dates = pickle.load(f)

            with open(data_saved_subfolder+'/sess_videotimes_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                sess_videotimes_all_dates = pickle.load(f)
            
            with open(data_saved_subfolder+'/dates_list_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                dates_list_all_dates = pickle.load(f)
            dates_list_all_dates = np.array(dates_list_all_dates)
        
            
            # combine owgaze and mtgaze
            gaze1_num_all_dates = (owgaze1_num_all_dates + mtgaze1_num_all_dates)/sess_videotimes_all_dates
            gaze2_num_all_dates = (owgaze2_num_all_dates + mtgaze2_num_all_dates)/sess_videotimes_all_dates

            #
            # 100: self; 3: 3s coop; 2: 2s coop; 1.5: 1.5s coop; 1: 1s coop; -1: no-vision
            tasktypes_all_dates[tasktypes_all_dates==5] = -1 # change the task type code for no-vision
            coopthres_forsort = (tasktypes_all_dates-1)*coopthres_all_dates/2
            coopthres_forsort[coopthres_forsort==0] = 100 # get the cooperation threshold for sorting

            dates_list_tgt = dates_list_all_dates[np.transpose(coopthres_forsort==coopthres_ID)[0]]
            gaze1_nums_tgt = gaze1_num_all_dates[coopthres_forsort==coopthres_ID]
            gaze2_nums_tgt = gaze2_num_all_dates[coopthres_forsort==coopthres_ID]
            pull1_nums_tgt = pull1_num_all_dates[coopthres_forsort==coopthres_ID]
            pull2_nums_tgt = pull2_num_all_dates[coopthres_forsort==coopthres_ID]
            ndates = np.shape(dates_list_tgt)[0]
            
            for idate in np.arange(0,ndates,1):
                date_tgt = dates_list_tgt[idate]
                gaze1_num = gaze1_nums_tgt[idate]
                gaze2_num = gaze2_nums_tgt[idate]
                pull1_num = pull1_nums_tgt[idate]
                pull2_num = pull2_nums_tgt[idate]
                
                gazenum_foreachgroup_foreachAni = gazenum_foreachgroup_foreachAni.append({'dates': date_tgt, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'gazenumber':gaze1_num,
                                                                                    'pullnumber':pull1_num,
                                                                                   }, ignore_index=True)
                
                gazenum_foreachgroup_foreachAni = gazenum_foreachgroup_foreachAni.append({'dates': date_tgt, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'gazenumber':gaze2_num,
                                                                                    'pullnumber':pull2_num,      
                                                                                   }, ignore_index=True)
                
            

            
    # for plot
    fig, axs = plt.subplots(2,ngrouptypes)
    fig.set_figheight(10)
    fig.set_figwidth(7*ngrouptypes)
    
    for igrouptype in np.arange(0,ngrouptypes,1):

        grouptype = grouptypes[igrouptype]

        gazenum_foreachgroup_foreachAni_toplot = gazenum_foreachgroup_foreachAni[gazenum_foreachgroup_foreachAni['condition']==grouptype]

        # seaborn.boxplot(ax=axs[0],data=gazenum_foreachgroup_foreachAni_toplot,
        #                 x='act_animal',y='pullnumber')  
        seaborn.violinplot(ax=axs[0],data=gazenum_foreachgroup_foreachAni_toplot,
                        x='act_animal',y='pullnumber')  
        axs[0].set_title('pull number')
        
        # seaborn.boxplot(ax=axs[1],data=gazenum_foreachgroup_foreachAni_toplot,
        #                 x='act_animal',y='gazenumber')  
        seaborn.violinplot(ax=axs[1],data=gazenum_foreachgroup_foreachAni_toplot,
                        x='act_animal',y='gazenumber')  
        axs[1].set_title('gaze number')
        
        # perform the anova on all animals
        if 0:
            import statsmodels.api as sm
            from statsmodels.formula.api import ols
            from statsmodels.stats.multicomp import pairwise_tukeyhsd

            # anova
            cw_lm=ols('pullnumber ~ act_animal', data=gazenum_foreachgroup_foreachAni_toplot).fit() #Specify C for Categorical
            print(sm.stats.anova_lm(cw_lm, typ=2))

            # post hoc test 
            tukey = pairwise_tukeyhsd(endog=gazenum_foreachgroup_foreachAni_toplot['pullnumber'], 
                                      groups=gazenum_foreachgroup_foreachAni_toplot['act_animal'], alpha=0.05)
            print(tukey)


            cw_lm=ols('gazenumber ~ act_animal', data=gazenum_foreachgroup_foreachAni_toplot).fit() #Specify C for Categorical
            print(sm.stats.anova_lm(cw_lm, typ=2))

            # post hoc test 
            tukey = pairwise_tukeyhsd(endog=gazenum_foreachgroup_foreachAni_toplot['gazenumber'], 
                                      groups=gazenum_foreachgroup_foreachAni_toplot['act_animal'], alpha=0.05)
            print(tukey)
    
        # perform t test
        if 1:
            data1 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='ginger_withD']['gazenumber']
            data2 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='ginger_withK']['gazenumber']
            t_stat, p_value = st.ttest_ind(data1, data2)
            print('Ginger with D or K gazenumber '+'ttest p value = '+str(p_value))
    
            data1 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='ginger_withD']['pullnumber']
            data2 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='ginger_withK']['pullnumber']
            t_stat, p_value = st.ttest_ind(data1, data2)
            print('Ginger with D or K pullnumber '+'ttest p value = '+str(p_value))
            
            data1 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='kanga_withD']['gazenumber']
            data2 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='kanga_withG']['gazenumber']
            t_stat, p_value = st.ttest_ind(data1, data2)
            print('Kanga with D or G gazenumber '+'ttest p value = '+str(p_value))
    
            data1 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='kanga_withD']['pullnumber']
            data2 = gazenum_foreachgroup_foreachAni_toplot[gazenum_foreachgroup_foreachAni_toplot['act_animal']=='kanga_withG']['pullnumber']
            t_stat, p_value = st.ttest_ind(data1, data2)
            print('Kanga with D or G pullnumber '+'ttest p value = '+str(p_value))
            
    savefigs = 1
    if savefigs:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)

        plt.savefig(figsavefolder+"socialgazenumber_pullnumber_summary_forallAnimals.pdf")

In [None]:
fig, axs = plt.subplots(1,2)
fig.set_figheight(4*1)
fig.set_figwidth(8*2)

gazenum_toplot = gazenum_foreachgroup_foreachAni[gazenum_foreachgroup_foreachAni['condition']=='MC']
    
# for Ginger
ind_G = (gazenum_toplot['act_animal']=='ginger_withD') | (gazenum_toplot['act_animal']=='ginger_withK')
gazenum_toplot_G = gazenum_toplot[ind_G]
gazenum_toplot_sorted = gazenum_toplot_G.sort_values(by=['dates'])

seaborn.lineplot(ax=axs[0],data=gazenum_toplot_sorted,
                 x='dates',y='gazenumber',color='darkgray') 
seaborn.scatterplot(ax=axs[0],data=gazenum_toplot_sorted,
                 x='dates',y='gazenumber',hue='act_animal',s=150) 
axs[0].set_ylabel('gaze number per second')


# for Kanga
ind_K = (gazenum_toplot['act_animal']=='kanga_withD') | (gazenum_toplot['act_animal']=='kanga_withG')
gazenum_toplot_K = gazenum_toplot[ind_K]
gazenum_toplot_sorted = gazenum_toplot_K.sort_values(by=['dates'])

seaborn.lineplot(ax=axs[1],data=gazenum_toplot_sorted,
                 x='dates',y='gazenumber',color='darkgray') 
seaborn.scatterplot(ax=axs[1],data=gazenum_toplot_sorted,
                 x='dates',y='gazenumber',hue='act_animal',s=150) 
axs[1].set_ylabel('gaze number per second')


plt.tight_layout()

savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    fig.savefig(figsavefolder+"socialgazenumber_ChangeOverDays_GingerAndKanga.pdf")

In [None]:
gazenum_foreachgroup_foreachAni

### prepare the input data for DBN

In [None]:
# define DBN related summarizing variables
DBN_input_data_alltypes = dict.fromkeys(dates_list, [])

doBhvitv_timebin = 0 # 1: if use the mean bhv event interval for time bin

prepare_input_data = 0

# DBN resolutions (make sure they are the same as in the later part of the code)
totalsess_time = 600 # total session time in s
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]

mergetempRos = 0

# # train the dynamic bayesian network - Alec's model 
#   prepare the multi-session table; one time lag; multi time steps (temporal resolution) as separate files

# prepare the DBN input data
if prepare_input_data:
    
    for idate in np.arange(0,ndates,1):
        date_tgt = dates_list[idate]
        session_start_time = session_start_times[idate]

        # load behavioral results
        try:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
        except:    
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            
        # get animal info
        animal1 = session_info['lever1_animal'][0].lower()
        animal2 = session_info['lever2_animal'][0].lower()
        
        # clean up the trial_record
        warnings.filterwarnings('ignore')
        trial_record_clean = pd.DataFrame(columns=trial_record.columns)
        for itrial in np.arange(0,np.max(trial_record['trial_number']),1):
            # trial_record_clean.loc[itrial] = trial_record[trial_record['trial_number']==itrial+1].iloc[[0]]
            trial_record_clean = trial_record_clean.append(trial_record[trial_record['trial_number']==itrial+1].iloc[[0]])
        trial_record_clean = trial_record_clean.reset_index(drop = True)

        # change bhv_data time to the absolute time
        time_points_new = pd.DataFrame(np.zeros(np.shape(bhv_data)[0]),columns=["time_points_new"])
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number']),1):
            ind = bhv_data["trial_number"]==itrial+1
            new_time_itrial = bhv_data[ind]["time_points"] + trial_record_clean["trial_starttime"].iloc[itrial]
            time_points_new["time_points_new"][ind] = new_time_itrial
        bhv_data["time_points"] = time_points_new["time_points_new"]
        bhv_data = bhv_data[bhv_data["time_points"] != 0]
            
        # get task type and cooperation threshold
        try:
            coop_thres = session_info["pulltime_thres"][0]
            tasktype = session_info["task_type"][0]
        except:
            coop_thres = 0
            tasktype = 1

        # load behavioral event results
        print('load social gaze with '+cameraID+' only of '+date_tgt)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'rb') as f:
            output_look_ornot = pickle.load(f)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'rb') as f:
            output_allvectors = pickle.load(f)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'rb') as f:
            output_allangles = pickle.load(f)  
        #
        look_at_other_or_not_merge = output_look_ornot['look_at_other_or_not_merge']
        look_at_tube_or_not_merge = output_look_ornot['look_at_tube_or_not_merge']
        look_at_lever_or_not_merge = output_look_ornot['look_at_lever_or_not_merge']
        # change the unit to second
        session_start_time = session_start_times[idate]
        look_at_other_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_other_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_lever_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_lever_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_tube_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_tube_or_not_merge['dodson'])[0],1)/fps - session_start_time 

        # redefine the totalsess_time for the length of each recording (NOT! remove the session_start_time)
        totalsess_time = int(np.ceil(np.shape(look_at_other_or_not_merge['dodson'])[0]/fps))
        
        # find time point of behavioral events
        output_time_points_socialgaze ,output_time_points_levertube = bhv_events_timepoint_singlecam(bhv_data,look_at_other_or_not_merge,look_at_lever_or_not_merge,look_at_tube_or_not_merge)
        time_point_pull1 = output_time_points_socialgaze['time_point_pull1']
        time_point_pull2 = output_time_points_socialgaze['time_point_pull2']
        oneway_gaze1 = output_time_points_socialgaze['oneway_gaze1']
        oneway_gaze2 = output_time_points_socialgaze['oneway_gaze2']
        mutual_gaze1 = output_time_points_socialgaze['mutual_gaze1']
        mutual_gaze2 = output_time_points_socialgaze['mutual_gaze2']   

        

        if mergetempRos:
            temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            # use bhv event to decide temporal resolution
            #
            #low_lim,up_lim,_ = bhv_events_interval(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            #temp_resolus = temp_resolus = np.arange(low_lim,up_lim,0.1)
        #
        if doBhvitv_timebin:
            pull_other_intv_ii = pd.Series(bhv_intv_all_dates[date_tgt]['pull_other_pooled'])
            # remove the interval that is too large
            pull_other_intv_ii[pull_other_intv_ii>(np.nanmean(pull_other_intv_ii)+2*np.nanstd(pull_other_intv_ii))]= np.nan    
            # pull_other_intv_ii[pull_other_intv_ii>10]= np.nan
            temp_resolus = [np.nanmean(pull_other_intv_ii)]          
        #
        ntemp_reses = np.shape(temp_resolus)[0]           

        
        # try different temporal resolutions
        for temp_resolu in temp_resolus:
            bhv_df = []

            if np.isin(animal1,animal1_fixedorder):
                bhv_df_itr,_,_ = train_DBN_multiLag_create_df_only(totalsess_time, session_start_time, temp_resolu, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            else:
                bhv_df_itr,_,_ = train_DBN_multiLag_create_df_only(totalsess_time, session_start_time, temp_resolu, time_point_pull2, time_point_pull1, oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)     

            if len(bhv_df)==0:
                bhv_df = bhv_df_itr
            else:
                bhv_df = pd.concat([bhv_df,bhv_df_itr])                   
                bhv_df = bhv_df.reset_index(drop=True)        

            DBN_input_data_alltypes[date_tgt] = bhv_df
            
    # save data
    if 1:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
        if not mergetempRos:
            if doBhvitv_timebin:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'bhvItvTempReSo.pkl', 'wb') as f:
                    pickle.dump(DBN_input_data_alltypes, f)
            else:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'wb') as f:
                    pickle.dump(DBN_input_data_alltypes, f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'wb') as f:
                pickle.dump(DBN_input_data_alltypes, f)     

In [None]:
# int(np.ceil(np.shape(look_at_other_or_not_merge['dodson'])[0]/fps-session_start_time))

#### plot the gaze distribution around pulls, analysis is based on the DBN_input_data all session format
#### similar plot was in "3LagDBN_and_SuccAndFailedPull_singlecam_wholebodylabels_allsessions_basicEvents" looking at the difference between successful and failed pulls
#### pool across all animals, compared self reward, 3s to 1s cooperation and no vision

#### get the half (max - min) width for selected conditions 

In [None]:
from scipy.interpolate import splrep, sproot, splev
import matplotlib.pyplot as plt 
from scipy.optimize import curve_fit 

class MultiplePeaks(Exception): pass
class NoPeaksFound(Exception): pass

def fwhm(x, y, k=10):
    """
    Determine full-with-half-maximum of a peaked set of points, x and y.

    Assumes that there is only one peak present in the datasset.  The function
    uses a spline interpolation of order k.
    """

    half_max = max(y)/2.0
    # half_max = y[round(np.shape(y)[0]/2)-1]
    s = splrep(x, y - half_max, k=k)
    roots = sproot(s)

    if len(roots) > 2:
    #     raise MultiplePeaks("The dataset appears to have multiple peaks, and "
    #             "thus the FWHM can't be determined.")
        # return np.nan
        return abs(roots[1] - roots[0])
    elif len(roots) < 2:
    #     raise NoPeaksFound("No proper peaks were found in the data set; likely "
    #             "the dataset is flat (e.g. all zeros).")
        # return np.max(x)-np.min(x)
        return np.nan
    else:
        return abs(roots[1] - roots[0])
        
        
#
# Define the Gaussian function 
def Gauss(x, A, B): 
    y = A*np.exp(-1*B*x**2) 
    return y 

# Define the Gaussian function
def gaussian(x, A, B, C):
    y = A*np.exp(-1*B*(x-C)**2) 
    return y 

In [None]:
if 0:
    # PLOT multiple pairs in one plot, so need to load data seperately
    mergetempRos = 0 # 1: merge different time bins
    minmaxfullSampSize = 1 # 1: use the  min row number and max row number, or the full row for each session
    moreSampSize = 0 # 1: use more sample size (more than just minimal row number and max row number)
    
    temp_resolu = 1
    dist_twin_range = 5
    
    #
    animal1_fixedorders = ['dodson',       'ginger_withK', 'dannon']
    animal2_fixedorders = ['ginger_withD', 'kanga_withG',  'kanga_withD']
    
    animal1_filenames = ['dodson', 'ginger', 'dannon']
    animal2_filenames = ['ginger', 'kanga',  'kanga']
    
    nanimalpairs = np.shape(animal1_fixedorders)[0]

    grouptypes = ['MC',]
    coopthres_IDs = [ 1, ]
    ngrouptypes = np.shape(grouptypes)[0]

    # initiate the final data set
    SameAnimal_gazeDist_mean_forEachAni = pd.DataFrame(columns=['dates','condition','act_animal','trig_average'])
    AcroAnimal_gazeDist_mean_forEachAni = pd.DataFrame(columns=['dates','condition','act_animal','trig_average'])
    # shuffle both the pull and gaze time stamp
    SameAnimal_gazeDist_shuffle_forEachAni = pd.DataFrame(columns=['dates','condition','act_animal','trig_average'])
    AcroAnimal_gazeDist_shuffle_forEachAni = pd.DataFrame(columns=['dates','condition','act_animal','trig_average'])
    
    
    #
    for igrouptype in np.arange(0,ngrouptypes,1):

        grouptype = grouptypes[igrouptype]
        coopthres_ID = coopthres_IDs[igrouptype]

        for ianimalpair in np.arange(0,nanimalpairs,1):
            animal1 = animal1_fixedorders[ianimalpair]
            animal2 = animal2_fixedorders[ianimalpair]
            #
            animal1_filename = animal1_filenames[ianimalpair]
            animal2_filename = animal2_filenames[ianimalpair]

            # load the basic behavioral measures
            # load saved data
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_filename+animal2_filename+'/'
            #
            with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                tasktypes_all_dates = pickle.load(f)
            with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_filename+animal2_filename+'.pkl', 'rb') as f:
                coopthres_all_dates = pickle.load(f)

            #     
            # load the DBN related analysis
            # load data
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_filename+animal2_filename+'/'
            #
            if not mergetempRos:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_filename+animal2_filename+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_filename+animal2_filename+'_mergeTempsReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)

            #
            # re-organize the target dates
            # 100: self; 3: 3s coop; 2: 2s coop; 1.5: 1.5s coop; 1: 1s coop; -1: no-vision
            tasktypes_all_dates[tasktypes_all_dates==5] = -1 # change the task type code for no-vision
            coopthres_forsort = (tasktypes_all_dates-1)*coopthres_all_dates/2
            coopthres_forsort[coopthres_forsort==0] = 100 # get the cooperation threshold for sorting


            #
            # sort the data based on task type and dates
            dates_list = list(DBN_input_data_alltypes.keys())
            sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
            sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
            #
            # only select the targeted dates
            sorting_tgt_df = sorting_df[(sorting_df['coopthres']==coopthres_ID)]
            dates_list_tgt = sorting_tgt_df['dates']
            dates_list_tgt = np.array(dates_list_tgt)
            #
            ndates_tgt = np.shape(dates_list_tgt)[0]

            # 
            for idate in np.arange(0,ndates_tgt,1):
                idate_name = dates_list_tgt[idate]

                DBN_input_data_idate = DBN_input_data_alltypes[idate_name]

                # pull1_t0 and gaze1_t0
                xxx1 = (np.array(DBN_input_data_idate['pull1_t0'])==1)*1
                xxx2 = (np.array(DBN_input_data_idate['owgaze1_t0'])==1)*1
                xxx1_shuffle = xxx1.copy()
                np.random.shuffle(xxx1_shuffle)
                xxx2_shuffle = xxx2.copy()
                np.random.shuffle(xxx2_shuffle)
                # pad the two sides
                xxx1 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1,np.zeros((1,dist_twin_range))[0]])
                xxx2 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2,np.zeros((1,dist_twin_range))[0]])
                xxx1_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1_shuffle,np.zeros((1,dist_twin_range))[0]])
                xxx2_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2_shuffle,np.zeros((1,dist_twin_range))[0]])
                # 
                npulls = int(np.nansum(xxx1))
                pullIDs = np.where(xxx1 == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                SameAnimal_gazeDist_mean_forEachAni = SameAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2)/np.sum(xxx1)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    SameAnimal_gazeDist_mean_forEachAni = SameAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # shuffle
                npulls = int(np.nansum(xxx1_shuffle))
                pullIDs = np.where(xxx1_shuffle == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2_shuffle[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                SameAnimal_gazeDist_shuffle_forEachAni = SameAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2_shuffle)/np.sum(xxx1_shuffle)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    SameAnimal_gazeDist_shuffle_forEachAni = SameAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # pull2_t0 and gaze2_t0
                xxx1 = (np.array(DBN_input_data_idate['pull2_t0'])==1)*1
                xxx2 = (np.array(DBN_input_data_idate['owgaze2_t0'])==1)*1
                xxx1_shuffle = xxx1.copy()
                np.random.shuffle(xxx1_shuffle)
                xxx2_shuffle = xxx2.copy()
                np.random.shuffle(xxx2_shuffle)
                # pad the two sides
                xxx1 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1,np.zeros((1,dist_twin_range))[0]])
                xxx2 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2,np.zeros((1,dist_twin_range))[0]])
                xxx1_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1_shuffle,np.zeros((1,dist_twin_range))[0]])
                xxx2_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2_shuffle,np.zeros((1,dist_twin_range))[0]])
                # 
                npulls = int(np.nansum(xxx1))
                pullIDs = np.where(xxx1 == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                SameAnimal_gazeDist_mean_forEachAni = SameAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2)/np.sum(xxx1)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    SameAnimal_gazeDist_mean_forEachAni = SameAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # shuffle
                npulls = int(np.nansum(xxx1_shuffle))
                pullIDs = np.where(xxx1_shuffle == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2_shuffle[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                SameAnimal_gazeDist_shuffle_forEachAni = SameAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2_shuffle)/np.sum(xxx1_shuffle)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    SameAnimal_gazeDist_shuffle_forEachAni = SameAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)

                # pull1_t0 and gaze2_t0
                xxx1 = (np.array(DBN_input_data_idate['pull1_t0'])==1)*1
                xxx2 = (np.array(DBN_input_data_idate['owgaze2_t0'])==1)*1
                xxx1_shuffle = xxx1.copy()
                np.random.shuffle(xxx1_shuffle)
                xxx2_shuffle = xxx2.copy()
                np.random.shuffle(xxx2_shuffle)
                # pad the two sides
                xxx1 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1,np.zeros((1,dist_twin_range))[0]])
                xxx2 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2,np.zeros((1,dist_twin_range))[0]])
                xxx1_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1_shuffle,np.zeros((1,dist_twin_range))[0]])
                xxx2_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2_shuffle,np.zeros((1,dist_twin_range))[0]])
                # 
                npulls = int(np.nansum(xxx1))
                pullIDs = np.where(xxx1 == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                AcroAnimal_gazeDist_mean_forEachAni = AcroAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2)/np.sum(xxx1)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    AcroAnimal_gazeDist_mean_forEachAni = AcroAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # shuffle
                npulls = int(np.nansum(xxx1_shuffle))
                pullIDs = np.where(xxx1_shuffle == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2_shuffle[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                AcroAnimal_gazeDist_shuffle_forEachAni = AcroAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2_shuffle)/np.sum(xxx1_shuffle)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    AcroAnimal_gazeDist_shuffle_forEachAni = AcroAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal2,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # pull2_t0 and gaze1_t0
                xxx1 = (np.array(DBN_input_data_idate['pull2_t0'])==1)*1
                xxx2 = (np.array(DBN_input_data_idate['owgaze1_t0'])==1)*1
                xxx1_shuffle = xxx1.copy()
                np.random.shuffle(xxx1_shuffle)
                xxx2_shuffle = xxx2.copy()
                np.random.shuffle(xxx2_shuffle)
                # pad the two sides
                xxx1 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1,np.zeros((1,dist_twin_range))[0]])
                xxx2 = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2,np.zeros((1,dist_twin_range))[0]])
                xxx1_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx1_shuffle,np.zeros((1,dist_twin_range))[0]])
                xxx2_shuffle = np.hstack([np.zeros((1,dist_twin_range))[0],xxx2_shuffle,np.zeros((1,dist_twin_range))[0]])
                # 
                npulls = int(np.nansum(xxx1))
                pullIDs = np.where(xxx1 == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                AcroAnimal_gazeDist_mean_forEachAni = AcroAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2)/np.sum(xxx1)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    AcroAnimal_gazeDist_mean_forEachAni = AcroAnimal_gazeDist_mean_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)
                # shuffle
                npulls = int(np.nansum(xxx1_shuffle))
                pullIDs = np.where(xxx1_shuffle == 1)[0]
                gazenum_dist_temp = np.zeros((npulls,2*dist_twin_range+1))
                #
                for ipull in np.arange(0,npulls,1):
                    pullID = pullIDs[ipull]
                    gazenum_dist_temp[ipull,:] = xxx2_shuffle[np.arange(pullID-dist_twin_range,pullID+dist_twin_range+1,1)]
                AcroAnimal_gazeDist_shuffle_forEachAni = AcroAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.nanmean(gazenum_dist_temp,axis=0)/(np.sum(xxx2_shuffle)/np.sum(xxx1_shuffle)),
                                                                                   }, ignore_index=True)
                if npulls == 0:
                    AcroAnimal_gazeDist_shuffle_forEachAni = AcroAnimal_gazeDist_shuffle_forEachAni.append({'dates': idate_name, 
                                                                                    'condition':grouptype,
                                                                                    'act_animal':animal1,
                                                                                    'trig_average':np.ones((1,2*dist_twin_range+1))[0]*np.nan,
                                                                                   }, ignore_index=True)

    #
    if 1:

        xxx = np.arange(-dist_twin_range,dist_twin_range+1,1)

        # for plot
        fig, axs = plt.subplots(ngrouptypes,2)
        fig.set_figheight(5*ngrouptypes)
        fig.set_figwidth(5*2)

        for iplottype in np.arange(0,2,1):
            
            for igrouptype in np.arange(0,ngrouptypes,1):

                grouptype = grouptypes[igrouptype]

                SameAnimal_mean_toplot = SameAnimal_gazeDist_mean_forEachAni[SameAnimal_gazeDist_mean_forEachAni['condition']==grouptype]
                SameAnimal_shuffle_toplot = SameAnimal_gazeDist_shuffle_forEachAni[SameAnimal_gazeDist_shuffle_forEachAni['condition']==grouptype]

                AcroAnimal_mean_toplot = AcroAnimal_gazeDist_mean_forEachAni[AcroAnimal_gazeDist_mean_forEachAni['condition']==grouptype]
                AcroAnimal_shuffle_toplot = AcroAnimal_gazeDist_shuffle_forEachAni[AcroAnimal_gazeDist_shuffle_forEachAni['condition']==grouptype]

                # plot, all animals in one figure
                conds_forplot = ['ginger_withK','ginger_withD','kanga_withG','kanga_withD','dodson','dannon']
                # conds_forplot = ['ginger_withK','ginger_withD','kanga_withG','kanga_withD',]

                gazeDist_average_forplot = dict.fromkeys(conds_forplot,[])
                gazeDist_std_forplot = dict.fromkeys(conds_forplot,[])
                gazeDist_average_shf_forplot = dict.fromkeys(conds_forplot,[])
                gazeDist_std_shf_forplot = dict.fromkeys(conds_forplot,[])
                for cond_forplot in conds_forplot:
                    if iplottype == 0:
                        gazeDist_average_forplot[cond_forplot] = np.nanmean(np.vstack(list(SameAnimal_mean_toplot[SameAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)
                        gazeDist_std_forplot[cond_forplot] = np.nanstd(np.vstack(list(SameAnimal_mean_toplot[SameAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)/np.sqrt(np.shape(np.vstack(list(SameAnimal_mean_toplot[SameAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])))[0])
                        #
                        gazeDist_average_shf_forplot[cond_forplot] = np.nanmean(np.vstack(list(SameAnimal_shuffle_toplot[SameAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)
                        gazeDist_std_shf_forplot[cond_forplot] = np.nanstd(np.vstack(list(SameAnimal_shuffle_toplot[SameAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)/np.sqrt(np.shape(np.vstack(list(SameAnimal_shuffle_toplot[SameAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])))[0])
                    elif iplottype == 1:
                        gazeDist_average_forplot[cond_forplot] = np.nanmean(np.vstack(list(AcroAnimal_mean_toplot[AcroAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)
                        gazeDist_std_forplot[cond_forplot] = np.nanstd(np.vstack(list(AcroAnimal_mean_toplot[AcroAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)/np.sqrt(np.shape(np.vstack(list(AcroAnimal_mean_toplot[AcroAnimal_mean_toplot['act_animal']==cond_forplot]['trig_average'])))[0])
                        #
                        gazeDist_average_shf_forplot[cond_forplot] = np.nanmean(np.vstack(list(AcroAnimal_shuffle_toplot[AcroAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)
                        gazeDist_std_shf_forplot[cond_forplot] = np.nanstd(np.vstack(list(AcroAnimal_shuffle_toplot[AcroAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])),axis=0)/np.sqrt(np.shape(np.vstack(list(AcroAnimal_shuffle_toplot[AcroAnimal_shuffle_toplot['act_animal']==cond_forplot]['trig_average'])))[0])

                    if ngrouptypes > 1:
                        axs[igrouptype,iplottype].errorbar(xxx,gazeDist_average_forplot[cond_forplot],
                                        gazeDist_std_forplot[cond_forplot],label=cond_forplot)
                        # axs[igrouptype,iplottype].errorbar(xxx,gazeDist_average_shf_forplot[cond_forplot],
                        #                 gazeDist_std_shf_forplot[cond_forplot],label="shuffled "+cond_forplot)
                    elif ngrouptypes == 1:
                        axs[iplottype].errorbar(xxx,gazeDist_average_forplot[cond_forplot],
                                        gazeDist_std_forplot[cond_forplot],label=cond_forplot)
                        # axs[iplottype].errorbar(xxx,gazeDist_average_shf_forplot[cond_forplot],
                        #                gazeDist_std_shf_forplot[cond_forplot],label="shuffled "+cond_forplot)
               
                if ngrouptypes > 1:        
                    axs[igrouptype,iplottype].plot([0,0],[0,1],'--',color='0.5')
                    axs[igrouptype,iplottype].set_xlim(-dist_twin_range-0.75,dist_twin_range+0.75)
                    axs[igrouptype,iplottype].set_ylim(0,0.3)
                    # axs[igrouptype,iplottype].set_xlabel('time (s)',fontsize=15)
                    axs[igrouptype,iplottype].set_ylabel('social gaze probability',fontsize=15)
                    axs[igrouptype,iplottype].legend()   
                    if iplottype == 0:
                        axs[igrouptype,iplottype].set_title('within animal: all animals',fontsize=16)   
                    elif iplottype == 1:
                        axs[igrouptype,iplottype].set_title('across animal: all animals',fontsize=16)

                elif ngrouptypes == 1:
                    axs[iplottype].plot([0,0],[0,1],'--',color='0.5')
                    axs[iplottype].set_xlim(-dist_twin_range-0.75,dist_twin_range+0.75)
                    axs[iplottype].set_ylim(0,0.3)
                    # axs[iplottype].set_xlabel('time (s)',fontsize=15)
                    axs[iplottype].set_ylabel('social gaze probability',fontsize=15)
                    axs[iplottype].legend()   
                    if iplottype == 0:
                        axs[iplottype].set_title('within animal: all animals',fontsize=16)   
                    elif iplottype == 1:
                        axs[iplottype].set_title('across animal: all animals',fontsize=16)

            
        
            

        savefigs = 1
        if savefigs:
            figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
            if not os.path.exists(figsavefolder):
                os.makedirs(figsavefolder)

            fig.savefig(figsavefolder+"socialgaze_distribution_summaryplot.pdf")

In [None]:
if 0:
    x =  np.arange(-dist_twin_range,dist_twin_range+1,1)

    conditions = ['ginger_withK','ginger_withD','kanga_withG','kanga_withD','dodson','dannon']
    # conditions = ['ginger_withK','ginger_withD','kanga_withG','kanga_withD',]
    nconds = np.shape(conditions)[0]

    halfwidth_all = dict.fromkeys(conditions)

    for icond in np.arange(0,nconds,1):

        condname = conditions[icond]

        y_allsess = np.array(AcroAnimal_gazeDist_mean_forEachAni[AcroAnimal_gazeDist_mean_forEachAni['act_animal']==condname]['trig_average'])
        nsess = np.shape(y_allsess)[0]

        halfwidth_all[condname] = np.ones((1,nsess))[0]*np.nan

        for isess in np.arange(0,nsess,1):

            try:
                y =  y_allsess[isess]
                y = (y-np.nanmin(y))/(np.nanmax(y)-np.nanmin(y))      

                # parameters, covariance = curve_fit(Gauss, x, y) 
                parameters, covariance = curve_fit(gaussian, x, y) 
                #
                fit_A = parameters[0] 
                fit_B = parameters[1] 
                fit_C = parameters[2] 
                #
                # fit_y = Gauss(x, fit_A, fit_B, fit_C) 
                fit_y = gaussian(x,fit_A,fit_B,fit_C)
                y = (fit_y-np.nanmin(fit_y))/(np.nanmax(fit_y)-np.nanmin(fit_y)) 

                halfwidth_all[condname][isess] = fwhm(x, y, k=3)

            except:
                halfwidth_all[condname][isess] = np.nan

    # box plot 
    fig, axs = plt.subplots(1,1)
    fig.set_figheight(5)
    fig.set_figwidth(5)

    # subplot 1 - all animals
    halfwidth_all_df = pd.DataFrame.from_dict(halfwidth_all,orient='index')
    halfwidth_all_df = halfwidth_all_df.transpose()
    halfwidth_all_df['type'] = 'all'
    #
    df_long=pd.concat([halfwidth_all_df])
    df_long2 = df_long.melt(id_vars=['type'], value_vars=conditions,var_name='condition', value_name='value')
    # 
    # barplot ans swarmplot
    seaborn.boxplot(ax=axs,data=df_long2,x='condition',y='value',hue='type')
    # seaborn.swarmplot(ax=axs,data=df_long2,x='condition',y='value',hue='type',
    #                   alpha=.9,size= 9,dodge=True,legend=False)
    axs.set_xlabel('')
    axs.set_xticklabels(conditions)
    axs.xaxis.set_tick_params(labelsize=15,rotation=45)
    axs.set_ylabel("half max width",fontsize=15)
    axs.set_title('all animals' ,fontsize=24)
    axs.set_ylim([0,10])
    axs.legend(fontsize=18)

    savefigs = 1
    if savefigs:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)

        plt.savefig(figsavefolder+"socialgaze_distribution_summaryplot_halfmaxWitdh.pdf")



In [None]:
if 0:
    df_long2 = df_long2[~np.isnan(df_long2.value)]
    # anova
    cw_lm=ols('value ~ condition', data=df_long2).fit() #Specify C for Categorical
    print(sm.stats.anova_lm(cw_lm, typ=2))

    # post hoc test 
    tukey = pairwise_tukeyhsd(endog=df_long2['value'], groups=df_long2['condition'], alpha=0.05)
    print(tukey)

### run the DBN model on the combined session data set

#### a test run

In [None]:
# run DBN on the large table with merged sessions

mergetempRos = 0 # 1: merge different time bins

minmaxfullSampSize = 1 # 1: use the  min row number and max row number, or the full row for each session

moreSampSize = 0 # 1: use more sample size (more than just minimal row number and max row number)

num_starting_points = 1 # number of random starting points/graphs
nbootstraps = 1

if 0:

    if moreSampSize:
        # different data (down/re)sampling numbers
        samplingsizes = np.arange(1100,3000,100)
        # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
        # samplingsizes = [100,500]
        # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
        samplingsizes_name = list(map(str, samplingsizes))
        nsamplings = np.shape(samplingsizes)[0]

    weighted_graphs_diffTempRo_diffSampSize = {}
    weighted_graphs_shuffled_diffTempRo_diffSampSize = {}
    sig_edges_diffTempRo_diffSampSize = {}
    DAGscores_diffTempRo_diffSampSize = {}
    DAGscores_shuffled_diffTempRo_diffSampSize = {}

    totalsess_time = 600 # total session time in s
    # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    ntemp_reses = np.shape(temp_resolus)[0]

    # try different temporal resolutions, remember to use the same settings as in the previous ones
    for temp_resolu in temp_resolus:

        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not mergetempRos:
            if doBhvitv_timebin:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'bhvItvTempReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)

                
        # only try three sample sizes
        #- minimal row number (require data downsample) and maximal row number (require data upsample)
        #- full row number of each session
        if minmaxfullSampSize:
            key_to_value_lengths = {k:len(v) for k, v in DBN_input_data_alltypes.items()}
            key_to_value_lengths_array = np.fromiter(key_to_value_lengths.values(),dtype=float)
            key_to_value_lengths_array[key_to_value_lengths_array==0]=np.nan
            min_samplesize = np.nanmin(key_to_value_lengths_array)
            min_samplesize = int(min_samplesize/100)*100
            max_samplesize = np.nanmax(key_to_value_lengths_array)
            max_samplesize = int(max_samplesize/100)*100
            #samplingsizes = [min_samplesize,max_samplesize,np.nan]
            #samplingsizes_name = ['min_row_number','max_row_number','full_row_number']
            samplingsizes = [np.nan]
            samplingsizes_name = ['full_row_number']
            nsamplings = np.shape(samplingsizes)[0]
            print(samplingsizes)
                
        # try different down/re-sampling size
        # for jj in np.arange(0,nsamplings,1):
        for jj in np.arange(0,1,1):
            
            isamplingsize = samplingsizes[jj]
            
            DAGs_alltypes = dict.fromkeys(dates_list, [])
            DAGs_shuffle_alltypes = dict.fromkeys(dates_list, [])
            DAGs_scores_alltypes = dict.fromkeys(dates_list, [])
            DAGs_shuffle_scores_alltypes = dict.fromkeys(dates_list, [])

            weighted_graphs_alltypes = dict.fromkeys(dates_list, [])
            weighted_graphs_shuffled_alltypes = dict.fromkeys(dates_list, [])
            sig_edges_alltypes = dict.fromkeys(dates_list, [])

            # different individual sessions
            ndates = np.shape(dates_list)[0]
            for idate in np.arange(0,ndates,1):
                date_tgt = dates_list[idate]
                
                if samplingsizes_name[jj]=='full_row_number':
                    isamplingsize = np.shape(DBN_input_data_alltypes[date_tgt])[0]

                try:
                    bhv_df_all = DBN_input_data_alltypes[date_tgt]

                    # define DBN graph structures; make sure they are the same as in the train_DBN_multiLag
                    colnames = list(bhv_df_all.columns)
                    eventnames = ["pull1","pull2","owgaze1","owgaze2"]
                    nevents = np.size(eventnames)

                    all_pops = list(bhv_df_all.columns)
                    from_pops = [pop for pop in all_pops if not pop.endswith('t3')]
                    to_pops = [pop for pop in all_pops if pop.endswith('t3')]
                    causal_whitelist = [(from_pop,to_pop) for from_pop in from_pops for to_pop in to_pops]

                    nFromNodes = np.shape(from_pops)[0]
                    nToNodes = np.shape(to_pops)[0]

                    DAGs_randstart = np.zeros((num_starting_points, nFromNodes, nToNodes))
                    DAGs_randstart_shuffle = np.zeros((num_starting_points, nFromNodes, nToNodes))
                    score_randstart = np.zeros((num_starting_points))
                    score_randstart_shuffle = np.zeros((num_starting_points))

                    # step 1: randomize the starting point for num_starting_points times
                    for istarting_points in np.arange(0,num_starting_points,1):

                        # try different down/re-sampling size
                        bhv_df = bhv_df_all.sample(isamplingsize,replace = True, random_state = istarting_points) # take the subset for DBN training
                        aic = AicScore(bhv_df)

                        #Anirban(Alec) shuffle, slow
                        bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))
                        aic_shuffle = AicScore(bhv_df_shuffle)

                        np.random.seed(istarting_points)
                        random.seed(istarting_points)
                        starting_edges = random.sample(causal_whitelist, np.random.randint(1,len(causal_whitelist)))
                        starting_graph = DAG()
                        starting_graph.add_nodes_from(nodes=all_pops)
                        starting_graph.add_edges_from(ebunch=starting_edges)

                        best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df,starting_graph,colnames,eventnames,from_pops,to_pops)           
                        DAGs[0][np.isnan(DAGs[0])]=0

                        DAGs_randstart[istarting_points,:,:] = DAGs[0]
                        score_randstart[istarting_points] = aic.score(best_model)

                        # step 2: add the shffled data results
                        # shuffled bhv_df
                        best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df_shuffle,starting_graph,colnames,eventnames,from_pops,to_pops)           
                        DAGs[0][np.isnan(DAGs[0])]=0

                        DAGs_randstart_shuffle[istarting_points,:,:] = DAGs[0]
                        score_randstart_shuffle[istarting_points] = aic_shuffle.score(best_model)

                    DAGs_alltypes[date_tgt] = DAGs_randstart 
                    DAGs_shuffle_alltypes[date_tgt] = DAGs_randstart_shuffle

                    DAGs_scores_alltypes[date_tgt] = score_randstart
                    DAGs_shuffle_scores_alltypes[date_tgt] = score_randstart_shuffle

                    weighted_graphs = get_weighted_dags(DAGs_alltypes[date_tgt],nbootstraps)
                    weighted_graphs_shuffled = get_weighted_dags(DAGs_shuffle_alltypes[date_tgt],nbootstraps)
                    sig_edges = get_significant_edges(weighted_graphs,weighted_graphs_shuffled)

                    weighted_graphs_alltypes[date_tgt] = weighted_graphs
                    weighted_graphs_shuffled_alltypes[date_tgt] = weighted_graphs_shuffled
                    sig_edges_alltypes[date_tgt] = sig_edges
                    
                except:
                    DAGs_alltypes[date_tgt] = [] 
                    DAGs_shuffle_alltypes[date_tgt] = []

                    DAGs_scores_alltypes[date_tgt] = []
                    DAGs_shuffle_scores_alltypes[date_tgt] = []

                    weighted_graphs_alltypes[date_tgt] = []
                    weighted_graphs_shuffled_alltypes[date_tgt] = []
                    sig_edges_alltypes[date_tgt] = []
                
            DAGscores_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_scores_alltypes
            DAGscores_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_shuffle_scores_alltypes

            weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_alltypes
            weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_shuffled_alltypes
            sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = sig_edges_alltypes

    print(weighted_graphs_diffTempRo_diffSampSize)
            
   

#### run on the entire population

In [None]:
# run DBN on the large table with merged sessions

mergetempRos = 0 # 1: merge different time bins

minmaxfullSampSize = 1 # 1: use the  min row number and max row number, or the full row for each session

moreSampSize = 0 # 1: use more sample size (more than just minimal row number and max row number)

num_starting_points = 100 # number of random starting points/graphs
nbootstraps = 95

try:
    # dumpy
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(data_saved_subfolder):
        os.makedirs(data_saved_subfolder)
    if moreSampSize:
        with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            DAGscores_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            DAGscores_shuffled_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)

    if minmaxfullSampSize:
        with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'rb') as f:
            DAGscores_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'rb') as f:
            DAGscores_shuffled_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)

except:
    if moreSampSize:
        # different data (down/re)sampling numbers
        samplingsizes = np.arange(1100,3000,100)
        # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
        # samplingsizes = [100,500]
        # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
        samplingsizes_name = list(map(str, samplingsizes))
        nsamplings = np.shape(samplingsizes)[0]

    weighted_graphs_diffTempRo_diffSampSize = {}
    weighted_graphs_shuffled_diffTempRo_diffSampSize = {}
    sig_edges_diffTempRo_diffSampSize = {}
    DAGscores_diffTempRo_diffSampSize = {}
    DAGscores_shuffled_diffTempRo_diffSampSize = {}

    totalsess_time = 600 # total session time in s
    # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    ntemp_reses = np.shape(temp_resolus)[0]

    # try different temporal resolutions, remember to use the same settings as in the previous ones
    for temp_resolu in temp_resolus:

        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not mergetempRos:
            if doBhvitv_timebin:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'bhvItvTempReSo.pkl', 'rb') as f:
                    DBN_input_data_allsessions = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                    DBN_input_data_allsessions = pickle.load(f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'rb') as f:
                DBN_input_data_alls = pickle.load(f)

                
        # only try three sample sizes
        #- minimal row number (require data downsample) and maximal row number (require data upsample)
        #- full row number of each session
        if minmaxfullSampSize:
            key_to_value_lengths = {k:len(v) for k, v in DBN_input_data_alltypes.items()}
            key_to_value_lengths_array = np.fromiter(key_to_value_lengths.values(),dtype=float)
            key_to_value_lengths_array[key_to_value_lengths_array==0]=np.nan
            min_samplesize = np.nanmin(key_to_value_lengths_array)
            min_samplesize = int(min_samplesize/100)*100
            max_samplesize = np.nanmax(key_to_value_lengths_array)
            max_samplesize = int(max_samplesize/100)*100
            # samplingsizes = [min_samplesize,max_samplesize,np.nan]
            # samplingsizes_name = ['min_row_number','max_row_number','full_row_number']   
            samplingsizes = [np.nan]
            samplingsizes_name = ['full_row_number']
            nsamplings = np.shape(samplingsizes)[0]
            print(samplingsizes)
                
        # try different down/re-sampling size
        for jj in np.arange(0,nsamplings,1):
            
            isamplingsize = samplingsizes[jj]
            
            DAGs_alltypes = dict.fromkeys(dates_list, [])
            DAGs_shuffle_alltypes = dict.fromkeys(dates_list, [])
            DAGs_scores_alltypes = dict.fromkeys(dates_list, [])
            DAGs_shuffle_scores_alltypes = dict.fromkeys(dates_list, [])

            weighted_graphs_alltypes = dict.fromkeys(dates_list, [])
            weighted_graphs_shuffled_alltypes = dict.fromkeys(dates_list, [])
            sig_edges_alltypes = dict.fromkeys(dates_list, [])

            # different individual sessions
            ndates = np.shape(dates_list)[0]
            for idate in np.arange(0,ndates,1):
                date_tgt = dates_list[idate]
                
                if samplingsizes_name[jj]=='full_row_number':
                    isamplingsize = np.shape(DBN_input_data_allsessions[date_tgt])[0]

                # try:
                bhv_df_all = DBN_input_data_alltypes[date_tgt]


                # define DBN graph structures; make sure they are the same as in the train_DBN_multiLag
                colnames = list(bhv_df_all.columns)
                eventnames = ["pull1","pull2","owgaze1","owgaze2"]
                nevents = np.size(eventnames)

                all_pops = list(bhv_df_all.columns)
                from_pops = [pop for pop in all_pops if not pop.endswith('t3')]
                to_pops = [pop for pop in all_pops if pop.endswith('t3')]
                causal_whitelist = [(from_pop,to_pop) for from_pop in from_pops for to_pop in to_pops]

                nFromNodes = np.shape(from_pops)[0]
                nToNodes = np.shape(to_pops)[0]

                DAGs_randstart = np.zeros((num_starting_points, nFromNodes, nToNodes))
                DAGs_randstart_shuffle = np.zeros((num_starting_points, nFromNodes, nToNodes))
                score_randstart = np.zeros((num_starting_points))
                score_randstart_shuffle = np.zeros((num_starting_points))

                # step 1: randomize the starting point for num_starting_points times
                for istarting_points in np.arange(0,num_starting_points,1):

                    # try different down/re-sampling size
                    bhv_df = bhv_df_all.sample(isamplingsize,replace = True, random_state = istarting_points) # take the subset for DBN training
                    aic = AicScore(bhv_df)

                    #Anirban(Alec) shuffle, slow
                    bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))
                    aic_shuffle = AicScore(bhv_df_shuffle)

                    np.random.seed(istarting_points)
                    random.seed(istarting_points)
                    starting_edges = random.sample(causal_whitelist, np.random.randint(1,len(causal_whitelist)))
                    starting_graph = DAG()
                    starting_graph.add_nodes_from(nodes=all_pops)
                    starting_graph.add_edges_from(ebunch=starting_edges)

                    best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df,starting_graph,colnames,eventnames,from_pops,to_pops)           
                    DAGs[0][np.isnan(DAGs[0])]=0

                    DAGs_randstart[istarting_points,:,:] = DAGs[0]
                    score_randstart[istarting_points] = aic.score(best_model)

                    # step 2: add the shffled data results
                    # shuffled bhv_df
                    best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df_shuffle,starting_graph,colnames,eventnames,from_pops,to_pops)           
                    DAGs[0][np.isnan(DAGs[0])]=0

                    DAGs_randstart_shuffle[istarting_points,:,:] = DAGs[0]
                    score_randstart_shuffle[istarting_points] = aic_shuffle.score(best_model)

                DAGs_alltypes[date_tgt] = DAGs_randstart 
                DAGs_shuffle_alltypes[date_tgt] = DAGs_randstart_shuffle

                DAGs_scores_alltypes[date_tgt] = score_randstart
                DAGs_shuffle_scores_alltypes[date_tgt] = score_randstart_shuffle

                weighted_graphs = get_weighted_dags(DAGs_alltypes[date_tgt],nbootstraps)
                weighted_graphs_shuffled = get_weighted_dags(DAGs_shuffle_alltypes[date_tgt],nbootstraps)
                sig_edges = get_significant_edges(weighted_graphs,weighted_graphs_shuffled)

                weighted_graphs_alltypes[date_tgt] = weighted_graphs
                weighted_graphs_shuffled_alltypes[date_tgt] = weighted_graphs_shuffled
                sig_edges_alltypes[date_tgt] = sig_edges
                    
                # except:
                #     DAGs_alltypes[date_tgt] = [] 
                #     DAGs_shuffle_alltypes[date_tgt] = []
                # 
                #     DAGs_scores_alltypes[date_tgt] = []
                #     DAGs_shuffle_scores_alltypes[date_tgt] = []
                # 
                #     weighted_graphs_alltypes[date_tgt] = []
                #     weighted_graphs_shuffled_alltypes[date_tgt] = []
                #     sig_edges_alltypes[date_tgt] = []
                
            DAGscores_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_scores_alltypes
            DAGscores_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_shuffle_scores_alltypes

            weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_alltypes
            weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_shuffled_alltypes
            sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = sig_edges_alltypes

            
    # save data
    savedata = 0
    if savedata:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
        if moreSampSize:  
            with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(sig_edges_diffTempRo_diffSampSize, f)
        elif minmaxfullSampSize:
            with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_minmaxfullSampSize.pkl', 'wb') as f:
                pickle.dump(sig_edges_diffTempRo_diffSampSize, f)        
        else:
            with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(DAGscores_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(DAGscores_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(sig_edges_diffTempRo_diffSampSize, f)


### plot the edges over time (session)
#### mean edge weights of selected edges

In [None]:
# 100: self; 3: 3s coop; 2: 2s coop; 1.5: 1.5s coop; 1: 1s coop; -1: no-vision
tasktypes_all_dates[tasktypes_all_dates==5] = -1 # change the task type code for no-vision
coopthres_forsort = (tasktypes_all_dates-1)*coopthres_all_dates/2
coopthres_forsort[coopthres_forsort==0] = 100 # get the cooperation threshold for sorting



#
# sort the data based on task type and dates
sorting_df = pd.DataFrame({'dates': dates_list, 'coopthres': coopthres_forsort.ravel()}, columns=['dates', 'coopthres'])
sorting_df = sorting_df.sort_values(by=['coopthres','dates'], ascending = [False, True])
dates_list_sorted = np.array(dates_list)[sorting_df.index]
ndates_sorted = np.shape(dates_list_sorted)[0]

In [None]:
dates_list

In [None]:
# make sure these variables are the same as in the previous steps
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]
#
if moreSampSize:
    # different data (down/re)sampling numbers
    # samplingsizes = np.arange(1100,3000,100)
    samplingsizes = [1100]
    # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
    # samplingsizes = [100,500]
    # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
    samplingsizes_name = list(map(str, samplingsizes))
elif minmaxfullSampSize:
    samplingsizes_name = ['full_row_number']   
nsamplings = np.shape(samplingsizes_name)[0]

temp_resolu = temp_resolus[0]
j_sampsize_name = samplingsizes_name[0]   

# 1s time lag
edges_target_names = [['1slag_pull2_pull1','1slag_pull1_pull2'],
                      ['1slag_gaze1_pull1','1slag_gaze2_pull2'],
                      ['1slag_pull2_gaze1','1slag_pull1_gaze2'],]
fromNodesIDs = [[ 9, 8],
                [10,11],
                [ 9, 8],]
toNodesIDs = [[0,1],
              [0,1],
              [2,3]]

n_edges = np.shape(np.array(edges_target_names).flatten())[0]

# figure initiate
fig, axs = plt.subplots(int(np.ceil(n_edges/2)),2)
fig.set_figheight(5*np.ceil(n_edges/2))
fig.set_figwidth(10*2)

#
for i_edge in np.arange(0,n_edges,1):
    #
    edgeweight_mean_forplot_all_dates = np.zeros((ndates_sorted,1))
    edgeweight_shuffled_mean_forplot_all_dates = np.zeros((ndates_sorted,1))
    edgeweight_std_forplot_all_dates = np.zeros((ndates_sorted,1))
    edgeweight_shuffled_std_forplot_all_dates = np.zeros((ndates_sorted,1))
    
    edge_tgt_name = np.array(edges_target_names).flatten()[i_edge]
    fromNodesID = np.array(fromNodesIDs).flatten()[i_edge]
    toNodesID = np.array(toNodesIDs).flatten()[i_edge]
    
    for idate in np.arange(0,ndates_sorted,1):
        idate_name = dates_list_sorted[idate]
        
        weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][idate_name]
        weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][idate_name]
    
        edgeweight_mean_forplot_all_dates[idate] = np.nanmean(weighted_graphs_tgt[:,fromNodesID,toNodesID])
        edgeweight_shuffled_mean_forplot_all_dates[idate] = np.nanmean(weighted_graphs_shuffled_tgt[:,fromNodesID,toNodesID])
        edgeweight_std_forplot_all_dates[idate] = np.nanstd(weighted_graphs_tgt[:,fromNodesID,toNodesID])
        edgeweight_shuffled_std_forplot_all_dates[idate] = np.nanstd(weighted_graphs_shuffled_tgt[:,fromNodesID,toNodesID])
        
      
    # plot 
    axs.flatten()[i_edge].plot(np.arange(0,ndates_sorted,1),edgeweight_mean_forplot_all_dates,'ko',markersize=10)
    #axs.flatten()[i_edge].plot(np.arange(0,ndates_sorted,1),edgeweight_shuffled_mean_forplot_all_dates,'bo',markersize=10)
    #
    axs.flatten()[i_edge].set_title(edge_tgt_name,fontsize=16)
    axs.flatten()[i_edge].set_ylabel('mean edge weight',fontsize=13)
    axs.flatten()[i_edge].set_ylim([-0.1,1.1])
    axs.flatten()[i_edge].set_xlim([-0.5,ndates_sorted-0.5])
    #
    if i_edge > int(n_edges-1):
        axs.flatten()[i_edge].set_xticks(np.arange(0,ndates_sorted,1))
        axs.flatten()[i_edge].set_xticklabels(dates_list_sorted, rotation=90,fontsize=10)
    else:
        axs.flatten()[i_edge].set_xticklabels('')
    #
    tasktypes = ['self','coop(3s)','coop(2s)','coop(1.5s)','coop(1s)','no-vision']
    taskswitches = np.where(np.array(sorting_df['coopthres'])[1:]-np.array(sorting_df['coopthres'])[:-1]!=0)[0]+0.5
    for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
        taskswitch = taskswitches[itaskswitch]
        axs.flatten()[i_edge].plot([taskswitch,taskswitch],[-0.1,1.1],'k--')
    taskswitches = np.concatenate(([0],taskswitches))
    for itaskswitch in np.arange(0,np.shape(taskswitches)[0],1):
        taskswitch = taskswitches[itaskswitch]
        axs.flatten()[i_edge].text(taskswitch+0.25,-0.05,tasktypes[itaskswitch],fontsize=10)


        
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"edgeweight_acrossAllSessions_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pdf')
    
    

In [None]:
dates_list_sorted

In [None]:
weighted_graphs_diffTempRo_diffSampSize[('1','full_row_number')].keys()

#### mean edge weights of selected edges v.s. other behavioral measures
##### only the cooperation days

In [None]:
# only select the targeted dates
# sorting_tgt_df = sorting_df[(sorting_df['coopthres']==1)|(sorting_df['coopthres']==1.5)|(sorting_df['coopthres']==2)|(sorting_df['coopthres']==3)]
# sorting_tgt_df = sorting_df[(sorting_df['coopthres']==1)|(sorting_df['coopthres']==2)]
sorting_tgt_df = sorting_df[(sorting_df['coopthres']==1)|(sorting_df['coopthres']==1.5)|(sorting_df['coopthres']==2)]
# sorting_tgt_df = sorting_df[(sorting_df['coopthres']==1)]
dates_list_tgt = sorting_tgt_df['dates']
dates_list_tgt = np.array(dates_list_tgt)
#
ndates_tgt = np.shape(dates_list_tgt)[0]

In [None]:
sorting_df

In [None]:
# make sure these variables are the same as in the previous steps
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]
#
if moreSampSize:
    # different data (down/re)sampling numbers
    # samplingsizes = np.arange(1100,3000,100)
    samplingsizes = [1100]
    # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
    # samplingsizes = [100,500]
    # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
    samplingsizes_name = list(map(str, samplingsizes))
elif minmaxfullSampSize:
    samplingsizes_name = ['full_row_number']   
nsamplings = np.shape(samplingsizes_name)[0]

temp_resolu = temp_resolus[0]
j_sampsize_name = samplingsizes_name[0]   

# 1s time lag
edges_target_names = [['1slag_pull2_pull1','1slag_pull1_pull2'],
                      ['1slag_gaze1_pull1','1slag_gaze2_pull2'],
                      ['1slag_pull2_gaze1','1slag_pull1_gaze2'],]
fromNodesIDs = [[ 9, 8],
                [10,11],
                [ 9, 8],]
toNodesIDs = [[0,1],
              [0,1],
              [2,3]]

#
xplottype = 'succrate' # 'succrate', 'meangazenum'
xplotlabel = 'successful rate' # 'successful rate', 'mean gaze number'
# xplottype = 'meangazenum' # 'succrate', 'meangazenum'
# xplotlabel = 'mean gaze number' # 'successful rate', 'mean gaze number'

n_edges = np.shape(np.array(edges_target_names).flatten())[0]

# figure initiate
fig, axs = plt.subplots(int(np.ceil(n_edges/2)),2)
fig.set_figheight(5*np.ceil(n_edges/2))
fig.set_figwidth(5*2)

#
for i_edge in np.arange(0,n_edges,1):
    #
    edgeweight_mean_forplot_all_dates = np.zeros((ndates_tgt,1))
    edgeweight_shuffled_mean_forplot_all_dates = np.zeros((ndates_tgt,1))
    edgeweight_std_forplot_all_dates = np.zeros((ndates_tgt,1))
    edgeweight_shuffled_std_forplot_all_dates = np.zeros((ndates_tgt,1))
    
    edge_tgt_name = np.array(edges_target_names).flatten()[i_edge]
    fromNodesID = np.array(fromNodesIDs).flatten()[i_edge]
    toNodesID = np.array(toNodesIDs).flatten()[i_edge]
    
    for idate in np.arange(0,ndates_tgt,1):
        idate_name = dates_list_tgt[idate]
        
        weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][idate_name]
        weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][idate_name]
    
        edgeweight_mean_forplot_all_dates[idate] = np.nanmean(weighted_graphs_tgt[:,fromNodesID,toNodesID])
        edgeweight_shuffled_mean_forplot_all_dates[idate] = np.nanmean(weighted_graphs_shuffled_tgt[:,fromNodesID,toNodesID])
        edgeweight_std_forplot_all_dates[idate] = np.nanstd(weighted_graphs_tgt[:,fromNodesID,toNodesID])
        edgeweight_shuffled_std_forplot_all_dates[idate] = np.nanstd(weighted_graphs_shuffled_tgt[:,fromNodesID,toNodesID])
        
      
    # plot 
    if xplottype == 'succrate':
        xxx = succ_rate_all_dates[sorting_tgt_df.index]
    elif xplottype == 'meangazenum':   
        xxx = gazemean_num_all_dates[sorting_tgt_df.index]
    #     
    yyy = edgeweight_mean_forplot_all_dates
    #
    rr_spe,pp_spe = scipy.stats.spearmanr(xxx, yyy)
    slope, intercept, rr_reg, pp_reg, std_err = st.linregress(xxx.astype(float).T[0], yyy.astype(float).T[0])
    #
    axs.flatten()[i_edge].plot(xxx,yyy,'bo',markersize=8)
    axs.flatten()[i_edge].plot(np.array([xxx.min(),xxx.max()]),np.array([xxx.min(),xxx.max()])*slope+intercept,'k-')
    #
    axs.flatten()[i_edge].set_title(edge_tgt_name,fontsize=16)
    axs.flatten()[i_edge].set_ylabel('mean edge weight',fontsize=13)
    axs.flatten()[i_edge].set_ylim([-0.1,1.1])
    #
    if i_edge > int(n_edges-3):
        axs.flatten()[i_edge].set_xlabel(xplotlabel,fontsize=13)
    else:
        axs.flatten()[i_edge].set_xticklabels('')
    #
    axs.flatten()[i_edge].text(xxx.min(),1.0,'spearman r='+"{:.2f}".format(rr_spe),fontsize=10)
    axs.flatten()[i_edge].text(xxx.min(),0.9,'spearman p='+"{:.2f}".format(pp_spe),fontsize=10)
    axs.flatten()[i_edge].text(xxx.min(),0.8,'regression r='+"{:.2f}".format(rr_reg),fontsize=10)
    axs.flatten()[i_edge].text(xxx.min(),0.7,'regression p='+"{:.2f}".format(pp_reg),fontsize=10)
    


        
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    plt.savefig(figsavefolder+"edgeweights_vs_"+xplottype+"_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pdf')
    
    

In [None]:
fromNodesIDs = [[ 9, 5],[ 8, 4],
                    [10, 6],[11, 7],
                    [ 9, 5],[ 8, 4],]
np.array(fromNodesIDs)[0]

## Plots that include all pairs
### Modulation index for Ginger between with kanga minus with dodson

In [None]:
# PLOT multiple pairs in one plot, so need to load data seperately
moreSampSize = 0
mergetempRos = 0 # 1: merge different time bins

temp_resolu = 1
j_sampsize_name = 'full_row_number'
if moreSampSize:
    j_sampsize_name = '3400'
condname = 'coop(1s)'
    
timelag = 0 # 1 or 2 or 3 or 0(merged - merge all three lags) or 12 (merged lag 1 and 2)
# timelagname = '1second' # '1/2/3second' or 'merged' or '12merged'
timelagname = 'merged' # together with timelag = 0
# timelagname = '12merged' # together with timelag = 12

#
if timelag == 1:
    pull_pull_fromNodes_all = [9,8]
    pull_pull_toNodes_all = [0,1]
    #
    gaze_gaze_fromNodes_all = [11,10]
    gaze_gaze_toNodes_all = [2,3]
    #
    within_pullgaze_fromNodes_all = [8,9]
    within_pullgaze_toNodes_all = [2,3]
    #
    across_pullgaze_fromNodes_all = [9,8]
    across_pullgaze_toNodes_all = [2,3]
    #
    within_gazepull_fromNodes_all = [10,11]
    within_gazepull_toNodes_all = [0,1]
    #
    across_gazepull_fromNodes_all = [11,10]
    across_gazepull_toNodes_all = [0,1]
    #
elif timelag == 2:
    pull_pull_fromNodes_all = [5,4]
    pull_pull_toNodes_all = [0,1]
    #
    gaze_gaze_fromNodes_all = [7,6]
    gaze_gaze_toNodes_all = [2,3]
    #
    within_pullgaze_fromNodes_all = [4,5]
    within_pullgaze_toNodes_all = [2,3]
    #
    across_pullgaze_fromNodes_all = [5,4]
    across_pullgaze_toNodes_all = [2,3]
    #
    within_gazepull_fromNodes_all = [6,7]
    within_gazepull_toNodes_all = [0,1]
    #
    across_gazepull_fromNodes_all = [7,6]
    across_gazepull_toNodes_all = [0,1]
    #
elif timelag == 3:
    pull_pull_fromNodes_all = [1,0]
    pull_pull_toNodes_all = [0,1]
    #
    gaze_gaze_fromNodes_all = [3,2]
    gaze_gaze_toNodes_all = [2,3]
    #
    within_pullgaze_fromNodes_all = [0,1]
    within_pullgaze_toNodes_all = [2,3]
    #
    across_pullgaze_fromNodes_all = [1,0]
    across_pullgaze_toNodes_all = [2,3]
    #
    within_gazepull_fromNodes_all = [2,3]
    within_gazepull_toNodes_all = [0,1]
    #
    across_gazepull_fromNodes_all = [3,2]
    across_gazepull_toNodes_all = [0,1]
    #
elif timelag == 0:
    pull_pull_fromNodes_all = [[1,5,9],[0,4,8]]
    pull_pull_toNodes_all = [[0,0,0],[1,1,1]]
    #
    gaze_gaze_fromNodes_all = [[3,7,11],[2,6,10]]
    gaze_gaze_toNodes_all = [[2,2,2],[3,3,3]]
    #
    within_pullgaze_fromNodes_all = [[0,4,8],[1,5,9]]
    within_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
    #
    across_pullgaze_fromNodes_all = [[1,5,9],[0,4,8]]
    across_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
    #
    within_gazepull_fromNodes_all = [[2,6,10],[3,7,11]]
    within_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
    #
    across_gazepull_fromNodes_all = [[3,7,11],[2,6,10]]
    across_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
    #
elif timelag == 12:
    pull_pull_fromNodes_all = [[5,9],[4,8]]
    pull_pull_toNodes_all = [[0,0],[1,1]]
    #
    gaze_gaze_fromNodes_all = [[7,11],[6,10]]
    gaze_gaze_toNodes_all = [[2,2],[3,3]]
    #
    within_pullgaze_fromNodes_all = [[4,8],[5,9]]
    within_pullgaze_toNodes_all = [[2,2],[3,3]]
    #
    across_pullgaze_fromNodes_all = [[5,9],[4,8]]
    across_pullgaze_toNodes_all = [[2,2],[3,3]]
    #
    within_gazepull_fromNodes_all = [[6,10],[7,11]]
    within_gazepull_toNodes_all = [[0,0],[1,1]]
    #
    across_gazepull_fromNodes_all = [[7,11],[6,10]]
    across_gazepull_toNodes_all = [[0,0],[1,1]]  
    
#    
weighted_all_df = pd.DataFrame(columns=['dependency','dates','act_animal','shuffleID','DepWeights'])
weighted_mean_df = pd.DataFrame(columns=['dependency','dates','act_animal','DepWeights'])

animal1_all = ['dodson','ginger','dannon']
animal2_all = ['ginger', 'kanga', 'kanga']

act_animal1_all = ['dodson',      'ginger_withK',      'dannon']
act_animal2_all = ['ginger_withD', 'kanga_withG', 'kanga_withD']

nanimalpairs = np.shape(animal1_all)[0]

for ianimalpair in np.arange(0,nanimalpairs,1):

    animal1 = animal1_all[ianimalpair]
    animal2 = animal2_all[ianimalpair]

    act_animal1 = act_animal1_all[ianimalpair]
    act_animal2 = act_animal2_all[ianimalpair]

    # load the DBN related analysis
    # load data
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_allsessions'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1+animal2+'/'
    #
    if moreSampSize:
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1+animal2+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1+animal2+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1+animal2+'_moreSampSize.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)
    if minmaxfullSampSize:
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1+animal2+'_minmaxfullSampSize.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1+animal2+'_minmaxfullSampSize.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1+animal2+'_minmaxfullSampSize.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)
    #
    if not mergetempRos:
        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1+animal2+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
            DBN_input_data_alltypes = pickle.load(f)
    else:
        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1+animal2+'_mergeTempsReSo.pkl', 'rb') as f:
            DBN_input_data_alltypes = pickle.load(f)

    weighted_graphs_MC1 = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)]
    weighted_graphs_sf_MC1 = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)]
    sig_edges_MC1 = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)]

    dateslist = list(weighted_graphs_MC1.keys())
    ndates = np.shape(dateslist)[0]

    for idate in np.arange(0,ndates,1):
        date_tgt = dateslist[idate]
        weighted_graphs_tgt = weighted_graphs_MC1[date_tgt]
        weighted_graphs_sf_tgt = weighted_graphs_sf_MC1[date_tgt]
        sig_edges_tgt = sig_edges_MC1[date_tgt]

        #
        sig_edges_tgt = sig_edges_tgt.astype('float')
        sig_edges_tgt[sig_edges_tgt==0] = np.nan
        # 
        # weighted_graphs_tgt = weighted_graphs_tgt * sig_edges_tgt

        #
        for ianimal in np.arange(0,2,1):

            if ianimal == 0:
                act_animal = act_animal1
            elif ianimal == 1:
                act_animal = act_animal2

            #                
            # pull-pull
            a1 = (weighted_graphs_tgt[:,pull_pull_fromNodes_all[ianimal],pull_pull_toNodes_all[ianimal]]).flatten()
            xxx1 = np.nanmean(a1)
            # gaze-gaze
            a2 = (weighted_graphs_tgt[:,gaze_gaze_fromNodes_all[ianimal],gaze_gaze_toNodes_all[ianimal]]).flatten()
            xxx2 = np.nanmean(a2)
            # within animal gazepull
            a3 = (weighted_graphs_tgt[:,within_gazepull_fromNodes_all[ianimal],within_gazepull_toNodes_all[ianimal]]).flatten()
            xxx3 = np.nanmean(a3)
            # across animal gazepull
            a4 = (weighted_graphs_tgt[:,across_gazepull_fromNodes_all[ianimal],across_gazepull_toNodes_all[ianimal]]).flatten()
            xxx4 = np.nanmean(a4)
            # within animal pullgaze
            a5 = (weighted_graphs_tgt[:,within_pullgaze_fromNodes_all[ianimal],within_pullgaze_toNodes_all[ianimal]]).flatten()
            xxx5 = np.nanmean(a5)
            # across animal pullgaze
            a6 = (weighted_graphs_tgt[:,across_pullgaze_fromNodes_all[ianimal],across_pullgaze_toNodes_all[ianimal]]).flatten()
            xxx6 = np.nanmean(a6)

            # fill up the weighted_all_df
            nshuffles = np.shape(a1)[0]
            for ishuffle in np.arange(0,nshuffles,1):
                # pull-pull
                weighted_all_df = weighted_all_df.append({'dependency':'pull-pull',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a1[ishuffle]},ignore_index=True)
                # gaze-gaze
                weighted_all_df = weighted_all_df.append({'dependency':'gaze-gaze',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a2[ishuffle]},ignore_index=True)
                # within animal gazepull
                weighted_all_df = weighted_all_df.append({'dependency':'within animal gazepull',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a3[ishuffle]},ignore_index=True)
                # across animal gazepull
                weighted_all_df = weighted_all_df.append({'dependency':'across animal gazepull',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a4[ishuffle]},ignore_index=True)
                # within animal pullgaze
                weighted_all_df = weighted_all_df.append({'dependency':'within animal pullgaze',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a5[ishuffle]},ignore_index=True)
                # across animal pullgaze
                weighted_all_df = weighted_all_df.append({'dependency':'across animal pullgaze',
                                                          'act_animal': act_animal,
                                                          'dates':date_tgt,
                                                          'shuffleID':ishuffle,
                                                          'DepWeights':a6[ishuffle]},ignore_index=True)

            # fill up the weighted_mean_df
            # pull-pull
            weighted_mean_df = weighted_mean_df.append({'dependency':'pull-pull',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx1},ignore_index=True)
            # gaze-gaze
            weighted_mean_df = weighted_mean_df.append({'dependency':'gaze-gaze',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx2},ignore_index=True)
            # within animal gazepull
            weighted_mean_df = weighted_mean_df.append({'dependency':'within animal gazepull',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx3},ignore_index=True)
            # across animal gazepull
            weighted_mean_df = weighted_mean_df.append({'dependency':'across animal gazepull',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx4},ignore_index=True)
            # within animal pullgaze
            weighted_mean_df = weighted_mean_df.append({'dependency':'within animal pullgaze',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx5},ignore_index=True)
            # across animal pullgaze
            weighted_mean_df = weighted_mean_df.append({'dependency':'across animal pullgaze',
                                                      'act_animal': act_animal,
                                                      'dates':date_tgt,
                                                      'DepWeights':xxx6},ignore_index=True)

            
# for plot     
fig, axs = plt.subplots(1,1)
fig.set_figheight(5)
fig.set_figwidth(8)

#
toplotorder = ['ginger_withD','ginger_withK','kanga_withD','kanga_withG']
weighted_mean_toplot = weighted_mean_df[np.isin(weighted_mean_df['act_animal'],toplotorder)]
weighted_mean_toplot['act_animal'] = pd.Categorical(weighted_mean_toplot['act_animal'], 
                                                    categories=toplotorder, ordered=True)
weighted_mean_toplot_sorted = weighted_mean_toplot.sort_values(by='act_animal')

#
seaborn.boxplot(ax=axs,data=weighted_mean_toplot_sorted,x='dependency',y='DepWeights',hue='act_animal') 
# seaborn.violinplot(ax=axs,data=weighted_mean_toplot_sorted,x='dependency',y='DepWeights',hue='act_animal') 
axs.set_title('time lag: '+timelagname)
axs.set_xlabel('dependencies')
axs.set_xticklabels(axs.get_xticklabels(),rotation=45)
    
savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    fig.savefig(figsavefolder+"DependencyWeights_GingerAndKanga_"+timelagname+'timelag.pdf')
    

In [None]:
# ind1=(weighted_mean_toplot['dependency']=='across animal pullgaze')&(weighted_mean_toplot['act_animal']=='ginger_withD')
# ind2=(weighted_mean_toplot['dependency']=='across animal pullgaze')&(weighted_mean_toplot['act_animal']=='ginger_withK')
# ind1=(weighted_mean_toplot['dependency']=='across animal pullgaze')&(weighted_mean_toplot['act_animal']=='kanga_withD')
# ind2=(weighted_mean_toplot['dependency']=='across animal pullgaze')&(weighted_mean_toplot['act_animal']=='kanga_withG')

# ind1=(weighted_mean_toplot['dependency']=='within animal gazepull')&(weighted_mean_toplot['act_animal']=='ginger_withD')
# ind2=(weighted_mean_toplot['dependency']=='within animal gazepull')&(weighted_mean_toplot['act_animal']=='ginger_withK')
ind1=(weighted_mean_toplot['dependency']=='within animal gazepull')&(weighted_mean_toplot['act_animal']=='kanga_withD')
ind2=(weighted_mean_toplot['dependency']=='within animal gazepull')&(weighted_mean_toplot['act_animal']=='kanga_withG')

xx1 = np.array(weighted_mean_toplot[ind1]['DepWeights'])
xx2 = np.array(weighted_mean_toplot[ind2]['DepWeights'])

st.ttest_ind(xx1[xx1!=np.nan],xx2[xx2!=np.nan])


In [None]:
# for plot   
warnings.filterwarnings("ignore", category=UserWarning) 

dependnames = np.unique(weighted_mean_df['dependency'])
ndependnames = np.shape(dependnames)[0]

fig, axs = plt.subplots(ndependnames,2)
fig.set_figheight(4*ndependnames)
fig.set_figwidth(8*2)

for idependname in np.arange(0,ndependnames,1):
    
    dependname = dependnames[idependname]
    
    weighted_mean_toplot = weighted_mean_df[weighted_mean_df['dependency']==dependname]
    
    # for Ginger
    ind_G = (weighted_mean_toplot['act_animal']=='ginger_withD') | (weighted_mean_toplot['act_animal']=='ginger_withK')
    weighted_mean_toplot_G = weighted_mean_toplot[ind_G]
    weighted_mean_toplot_sorted = weighted_mean_toplot_G.sort_values(by=['dates'])
    
    seaborn.lineplot(ax=axs[idependname,0],data=weighted_mean_toplot_sorted,
                     x='dates',y='DepWeights',color='darkgray') 
    seaborn.scatterplot(ax=axs[idependname,0],data=weighted_mean_toplot_sorted,
                     x='dates',y='DepWeights',hue='act_animal',s=150) 
    axs[idependname,0].set_title('time lag: '+timelagname)
    axs[idependname,0].set_ylabel('dependency weight:'+dependname)
    
    
    # for Kanga
    ind_K = (weighted_mean_toplot['act_animal']=='kanga_withD') | (weighted_mean_toplot['act_animal']=='kanga_withG')
    weighted_mean_toplot_K = weighted_mean_toplot[ind_K]
    weighted_mean_toplot_sorted = weighted_mean_toplot_K.sort_values(by=['dates'])
    
    seaborn.lineplot(ax=axs[idependname,1],data=weighted_mean_toplot_sorted,
                     x='dates',y='DepWeights',color='darkgray') 
    seaborn.scatterplot(ax=axs[idependname,1],data=weighted_mean_toplot_sorted,
                     x='dates',y='DepWeights',hue='act_animal',s=150) 
    axs[idependname,1].set_title('time lag: '+timelagname)
    axs[idependname,1].set_ylabel('dependency weight:'+dependname)
    
    
plt.tight_layout()

savefigs = 1
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_allsessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    fig.savefig(figsavefolder+"DependencyWeights_ChangeOverDays_GingerAndKanga_"+timelagname+'timelag.pdf')
    

In [None]:
axs[idependname,0].get_xticklabels()