### In this script, DBN is run on the combined sessions, combined for each condition
### In this script, DBN is run with 1s time bin, 3 time lag 
### In this script, the animal tracking is done with only one camera - camera 2 (middle) 
### In this script, DBN structures also consider within layer edges (xx_t0 to yy_t0)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn
import scipy
import scipy.stats as st
import sklearn
from sklearn.neighbors import KernelDensity
import string
import warnings
import pickle

import os
import glob
import random
from time import time

from pgmpy.models import BayesianModel
from pgmpy.models import DynamicBayesianNetwork as DBN
from pgmpy.estimators import BayesianEstimator
from pgmpy.estimators import HillClimbSearch,BicScore
from pgmpy.base import DAG
import networkx as nx


### function - get body part location for each pair of cameras

In [None]:
from ana_functions.body_part_locs_eachpair import body_part_locs_eachpair
from ana_functions.body_part_locs_singlecam import body_part_locs_singlecam

### function - align the two cameras

In [None]:
from ana_functions.camera_align import camera_align       

### function - merge the two pairs of cameras

In [None]:
from ana_functions.camera_merge import camera_merge

### function - find social gaze time point

In [None]:
from ana_functions.find_socialgaze_timepoint import find_socialgaze_timepoint
from ana_functions.find_socialgaze_timepoint_singlecam import find_socialgaze_timepoint_singlecam
from ana_functions.find_socialgaze_timepoint_singlecam_wholebody import find_socialgaze_timepoint_singlecam_wholebody

### function - define time point of behavioral events

In [None]:
from ana_functions.bhv_events_timepoint import bhv_events_timepoint
from ana_functions.bhv_events_timepoint_singlecam import bhv_events_timepoint_singlecam

### function - plot behavioral events

In [None]:
from ana_functions.plot_bhv_events import plot_bhv_events
from ana_functions.plot_bhv_events_levertube import plot_bhv_events_levertube
from ana_functions.draw_self_loop import draw_self_loop
import matplotlib.patches as mpatches 
from matplotlib.collections import PatchCollection

### function - make demo videos with skeleton and inportant vectors

In [None]:
from ana_functions.tracking_video_singlecam_demo import tracking_video_singlecam_demo
from ana_functions.tracking_video_singlecam_wholebody_demo import tracking_video_singlecam_wholebody_demo

### function - interval between all behavioral events

In [None]:
from ana_functions.bhv_events_interval import bhv_events_interval
from ana_functions.bhv_events_interval import bhv_events_interval_certainEdges

### function - train the dynamic bayesian network - multi time lag (3 lags)

In [None]:
from ana_functions.train_DBN_multiLag_withinLayerEdges import train_DBN_multiLag
from ana_functions.train_DBN_multiLag_withinLayerEdges import train_DBN_multiLag_create_df_only
from ana_functions.train_DBN_multiLag_withinLayerEdges import train_DBN_multiLag_training_only
from ana_functions.train_DBN_multiLag_withinLayerEdges import graph_to_matrix
from ana_functions.train_DBN_multiLag_withinLayerEdges import get_weighted_dags
from ana_functions.train_DBN_multiLag_withinLayerEdges import get_significant_edges
from ana_functions.train_DBN_multiLag_withinLayerEdges import threshold_edges
from ana_functions.train_DBN_multiLag_withinLayerEdges import Modulation_Index
from ana_functions.EfficientTimeShuffling import EfficientShuffle
from ana_functions.AicScore import AicScore

## Analyze each session

### prepare the basic behavioral data (especially the time stamps for each bhv events)

In [None]:
# instead of using gaze angle threshold, use the target rectagon to deside gaze info
# ...need to update
sqr_thres_tubelever = 75 # draw the square around tube and lever
sqr_thres_face = 1.15 # a ratio for defining face boundary
sqr_thres_body = 4 # how many times to enlongate the face box boundry to the body


# get the fps of the analyzed video
fps = 30

# frame number of the demo video
# nframes = 0.5*30 # second*30fps
nframes = 2*30 # second*30fps

# re-analyze the video or not
reanalyze_video = 0
redo_anystep = 0

# session list options
do_bestsession = 1 # only analyze the best (five) sessions for each conditions during the training phase
do_trainedMCs = 1 # the list that only consider trained (1s) MC, together with SR and NV as controls
if do_bestsession:
    if not do_trainedMCs:
        savefile_sufix = '_bestsessions'
    elif do_trainedMCs:
        savefile_sufix = '_trainedMCsessions'
else:
    savefile_sufix = ''
    
# all the videos (no misaligned ones)
# aligned with the audio
# get the session start time from "videosound_bhv_sync.py/.ipynb"
# currently the session_start_time will be manually typed in. It can be updated after a better method is used

# dodson scorch
if 1:
    if not do_bestsession:
        dates_list = [
            
                     ]
        session_start_times = [ 
            
                              ] # in second
    elif do_bestsession:
        if not do_trainedMCs:
            # pick only five sessions for each conditions during the training phase
            dates_list = [
                          # "20220912",
                          "20220915","20220920","20221010","20230208",
                          "20221011","20221013","20221015","20221017",
                          "20221022","20221026","20221028","20221030","20230209",
                          "20221125","20221128","20221129","20230214","20230215",                  
                          "20221205","20221206","20221209","20221214","20230112",
                          "20230117","20230118","20230124",
                          # "20230126",
                         ]
            session_start_times = [ 
                                    # 18.10, 
                                     0.00, 33.03,  6.50,  0.00, 
                                     2.80, 27.80, 27.90, 27.00,  
                                    51.90, 21.00, 30.80, 17.50,  0.00,                    
                                    26.40, 22.50, 28.50,  0.00, 33.00,                     
                                     0.00,  0.00, 21.70, 17.00, 14.20, 
                                     0.00,  0.00,  0.00, 
                                     # 0.00,  
                                  ] # in second
        elif do_trainedMCs:
            dates_list = [
                          "20220915","20220920","20221010","20230208", # SR
                          
                          "20230321","20230322","20230323","20230324","20230412","20230413", # trained MC
                          
                          "20230117","20230118","20230124", # NV 
                         ]
            session_start_times = [ 
                                     0.00, 33.03,  6.50,  0.00, 
                                     
                                     20.5,  21.4,  21.0,  24.5,  20.5,  26.6,
                    
                                     0.00,  0.00,  0.00,  
                                  ] # in second
    
    animal1_fixedorder = ['dodson']
    animal2_fixedorder = ['scorch']

    animal1_filename = "Dodson"
    animal2_filename = "Scorch"
    
    
# eddie sparkle
if 1:
    if not do_bestsession:
        dates_list = [
                                    
                   ]
        session_start_times = [ 
                                 
                              ] # in second
    elif do_bestsession:   
        if not do_trainedMCs:
            # pick only five sessions for each conditions during the training phase
            dates_list = [
                          "20221122",  "20221125",  
                          "20221202",  "20221206",  "20230126",  "20230130",  "20230201",
                          "20230207",  "20230208-1","20230209",  "20230222",  "20230223-1",
                          "20230227-1","20230228-1","20230302-1","20230307-2","20230313",
                          "20230321",  "20230322",  "20230324",  "20230327",  "20230328",
                          "20230331",  "20230403",  "20230404",  "20230405",  "20230406"
                       ]
            session_start_times = [ 
                                      8.00,  38.00, 
                                      9.50,   1.00, 38.00,  4.20,  3.80,
                                      9.00,   7.50,  8.50, 14.50,  7.80,
                                      8.00,   7.50,  8.00,  8.00,  4.00,
                                      7.00,   7.50,  5.50, 11.00,  9.00,
                                      4.50,   9.30, 25.50, 20.40, 21.30,
                                  ] # in second
        elif do_trainedMCs:
            dates_list = [
                          "20221122",  "20221125",  # sr
                
                          "20230410",  "20230411",  "20230412",  "20230413",  "20230616", # trained MC
                
                          "20230331",  "20230403",  "20230404",  "20230405",  "20230406", # nv
                       ]
            session_start_times = [ 
                                      8.00, 38.00, 
                
                                      23.2,  23.0,  21.2,  25.0,  23.0,   
                
                                      4.50,  9.30, 25.50, 20.40, 21.30,
                
                                  ] # in second
    animal1_fixedorder = ['eddie']
    animal2_fixedorder = ['sparkle']

    animal1_filename = "Eddie"
    animal2_filename = "Sparkle"
    
    
# ginger kanga
if 1:
    if not do_bestsession:
        dates_list = [
                      
                   ]
        session_start_times = [ 
                                
                              ] # in second 
    elif do_bestsession:
        if not do_trainedMCs:
            # pick only five sessions for each conditions during the training phase
            dates_list = [
                          #"20230213",
                          "20230214","20230216",
                          "20230228","20230302","20230303","20230307",          
                          "20230314","20230315","20230316","20230317",
                          "20230301","20230320","20230321","20230322",
                          "20230323","20230412","20230413","20230517",
                          "20230522_ws","20230524","20230605_1","20230606","20230607"
                       ]
            session_start_times = [ 
                                    # 0.00, 
                                     0.00, 48.00, 
                                    23.00, 28.50, 34.00, 25.50, 
                                    25.50, 31.50, 28.00, 30.50,
                                     0.00,  0.00,  0.00,  0.00, 
                                     0.00,  0.00,  0.00,  0.00, 
                                     0.00,  0.00,  0.00,  0.00,  0.00,
                                  ] # in second 
        elif do_trainedMCs:
            dates_list = [
                          "20230214",   "20230216",  # SR
                          
                          "20230614",   "20230615",  "20230711","20230712", # trained MC
                
                          "20230522_ws","20230524","20230605_1","20230606","20230607", # nv  
                       ]
            session_start_times = [ 
                                     0.00, 48.00, 
                                    
                                     0.00,  0.00,  54.5,  24.7,
                
                                     0.00,  0.00,  0.00,  0.00,  0.00,
                                  ] # in second 
    
    animal1_fixedorder = ['ginger']
    animal2_fixedorder = ['kanga']

    animal1_filename = "Ginger"
    animal2_filename = "Kanga"

    
# dannon kanga
if 1:
    if not do_bestsession:
        dates_list = [
                    
                   ]
        session_start_times = [ 
                              
                              ] # in second 
    elif do_bestsession: 
        if not do_trainedMCs:
            # pick only five sessions for each conditions during the training phase
            dates_list = [
                          "20230718","20230720","20230914","20230726","20230727","20230809",
                          "20230810","20230811","20230814","20230816","20230829","20230907","20230915",
                          "20230918","20230926","20230928","20231002","20231010","20231011",
                          "20231013","20231020","20231024","20231025",
                       ]
            session_start_times = [ 
                                        0,    0,    0, 32.2, 27.2, 37.5,
                                     21.0, 21.5, 19.8, 32.0,    0,    0,   0, 
                                        0,    0,    0,    0,    0,    0,
                                        0,    0,    0,    0, 
                                  ] # in second 
        elif do_trainedMCs:
            dates_list = [
                          "20230718","20230720","20230914", # sr
                
                          "20231030","20231031","20231101","20231102","20240304","20240305", # trained MC
                
                          "20231011","20231013","20231020","20231024","20231025", # nv
                       ]
            session_start_times = [ 
                                       0,    0,    0,
                
                                    18.2, 14.0, 15.8, 15.2, 16.3, 37.9,
                
                                       0,    0,    0,    0,    0, 
                                  ] # in second 
    
    animal1_fixedorder = ['dannon']
    animal2_fixedorder = ['kanga']

    animal1_filename = "Dannon"
    animal2_filename = "Kanga"

# Koala Vermelho
if 1:
    if not do_bestsession:
        dates_list = [
                     
                     ]
        session_start_times = [ 
                               
                              ] # in second
    elif do_bestsession:
        if not do_trainedMCs:
            # pick only five sessions for each conditions during the training phase
            dates_list = [
                          "20231222","20231226","20231227",  "20231229","20231230",
                          "20231231","20240102","20240104-2","20240105","20240108",
                          "20240109","20240115","20240116",  "20240117","20240118","20240119",
                          "20240207","20240208","20240209",  "20240212","20240213",
                          "20240214","20240215","20240216",  
                         ]
            session_start_times = [ 
                                    21.5,  0.00,  0.00,  0.00,  0.00, 
                                    0.00,  12.2,  0.00,  18.8,  31.2,  
                                    32.5,  0.00,  50.0,  0.00,  37.5,  29.5,
                                    58.5,  72.0,  0.00,  71.5,  70.5,
                                    86.8,  94.0,  65.0, 
                                  ] # in second
        elif do_trainedMCs:
            dates_list = [
                          "20231222","20231226","20231227", # SR
                          
                          "20240220","20240222","20240223","20240226", # trained MC
                 
                          "20240214","20240215","20240216",  # NV
                         ]
            session_start_times = [ 
                                    21.5,  0.00,  0.00, 
                                    
                                    68.8,  43.8,  13.2,  47.5,
                
                                    86.8,  94.0,  65.0, 
                                  ] # in second

    animal1_fixedorder = ['koala']
    animal2_fixedorder = ['vermelho']

    animal1_filename = "Koala"
    animal2_filename = "Vermelho"
    
    
#    
# dates_list = ["20230718"]
# session_start_times = [0.00] # in second

ndates = np.shape(dates_list)[0]

session_start_frames = session_start_times * fps # fps is 30Hz

totalsess_time = 600

# video tracking results info
animalnames_videotrack = ['dodson','scorch'] # does not really mean dodson and scorch, instead, indicate animal1 and animal2
bodypartnames_videotrack = ['rightTuft','whiteBlaze','leftTuft','rightEye','leftEye','mouth']


# which camera to analyzed
cameraID = 'camera-2'
cameraID_short = 'cam2'


# location of levers and tubes for camera 2
# get this information using DLC animal tracking GUI, the results are stored: 
# /home/ws523/marmoset_tracking_DLCv2/marmoset_tracking_with_lever_tube-weikang-2023-04-13/labeled-data/
considerlevertube = 1
considertubeonly = 0
# # camera 1
# lever_locs_camI = {'dodson':np.array([645,600]),'scorch':np.array([425,435])}
# tube_locs_camI  = {'dodson':np.array([1350,630]),'scorch':np.array([555,345])}
# # camera 2
lever_locs_camI = {'dodson':np.array([1335,715]),'scorch':np.array([550,715])}
tube_locs_camI  = {'dodson':np.array([1550,515]),'scorch':np.array([350,515])}
# # lever_locs_camI = {'dodson':np.array([1335,715]),'scorch':np.array([550,715])}
# # tube_locs_camI  = {'dodson':np.array([1650,490]),'scorch':np.array([250,490])}
# # camera 3
# lever_locs_camI = {'dodson':np.array([1580,440]),'scorch':np.array([1296,540])}
# tube_locs_camI  = {'dodson':np.array([1470,375]),'scorch':np.array([805,475])}


if np.shape(session_start_times)[0] != np.shape(dates_list)[0]:
    exit()

    
# define bhv events summarizing variables     
tasktypes_all_dates = np.zeros((ndates,1))
coopthres_all_dates = np.zeros((ndates,1))

succ_rate_all_dates = np.zeros((ndates,1))
interpullintv_all_dates = np.zeros((ndates,1))
trialnum_all_dates = np.zeros((ndates,1))

owgaze1_num_all_dates = np.zeros((ndates,1))
owgaze2_num_all_dates = np.zeros((ndates,1))
mtgaze1_num_all_dates = np.zeros((ndates,1))
mtgaze2_num_all_dates = np.zeros((ndates,1))
pull1_num_all_dates = np.zeros((ndates,1))
pull2_num_all_dates = np.zeros((ndates,1))

bhv_intv_all_dates = dict.fromkeys(dates_list, [])
pull_edges_intv_all_dates = dict.fromkeys(dates_list, [])


# where to save the summarizing data
data_saved_folder = '/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/3d_recontruction_analysis_self_and_coop_task_data_saved/'


    

In [None]:
# basic behavior analysis (define time stamps for each bhv events, etc)

try:
    if redo_anystep:
        dummy
    
    # load saved data
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    
    with open(data_saved_subfolder+'/owgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        owgaze1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/owgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        owgaze2_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/mtgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        mtgaze1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/mtgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        mtgaze2_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/pull1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        pull1_num_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/pull2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        pull2_num_all_dates = pickle.load(f)

    with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        tasktypes_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        coopthres_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/succ_rate_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        succ_rate_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/interpullintv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        interpullintv_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/trialnum_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        trialnum_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/bhv_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        bhv_intv_all_dates = pickle.load(f)
    with open(data_saved_subfolder+'/pull_edges_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
        pull_edges_intv_all_dates = pickle.load(f)

    print('all data from all dates are loaded')

except:

    print('analyze all dates')

    for idate in np.arange(0,ndates,1):
        date_tgt = dates_list[idate]
        session_start_time = session_start_times[idate]

        # folder and file path
        camera12_analyzed_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/test_video_cooperative_task_3d/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_camera12/"
        camera23_analyzed_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/test_video_cooperative_task_3d/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_camera23/"
        
        try:
            singlecam_ana_type = "DLC_dlcrnetms5_marmoset_tracking_with_middle_cameraSep1shuffle1_150000"
            try: 
                bodyparts_camI_camIJ = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"
            except:
                bodyparts_camI_camIJ = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"        
        except:
            singlecam_ana_type = "DLC_dlcrnetms5_marmoset_tracking_with_middle_camera_withHeadchamberFeb28shuffle1_167500"
            try: 
                bodyparts_camI_camIJ = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera12_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"
            except:
                bodyparts_camI_camIJ = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+singlecam_ana_type+"_el_filtered.h5"
                # get the bodypart data from files
                bodyparts_locs_camI = body_part_locs_singlecam(bodyparts_camI_camIJ,singlecam_ana_type,animalnames_videotrack,bodypartnames_videotrack,date_tgt)
                video_file_original = camera23_analyzed_path+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_"+cameraID+".mp4"        
        
        
        # load behavioral results
        try:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
        except:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])

        # get animal info from the session information
        animal1 = session_info['lever1_animal'][0].lower()
        animal2 = session_info['lever2_animal'][0].lower()

        
        # get task type and cooperation threshold
        try:
            coop_thres = session_info["pulltime_thres"][0]
            tasktype = session_info["task_type"][0]
        except:
            coop_thres = 0
            tasktype = 1
        tasktypes_all_dates[idate] = tasktype
        coopthres_all_dates[idate] = coop_thres   

        # clean up the trial_record
        warnings.filterwarnings('ignore')
        trial_record_clean = pd.DataFrame(columns=trial_record.columns)
        for itrial in np.arange(0,np.max(trial_record['trial_number']),1):
            # trial_record_clean.loc[itrial] = trial_record[trial_record['trial_number']==itrial+1].iloc[[0]]
            trial_record_clean = trial_record_clean.append(trial_record[trial_record['trial_number']==itrial+1].iloc[[0]])
        trial_record_clean = trial_record_clean.reset_index(drop = True)

        # change bhv_data time to the absolute time
        time_points_new = pd.DataFrame(np.zeros(np.shape(bhv_data)[0]),columns=["time_points_new"])
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number']),1):
            ind = bhv_data["trial_number"]==itrial+1
            new_time_itrial = bhv_data[ind]["time_points"] + trial_record_clean["trial_starttime"].iloc[itrial]
            time_points_new["time_points_new"][ind] = new_time_itrial
        bhv_data["time_points"] = time_points_new["time_points_new"]
        bhv_data = bhv_data[bhv_data["time_points"] != 0]


        # analyze behavior results
        # succ_rate_all_dates[idate] = np.sum(trial_record_clean["rewarded"]>0)/np.shape(trial_record_clean)[0]
        succ_rate_all_dates[idate] = np.sum((bhv_data['behavior_events']==3)|(bhv_data['behavior_events']==4))/np.sum((bhv_data['behavior_events']==1)|(bhv_data['behavior_events']==2))
        trialnum_all_dates[idate] = np.shape(trial_record_clean)[0]
        #
        pullid = np.array(bhv_data[(bhv_data['behavior_events']==1) | (bhv_data['behavior_events']==2)]["behavior_events"])
        pulltime = np.array(bhv_data[(bhv_data['behavior_events']==1) | (bhv_data['behavior_events']==2)]["time_points"])
        pullid_diff = np.abs(pullid[1:] - pullid[0:-1])
        pulltime_diff = pulltime[1:] - pulltime[0:-1]
        interpull_intv = pulltime_diff[pullid_diff==1]
        interpull_intv = interpull_intv[interpull_intv<10]
        mean_interpull_intv = np.nanmean(interpull_intv)
        std_interpull_intv = np.nanstd(interpull_intv)
        #
        interpullintv_all_dates[idate] = mean_interpull_intv
        # 
        pull1_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==1) 
        pull2_num_all_dates[idate] = np.sum(bhv_data['behavior_events']==2)

        
        # load behavioral event results
        try:
            # dummy
            print('load social gaze with '+cameraID+' only of '+date_tgt)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'rb') as f:
                output_look_ornot = pickle.load(f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'rb') as f:
                output_allvectors = pickle.load(f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'rb') as f:
                output_allangles = pickle.load(f)  
        except:   
            print('analyze social gaze with '+cameraID+' only of '+date_tgt)
            # get social gaze information 
            output_look_ornot, output_allvectors, output_allangles = find_socialgaze_timepoint_singlecam_wholebody(bodyparts_locs_camI,lever_locs_camI,tube_locs_camI,
                                                                                                                   considerlevertube,considertubeonly,sqr_thres_tubelever,
                                                                                                                   sqr_thres_face,sqr_thres_body)
            # save data
            current_dir = data_saved_folder+'/bhv_events_singlecam_wholebody/'+animal1_fixedorder[0]+animal2_fixedorder[0]
            add_date_dir = os.path.join(current_dir,cameraID+'/'+date_tgt)
            if not os.path.exists(add_date_dir):
                os.makedirs(add_date_dir)
            #
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'wb') as f:
                pickle.dump(output_look_ornot, f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'wb') as f:
                pickle.dump(output_allvectors, f)
            with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'wb') as f:
                pickle.dump(output_allangles, f)
  

        look_at_other_or_not_merge = output_look_ornot['look_at_other_or_not_merge']
        look_at_tube_or_not_merge = output_look_ornot['look_at_tube_or_not_merge']
        look_at_lever_or_not_merge = output_look_ornot['look_at_lever_or_not_merge']
        # change the unit to second
        session_start_time = session_start_times[idate]
        look_at_other_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_other_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_lever_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_lever_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_tube_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_tube_or_not_merge['dodson'])[0],1)/fps - session_start_time 

        # find time point of behavioral events
        output_time_points_socialgaze ,output_time_points_levertube = bhv_events_timepoint_singlecam(bhv_data,look_at_other_or_not_merge,look_at_lever_or_not_merge,look_at_tube_or_not_merge)
        time_point_pull1 = output_time_points_socialgaze['time_point_pull1']
        time_point_pull2 = output_time_points_socialgaze['time_point_pull2']
        oneway_gaze1 = output_time_points_socialgaze['oneway_gaze1']
        oneway_gaze2 = output_time_points_socialgaze['oneway_gaze2']
        mutual_gaze1 = output_time_points_socialgaze['mutual_gaze1']
        mutual_gaze2 = output_time_points_socialgaze['mutual_gaze2']
            
                
        # # plot behavioral events
        if np.isin(animal1,animal1_fixedorder):
                plot_bhv_events(date_tgt,animal1, animal2, session_start_time, 600, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
        else:
                plot_bhv_events(date_tgt,animal2, animal1, session_start_time, 600, time_point_pull2, time_point_pull1, oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)
        #
        # save behavioral events plot
        if 0:
            current_dir = data_saved_folder+'/bhv_events_singlecam_wholebody/'+animal1_fixedorder[0]+animal2_fixedorder[0]
            add_date_dir = os.path.join(current_dir,cameraID+'/'+date_tgt)
            if not os.path.exists(add_date_dir):
                os.makedirs(add_date_dir)
            plt.savefig(data_saved_folder+"/bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/'+date_tgt+"_"+cameraID_short+".pdf")

        #
        owgaze1_num_all_dates[idate] = np.shape(oneway_gaze1)[0]
        owgaze2_num_all_dates[idate] = np.shape(oneway_gaze2)[0]
        mtgaze1_num_all_dates[idate] = np.shape(mutual_gaze1)[0]
        mtgaze2_num_all_dates[idate] = np.shape(mutual_gaze2)[0]

        # analyze the events interval, especially for the pull to other and other to pull interval
        # could be used for define time bin for DBN
        if np.isin(animal1,animal1_fixedorder):
            _,_,_,pullTOother_itv, otherTOpull_itv = bhv_events_interval(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, 
                                                                         oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            #
            pull_other_pool_itv = np.concatenate((pullTOother_itv,otherTOpull_itv))
            bhv_intv_all_dates[date_tgt] = {'pull_to_other':pullTOother_itv,'other_to_pull':otherTOpull_itv,
                            'pull_other_pooled': pull_other_pool_itv}
            
            all_pull_edges_intervals = bhv_events_interval_certainEdges(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, 
                                                                        oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            pull_edges_intv_all_dates[date_tgt] = all_pull_edges_intervals
        else:
            _,_,_,pullTOother_itv, otherTOpull_itv = bhv_events_interval(totalsess_time, session_start_time, time_point_pull2, time_point_pull1, 
                                                                         oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)
            #
            pull_other_pool_itv = np.concatenate((pullTOother_itv,otherTOpull_itv))
            bhv_intv_all_dates[date_tgt] = {'pull_to_other':pullTOother_itv,'other_to_pull':otherTOpull_itv,
                            'pull_other_pooled': pull_other_pool_itv}
            
            all_pull_edges_intervals = bhv_events_interval_certainEdges(totalsess_time, session_start_time, time_point_pull2, time_point_pull1, 
                                                                        oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)
            pull_edges_intv_all_dates[date_tgt] = all_pull_edges_intervals
   

        # plot the tracking demo video
        if 0: 
            tracking_video_singlecam_wholebody_demo(bodyparts_locs_camI,output_look_ornot,output_allvectors,output_allangles,
                                              lever_locs_camI,tube_locs_camI,time_point_pull1,time_point_pull2,
                                              animalnames_videotrack,bodypartnames_videotrack,date_tgt,
                                              animal1_filename,animal2_filename,session_start_time,fps,nframes,cameraID,
                                              video_file_original,sqr_thres_tubelever,sqr_thres_face,sqr_thres_body)         
        

    # save data
    if 1:
        
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
                
        # with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
        #     pickle.dump(DBN_input_data_alltypes, f)

        with open(data_saved_subfolder+'/owgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(owgaze1_num_all_dates, f)
        with open(data_saved_subfolder+'/owgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(owgaze2_num_all_dates, f)
        with open(data_saved_subfolder+'/mtgaze1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(mtgaze1_num_all_dates, f)
        with open(data_saved_subfolder+'/mtgaze2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(mtgaze2_num_all_dates, f)
        with open(data_saved_subfolder+'/pull1_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(pull1_num_all_dates, f)
        with open(data_saved_subfolder+'/pull2_num_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(pull2_num_all_dates, f)

        with open(data_saved_subfolder+'/tasktypes_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(tasktypes_all_dates, f)
        with open(data_saved_subfolder+'/coopthres_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(coopthres_all_dates, f)
        with open(data_saved_subfolder+'/succ_rate_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(succ_rate_all_dates, f)
        with open(data_saved_subfolder+'/interpullintv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(interpullintv_all_dates, f)
        with open(data_saved_subfolder+'/trialnum_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(trialnum_all_dates, f)
        with open(data_saved_subfolder+'/bhv_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(bhv_intv_all_dates, f)
        with open(data_saved_subfolder+'/pull_edges_intv_all_dates_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
            pickle.dump(pull_edges_intv_all_dates, f)
    
    
    

### prepare the input data for DBN

In [None]:
# define DBN related summarizing variables
DBN_group_typenames = ['self','coop(3s)','coop(2s)','coop(1.5s)','coop(1s)','no-vision']
DBN_group_typeIDs  =  [1,3,3,  3,3,5]
DBN_group_coopthres = [0,3,2,1.5,1,0]
if do_trainedMCs:
    DBN_group_typenames = ['self','coop(1s)','no-vision']
    DBN_group_typeIDs  =  [1,3,5]
    DBN_group_coopthres = [0,1,0]
nDBN_groups = np.shape(DBN_group_typenames)[0]

prepare_input_data = 0

DBN_input_data_alltypes = dict.fromkeys(DBN_group_typenames, [])

# DBN resolutions (make sure they are the same as in the later part of the code)
totalsess_time = 600 # total session time in s
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]

mergetempRos = 0
if mergetempRos:
    temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    # use bhv event to decide temporal resolution
    #
    #low_lim,up_lim,_ = bhv_events_interval(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
    #temp_resolus = temp_resolus = np.arange(low_lim,up_lim,0.1)
ntemp_reses = np.shape(temp_resolus)[0]

# # train the dynamic bayesian network - Alec's model 
#   prepare the multi-session table; one time lag; multi time steps (temporal resolution) as separate files

try:
    for temp_resolu in temp_resolus:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not mergetempRos:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)
        print('load DBN input results')    
except:
    prepare_input_data = 1



# prepare the DBN input data
if prepare_input_data:
    
    for idate in np.arange(0,ndates,1):
        date_tgt = dates_list[idate]
        session_start_time = session_start_times[idate]

        # load behavioral results
        try:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
        except:
            try:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
            except:
                bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_forceManipulation_task/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
                trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
                bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
                session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
                #
                trial_record = pd.read_json(trial_record_json[0])
                bhv_data = pd.read_json(bhv_data_json[0])
                session_info = pd.read_json(session_info_json[0])
                
            
        # get animal info
        animal1 = session_info['lever1_animal'][0].lower()
        animal2 = session_info['lever2_animal'][0].lower()
        
        # clean up the trial_record
        warnings.filterwarnings('ignore')
        trial_record_clean = pd.DataFrame(columns=trial_record.columns)
        for itrial in np.arange(0,np.max(trial_record['trial_number']),1):
            # trial_record_clean.loc[itrial] = trial_record[trial_record['trial_number']==itrial+1].iloc[[0]]
            trial_record_clean = trial_record_clean.append(trial_record[trial_record['trial_number']==itrial+1].iloc[[0]])
        trial_record_clean = trial_record_clean.reset_index(drop = True)

        # change bhv_data time to the absolute time
        time_points_new = pd.DataFrame(np.zeros(np.shape(bhv_data)[0]),columns=["time_points_new"])
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number']),1):
            ind = bhv_data["trial_number"]==itrial+1
            new_time_itrial = bhv_data[ind]["time_points"] + trial_record_clean["trial_starttime"].iloc[itrial]
            time_points_new["time_points_new"][ind] = new_time_itrial
        bhv_data["time_points"] = time_points_new["time_points_new"]
        bhv_data = bhv_data[bhv_data["time_points"] != 0]
            
        # get task type and cooperation threshold
        try:
            coop_thres = session_info["pulltime_thres"][0]
            tasktype = session_info["task_type"][0]
        except:
            coop_thres = 0
            tasktype = 1

        # load behavioral event results
        print('load social gaze with '+cameraID+' only of '+date_tgt)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_look_ornot.pkl', 'rb') as f:
            output_look_ornot = pickle.load(f)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allvectors.pkl', 'rb') as f:
            output_allvectors = pickle.load(f)
        with open(data_saved_folder+"bhv_events_singlecam_wholebody/"+animal1_fixedorder[0]+animal2_fixedorder[0]+"/"+cameraID+'/'+date_tgt+'/output_allangles.pkl', 'rb') as f:
            output_allangles = pickle.load(f)  
        #
        look_at_other_or_not_merge = output_look_ornot['look_at_other_or_not_merge']
        look_at_tube_or_not_merge = output_look_ornot['look_at_tube_or_not_merge']
        look_at_lever_or_not_merge = output_look_ornot['look_at_lever_or_not_merge']
        # change the unit to second
        session_start_time = session_start_times[idate]
        look_at_other_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_other_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_lever_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_lever_or_not_merge['dodson'])[0],1)/fps - session_start_time
        look_at_tube_or_not_merge['time_in_second'] = np.arange(0,np.shape(look_at_tube_or_not_merge['dodson'])[0],1)/fps - session_start_time 

        # redefine the totalsess_time for the length of each recording (NOT! remove the session_start_time)
        totalsess_time = int(np.ceil(np.shape(look_at_other_or_not_merge['dodson'])[0]/fps))
        
        # find time point of behavioral events
        output_time_points_socialgaze ,output_time_points_levertube = bhv_events_timepoint_singlecam(bhv_data,look_at_other_or_not_merge,look_at_lever_or_not_merge,look_at_tube_or_not_merge)
        time_point_pull1 = output_time_points_socialgaze['time_point_pull1']
        time_point_pull2 = output_time_points_socialgaze['time_point_pull2']
        oneway_gaze1 = output_time_points_socialgaze['oneway_gaze1']
        oneway_gaze2 = output_time_points_socialgaze['oneway_gaze2']
        mutual_gaze1 = output_time_points_socialgaze['mutual_gaze1']
        mutual_gaze2 = output_time_points_socialgaze['mutual_gaze2']   


        if mergetempRos:
            temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            # use bhv event to decide temporal resolution
            #
            #low_lim,up_lim,_ = bhv_events_interval(totalsess_time, session_start_time, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            #temp_resolus = temp_resolus = np.arange(low_lim,up_lim,0.1)

        ntemp_reses = np.shape(temp_resolus)[0]           

        # try different temporal resolutions
        for temp_resolu in temp_resolus:
            bhv_df = []

            if np.isin(animal1,animal1_fixedorder):
                bhv_df_itr,_,_ = train_DBN_multiLag_create_df_only(totalsess_time, session_start_time, temp_resolu, time_point_pull1, time_point_pull2, oneway_gaze1, oneway_gaze2, mutual_gaze1, mutual_gaze2)
            else:
                bhv_df_itr,_,_ = train_DBN_multiLag_create_df_only(totalsess_time, session_start_time, temp_resolu, time_point_pull2, time_point_pull1, oneway_gaze2, oneway_gaze1, mutual_gaze2, mutual_gaze1)     

            if len(bhv_df)==0:
                bhv_df = bhv_df_itr
            else:
                bhv_df = pd.concat([bhv_df,bhv_df_itr])                   
                bhv_df = bhv_df.reset_index(drop=True)        

            # merge sessions from the same condition
            for iDBN_group in np.arange(0,nDBN_groups,1):
                iDBN_group_typename = DBN_group_typenames[iDBN_group] 
                iDBN_group_typeID =  DBN_group_typeIDs[iDBN_group] 
                iDBN_group_cothres = DBN_group_coopthres[iDBN_group] 

                # merge sessions 
                if (tasktype!=3):
                    if (tasktype==iDBN_group_typeID):
                        if (len(DBN_input_data_alltypes[iDBN_group_typename])==0):
                            DBN_input_data_alltypes[iDBN_group_typename] = bhv_df
                        else:
                            DBN_input_data_alltypes[iDBN_group_typename] = pd.concat([DBN_input_data_alltypes[iDBN_group_typename],bhv_df])
                else:
                    if (coop_thres==iDBN_group_cothres):
                        if (len(DBN_input_data_alltypes[iDBN_group_typename])==0):
                            DBN_input_data_alltypes[iDBN_group_typename] = bhv_df
                        else:
                            DBN_input_data_alltypes[iDBN_group_typename] = pd.concat([DBN_input_data_alltypes[iDBN_group_typename],bhv_df])

    # save data
    if 1:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
        if not mergetempRos:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'wb') as f:
                pickle.dump(DBN_input_data_alltypes, f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'wb') as f:
                pickle.dump(DBN_input_data_alltypes, f)     

In [None]:
DBN_input_data_alltypes


### run the DBN model on the combined session data set

#### a test run

In [None]:
# run DBN on the large table with merged sessions

mergetempRos = 0 # 1: merge different time bins

moreSampSize = 0 # 1: use more sample size (more than just minimal row number and max row number)

num_starting_points = 1 # number of random starting points/graphs
nbootstraps = 1

if 1:

    if moreSampSize:
        # different data (down/re)sampling numbers
        samplingsizes = np.arange(1100,3000,100)
        # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
        # samplingsizes = [100,500]
        # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
        samplingsizes_name = list(map(str, samplingsizes))
        nsamplings = np.shape(samplingsizes)[0]

    weighted_graphs_diffTempRo_diffSampSize = {}
    weighted_graphs_shuffled_diffTempRo_diffSampSize = {}
    sig_edges_diffTempRo_diffSampSize = {}
    DAGscores_diffTempRo_diffSampSize = {}
    DAGscores_shuffled_diffTempRo_diffSampSize = {}

    totalsess_time = 600 # total session time in s
    # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    ntemp_reses = np.shape(temp_resolus)[0]

    # try different temporal resolutions, remember to use the same settings as in the previous ones
    for temp_resolu in temp_resolus:

        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not mergetempRos:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)

                
        # only try two sample sizes - minimal row number (require data downsample) and maximal row number (require data upsample)
       
        if not moreSampSize:
            key_to_value_lengths = {k:len(v) for k, v in DBN_input_data_alltypes.items()}
            key_to_value_lengths_array = np.fromiter(key_to_value_lengths.values(),dtype=float)
            key_to_value_lengths_array[key_to_value_lengths_array==0]=np.nan
            min_samplesize = np.nanmin(key_to_value_lengths_array)
            min_samplesize = int(min_samplesize/100)*100
            max_samplesize = np.nanmax(key_to_value_lengths_array)
            max_samplesize = int(max_samplesize/100)*100
            samplingsizes = [min_samplesize,max_samplesize]
            samplingsizes_name = ['min_row_number','max_row_number']   
            nsamplings = np.shape(samplingsizes)[0]
            print(samplingsizes)
                
        # try different down/re-sampling size
        # for jj in np.arange(0,nsamplings,1):
        for jj in np.arange(0,1,1):
            
            isamplingsize = samplingsizes[jj]
            
            DAGs_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_shuffle_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_scores_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_shuffle_scores_alltypes = dict.fromkeys(DBN_group_typenames, [])

            weighted_graphs_alltypes = dict.fromkeys(DBN_group_typenames, [])
            weighted_graphs_shuffled_alltypes = dict.fromkeys(DBN_group_typenames, [])
            sig_edges_alltypes = dict.fromkeys(DBN_group_typenames, [])

            # different session conditions (aka DBN groups)
            # for iDBN_group in np.arange(0,nDBN_groups,1):
            for iDBN_group in np.arange(0,1,1):
                iDBN_group_typename = DBN_group_typenames[iDBN_group] 
                iDBN_group_typeID =  DBN_group_typeIDs[iDBN_group] 
                iDBN_group_cothres = DBN_group_coopthres[iDBN_group] 

                # try:
                bhv_df_all = DBN_input_data_alltypes[iDBN_group_typename]
                # bhv_df = bhv_df_all.sample(30*100,replace = True, random_state = round(time())) # take the subset for DBN training

                #Anirban(Alec) shuffle, slow
                # bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))


                # define DBN graph structures; make sure they are the same as in the train_DBN_multiLag
                colnames = list(bhv_df_all.columns)
                eventnames = ["pull1","pull2","owgaze1","owgaze2"]
                nevents = np.size(eventnames)

                all_pops = list(bhv_df_all.columns)
                # from_pops = [pop for pop in all_pops if not pop.endswith('t3')]
                from_pops = [pop for pop in all_pops]
                to_pops = [pop for pop in all_pops if pop.endswith('t3')]
                causal_whitelist = [(from_pop,to_pop) for from_pop in from_pops for to_pop in to_pops]
                # remove cycle edge (to self)
                causal_whitelist = [edge for edge in causal_whitelist if edge[0] != edge[1]]


                nFromNodes = np.shape(from_pops)[0]
                nToNodes = np.shape(to_pops)[0]

                DAGs_randstart = np.zeros((num_starting_points, nFromNodes, nToNodes))
                DAGs_randstart_shuffle = np.zeros((num_starting_points, nFromNodes, nToNodes))
                score_randstart = np.zeros((num_starting_points))
                score_randstart_shuffle = np.zeros((num_starting_points))

                # step 1: randomize the starting point for num_starting_points times
                for istarting_points in np.arange(0,num_starting_points,1):

                    # try different down/re-sampling size
                    bhv_df = bhv_df_all.sample(isamplingsize,replace = True, random_state = istarting_points) # take the subset for DBN training
                    aic = AicScore(bhv_df)

                    #Anirban(Alec) shuffle, slow
                    bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))
                    aic_shuffle = AicScore(bhv_df_shuffle)

                    np.random.seed(istarting_points)
                    random.seed(istarting_points)
                    starting_edges = random.sample(causal_whitelist, np.random.randint(1,len(causal_whitelist)))
                    starting_graph = DAG()
                    starting_graph.add_nodes_from(nodes=all_pops)
                    starting_graph.add_edges_from(ebunch=starting_edges)

                    best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df,starting_graph,colnames,eventnames,from_pops,to_pops)           
                    DAGs[0][np.isnan(DAGs[0])]=0

                    DAGs_randstart[istarting_points,:,:] = DAGs[0]
                    score_randstart[istarting_points] = aic.score(best_model)

                    # step 2: add the shffled data results
                    # shuffled bhv_df
                    best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df_shuffle,starting_graph,colnames,eventnames,from_pops,to_pops)           
                    DAGs[0][np.isnan(DAGs[0])]=0

                    DAGs_randstart_shuffle[istarting_points,:,:] = DAGs[0]
                    score_randstart_shuffle[istarting_points] = aic_shuffle.score(best_model)

                DAGs_alltypes[iDBN_group_typename] = DAGs_randstart 
                DAGs_shuffle_alltypes[iDBN_group_typename] = DAGs_randstart_shuffle

                DAGs_scores_alltypes[iDBN_group_typename] = score_randstart
                DAGs_shuffle_scores_alltypes[iDBN_group_typename] = score_randstart_shuffle

                weighted_graphs = get_weighted_dags(DAGs_alltypes[iDBN_group_typename],nbootstraps)
                weighted_graphs_shuffled = get_weighted_dags(DAGs_shuffle_alltypes[iDBN_group_typename],nbootstraps)
                sig_edges = get_significant_edges(weighted_graphs,weighted_graphs_shuffled)

                weighted_graphs_alltypes[iDBN_group_typename] = weighted_graphs
                weighted_graphs_shuffled_alltypes[iDBN_group_typename] = weighted_graphs_shuffled
                sig_edges_alltypes[iDBN_group_typename] = sig_edges
                    
                # except:
                #     DAGs_alltypes[iDBN_group_typename] = [] 
                 #    DAGs_shuffle_alltypes[iDBN_group_typename] = []

                #     DAGs_scores_alltypes[iDBN_group_typename] = []
                 #    DAGs_shuffle_scores_alltypes[iDBN_group_typename] = []

                 #   weighted_graphs_alltypes[iDBN_group_typename] = []
                 #   weighted_graphs_shuffled_alltypes[iDBN_group_typename] = []
                 #    sig_edges_alltypes[iDBN_group_typename] = []
                
            DAGscores_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_scores_alltypes
            DAGscores_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_shuffle_scores_alltypes

            weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_alltypes
            weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_shuffled_alltypes
            sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = sig_edges_alltypes

    print(weighted_graphs_diffTempRo_diffSampSize)
            
   

In [None]:
# define DBN structures
all_pops = list(bhv_df.columns)
causal_whitelist = [(from_pop, to_pop) for from_pop in from_pops for to_pop in to_pops]

# remove cycle edge (to self)
causal_whitelist = [edge for edge in causal_whitelist if edge[0] != edge[1]]

# Initialize the starting graph and add edges from causal_whitelist
starting_graph = nx.DiGraph()
starting_graph.add_edges_from(causal_whitelist)

# Check for cycles and remove edges causing cycles
try:
    while True:
        cycle = nx.find_cycle(starting_graph)
        print(f"Cycle detected and removing edge: {cycle[-1]}")
        starting_graph.remove_edge(*cycle[-1])  # Remove one edge from the cycle to break it
except nx.NetworkXNoCycle:
    print("No cycles detected in the graph.")

# Convert starting_graph to a pgmpy format for HillClimbSearch
# (Assuming starting_graph is compatible or convertible for HillClimbSearch)
edges = list(starting_graph.edges)
causal_whitelist_no_cycles = [edge for edge in causal_whitelist if edge in edges]

In [None]:
causal_whitelist 

#### run on the entire population

In [None]:
# run DBN on the large table with merged sessions

mergetempRos = 0 # 1: merge different time bins

moreSampSize = 0 # 1: use more sample size (more than just minimal row number and max row number)

num_starting_points = 100 # number of random starting points/graphs
nbootstraps = 95

try:
    # dumpy
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags_withinLayerEdges/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(data_saved_subfolder):
        os.makedirs(data_saved_subfolder)
    if moreSampSize:
        with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            DAGscores_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            DAGscores_shuffled_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)

    else:
        with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
            DAGscores_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
            DAGscores_shuffled_diffTempRo_diffSampSize = pickle.load(f) 
        with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
            weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
            weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
        with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'rb') as f:
            sig_edges_diffTempRo_diffSampSize = pickle.load(f)

except:
    if moreSampSize:
        # different data (down/re)sampling numbers
        samplingsizes = np.arange(1100,3000,100)
        # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
        # samplingsizes = [100,500]
        # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
        samplingsizes_name = list(map(str, samplingsizes))
        nsamplings = np.shape(samplingsizes)[0]

    weighted_graphs_diffTempRo_diffSampSize = {}
    weighted_graphs_shuffled_diffTempRo_diffSampSize = {}
    sig_edges_diffTempRo_diffSampSize = {}
    DAGscores_diffTempRo_diffSampSize = {}
    DAGscores_shuffled_diffTempRo_diffSampSize = {}

    totalsess_time = 600 # total session time in s
    # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
    ntemp_reses = np.shape(temp_resolus)[0]

    # try different temporal resolutions, remember to use the same settings as in the previous ones
    for temp_resolu in temp_resolus:

        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not mergetempRos:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)
        else:
            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_mergeTempsReSo.pkl', 'rb') as f:
                DBN_input_data_alltypes = pickle.load(f)

                
        # only try two sample sizes - minimal row number (require data downsample) and maximal row number (require data upsample)
       
        if not moreSampSize:
            key_to_value_lengths = {k:len(v) for k, v in DBN_input_data_alltypes.items()}
            key_to_value_lengths_array = np.fromiter(key_to_value_lengths.values(),dtype=float)
            key_to_value_lengths_array[key_to_value_lengths_array==0]=np.nan
            min_samplesize = np.nanmin(key_to_value_lengths_array)
            min_samplesize = int(min_samplesize/100)*100
            max_samplesize = np.nanmax(key_to_value_lengths_array)
            max_samplesize = int(max_samplesize/100)*100
            # samplingsizes = [min_samplesize,max_samplesize]
            # samplingsizes_name = ['min_row_number','max_row_number'] 
            samplingsizes = [min_samplesize,]
            samplingsizes_name = ['min_row_number',] 
            nsamplings = np.shape(samplingsizes)[0]
            print(samplingsizes)
                
        # try different down/re-sampling size
        for jj in np.arange(0,nsamplings,1):
            
            isamplingsize = samplingsizes[jj]
            
            DAGs_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_shuffle_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_scores_alltypes = dict.fromkeys(DBN_group_typenames, [])
            DAGs_shuffle_scores_alltypes = dict.fromkeys(DBN_group_typenames, [])

            weighted_graphs_alltypes = dict.fromkeys(DBN_group_typenames, [])
            weighted_graphs_shuffled_alltypes = dict.fromkeys(DBN_group_typenames, [])
            sig_edges_alltypes = dict.fromkeys(DBN_group_typenames, [])

            # different session conditions (aka DBN groups)
            for iDBN_group in np.arange(0,nDBN_groups,1):
                iDBN_group_typename = DBN_group_typenames[iDBN_group] 
                iDBN_group_typeID =  DBN_group_typeIDs[iDBN_group] 
                iDBN_group_cothres = DBN_group_coopthres[iDBN_group] 

                try:
                    bhv_df_all = DBN_input_data_alltypes[iDBN_group_typename]
                    # bhv_df = bhv_df_all.sample(30*100,replace = True, random_state = round(time())) # take the subset for DBN training

                    #Anirban(Alec) shuffle, slow
                    # bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))


                    # define DBN graph structures; make sure they are the same as in the train_DBN_multiLag
                    colnames = list(bhv_df_all.columns)
                    eventnames = ["pull1","pull2","owgaze1","owgaze2"]
                    nevents = np.size(eventnames)

                    all_pops = list(bhv_df_all.columns)
                    # from_pops = [pop for pop in all_pops if not pop.endswith('t3')]
                    from_pops = [pop for pop in all_pops]
                    to_pops = [pop for pop in all_pops if pop.endswith('t3')]
                    causal_whitelist = [(from_pop,to_pop) for from_pop in from_pops for to_pop in to_pops]

                    nFromNodes = np.shape(from_pops)[0]
                    nToNodes = np.shape(to_pops)[0]

                    DAGs_randstart = np.zeros((num_starting_points, nFromNodes, nToNodes))
                    DAGs_randstart_shuffle = np.zeros((num_starting_points, nFromNodes, nToNodes))
                    score_randstart = np.zeros((num_starting_points))
                    score_randstart_shuffle = np.zeros((num_starting_points))

                    # step 1: randomize the starting point for num_starting_points times
                    for istarting_points in np.arange(0,num_starting_points,1):

                        # try different down/re-sampling size
                        bhv_df = bhv_df_all.sample(isamplingsize,replace = True, random_state = istarting_points) # take the subset for DBN training
                        aic = AicScore(bhv_df)

                        #Anirban(Alec) shuffle, slow
                        bhv_df_shuffle, df_shufflekeys = EfficientShuffle(bhv_df,round(time()))
                        aic_shuffle = AicScore(bhv_df_shuffle)

                        np.random.seed(istarting_points)
                        random.seed(istarting_points)
                        starting_edges = random.sample(causal_whitelist, np.random.randint(1,len(causal_whitelist)))
                        starting_graph = DAG()
                        starting_graph.add_nodes_from(nodes=all_pops)
                        starting_graph.add_edges_from(ebunch=starting_edges)

                        best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df,starting_graph,colnames,eventnames,from_pops,to_pops)           
                        DAGs[0][np.isnan(DAGs[0])]=0

                        DAGs_randstart[istarting_points,:,:] = DAGs[0]
                        score_randstart[istarting_points] = aic.score(best_model)

                        # step 2: add the shffled data results
                        # shuffled bhv_df
                        best_model,edges,DAGs = train_DBN_multiLag_training_only(bhv_df_shuffle,starting_graph,colnames,eventnames,from_pops,to_pops)           
                        DAGs[0][np.isnan(DAGs[0])]=0

                        DAGs_randstart_shuffle[istarting_points,:,:] = DAGs[0]
                        score_randstart_shuffle[istarting_points] = aic_shuffle.score(best_model)

                    DAGs_alltypes[iDBN_group_typename] = DAGs_randstart 
                    DAGs_shuffle_alltypes[iDBN_group_typename] = DAGs_randstart_shuffle

                    DAGs_scores_alltypes[iDBN_group_typename] = score_randstart
                    DAGs_shuffle_scores_alltypes[iDBN_group_typename] = score_randstart_shuffle

                    weighted_graphs = get_weighted_dags(DAGs_alltypes[iDBN_group_typename],nbootstraps)
                    weighted_graphs_shuffled = get_weighted_dags(DAGs_shuffle_alltypes[iDBN_group_typename],nbootstraps)
                    sig_edges = get_significant_edges(weighted_graphs,weighted_graphs_shuffled)

                    weighted_graphs_alltypes[iDBN_group_typename] = weighted_graphs
                    weighted_graphs_shuffled_alltypes[iDBN_group_typename] = weighted_graphs_shuffled
                    sig_edges_alltypes[iDBN_group_typename] = sig_edges
                    
                except:
                    DAGs_alltypes[iDBN_group_typename] = [] 
                    DAGs_shuffle_alltypes[iDBN_group_typename] = []

                    DAGs_scores_alltypes[iDBN_group_typename] = []
                    DAGs_shuffle_scores_alltypes[iDBN_group_typename] = []

                    weighted_graphs_alltypes[iDBN_group_typename] = []
                    weighted_graphs_shuffled_alltypes[iDBN_group_typename] = []
                    sig_edges_alltypes[iDBN_group_typename] = []
                
            DAGscores_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_scores_alltypes
            DAGscores_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = DAGs_shuffle_scores_alltypes

            weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_alltypes
            weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = weighted_graphs_shuffled_alltypes
            sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),samplingsizes_name[jj])] = sig_edges_alltypes

            
    # save data
    savedata = 1
    if savedata:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags_withinLayerEdges/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)
        if moreSampSize:  
            with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(DAGscores_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'_moreSampSize.pkl', 'wb') as f:
                pickle.dump(sig_edges_diffTempRo_diffSampSize, f)

        else:
            with open(data_saved_subfolder+'/DAGscores_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(DAGscores_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/DAGscores_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(DAGscores_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(weighted_graphs_shuffled_diffTempRo_diffSampSize, f)
            with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder[0]+animal2_fixedorder[0]+'.pkl', 'wb') as f:
                pickle.dump(sig_edges_diffTempRo_diffSampSize, f)


### plot graphs - show the edge with arrows; show the best time bin and row number; show the three time lag separately

In [None]:
# ONLY FOR PLOT!! 
# define DBN related summarizing variables
DBN_group_typenames = ['self','coop(1s)','no-vision']
DBN_group_typeIDs  =  [1,3,5]
DBN_group_coopthres = [0,1,0]
nDBN_groups = np.shape(DBN_group_typenames)[0]


# make sure these variables are the same as in the previous steps
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]
#
if moreSampSize:
    # different data (down/re)sampling numbers
    # samplingsizes = np.arange(1100,3000,100)
    samplingsizes = [1100]
    # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
    # samplingsizes = [100,500]
    # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
    samplingsizes_name = list(map(str, samplingsizes))
else:
    samplingsizes_name = ['min_row_number']   
nsamplings = np.shape(samplingsizes_name)[0]

# make sure these variables are consistent with the train_DBN_alec.py settings
# eventnames = ["pull1","pull2","gaze1","gaze2"]
eventnames = ["M1pull","M2pull","M1gaze","M2gaze"]
eventnode_locations = [[0,1],[1,1],[0,0],[1,0]]
eventname_locations = [[-0.5,1.0],[1.2,1],[-0.6,0],[1.2,0]]
# indicate where edge starts
# for the self edge, it's the center of the self loop
nodearrow_locations = [[[0.00,1.25],[0.25,1.10],[-.10,0.75],[0.15,0.65]],
                       [[0.75,1.00],[1.00,1.25],[0.85,0.65],[1.10,0.75]],
                       [[0.00,0.25],[0.25,0.35],[0.00,-.25],[0.25,-.10]],
                       [[0.75,0.35],[1.00,0.25],[0.75,0.00],[1.00,-.25]]]
# indicate where edge goes
# for the self edge, it's the theta1 and theta2 (with fixed radius)
nodearrow_directions = [[[ -45,-180],[0.50,0.00],[0.00,-.50],[0.50,-.50]],
                        [[-.50,0.00],[ -45,-180],[-.50,-.50],[0.00,-.50]],
                        [[0.00,0.50],[0.50,0.50],[ 180,  45],[0.50,0.00]],
                        [[-.50,0.50],[0.00,0.50],[-.50,0.00],[ 180,  45]]]

nevents = np.size(eventnames)
# eventnodes_color = ['b','r','y','g']
eventnodes_color = ['#BF3EFF','#FF7F00','#BF3EFF','#FF7F00']
eventnodes_shape = ["o","o","^","^"]
    
savefigs = 1

# different session conditions (aka DBN groups)
# different time lags (t_-3, t_-2 and t_-1)
fig, axs = plt.subplots(6,nDBN_groups)
fig.set_figheight(48)
fig.set_figwidth(8*nDBN_groups)

time_lags = ['t_-3','t_-2','t_-1']
fromRowIDs =[[0,1,2,3],[4,5,6,7],[8,9,10,11]]
ntime_lags = np.shape(time_lags)[0]

temp_resolu = temp_resolus[0]
j_sampsize_name = samplingsizes_name[0]    

for ilag in np.arange(0,ntime_lags,1):
    
    time_lag_name = time_lags[ilag]
    fromRowID = fromRowIDs[ilag]
    
    for iDBN_group in np.arange(0,nDBN_groups,1):

        try:

            iDBN_group_typename = DBN_group_typenames[iDBN_group]

            weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            # sig_edges_tgt = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            sig_edges_tgt = get_significant_edges(weighted_graphs_tgt,weighted_graphs_shuffled_tgt)

            #sig_edges_tgt = sig_edges_tgt*((weighted_graphs_tgt.mean(axis=0)>0.5)*1)

            sig_avg_dags = weighted_graphs_tgt.mean(axis = 0) * sig_edges_tgt
            sig_avg_dags = sig_avg_dags[fromRowID,:]

            # plot
            axs[ilag*2+0,iDBN_group].set_title(iDBN_group_typename,fontsize=18)
            axs[ilag*2+0,iDBN_group].set_xlim([-0.5,1.5])
            axs[ilag*2+0,iDBN_group].set_ylim([-0.5,1.5])
            axs[ilag*2+0,iDBN_group].set_xticks([])
            axs[ilag*2+0,iDBN_group].set_xticklabels([])
            axs[ilag*2+0,iDBN_group].set_yticks([])
            axs[ilag*2+0,iDBN_group].set_yticklabels([])
            axs[ilag*2+0,iDBN_group].spines['top'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['right'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['bottom'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['left'].set_visible(False)
            # axs[ilag*2+0,iDBN_group].axis('equal')


            for ieventnode in np.arange(0,nevents,1):
                # plot the event nodes
                axs[ilag*2+0,iDBN_group].plot(eventnode_locations[ieventnode][0],eventnode_locations[ieventnode][1],
                                              eventnodes_shape[ieventnode],markersize=60,markerfacecolor=eventnodes_color[ieventnode],
                                              markeredgecolor='none')              
                #axs[ilag*2+0,iDBN_group].text(eventname_locations[ieventnode][0],eventname_locations[ieventnode][1],
                #                       eventnames[ieventnode],fontsize=15)

                clmap = mpl.cm.get_cmap('Greens')

                # plot the event edges
                for ifromNode in np.arange(0,nevents,1):
                    for itoNode in np.arange(0,nevents,1):
                        edge_weight_tgt = sig_avg_dags[ifromNode,itoNode]
                        if edge_weight_tgt>0:
                            if not ifromNode == itoNode:
                                #axs[ilag*2+0,iDBN_group].plot(eventnode_locations[ifromNode],eventnode_locations[itoNode],'k-',linewidth=edge_weight_tgt*3)
                                axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0],
                                                        nodearrow_locations[ifromNode][itoNode][1],
                                                        nodearrow_directions[ifromNode][itoNode][0],
                                                        nodearrow_directions[ifromNode][itoNode][1],
                                                        # head_width=0.08*abs(edge_weight_tgt),
                                                        # width=0.04*abs(edge_weight_tgt),
                                                        head_width=0.08,
                                                        width=0.04,   
                                                        color = clmap(edge_weight_tgt))
                            if ifromNode == itoNode:
                                ring = mpatches.Wedge(nodearrow_locations[ifromNode][itoNode],
                                                      .1, nodearrow_directions[ifromNode][itoNode][0],
                                                      nodearrow_directions[ifromNode][itoNode][1], 
                                                      # 0.04*abs(edge_weight_tgt),
                                                      0.04,
                                                      color = clmap(edge_weight_tgt))
                                p = PatchCollection(
                                    [ring], 
                                    facecolor=clmap(edge_weight_tgt), 
                                    edgecolor=clmap(edge_weight_tgt)
                                )
                                axs[ilag*2+0,iDBN_group].add_collection(p)
                                # add arrow head
                                if ifromNode < 2:
                                    axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0]-0.1+0.02*edge_weight_tgt,
                                                            nodearrow_locations[ifromNode][itoNode][1],
                                                            0,-0.05,color=clmap(edge_weight_tgt),
                                                            # head_width=0.08*edge_weight_tgt,width=0.04*edge_weight_tgt
                                                            head_width=0.08,width=0.04      
                                                            )
                                else:
                                    axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0]-0.1+0.02*edge_weight_tgt,
                                                            nodearrow_locations[ifromNode][itoNode][1],
                                                            0,0.02,color=clmap(edge_weight_tgt),
                                                            # head_width=0.08*edge_weight_tgt,width=0.04*edge_weight_tgt
                                                            head_width=0.08,width=0.04      
                                                            )

            # heatmap for the weights
            sig_avg_dags_df = pd.DataFrame(sig_avg_dags)
            sig_avg_dags_df.columns = eventnames
            sig_avg_dags_df.index = eventnames
            vmin,vmax = 0,1

            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
            im = axs[ilag*2+1,iDBN_group].pcolormesh(sig_avg_dags_df,cmap="Greens",norm=norm)
            #
            if iDBN_group == nDBN_groups-1:
                cax = axs[ilag*2+1,iDBN_group].inset_axes([1.04, 0.2, 0.05, 0.8])
                fig.colorbar(im, ax=axs[ilag*2+1,iDBN_group], cax=cax,label='edge confidence')

            axs[ilag*2+1,iDBN_group].axis('equal')
            axs[ilag*2+1,iDBN_group].set_xlabel('to Node',fontsize=14)
            axs[ilag*2+1,iDBN_group].set_xticks(np.arange(0.5,4.5,1))
            axs[ilag*2+1,iDBN_group].set_xticklabels(eventnames)
            if iDBN_group == 0:
                axs[ilag*2+1,iDBN_group].set_ylabel('from Node',fontsize=14)
                axs[ilag*2+1,iDBN_group].set_yticks(np.arange(0.5,4.5,1))
                axs[ilag*2+1,iDBN_group].set_yticklabels(eventnames)
                axs[ilag*2+1,iDBN_group].text(-1.5,1,time_lag_name+' time lag',rotation=90,fontsize=20)
                axs[ilag*2+0,iDBN_group].text(-1.25,0,time_lag_name+' time lag',rotation=90,fontsize=20)
            else:
                axs[ilag*2+1,iDBN_group].set_yticks([])
                axs[ilag*2+1,iDBN_group].set_yticklabels([])

        except:
            continue
    
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    if moreSampSize:
        plt.savefig(figsavefolder+"threeTimeLag_DAGs_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'_'+str(j_sampsize_name)+'_rows.pdf')
    else:  
        plt.savefig(figsavefolder+"threeTimeLag_DAGs_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'_'+j_sampsize_name+'.pdf')
            
            
            

### plot graphs - show the edge differences, use one condition as the base

In [None]:
# ONLY FOR PLOT!! 
# define DBN related summarizing variables
DBN_group_typenames = ['self','coop(1s)','no-vision']
DBN_group_typeIDs  =  [1,3,5]
DBN_group_coopthres = [0,1,0]
nDBN_groups = np.shape(DBN_group_typenames)[0]

# make sure these variables are the same as in the previous steps
# temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
ntemp_reses = np.shape(temp_resolus)[0]
#
if moreSampSize:
    # different data (down/re)sampling numbers
    # samplingsizes = np.arange(1100,3000,100)
    samplingsizes = [1100]
    # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
    # samplingsizes = [100,500]
    # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
    samplingsizes_name = list(map(str, samplingsizes))
else:
    samplingsizes_name = ['min_row_number']   
nsamplings = np.shape(samplingsizes_name)[0]

basecondition = 'self' # 'self', 'coop(1s)'

# make sure these variables are consistent with the train_DBN_alec.py settings
# eventnames = ["pull1","pull2","gaze1","gaze2"]
eventnames = ["M1pull","M2pull","M1gaze","M2gaze"]
eventnode_locations = [[0,1],[1,1],[0,0],[1,0]]
eventname_locations = [[-0.5,1.0],[1.2,1],[-0.6,0],[1.2,0]]
# indicate where edge starts
# for the self edge, it's the center of the self loop
nodearrow_locations = [[[0.00,1.25],[0.25,1.10],[-.10,0.75],[0.15,0.65]],
                       [[0.75,1.00],[1.00,1.25],[0.85,0.65],[1.10,0.75]],
                       [[0.00,0.25],[0.25,0.35],[0.00,-.25],[0.25,-.10]],
                       [[0.75,0.35],[1.00,0.25],[0.75,0.00],[1.00,-.25]]]
# indicate where edge goes
# for the self edge, it's the theta1 and theta2 (with fixed radius)
nodearrow_directions = [[[ -45,-180],[0.50,0.00],[0.00,-.50],[0.50,-.50]],
                        [[-.50,0.00],[ -45,-180],[-.50,-.50],[0.00,-.50]],
                        [[0.00,0.50],[0.50,0.50],[ 180,  45],[0.50,0.00]],
                        [[-.50,0.50],[0.00,0.50],[-.50,0.00],[ 180,  45]]]

nevents = np.size(eventnames)
# eventnodes_color = ['b','r','y','g']
eventnodes_color = ['#BF3EFF','#FF7F00','#BF3EFF','#FF7F00']
eventnodes_shape = ["o","o","^","^"]

nFromNodes = nevents
nToNodes = nevents
    
savefigs = 1

# different session conditions (aka DBN groups)
# different time lags (t_-3, t_-2 and t_-1)
fig, axs = plt.subplots(6,nDBN_groups)
fig.set_figheight(48)
fig.set_figwidth(8*nDBN_groups)

time_lags = ['t_-3','t_-2','t_-1']
fromRowIDs =[[0,1,2,3],[4,5,6,7],[8,9,10,11]]
ntime_lags = np.shape(time_lags)[0]

temp_resolu = temp_resolus[0]
j_sampsize_name = samplingsizes_name[0]    
    
weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][basecondition]
weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][basecondition]
#sig_edges_tgt = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][basecondition]
sig_edges_tgt = get_significant_edges(weighted_graphs_tgt,weighted_graphs_shuffled_tgt)
           
# sig_edges_tgt = sig_edges_tgt*((weighted_graphs_tgt.mean(axis=0)>0.5)*1)

weighted_graphs_base = weighted_graphs_tgt

sig_edges_base = sig_edges_tgt

sig_avg_dags_base =  weighted_graphs_base.mean(axis = 0) * sig_edges_base
    
    
for ilag in np.arange(0,ntime_lags,1):
    
    time_lag_name = time_lags[ilag]
    fromRowID = fromRowIDs[ilag]
    
       
    for iDBN_group in np.arange(0,nDBN_groups,1):

        try:

            iDBN_group_typename = DBN_group_typenames[iDBN_group]


            weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            # sig_edges_tgt = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][iDBN_group_typename]
            sig_edges_tgt = get_significant_edges(weighted_graphs_tgt,weighted_graphs_shuffled_tgt)
           
            #sig_edges_tgt = sig_edges_tgt*((weighted_graphs_tgt.mean(axis=0)>0.5)*1)
            
            if 0:
                weighted_graphs_delta = (weighted_graphs_tgt-weighted_graphs_base)
                weighted_graphs_delta = weighted_graphs_delta.mean(axis=0)
                #
                sig_edges_delta = ((sig_edges_tgt+sig_edges_base)>0)*1
            else:
                weighted_graphs_delta,sig_edges_delta = Modulation_Index(weighted_graphs_base, weighted_graphs_tgt,
                                                                         sig_edges_base, sig_edges_tgt, 150)
                weighted_graphs_delta = weighted_graphs_delta.mean(axis=0)
                
            sig_avg_dags = weighted_graphs_delta * sig_edges_delta
            sig_avg_dags = sig_avg_dags[fromRowID,:]

            # plot
            axs[ilag*2+0,iDBN_group].set_title(iDBN_group_typename,fontsize=18)
            axs[ilag*2+0,iDBN_group].set_xlim([-0.5,1.5])
            axs[ilag*2+0,iDBN_group].set_ylim([-0.5,1.5])
            axs[ilag*2+0,iDBN_group].set_xticks([])
            axs[ilag*2+0,iDBN_group].set_xticklabels([])
            axs[ilag*2+0,iDBN_group].set_yticks([])
            axs[ilag*2+0,iDBN_group].set_yticklabels([])
            axs[ilag*2+0,iDBN_group].spines['top'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['right'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['bottom'].set_visible(False)
            axs[ilag*2+0,iDBN_group].spines['left'].set_visible(False)
            # axs[ilag*2+0,iDBN_group].axis('equal')

            for ieventnode in np.arange(0,nevents,1):
                # plot the event nodes
                axs[ilag*2+0,iDBN_group].plot(eventnode_locations[ieventnode][0],eventnode_locations[ieventnode][1],
                                              eventnodes_shape[ieventnode],markersize=60,markerfacecolor=eventnodes_color[ieventnode],
                                              markeredgecolor='none')              
                
                axs[ilag*2+0,iDBN_group].text(eventname_locations[ieventnode][0],eventname_locations[ieventnode][1],
                                       eventnames[ieventnode],fontsize=10)
                
                clmap = mpl.cm.get_cmap('bwr')
                
                # plot the event edges
                for ifromNode in np.arange(0,nevents,1):
                    for itoNode in np.arange(0,nevents,1):
                        edge_weight_tgt = sig_avg_dags[ifromNode,itoNode]
                        if edge_weight_tgt!=0:
                            if not ifromNode == itoNode:
                                #axs[ilag*2+0,iDBN_group].plot(eventnode_locations[ifromNode],eventnode_locations[itoNode],'k-',linewidth=edge_weight_tgt*3)
                                axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0],
                                                        nodearrow_locations[ifromNode][itoNode][1],
                                                        nodearrow_directions[ifromNode][itoNode][0],
                                                        nodearrow_directions[ifromNode][itoNode][1],
                                                        # head_width=0.08*abs(edge_weight_tgt),
                                                        # width=0.04*abs(edge_weight_tgt),
                                                        head_width=0.08,
                                                        width=0.04,       
                                                        color = clmap((1+edge_weight_tgt)/2))
                            if ifromNode == itoNode:
                                ring = mpatches.Wedge(nodearrow_locations[ifromNode][itoNode],
                                                      .1, nodearrow_directions[ifromNode][itoNode][0],
                                                      nodearrow_directions[ifromNode][itoNode][1], 
                                                      # 0.04*abs(edge_weight_tgt)
                                                      0.04
                                                     )
                                p = PatchCollection(
                                    [ring], 
                                    facecolor=clmap((1+edge_weight_tgt)/2), 
                                    edgecolor=clmap((1+edge_weight_tgt)/2)
                                )
                                axs[ilag*2+0,iDBN_group].add_collection(p)
                                # add arrow head
                                if ifromNode < 2:
                                    axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0]-0.1+0.02*edge_weight_tgt,
                                                            nodearrow_locations[ifromNode][itoNode][1],
                                                            0,-0.05,color=clmap((1+edge_weight_tgt)/2),
                                                            # head_width=0.08*edge_weight_tgt,width=0.04*edge_weight_tgt
                                                            head_width=0.08,width=0.04      
                                                            )
                                else:
                                    axs[ilag*2+0,iDBN_group].arrow(nodearrow_locations[ifromNode][itoNode][0]-0.1+0.02*edge_weight_tgt,
                                                            nodearrow_locations[ifromNode][itoNode][1],
                                                            0,0.02,color=clmap((1+edge_weight_tgt)/2),
                                                            # head_width=0.08*edge_weight_tgt,width=0.04*edge_weight_tgt
                                                            head_width=0.08,width=0.04      
                                                            )

            # heatmap for the weights
            sig_avg_dags_df = pd.DataFrame(sig_avg_dags)
            sig_avg_dags_df.columns = eventnames
            sig_avg_dags_df.index = eventnames
            vmin,vmax = -1,1
            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
            im = axs[ilag*2+1,iDBN_group].pcolormesh(sig_avg_dags_df,cmap="bwr",norm=norm)
            #-
            if iDBN_group == nDBN_groups-1:
                cax = axs[ilag*2+1,iDBN_group].inset_axes([1.04, 0.2, 0.05, 0.8])
                fig.colorbar(im, ax=axs[ilag*2+1,iDBN_group], cax=cax,label='edge confidence')

            axs[ilag*2+1,iDBN_group].axis('equal')
            axs[ilag*2+1,iDBN_group].set_xlabel('to Node',fontsize=14)
            axs[ilag*2+1,iDBN_group].set_xticks(np.arange(0.5,4.5,1))
            axs[ilag*2+1,iDBN_group].set_xticklabels(eventnames)
            if iDBN_group == 0:
                axs[ilag*2+1,iDBN_group].set_ylabel('from Node',fontsize=14)
                axs[ilag*2+1,iDBN_group].set_yticks(np.arange(0.5,4.5,1))
                axs[ilag*2+1,iDBN_group].set_yticklabels(eventnames)
                axs[ilag*2+1,iDBN_group].text(-1.5,1,time_lag_name+' time lag',rotation=90,fontsize=20)
                axs[ilag*2+0,iDBN_group].text(-1.25,0,time_lag_name+' time lag',rotation=90,fontsize=20)
            else:
                axs[ilag*2+1,iDBN_group].set_yticks([])
                axs[ilag*2+1,iDBN_group].set_yticklabels([])

        except:
            continue
    
    
if savefigs:
    figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'+animal1_fixedorder[0]+animal2_fixedorder[0]+'/'
    if not os.path.exists(figsavefolder):
        os.makedirs(figsavefolder)
    if moreSampSize:
        plt.savefig(figsavefolder+"threeTimeLag_DAGs_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'_'+str(j_sampsize_name)+'_rows_EdgeFifferenceFrom_'+basecondition+'AsBase.pdf')
    else:
        plt.savefig(figsavefolder+"threeTimeLag_DAGs_"+animal1_fixedorder[0]+animal2_fixedorder[0]+'_'+str(temp_resolu)+'_'+j_sampsize_name+'_EdgeFifferenceFrom_'+basecondition+'AsBase.pdf')
            
            
            

## Plots that include all pairs
## Plots the frequency/distribution of certain edges

### version 7-2-3-4:
#### similar as version 7-2-3, but group and plot differently
#### show coop1s, 1.5s, 2s, 3s in one plot, and show 1s, 2s, 3s time lags or merged

In [None]:
do_bestsession = 1 # only analyze the best (five) sessions for each conditions during the training phase
do_trainedMCs = 1 # the list that only consider trained (1s) MC, together with SR and NV as controls
if do_bestsession:
    if not do_trainedMCs:
        savefile_sufix = '_bestsessions'
    elif do_trainedMCs:
        savefile_sufix = '_trainedMCsessions'
else:
    savefile_sufix = ''


# PLOT multiple pairs in one plot, so need to load data seperately
moreSampSize = 0
#
animal1_fixedorders = ['eddie','dodson','dannon','ginger','koala']
animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
animal_pooled_list = ['E','SP','DO','SC','DA','KwDA','G','KwG','KO','V']

# animal1_fixedorders = ['eddie','dodson','ginger',]
# animal2_fixedorders = ['sparkle','scorch','kanga',]
# animal_pooled_list = ['E','SP','DO','SC','G','KwG',]
if do_trainedMCs:
    animal1_fixedorders = ['eddie','dodson','dannon','ginger','koala']
    animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
    animal_pooled_list = ['E','SP','DO','SC','DA','KwDA','G','KwG','KO','V']

# dannon kanga did not have coop 3s data
# animal1_fixedorders = ['eddie','dodson','ginger','koala']
# animal2_fixedorders = ['sparkle','scorch','kanga','vermelho']
# animal_pooled_list = ['E','SP','DO','SC','G','KwG','KO','V']

nanimalpairs = np.shape(animal1_fixedorders)[0]
nanimalpooled = np.shape(animal_pooled_list)[0]

nMIbootstraps = 150

# timelags = [1,2,3] # 1 or 2 or 3 or 0(merged - merge all three lags) or 12 (merged lag 1 and 2)
timelags = [0]
# timelags = [12]
# timelagnames = ['1secondlag','2secondlag','3secondlag'] # '1/2/3second' or 'merged' or '12merged'
timelagnames = ['merged'] # together with timelag = 0
# timelagnames = ['12merged'] # together with timelag = 12
ntimelags_forplot = np.shape(timelags)[0]


MI_basetype = 'self' # 'self'; other options: 'coop(2s)', 'coop(1.5s)'
# MI_comptypes = ['no-vision','coop(3s)','coop(2s)','coop(1.5s)','coop(1s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
# MI_comptypes = ['coop(3s)','coop(2s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
MI_comptypes = ['coop(3s)','coop(2s)','coop(1.5s)','coop(1s)']
MI_conttype = 'no-vision' # 'no-vision'; other options: 'coop(2s)', 'coop(1.5s)'
if do_trainedMCs:
    MI_basetype = 'self' # 'self'; other options: 'coop(2s)', 'coop(1.5s)'
    # MI_comptypes = ['no-vision','coop(1s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
    # MI_comptypes = ['coop(3s)','coop(2s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
    MI_comptypes = ['coop(1s)']
    MI_conttype = 'no-vision' # 'no-vision'; other options: 'coop(2s)', 'coop(1.5s)'
nMI_comptypes = np.shape(MI_comptypes)[0]

# for plot
dependencynames = ['pull-pull','gaze-gaze','within_gazepull','across_gazepull','within_pullgaze','across_pullgaze']
# dependencytargets = dependencynames
dependencytargets = ['pull-pull','within_gazepull','across_pullgaze']
# dependencytargets = ['pull-pull','within_gazepull','across_pullgaze','pullgaze_merged']
ndeptargets = np.shape(dependencytargets)[0]
    
#
fig, axs = plt.subplots(ntimelags_forplot, ndeptargets)
fig.set_figheight(10*ntimelags_forplot)
fig.set_figwidth(10*ndeptargets)

#
for itimelag in np.arange(0,ntimelags_forplot,1):
    timelag = timelags[itimelag]
    timelagname = timelagnames[itimelag]
    
    MI_coop_self_all_IndiAni_pooled = dict.fromkeys(MI_comptypes,[])
    MI_nov_coop_all_IndiAni_pooled =  dict.fromkeys(MI_comptypes,[])
    MI_coop_self_mean_IndiAni_pooled =dict.fromkeys(MI_comptypes,[])
    MI_nov_coop_mean_IndiAni_pooled = dict.fromkeys(MI_comptypes,[])
    
    for iMI_comptype in np.arange(0,nMI_comptypes,1):
        MI_comptype = MI_comptypes[iMI_comptype]
        
        # 
        MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps,6])
        MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps,6])
        MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        
        ntimelags = 1
        if timelag == 0:
            ntimelags = 3
            MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*3,6])
            MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
            MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*3,6])
            MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        if timelag == 12:
            ntimelags = 2
            MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*2,6])
            MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
            MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*2,6])
            MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])

            

        for ianimalpair in np.arange(0,nanimalpairs,1):
            animal1_fixedorder = animal1_fixedorders[ianimalpair]
            animal2_fixedorder = animal2_fixedorders[ianimalpair]
            #
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags_withinLayerEdges/'+cameraID+'/'+animal1_fixedorder+animal2_fixedorder+'/'
            #
            if moreSampSize:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)

            # make sure these variables are the same as in the previous steps
            # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            ntemp_reses = np.shape(temp_resolus)[0]
            #
            if moreSampSize:
                # different data (down/re)sampling numbers
                # samplingsizes = np.arange(1100,3000,100)
                samplingsizes = [1100]
                # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
                # samplingsizes = [100,500]
                # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
                samplingsizes_name = list(map(str, samplingsizes))
            else:
                samplingsizes_name = ['min_row_number']   
            nsamplings = np.shape(samplingsizes_name)[0]

            #
            temp_resolu = temp_resolus[0]
            j_sampsize_name = samplingsizes_name[0]    

            # load edge weight data    
            weighted_graphs_self = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            weighted_graphs_sf_self = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            sig_edges_self = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            #
            weighted_graphs_coop = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            weighted_graphs_sf_coop = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            sig_edges_coop = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            #
            weighted_graphs_nov = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]
            weighted_graphs_sf_nov = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]
            sig_edges_nov = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]

            ## mainly for dannon kanga because they dont have coop3s
            if np.shape(weighted_graphs_coop)[0]==0:
                weighted_graphs_coop = np.ones((95,12,4))*np.nan
                weighted_graphs_sf_coop = np.ones((95,12,4))*np.nan
                sig_edges_coop = np.ones((12,4))*np.nan
            
            # organize the key edge data
            weighted_graphs_self_mean = weighted_graphs_self.mean(axis=0)
            weighted_graphs_coop_mean = weighted_graphs_coop.mean(axis=0)
            weighted_graphs_nov_mean = weighted_graphs_nov.mean(axis=0)
            # MI_coop_self = (weighted_graphs_coop_mean-weighted_graphs_self_mean)/(weighted_graphs_coop_mean+weighted_graphs_self_mean)
            # MI_nov_coop = (weighted_graphs_nov_mean-weighted_graphs_coop_mean)/(weighted_graphs_nov_mean+weighted_graphs_coop_mean)
            # MI_coop_self = ((weighted_graphs_coop-weighted_graphs_self)/(weighted_graphs_coop+weighted_graphs_self)).mean(axis=0)
            # MI_nov_coop = ((weighted_graphs_nov-weighted_graphs_coop)/(weighted_graphs_nov+weighted_graphs_coop)).mean(axis=0)
            #
            if 0:
                MI_coop_self_all = weighted_graphs_coop-weighted_graphs_self
                MI_nov_coop_all = weighted_graphs_nov-weighted_graphs_coop  
                MI_coop_self = (weighted_graphs_coop-weighted_graphs_self).mean(axis=0)
                MI_nov_coop = (weighted_graphs_nov-weighted_graphs_coop).mean(axis=0)
                #
                sig_edges_coop_self = ((sig_edges_coop+sig_edges_self)>0)*1
                sig_edges_nov_coop = ((sig_edges_coop+sig_edges_nov)>0)*1
                #
                MI_coop_self = MI_coop_self * sig_edges_coop_self
                MI_nov_coop = MI_nov_coop * sig_edges_nov_coop
                #
                nMIbootstraps = 1
            else:
                nMIbootstraps = 150
                #
                MI_coop_self_all,sig_edges_coop_self = Modulation_Index(weighted_graphs_self, weighted_graphs_coop,
                                                  sig_edges_self, sig_edges_coop, nMIbootstraps)
                
                # sig_edges_coop_self = sig_edges_coop_self.astype('float')
                # sig_edges_coop_self[sig_edges_coop_self==0]=np.nan
                
                # MI_coop_self_all = MI_coop_self_all * sig_edges_coop_self
                # MI_coop_self_all[MI_coop_self_all==0] = np.nan
                MI_coop_self = np.nanmean(MI_coop_self_all,axis = 0)
                # MI_coop_self = MI_coop_self * sig_edges_coop_self
                
                MI_nov_coop_all,sig_edges_nov_coop  = Modulation_Index(weighted_graphs_coop, weighted_graphs_nov,
                                                  sig_edges_coop, sig_edges_nov, nMIbootstraps)
                
                # sig_edges_nov_coop = sig_edges_nov_coop.astype('float')
                # sig_edges_nov_coop[sig_edges_nov_coop==0] = np.nan
                
                # MI_nov_coop_all = MI_nov_coop_all * sig_edges_nov_coop
                # MI_nov_coop_all[MI_nov_coop_all==0] = np.nan
                MI_nov_coop = np.nanmean(MI_nov_coop_all,axis = 0)
                # MI_nov_coop = MI_nov_coop * sig_edges_nov_coop
    
            #
            if timelag == 1:
                pull_pull_fromNodes_all = [9,8]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [11,10]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [8,9]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [9,8]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [10,11]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [11,10]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 2:
                pull_pull_fromNodes_all = [5,4]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [7,6]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [4,5]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [5,4]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [6,7]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [7,6]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 3:
                pull_pull_fromNodes_all = [1,0]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [3,2]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [0,1]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [1,0]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [2,3]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [3,2]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 0:
                pull_pull_fromNodes_all = [[1,5,9],[0,4,8]]
                pull_pull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                gaze_gaze_fromNodes_all = [[3,7,11],[2,6,10]]
                gaze_gaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                within_pullgaze_fromNodes_all = [[0,4,8],[1,5,9]]
                within_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                across_pullgaze_fromNodes_all = [[1,5,9],[0,4,8]]
                across_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                within_gazepull_fromNodes_all = [[2,6,10],[3,7,11]]
                within_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                across_gazepull_fromNodes_all = [[3,7,11],[2,6,10]]
                across_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                ntimelags = 3
                #
            elif timelag == 12:
                pull_pull_fromNodes_all = [[5,9],[4,8]]
                pull_pull_toNodes_all = [[0,0],[1,1]]
                #
                gaze_gaze_fromNodes_all = [[7,11],[6,10]]
                gaze_gaze_toNodes_all = [[2,2],[3,3]]
                #
                within_pullgaze_fromNodes_all = [[4,8],[5,9]]
                within_pullgaze_toNodes_all = [[2,2],[3,3]]
                #
                across_pullgaze_fromNodes_all = [[5,9],[4,8]]
                across_pullgaze_toNodes_all = [[2,2],[3,3]]
                #
                within_gazepull_fromNodes_all = [[6,10],[7,11]]
                within_gazepull_toNodes_all = [[0,0],[1,1]]
                #
                across_gazepull_fromNodes_all = [[7,11],[6,10]]
                across_gazepull_toNodes_all = [[0,0],[1,1]]
                #
                ntimelags = 2
                #
    
    
            for ianimal in np.arange(0,2,1):

                # coop self modulation
                # pull-pull
                a1 = MI_coop_self_all[:,pull_pull_fromNodes_all[ianimal],pull_pull_toNodes_all[ianimal]].flatten()
                xxx1 = np.nanmean(a1)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,0] = a1
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,0] = xxx1
                # gaze-gaze
                a2 = (MI_coop_self_all[:,gaze_gaze_fromNodes_all[ianimal],gaze_gaze_toNodes_all[ianimal]]).flatten()
                xxx2 = np.nanmean(a2)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,1] = a2
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,1] = xxx2
                # within animal gazepull
                a3 = (MI_coop_self_all[:,within_gazepull_fromNodes_all[ianimal],within_gazepull_toNodes_all[ianimal]]).flatten()
                xxx3 = np.nanmean(a3)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,2] = a3
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,2] = xxx3
                # across animal gazepull
                a4 = (MI_coop_self_all[:,across_gazepull_fromNodes_all[ianimal],across_gazepull_toNodes_all[ianimal]]).flatten()
                xxx4 = np.nanmean(a4)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,3] = a4
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,3] = xxx4
                # within animal pullgaze
                a5 = (MI_coop_self_all[:,within_pullgaze_fromNodes_all[ianimal],within_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx5 = np.nanmean(a5)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,4] = a5
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,4] = xxx5
                # across animal pullgaze
                a6 = (MI_coop_self_all[:,across_pullgaze_fromNodes_all[ianimal],across_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx6 = np.nanmean(a6)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,5] = a6
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,5] = xxx6


                # novision coop modulation
                # pull-pull
                a1 = MI_nov_coop_all[:,pull_pull_fromNodes_all[ianimal],pull_pull_toNodes_all[ianimal]].flatten()
                xxx1 = np.nanmean(a1)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,0] = a1
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,0] = xxx1
                # gaze-gaze
                a2 = (MI_nov_coop_all[:,gaze_gaze_fromNodes_all[ianimal],gaze_gaze_toNodes_all[ianimal]]).flatten()
                xxx2 = np.nanmean(a2)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,1] = a2
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,1] = xxx2
                # within animal gazepull
                a3 = (MI_nov_coop_all[:,within_gazepull_fromNodes_all[ianimal],within_gazepull_toNodes_all[ianimal]]).flatten()
                xxx3 = np.nanmean(a3)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,2] = a3
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,2] = xxx3
                # across animal gazepull
                a4 = (MI_nov_coop_all[:,across_gazepull_fromNodes_all[ianimal],across_gazepull_toNodes_all[ianimal]]).flatten()
                xxx4 = np.nanmean(a4)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,3] = a4
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,3] = xxx4
                # within animal pullgaze
                a5 = (MI_nov_coop_all[:,within_pullgaze_fromNodes_all[ianimal],within_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx5 = np.nanmean(a5)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,4] = a5
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,4] = xxx5
                # across animal pullgaze
                a6 = (MI_nov_coop_all[:,across_pullgaze_fromNodes_all[ianimal],across_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx6 = np.nanmean(a6)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,5] = a6
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,5] = xxx6


        # prepare the data
        # average all animals for each dependency
        MI_coop_self_all_IndiAni_all = MI_coop_self_all_IndiAni.reshape(nanimalpooled*nMIbootstraps*ntimelags,6)
        MI_nov_coop_all_IndiAni_all = MI_nov_coop_all_IndiAni.reshape(nanimalpooled*nMIbootstraps*ntimelags,6)
        MI_coop_self_all_IndiAni_allmean = np.nanmean(MI_coop_self_all_IndiAni_all,axis=0)
        MI_nov_coop_all_IndiAni_allmean = np.nanmean(MI_nov_coop_all_IndiAni_all,axis=0) 
        MI_coop_self_all_IndiAni_allse = np.nanstd(MI_coop_self_all_IndiAni_all,axis=0)/np.sqrt(nanimalpooled*nMIbootstraps*ntimelags) 
        MI_nov_coop_all_IndiAni_allse = np.nanstd(MI_nov_coop_all_IndiAni_all,axis=0)/np.sqrt(nanimalpooled*nMIbootstraps*ntimelags) 


        # pool everything together
        MI_coop_self_all_IndiAni_pooled[MI_comptype] = MI_coop_self_all_IndiAni_all
        MI_nov_coop_all_IndiAni_pooled[MI_comptype] = MI_nov_coop_all_IndiAni_all
        MI_coop_self_mean_IndiAni_pooled[MI_comptype] = MI_coop_self_mean_IndiAni
        MI_nov_coop_mean_IndiAni_pooled[MI_comptype] = MI_nov_coop_mean_IndiAni
    

    

    # plot 

    # all bootstraps
    # coop (1s)
    MI_coop_self_all_IndiAni_MC1s_df = pd.DataFrame(MI_coop_self_all_IndiAni_pooled['coop(1s)'])
    MI_coop_self_all_IndiAni_MC1s_df.columns = dependencynames
    MI_coop_self_all_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
    MI_coop_self_all_IndiAni_MC1s_df['CTtype'] = MI_basetype
    MI_coop_self_all_IndiAni_MC1s_df['pullgaze_merged'] = (MI_coop_self_all_IndiAni_MC1s_df['within_gazepull']+MI_coop_self_all_IndiAni_MC1s_df['across_pullgaze'])/2
    #
    MI_nov_coop_all_IndiAni_MC1s_df = pd.DataFrame(MI_nov_coop_all_IndiAni_pooled['coop(1s)'])
    MI_nov_coop_all_IndiAni_MC1s_df.columns = dependencynames
    MI_nov_coop_all_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
    MI_nov_coop_all_IndiAni_MC1s_df['CTtype'] = MI_conttype
    MI_nov_coop_all_IndiAni_MC1s_df['pullgaze_merged'] = (MI_nov_coop_all_IndiAni_MC1s_df['within_gazepull']+MI_nov_coop_all_IndiAni_MC1s_df['across_pullgaze'])/2

    if 0:
        # no vision
        MI_coop_self_all_IndiAni_NV_df = pd.DataFrame(MI_coop_self_all_IndiAni_pooled['no-vision'])
        MI_coop_self_all_IndiAni_NV_df.columns = dependencynames
        MI_coop_self_all_IndiAni_NV_df['MItype'] = 'no-vision'
        MI_coop_self_all_IndiAni_NV_df['CTtype'] = MI_basetype
        MI_coop_self_all_IndiAni_NV_df['pullgaze_merged'] = (MI_coop_self_all_IndiAni_NV_df['within_gazepull']+MI_coop_self_all_IndiAni_NV_df['across_pullgaze'])/2
        #
        MI_nov_coop_all_IndiAni_NV_df = pd.DataFrame(MI_nov_coop_all_IndiAni_pooled['no-vision'])
        MI_nov_coop_all_IndiAni_NV_df.columns = dependencynames
        MI_nov_coop_all_IndiAni_NV_df['MItype'] = 'no-vision'
        MI_nov_coop_all_IndiAni_NV_df['CTtype'] = MI_conttype
        MI_nov_coop_all_IndiAni_NV_df['pullgaze_merged'] = (MI_nov_coop_all_IndiAni_NV_df['within_gazepull']+MI_nov_coop_all_IndiAni_NV_df['across_pullgaze'])/2

    
    if not do_trainedMCs:
        # coop (1.5s)
        MI_coop_self_all_IndiAni_MC15s_df = pd.DataFrame(MI_coop_self_all_IndiAni_pooled['coop(1.5s)'])
        MI_coop_self_all_IndiAni_MC15s_df.columns = dependencynames
        MI_coop_self_all_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'
        MI_coop_self_all_IndiAni_MC15s_df['CTtype'] = MI_basetype
        MI_coop_self_all_IndiAni_MC15s_df['pullgaze_merged'] = (MI_coop_self_all_IndiAni_MC15s_df['within_gazepull']+MI_coop_self_all_IndiAni_MC15s_df['across_pullgaze'])/2
        #
        MI_nov_coop_all_IndiAni_MC15s_df = pd.DataFrame(MI_nov_coop_all_IndiAni_pooled['coop(1.5s)'])
        MI_nov_coop_all_IndiAni_MC15s_df.columns = dependencynames
        MI_nov_coop_all_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'
        MI_nov_coop_all_IndiAni_MC15s_df['CTtype'] = MI_conttype
        MI_nov_coop_all_IndiAni_MC15s_df['pullgaze_merged'] = (MI_nov_coop_all_IndiAni_MC15s_df['within_gazepull']+MI_nov_coop_all_IndiAni_MC15s_df['across_pullgaze'])/2

        # coop (2s)
        MI_coop_self_all_IndiAni_MC2s_df = pd.DataFrame(MI_coop_self_all_IndiAni_pooled['coop(2s)'])
        MI_coop_self_all_IndiAni_MC2s_df.columns = dependencynames
        MI_coop_self_all_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
        MI_coop_self_all_IndiAni_MC2s_df['CTtype'] = MI_basetype
        MI_coop_self_all_IndiAni_MC2s_df['pullgaze_merged'] = (MI_coop_self_all_IndiAni_MC2s_df['within_gazepull']+MI_coop_self_all_IndiAni_MC2s_df['across_pullgaze'])/2
        #
        MI_nov_coop_all_IndiAni_MC2s_df = pd.DataFrame(MI_nov_coop_all_IndiAni_pooled['coop(2s)'])
        MI_nov_coop_all_IndiAni_MC2s_df.columns = dependencynames
        MI_nov_coop_all_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
        MI_nov_coop_all_IndiAni_MC2s_df['CTtype'] = MI_conttype
        MI_nov_coop_all_IndiAni_MC2s_df['pullgaze_merged'] = (MI_nov_coop_all_IndiAni_MC2s_df['within_gazepull']+MI_nov_coop_all_IndiAni_MC2s_df['across_pullgaze'])/2

        # coop (3s)
        MI_coop_self_all_IndiAni_MC3s_df = pd.DataFrame(MI_coop_self_all_IndiAni_pooled['coop(3s)'])
        MI_coop_self_all_IndiAni_MC3s_df.columns = dependencynames
        MI_coop_self_all_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
        MI_coop_self_all_IndiAni_MC3s_df['CTtype'] = MI_basetype
        MI_coop_self_all_IndiAni_MC3s_df['pullgaze_merged'] = (MI_coop_self_all_IndiAni_MC3s_df['within_gazepull']+MI_coop_self_all_IndiAni_MC3s_df['across_pullgaze'])/2
        #
        MI_nov_coop_all_IndiAni_MC3s_df = pd.DataFrame(MI_nov_coop_all_IndiAni_pooled['coop(3s)'])
        MI_nov_coop_all_IndiAni_MC3s_df.columns = dependencynames
        MI_nov_coop_all_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
        MI_nov_coop_all_IndiAni_MC3s_df['CTtype'] = MI_conttype
        MI_nov_coop_all_IndiAni_MC3s_df['pullgaze_merged'] = (MI_nov_coop_all_IndiAni_MC3s_df['within_gazepull']+MI_nov_coop_all_IndiAni_MC3s_df['across_pullgaze'])/2
    
        #
        df_long_bt =pd.concat([MI_coop_self_all_IndiAni_MC3s_df,MI_nov_coop_all_IndiAni_MC3s_df,
                               MI_coop_self_all_IndiAni_MC2s_df,MI_nov_coop_all_IndiAni_MC2s_df,
                               MI_coop_self_all_IndiAni_MC15s_df,MI_nov_coop_all_IndiAni_MC15s_df,
                               MI_coop_self_all_IndiAni_MC1s_df,MI_nov_coop_all_IndiAni_MC1s_df,
                               # MI_coop_self_all_IndiAni_NV_df,MI_nov_coop_all_IndiAni_NV_df,
                              ])
        
    elif do_trainedMCs:
        df_long_bt =pd.concat([MI_coop_self_all_IndiAni_MC1s_df,MI_nov_coop_all_IndiAni_MC1s_df,
                              #  MI_coop_self_all_IndiAni_NV_df,MI_nov_coop_all_IndiAni_NV_df
                              ])
        
    df_long2_bt = df_long_bt.melt(id_vars=['MItype','CTtype'], value_vars=dependencytargets,
                                  var_name='condition', value_name='value')

    #
    # average for each animal individuals
    # coop (1s)
    MI_coop_self_mean_IndiAni_MC1s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(1s)'])
    MI_coop_self_mean_IndiAni_MC1s_df.columns = dependencynames
    MI_coop_self_mean_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
    MI_coop_self_mean_IndiAni_MC1s_df['CTtype'] = MI_basetype
    MI_coop_self_mean_IndiAni_MC1s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC1s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC1s_df['across_pullgaze'])/2
    #
    MI_nov_coop_mean_IndiAni_MC1s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(1s)'])
    MI_nov_coop_mean_IndiAni_MC1s_df.columns = dependencynames
    MI_nov_coop_mean_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
    MI_nov_coop_mean_IndiAni_MC1s_df['CTtype'] = MI_conttype
    MI_nov_coop_mean_IndiAni_MC1s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC1s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC1s_df['across_pullgaze'])/2

    if 0:
        # no vision
        MI_coop_self_mean_IndiAni_NV_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['no-vision'])
        MI_coop_self_mean_IndiAni_NV_df.columns = dependencynames
        MI_coop_self_mean_IndiAni_NV_df['MItype'] = 'no-vision'
        MI_coop_self_mean_IndiAni_NV_df['CTtype'] = MI_basetype
        MI_coop_self_mean_IndiAni_NV_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_NV_df['within_gazepull']+MI_coop_self_mean_IndiAni_NV_df['across_pullgaze'])/2
        #
        MI_nov_coop_mean_IndiAni_NV_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['no-vision'])
        MI_nov_coop_mean_IndiAni_NV_df.columns = dependencynames
        MI_nov_coop_mean_IndiAni_NV_df['MItype'] = 'no-vision'
        MI_nov_coop_mean_IndiAni_NV_df['CTtype'] = MI_conttype
        MI_nov_coop_mean_IndiAni_NV_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_NV_df['within_gazepull']+MI_nov_coop_mean_IndiAni_NV_df['across_pullgaze'])/2

    
    if not do_trainedMCs:
        # coop (1.5s)
        MI_coop_self_mean_IndiAni_MC15s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(1.5s)'])
        MI_coop_self_mean_IndiAni_MC15s_df.columns = dependencynames
        MI_coop_self_mean_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'
        MI_coop_self_mean_IndiAni_MC15s_df['CTtype'] = MI_basetype
        MI_coop_self_mean_IndiAni_MC15s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC15s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC15s_df['across_pullgaze'])/2
        #
        MI_nov_coop_mean_IndiAni_MC15s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(1.5s)'])
        MI_nov_coop_mean_IndiAni_MC15s_df.columns = dependencynames
        MI_nov_coop_mean_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'
        MI_nov_coop_mean_IndiAni_MC15s_df['CTtype'] = MI_conttype
        MI_nov_coop_mean_IndiAni_MC15s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC15s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC15s_df['across_pullgaze'])/2

        # coop (2s)
        MI_coop_self_mean_IndiAni_MC2s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(2s)'])
        MI_coop_self_mean_IndiAni_MC2s_df.columns = dependencynames
        MI_coop_self_mean_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
        MI_coop_self_mean_IndiAni_MC2s_df['CTtype'] = MI_basetype
        MI_coop_self_mean_IndiAni_MC2s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC2s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC2s_df['across_pullgaze'])/2
        #
        MI_nov_coop_mean_IndiAni_MC2s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(2s)'])
        MI_nov_coop_mean_IndiAni_MC2s_df.columns = dependencynames
        MI_nov_coop_mean_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
        MI_nov_coop_mean_IndiAni_MC2s_df['CTtype'] = MI_conttype
        MI_nov_coop_mean_IndiAni_MC2s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC2s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC2s_df['across_pullgaze'])/2

        # coop (3s)
        MI_coop_self_mean_IndiAni_MC3s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(3s)'])
        MI_coop_self_mean_IndiAni_MC3s_df.columns = dependencynames
        MI_coop_self_mean_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
        MI_coop_self_mean_IndiAni_MC3s_df['CTtype'] = MI_basetype
        MI_coop_self_mean_IndiAni_MC3s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC3s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC3s_df['across_pullgaze'])/2
        #
        MI_nov_coop_mean_IndiAni_MC3s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(3s)'])
        MI_nov_coop_mean_IndiAni_MC3s_df.columns = dependencynames
        MI_nov_coop_mean_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
        MI_nov_coop_mean_IndiAni_MC3s_df['CTtype'] = MI_conttype
        MI_nov_coop_mean_IndiAni_MC3s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC3s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC3s_df['across_pullgaze'])/2

        df_long=pd.concat([MI_coop_self_mean_IndiAni_MC3s_df,MI_nov_coop_mean_IndiAni_MC3s_df,
                           MI_coop_self_mean_IndiAni_MC2s_df,MI_nov_coop_mean_IndiAni_MC2s_df,
                           MI_coop_self_mean_IndiAni_MC15s_df,MI_nov_coop_mean_IndiAni_MC15s_df,
                           MI_coop_self_mean_IndiAni_MC1s_df,MI_nov_coop_mean_IndiAni_MC1s_df,
                           # MI_coop_self_mean_IndiAni_NV_df,MI_nov_coop_mean_IndiAni_NV_df,
                          ])
    elif do_trainedMCs:
        df_long=pd.concat([MI_coop_self_mean_IndiAni_MC1s_df,MI_nov_coop_mean_IndiAni_MC1s_df,
                          # MI_coop_self_mean_IndiAni_NV_df,MI_nov_coop_mean_IndiAni_NV_df
                          ],)
        
    df_long2 = df_long.melt(id_vars=['MItype','CTtype'], value_vars=dependencytargets,
                        var_name='condition', value_name='value')
    
    # for plot
    for idep in np.arange(0,ndeptargets,1):
        ind = df_long2.condition==dependencytargets[idep]
        # ind = df_long2_bt.condition==dependencytargets[idep]
        #
        if ntimelags_forplot == 1:
            # seaborn.lineplot(ax=axs[idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            # # seaborn.lineplot(ax=axs[idep],data=df_long2_bt[ind],x='MItype',y='value',hue='CTtype')
            # seaborn.boxplot(ax=axs[idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            seaborn.violinplot(ax=axs[idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            axs[idep].plot([0,3],[0,0],'k--')
            axs[idep].set_ylabel('Modulation Index',fontsize=20)
            axs[idep].set_title(timelagname+' '+dependencytargets[idep],fontsize=24)
            axs[idep].set_ylim([-2.02,2.02])
            #
            # add statistics
            CTtypes = [MI_basetype,MI_conttype]
            CTtype_plotlocs = [.75,-.75]
            nCTtypes = np.shape(CTtypes)[0]
            for iMItype in np.arange(0,nMI_comptypes,1):
                MItype_totest = MI_comptypes[iMItype]
                #
                for iCTtype in np.arange(0,nCTtypes,1):
                    CTtype_totest = CTtypes[iCTtype]
                    # 
                    ind_totest = (df_long2['condition']==dependencytargets[idep])&(df_long2['MItype']==MItype_totest)&(df_long2['CTtype']==CTtype_totest)
                    data_totest = np.array(df_long2[ind_totest]['value'])
                    # pp = st.ttest_1samp(data_totest[~np.isnan(data_totest)],0).pvalue
                    pp = st.wilcoxon(data_totest[~np.isnan(data_totest)]).pvalue
                    # 
                    if pp<=0.001:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'***',fontsize=20)
                    elif pp<=0.01:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'**',fontsize=20)
                    elif pp<=0.05:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'*',fontsize=20)
                    
                    
        else:
            # seaborn.lineplot(ax=axs[itimelag,idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            # # seaborn.lineplot(ax=axs[itimelag,idep],data=df_long2_bt[ind],x='MItype',y='value',hue='CTtype')
            seaborn.violinplot(ax=axs[itimelag,idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            axs[itimelag,idep].plot([0,3],[0,0],'k--')
            axs[itimelag,idep].set_ylabel('Modulation Index',fontsize=20)
            axs[itimelag,idep].set_title(timelagname+' '+dependencytargets[idep],fontsize=24)
            axs[itimelag,idep].set_ylim([-2.02,2.02])
            #
            # add statistics
            CTtypes = [MI_basetype,MI_conttype]
            CTtype_plotlocs = [.75,-.75]
            nCTtypes = np.shape(CTtypes)[0]
            for iMItype in np.arange(0,nMI_comptypes,1):
                MItype_totest = MI_comptypes[iMItype]
                #
                for iCTtype in np.arange(0,nCTtypes,1):
                    CTtype_totest = CTtypes[iCTtype]
                    # 
                    ind_totest = (df_long2['condition']==dependencytargets[idep])&(df_long2['MItype']==MItype_totest)&(df_long2['CTtype']==CTtype_totest)
                    data_totest = np.array(df_long2[ind_totest]['value'])
                    # pp = st.ttest_1samp(data_totest[~np.isnan(data_totest)],0).pvalue
                    pp = st.wilcoxon(data_totest[~np.isnan(data_totest)]).pvalue
                    # 
                    if pp<=0.001:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'***',fontsize=20)
                    elif pp<=0.01:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'**',fontsize=20)
                    elif pp<=0.05:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'*',fontsize=20)
        

savefig = 1
if savefig:
    if moreSampSize:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)
        plt.savefig(figsavefolder+'threeTimeLag_Edge_ModulationIndex_'+timelagname+'Lag_IndiAnimal_summarized_'+str(temp_resolu)+'_'+str(j_sampsize_name)+'_rows_subset_basedonToNodes_multiTimeLag_multiCoopsOnePanel.pdf')
    else:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)
        plt.savefig(figsavefolder+'threeTimeLag_Edge_ModulationIndex_'+timelagname+'Lag_IndiAnimal_summarized_'+str(temp_resolu)+'_'+j_sampsize_name+'_subset_basedonToNodes_multiTimeLag_multiCoopsOnePanel.pdf')
           
    


In [None]:
scipy.stats.ranksums(np.array(MI_coop_self_mean_IndiAni_MC1s_df['within_gazepull']),np.array(MI_coop_self_mean_IndiAni_MC1s_df['within_gazepull']))



In [None]:
df_long[df_long['CTtype']=='self']

##### same as the previous plot, but plot based on male, female, dom, or sub

In [None]:
do_bestsession = 1 # only analyze the best (five) sessions for each conditions during the training phase
do_trainedMCs = 1 # the list that only consider trained (1s) MC, together with SR and NV as controls
if do_bestsession:
    if not do_trainedMCs:
        savefile_sufix = '_bestsessions'
    elif do_trainedMCs:
        savefile_sufix = '_trainedMCsessions'
else:
    savefile_sufix = ''


# PLOT multiple pairs in one plot, so need to load data seperately
moreSampSize = 0
#
animal1_fixedorders = ['eddie','dodson','dannon','ginger','koala']
animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
animal_pooled_list = ['E','SP','DO','SC','DA','KwDA','G','KwG','KO','V']

# animal1_fixedorders = ['eddie','dodson','ginger',]
# animal2_fixedorders = ['sparkle','scorch','kanga',]
# animal_pooled_list = ['E','SP','DO','SC','G','KwG',]
if do_trainedMCs:
    animal1_fixedorders = ['eddie','dodson','dannon','ginger','koala']
    animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
    animal_pooled_list = ['E','SP','DO','SC','DA','KwDA','G','KwG','KO','V']
    

# dannon kanga did not have coop 3s data
# animal1_fixedorders = ['eddie','dodson','ginger','koala']
# animal2_fixedorders = ['sparkle','scorch','kanga','vermelho']
# animal_pooled_list = ['E','SP','DO','SC','G','KwG','KO','V']

nanimalpairs = np.shape(animal1_fixedorders)[0]
nanimalpooled = np.shape(animal_pooled_list)[0]

nMIbootstraps = 150

# timelags = [1,2,3] # 1 or 2 or 3 or 0(merged - merge all three lags) or 12 (merged lag 1 and 2)
timelags = [0]
# timelags = [12]
# timelagnames = ['1secondlag','2secondlag','3secondlag'] # '1/2/3second' or 'merged' or '12merged'
timelagnames = ['merged'] # together with timelag = 0
# timelagnames = ['12merged'] # together with timelag = 12
ntimelags_forplot = np.shape(timelags)[0]


MI_basetype = 'self' # 'self'; other options: 'coop(2s)', 'coop(1.5s)'
# MI_comptypes = ['coop(3s)','coop(2s)','coop(1.5s)','coop(1s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
MI_comptypes = ['coop(1s)'] # coop(1s)'; other options: 'coop(2s)', 'coop(1.5s)'
MI_conttype = 'no-vision' # 'no-vision'; other options: 'coop(2s)', 'coop(1.5s)'
nMI_comptypes = np.shape(MI_comptypes)[0]

# for plot
dependencynames = ['pull-pull','gaze-gaze','within_gazepull','across_gazepull','within_pullgaze','across_pullgaze']
# dependencytargets = dependencynames
dependencytargets = ['pull-pull','within_gazepull','across_pullgaze']
# dependencytargets = ['pull-pull','within_gazepull','across_pullgaze','pullgaze_merged']
ndeptargets = np.shape(dependencytargets)[0]
    
#
fig, axs = plt.subplots(ntimelags_forplot, ndeptargets)
fig.set_figheight(10*ntimelags_forplot)
fig.set_figwidth(10*ndeptargets)

#
for itimelag in np.arange(0,ntimelags_forplot,1):
    timelag = timelags[itimelag]
    timelagname = timelagnames[itimelag]
    
    MI_coop_self_all_IndiAni_pooled = dict.fromkeys(MI_comptypes,[])
    MI_nov_coop_all_IndiAni_pooled =  dict.fromkeys(MI_comptypes,[])
    MI_coop_self_mean_IndiAni_pooled =dict.fromkeys(MI_comptypes,[])
    MI_nov_coop_mean_IndiAni_pooled = dict.fromkeys(MI_comptypes,[])
    
    for iMI_comptype in np.arange(0,nMI_comptypes,1):
        MI_comptype = MI_comptypes[iMI_comptype]
        
        # 
        MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps,6])
        MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps,6])
        MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        
        ntimelags = 1
        if timelag == 0:
            ntimelags = 3
            MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*3,6])
            MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
            MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*3,6])
            MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])
        if timelag == 12:
            ntimelags = 2
            MI_coop_self_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*2,6])
            MI_coop_self_mean_IndiAni = np.zeros([nanimalpairs*2,6])
            MI_nov_coop_all_IndiAni = np.zeros([nanimalpairs*2,nMIbootstraps*2,6])
            MI_nov_coop_mean_IndiAni = np.zeros([nanimalpairs*2,6])

            

        for ianimalpair in np.arange(0,nanimalpairs,1):
            animal1_fixedorder = animal1_fixedorders[ianimalpair]
            animal2_fixedorder = animal2_fixedorders[ianimalpair]
            #
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags_withinLayerEdges/'+cameraID+'/'+animal1_fixedorder+animal2_fixedorder+'/'
            #
            if moreSampSize:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'_moreSampSize.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_fixedorder+animal2_fixedorder+'.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)

            # make sure these variables are the same as in the previous steps
            # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            ntemp_reses = np.shape(temp_resolus)[0]
            #
            if moreSampSize:
                # different data (down/re)sampling numbers
                # samplingsizes = np.arange(1100,3000,100)
                samplingsizes = [1100]
                # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
                # samplingsizes = [100,500]
                # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
                samplingsizes_name = list(map(str, samplingsizes))
            else:
                samplingsizes_name = ['min_row_number']   
            nsamplings = np.shape(samplingsizes_name)[0]

            #
            temp_resolu = temp_resolus[0]
            j_sampsize_name = samplingsizes_name[0]    

            # load edge weight data    
            weighted_graphs_self = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            weighted_graphs_sf_self = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            sig_edges_self = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_basetype]
            #
            weighted_graphs_coop = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            weighted_graphs_sf_coop = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            sig_edges_coop = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_comptype]
            #
            weighted_graphs_nov = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]
            weighted_graphs_sf_nov = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]
            sig_edges_nov = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][MI_conttype]

            ## mainly for dannon kanga because they dont have coop3s
            if np.shape(weighted_graphs_coop)[0]==0:
                weighted_graphs_coop = np.ones((95,12,4))*np.nan
                weighted_graphs_sf_coop = np.ones((95,12,4))*np.nan
                sig_edges_coop = np.ones((12,4))*np.nan
            
            # organize the key edge data
            weighted_graphs_self_mean = weighted_graphs_self.mean(axis=0)
            weighted_graphs_coop_mean = weighted_graphs_coop.mean(axis=0)
            weighted_graphs_nov_mean = weighted_graphs_nov.mean(axis=0)
            # MI_coop_self = (weighted_graphs_coop_mean-weighted_graphs_self_mean)/(weighted_graphs_coop_mean+weighted_graphs_self_mean)
            # MI_nov_coop = (weighted_graphs_nov_mean-weighted_graphs_coop_mean)/(weighted_graphs_nov_mean+weighted_graphs_coop_mean)
            # MI_coop_self = ((weighted_graphs_coop-weighted_graphs_self)/(weighted_graphs_coop+weighted_graphs_self)).mean(axis=0)
            # MI_nov_coop = ((weighted_graphs_nov-weighted_graphs_coop)/(weighted_graphs_nov+weighted_graphs_coop)).mean(axis=0)
            #
            if 0:
                MI_coop_self_all = weighted_graphs_coop-weighted_graphs_self
                MI_nov_coop_all = weighted_graphs_nov-weighted_graphs_coop  
                MI_coop_self = (weighted_graphs_coop-weighted_graphs_self).mean(axis=0)
                MI_nov_coop = (weighted_graphs_nov-weighted_graphs_coop).mean(axis=0)
                #
                sig_edges_coop_self = ((sig_edges_coop+sig_edges_self)>0)*1
                sig_edges_nov_coop = ((sig_edges_coop+sig_edges_nov)>0)*1
                #
                MI_coop_self = MI_coop_self * sig_edges_coop_self
                MI_nov_coop = MI_nov_coop * sig_edges_nov_coop
                #
                nMIbootstraps = 1
            else:
                nMIbootstraps = 150
                #
                MI_coop_self_all,sig_edges_coop_self = Modulation_Index(weighted_graphs_self, weighted_graphs_coop,
                                                  sig_edges_self, sig_edges_coop, nMIbootstraps)
                # MI_coop_self_all = MI_coop_self_all * sig_edges_coop_self
                # MI_coop_self_all[MI_coop_self_all==0] = np.nan
                MI_coop_self = np.nanmean(MI_coop_self_all,axis = 0)
                # MI_coop_self = MI_coop_self * sig_edges_coop_self
                
                MI_nov_coop_all,sig_edges_nov_coop  = Modulation_Index(weighted_graphs_coop, weighted_graphs_nov,
                                                  sig_edges_coop, sig_edges_nov, nMIbootstraps)
                # MI_nov_coop_all = MI_nov_coop_all * sig_edges_nov_coop
                # MI_nov_coop_all[MI_nov_coop_all==0] = np.nan
                MI_nov_coop = np.nanmean(MI_nov_coop_all,axis = 0)
                # MI_nov_coop = MI_nov_coop * sig_edges_nov_coop
    
            #
            if timelag == 1:
                pull_pull_fromNodes_all = [9,8]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [11,10]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [8,9]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [9,8]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [10,11]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [11,10]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 2:
                pull_pull_fromNodes_all = [5,4]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [7,6]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [4,5]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [5,4]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [6,7]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [7,6]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 3:
                pull_pull_fromNodes_all = [1,0]
                pull_pull_toNodes_all = [0,1]
                #
                gaze_gaze_fromNodes_all = [3,2]
                gaze_gaze_toNodes_all = [2,3]
                #
                within_pullgaze_fromNodes_all = [0,1]
                within_pullgaze_toNodes_all = [2,3]
                #
                across_pullgaze_fromNodes_all = [1,0]
                across_pullgaze_toNodes_all = [2,3]
                #
                within_gazepull_fromNodes_all = [2,3]
                within_gazepull_toNodes_all = [0,1]
                #
                across_gazepull_fromNodes_all = [3,2]
                across_gazepull_toNodes_all = [0,1]
                #
                ntimelags = 1
                #
            elif timelag == 0:
                pull_pull_fromNodes_all = [[1,5,9],[0,4,8]]
                pull_pull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                gaze_gaze_fromNodes_all = [[3,7,11],[2,6,10]]
                gaze_gaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                within_pullgaze_fromNodes_all = [[0,4,8],[1,5,9]]
                within_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                across_pullgaze_fromNodes_all = [[1,5,9],[0,4,8]]
                across_pullgaze_toNodes_all = [[2,2,2],[3,3,3]]
                #
                within_gazepull_fromNodes_all = [[2,6,10],[3,7,11]]
                within_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                across_gazepull_fromNodes_all = [[3,7,11],[2,6,10]]
                across_gazepull_toNodes_all = [[0,0,0],[1,1,1]]
                #
                ntimelags = 3
                #
            elif timelag == 12:
                pull_pull_fromNodes_all = [[5,9],[4,8]]
                pull_pull_toNodes_all = [[0,0],[1,1]]
                #
                gaze_gaze_fromNodes_all = [[7,11],[6,10]]
                gaze_gaze_toNodes_all = [[2,2],[3,3]]
                #
                within_pullgaze_fromNodes_all = [[4,8],[5,9]]
                within_pullgaze_toNodes_all = [[2,2],[3,3]]
                #
                across_pullgaze_fromNodes_all = [[5,9],[4,8]]
                across_pullgaze_toNodes_all = [[2,2],[3,3]]
                #
                within_gazepull_fromNodes_all = [[6,10],[7,11]]
                within_gazepull_toNodes_all = [[0,0],[1,1]]
                #
                across_gazepull_fromNodes_all = [[7,11],[6,10]]
                across_gazepull_toNodes_all = [[0,0],[1,1]]
                #
                ntimelags = 2
                #
    
    
            for ianimal in np.arange(0,2,1):

                # coop self modulation
                # pull-pull
                a1 = MI_coop_self_all[:,pull_pull_fromNodes_all[ianimal],pull_pull_toNodes_all[ianimal]].flatten()
                xxx1 = np.nanmean(a1)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,0] = a1
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,0] = xxx1
                # gaze-gaze
                a2 = (MI_coop_self_all[:,gaze_gaze_fromNodes_all[ianimal],gaze_gaze_toNodes_all[ianimal]]).flatten()
                xxx2 = np.nanmean(a2)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,1] = a2
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,1] = xxx2
                # within animal gazepull
                a3 = (MI_coop_self_all[:,within_gazepull_fromNodes_all[ianimal],within_gazepull_toNodes_all[ianimal]]).flatten()
                xxx3 = np.nanmean(a3)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,2] = a3
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,2] = xxx3
                # across animal gazepull
                a4 = (MI_coop_self_all[:,across_gazepull_fromNodes_all[ianimal],across_gazepull_toNodes_all[ianimal]]).flatten()
                xxx4 = np.nanmean(a4)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,3] = a4
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,3] = xxx4
                # within animal pullgaze
                a5 = (MI_coop_self_all[:,within_pullgaze_fromNodes_all[ianimal],within_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx5 = np.nanmean(a5)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,4] = a5
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,4] = xxx5
                # across animal pullgaze
                a6 = (MI_coop_self_all[:,across_pullgaze_fromNodes_all[ianimal],across_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx6 = np.nanmean(a6)
                MI_coop_self_all_IndiAni[2*ianimalpair+ianimal,:,5] = a6
                MI_coop_self_mean_IndiAni[2*ianimalpair+ianimal,5] = xxx6


                # novision coop modulation
                # pull-pull
                a1 = MI_nov_coop_all[:,pull_pull_fromNodes_all[ianimal],pull_pull_toNodes_all[ianimal]].flatten()
                xxx1 = np.nanmean(a1)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,0] = a1
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,0] = xxx1
                # gaze-gaze
                a2 = (MI_nov_coop_all[:,gaze_gaze_fromNodes_all[ianimal],gaze_gaze_toNodes_all[ianimal]]).flatten()
                xxx2 = np.nanmean(a2)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,1] = a2
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,1] = xxx2
                # within animal gazepull
                a3 = (MI_nov_coop_all[:,within_gazepull_fromNodes_all[ianimal],within_gazepull_toNodes_all[ianimal]]).flatten()
                xxx3 = np.nanmean(a3)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,2] = a3
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,2] = xxx3
                # across animal gazepull
                a4 = (MI_nov_coop_all[:,across_gazepull_fromNodes_all[ianimal],across_gazepull_toNodes_all[ianimal]]).flatten()
                xxx4 = np.nanmean(a4)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,3] = a4
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,3] = xxx4
                # within animal pullgaze
                a5 = (MI_nov_coop_all[:,within_pullgaze_fromNodes_all[ianimal],within_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx5 = np.nanmean(a5)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,4] = a5
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,4] = xxx5
                # across animal pullgaze
                a6 = (MI_nov_coop_all[:,across_pullgaze_fromNodes_all[ianimal],across_pullgaze_toNodes_all[ianimal]]).flatten()
                xxx6 = np.nanmean(a6)
                MI_nov_coop_all_IndiAni[2*ianimalpair+ianimal,:,5] = a6
                MI_nov_coop_mean_IndiAni[2*ianimalpair+ianimal,5] = xxx6


        # prepare the data
        # average all animals for each dependency
        MI_coop_self_all_IndiAni_all = MI_coop_self_all_IndiAni.reshape(nanimalpooled*nMIbootstraps*ntimelags,6)
        MI_nov_coop_all_IndiAni_all = MI_nov_coop_all_IndiAni.reshape(nanimalpooled*nMIbootstraps*ntimelags,6)
        MI_coop_self_all_IndiAni_allmean = np.nanmean(MI_coop_self_all_IndiAni_all,axis=0)
        MI_nov_coop_all_IndiAni_allmean = np.nanmean(MI_nov_coop_all_IndiAni_all,axis=0) 
        MI_coop_self_all_IndiAni_allse = np.nanstd(MI_coop_self_all_IndiAni_all,axis=0)/np.sqrt(nanimalpooled*nMIbootstraps*ntimelags) 
        MI_nov_coop_all_IndiAni_allse = np.nanstd(MI_nov_coop_all_IndiAni_all,axis=0)/np.sqrt(nanimalpooled*nMIbootstraps*ntimelags) 


        # pool everything together
        MI_coop_self_all_IndiAni_pooled[MI_comptype] = MI_coop_self_all_IndiAni_all
        MI_nov_coop_all_IndiAni_pooled[MI_comptype] = MI_nov_coop_all_IndiAni_all
        MI_coop_self_mean_IndiAni_pooled[MI_comptype] = MI_coop_self_mean_IndiAni
        MI_nov_coop_mean_IndiAni_pooled[MI_comptype] = MI_nov_coop_mean_IndiAni
    

    

    # plot 
    # separating animal group types
    animalgrouptypes = ['female','male','dom','sub'] # male, female, dom, sub
    nanimalgrouptypes = np.shape(animalgrouptypes)[0] 

    for ianimalgrouptype in np.arange(0,nanimalgrouptypes,1):
        
        animalgrouptype = animalgrouptypes[ianimalgrouptype]
        
        if animalgrouptype == 'male':
            anigroupID = [0,2,4,9]
        elif animalgrouptype == 'female':
            anigroupID = [1,3,5,6,7,8]
        elif animalgrouptype == 'sub':
            anigroupID = [0,2,4,6,9]
        elif animalgrouptype == 'dom':
            anigroupID = [1,3,5,7,8]    

        #
        # average for each animal individuals
        # coop (1s)
        MI_coop_self_mean_IndiAni_MC1s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(1s)'][anigroupID])
        MI_coop_self_mean_IndiAni_MC1s_df.columns = dependencynames
        MI_coop_self_mean_IndiAni_MC1s_df['anitype'] = animalgrouptype
        MI_coop_self_mean_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
        MI_coop_self_mean_IndiAni_MC1s_df['CTtype'] = MI_basetype
        MI_coop_self_mean_IndiAni_MC1s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC1s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC1s_df['across_pullgaze'])/2
        #
        MI_nov_coop_mean_IndiAni_MC1s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(1s)'][anigroupID])
        MI_nov_coop_mean_IndiAni_MC1s_df.columns = dependencynames
        MI_nov_coop_mean_IndiAni_MC1s_df['anitype'] = animalgrouptype
        MI_nov_coop_mean_IndiAni_MC1s_df['MItype'] = 'coop(1s)'
        MI_nov_coop_mean_IndiAni_MC1s_df['CTtype'] = MI_conttype
        MI_nov_coop_mean_IndiAni_MC1s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC1s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC1s_df['across_pullgaze'])/2

        if not do_trainedMCs:
            # coop (1.5s)
            MI_coop_self_mean_IndiAni_MC15s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(1.5s)'][anigroupID])
            MI_coop_self_mean_IndiAni_MC15s_df.columns = dependencynames
            MI_coop_self_mean_IndiAni_MC15s_df['anitype'] = animalgrouptype
            MI_coop_self_mean_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'           
            MI_coop_self_mean_IndiAni_MC15s_df['CTtype'] = MI_basetype
            MI_coop_self_mean_IndiAni_MC15s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC15s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC15s_df['across_pullgaze'])/2
            #
            MI_nov_coop_mean_IndiAni_MC15s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(1.5s)'][anigroupID])
            MI_nov_coop_mean_IndiAni_MC15s_df.columns = dependencynames
            MI_nov_coop_mean_IndiAni_MC15s_df['anitype'] = animalgrouptype
            MI_nov_coop_mean_IndiAni_MC15s_df['MItype'] = 'coop(1.5s)'
            MI_nov_coop_mean_IndiAni_MC15s_df['CTtype'] = MI_conttype
            MI_nov_coop_mean_IndiAni_MC15s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC15s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC15s_df['across_pullgaze'])/2

            # coop (2s)
            MI_coop_self_mean_IndiAni_MC2s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(2s)'][anigroupID])
            MI_coop_self_mean_IndiAni_MC2s_df.columns = dependencynames
            MI_coop_self_mean_IndiAni_MC2s_df['anitype'] = animalgrouptype
            MI_coop_self_mean_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
            MI_coop_self_mean_IndiAni_MC2s_df['CTtype'] = MI_basetype
            MI_coop_self_mean_IndiAni_MC2s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC2s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC2s_df['across_pullgaze'])/2
            #
            MI_nov_coop_mean_IndiAni_MC2s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(2s)'][anigroupID])
            MI_nov_coop_mean_IndiAni_MC2s_df.columns = dependencynames
            MI_nov_coop_mean_IndiAni_MC2s_df['anitype'] = animalgrouptype
            MI_nov_coop_mean_IndiAni_MC2s_df['MItype'] = 'coop(2s)'
            MI_nov_coop_mean_IndiAni_MC2s_df['CTtype'] = MI_conttype
            MI_nov_coop_mean_IndiAni_MC2s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC2s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC2s_df['across_pullgaze'])/2

            # coop (3s)
            MI_coop_self_mean_IndiAni_MC3s_df = pd.DataFrame(MI_coop_self_mean_IndiAni_pooled['coop(3s)'][anigroupID])
            MI_coop_self_mean_IndiAni_MC3s_df.columns = dependencynames
            MI_coop_self_mean_IndiAni_MC3s_df['anitype'] = animalgrouptype
            MI_coop_self_mean_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
            MI_coop_self_mean_IndiAni_MC3s_df['CTtype'] = MI_basetype
            MI_coop_self_mean_IndiAni_MC3s_df['pullgaze_merged'] = (MI_coop_self_mean_IndiAni_MC3s_df['within_gazepull']+MI_coop_self_mean_IndiAni_MC3s_df['across_pullgaze'])/2
            #
            MI_nov_coop_mean_IndiAni_MC3s_df = pd.DataFrame(MI_nov_coop_mean_IndiAni_pooled['coop(3s)'][anigroupID])
            MI_nov_coop_mean_IndiAni_MC3s_df.columns = dependencynames
            MI_nov_coop_mean_IndiAni_MC3s_df['anitype'] = animalgrouptype
            MI_nov_coop_mean_IndiAni_MC3s_df['MItype'] = 'coop(3s)'
            MI_nov_coop_mean_IndiAni_MC3s_df['CTtype'] = MI_conttype
            MI_nov_coop_mean_IndiAni_MC3s_df['pullgaze_merged'] = (MI_nov_coop_mean_IndiAni_MC3s_df['within_gazepull']+MI_nov_coop_mean_IndiAni_MC3s_df['across_pullgaze'])/2
            
            if ianimalgrouptype == 0:
                df_long=pd.concat([MI_coop_self_mean_IndiAni_MC3s_df,MI_nov_coop_mean_IndiAni_MC3s_df,
                                   MI_coop_self_mean_IndiAni_MC2s_df,MI_nov_coop_mean_IndiAni_MC2s_df,
                                   MI_coop_self_mean_IndiAni_MC15s_df,MI_nov_coop_mean_IndiAni_MC15s_df,
                                   MI_coop_self_mean_IndiAni_MC1s_df,MI_nov_coop_mean_IndiAni_MC1s_df,
                                  ])
            else:
                df_long = pd.concat([df_long,MI_coop_self_mean_IndiAni_MC3s_df,MI_nov_coop_mean_IndiAni_MC3s_df,
                                   MI_coop_self_mean_IndiAni_MC2s_df,MI_nov_coop_mean_IndiAni_MC2s_df,
                                   MI_coop_self_mean_IndiAni_MC15s_df,MI_nov_coop_mean_IndiAni_MC15s_df,
                                   MI_coop_self_mean_IndiAni_MC1s_df,MI_nov_coop_mean_IndiAni_MC1s_df,
                                    ])
        elif do_trainedMCs:
            if ianimalgrouptype == 0:
                df_long=pd.concat([MI_coop_self_mean_IndiAni_MC1s_df,# MI_nov_coop_mean_IndiAni_MC1s_df,
                                  ])
            else:
                df_long=pd.concat([df_long,MI_coop_self_mean_IndiAni_MC1s_df,# MI_nov_coop_mean_IndiAni_MC1s_df,
                                  ])
        df_long2 = df_long.melt(id_vars=['MItype','CTtype','anitype'], value_vars=dependencytargets,
                            var_name='condition', value_name='value')
        df_long2['ani_ct_type'] = df_long2['anitype']+df_long2['CTtype']
        
    # for plot
    for idep in np.arange(0,ndeptargets,1):
        ind = df_long2.condition==dependencytargets[idep]
        # ind = df_long2_bt.condition==dependencytargets[idep]
        #
        if ntimelags_forplot == 1:
            # seaborn.lineplot(ax=axs[idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            seaborn.violinplot(ax=axs[idep],data=df_long2[ind],x='MItype',y='value',hue='ani_ct_type')
            # seaborn.lineplot(ax=axs[idep],data=df_long2_bt[ind],x='MItype',y='value',hue='CTtype')
            axs[idep].plot([0,3],[0,0],'k--')
            axs[idep].set_ylabel('Modulation Index',fontsize=20)
            axs[idep].set_title(timelagname+' '+dependencytargets[idep],fontsize=24)
            axs[idep].set_ylim([-2.02,2.02])
            #
            # add statistics
            # CTtypes = [MI_basetype,MI_conttype]
            CTtypes = [MI_basetype]
            CTtype_plotlocs = [.75,-.75]
            nCTtypes = np.shape(CTtypes)[0]
            for iMItype in np.arange(0,nMI_comptypes,1):
                MItype_totest = MI_comptypes[iMItype]
                #
                for iCTtype in np.arange(0,nCTtypes,1):
                    CTtype_totest = CTtypes[iCTtype]
                    # 
                    ind_totest = (df_long2['condition']==dependencytargets[idep])&(df_long2['MItype']==MItype_totest)&(df_long2['CTtype']==CTtype_totest)
                    data_totest = np.array(df_long2[ind_totest]['value'])
                    # pp = st.ttest_1samp(data_totest[~np.isnan(data_totest)],0).pvalue
                    pp = st.wilcoxon(data_totest[~np.isnan(data_totest)]).pvalue
                    # 
                    if pp<=0.001:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'***',fontsize=20)
                    elif pp<=0.01:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'**',fontsize=20)
                    elif pp<=0.05:
                        axs[idep].text(iMItype,CTtype_plotlocs[iCTtype],'*',fontsize=20)
                    
                    
        else:
            # seaborn.lineplot(ax=axs[itimelag,idep],data=df_long2[ind],x='MItype',y='value',hue='CTtype')
            seaborn.violinlot(ax=axs[itimelag,idep],data=df_long2[ind],x='MItype',y='value',hue='ani_ct_type')
            # seaborn.lineplot(ax=axs[itimelag,idep],data=df_long2_bt[ind],x='MItype',y='value',hue='CTtype')
            axs[itimelag,idep].plot([0,3],[0,0],'k--')
            axs[itimelag,idep].set_ylabel('Modulation Index',fontsize=20)
            axs[itimelag,idep].set_title(timelagname+' '+dependencytargets[idep],fontsize=24)
            axs[itimelag,idep].set_ylim([-2.02,2.02])
            #
            # add statistics
            CTtypes = [MI_basetype,MI_conttype]
            CTtype_plotlocs = [.75,-.75]
            nCTtypes = np.shape(CTtypes)[0]
            for iMItype in np.arange(0,nMI_comptypes,1):
                MItype_totest = MI_comptypes[iMItype]
                #
                for iCTtype in np.arange(0,nCTtypes,1):
                    CTtype_totest = CTtypes[iCTtype]
                    # 
                    ind_totest = (df_long2['condition']==dependencytargets[idep])&(df_long2['MItype']==MItype_totest)&(df_long2['CTtype']==CTtype_totest)
                    data_totest = np.array(df_long2[ind_totest]['value'])
                    # pp = st.ttest_1samp(data_totest[~np.isnan(data_totest)],0).pvalue
                    pp = st.wilcoxon(data_totest[~np.isnan(data_totest)]).pvalue
                    # 
                    if pp<=0.001:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'***',fontsize=20)
                    elif pp<=0.01:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'**',fontsize=20)
                    elif pp<=0.05:
                        axs[itimelag,idep].text(iMItype,CTtype_plotlocs[iCTtype],'*',fontsize=20)
        
plt.tight_layout()
        
savefig = 1
if savefig:
    if moreSampSize:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)
        plt.savefig(figsavefolder+'threeTimeLag_Edge_ModulationIndex_'+timelagname+'Lag_allanimalgrouptypes_summarized_'+str(temp_resolu)+'_'+str(j_sampsize_name)+'_rows_subset_basedonToNodes_multiTimeLag_multiCoopsOnePanel.pdf')
    else:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_withinLayerEdges_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)
        plt.savefig(figsavefolder+'threeTimeLag_Edge_ModulationIndex_'+timelagname+'Lag_allanimalgrouptypes_summarized_'+str(temp_resolu)+'_'+j_sampsize_name+'_subset_basedonToNodes_multiTimeLag_multiCoopsOnePanel.pdf')
           
    


In [None]:
df_long

In [None]:
print(st.mannwhitneyu(df_long[df_long['anitype']=='sub']['within_gazepull'],df_long[df_long['anitype']=='dom']['within_gazepull']))
print(st.mannwhitneyu(df_long[df_long['anitype']=='male']['within_gazepull'],df_long[df_long['anitype']=='female']['within_gazepull']))

print(st.mannwhitneyu(df_long[df_long['anitype']=='sub']['across_pullgaze'],df_long[df_long['anitype']=='dom']['across_pullgaze']))
print(st.mannwhitneyu(df_long[df_long['anitype']=='male']['across_pullgaze'],df_long[df_long['anitype']=='female']['across_pullgaze']))

print(st.mannwhitneyu(df_long[df_long['anitype']=='sub']['pull-pull'],df_long[df_long['anitype']=='dom']['pull-pull']))
print(st.mannwhitneyu(df_long[df_long['anitype']=='male']['pull-pull'],df_long[df_long['anitype']=='female']['pull-pull']))

