### In this script, DBN has run and this script is used to make predictions
### In this script, DBN is run with 1s time bin, 3 time lag 
### In this script, the animal tracking is done with only one camera - camera 2 (middle) 

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn
import scipy
import scipy.stats as st
import sklearn
from sklearn.neighbors import KernelDensity
import string
import warnings
import pickle

import os
import glob
import random
from time import time

from pgmpy.models import BayesianModel
from pgmpy.models import DynamicBayesianNetwork as DBN
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination
from pgmpy.estimators import BayesianEstimator
from pgmpy.estimators import HillClimbSearch,BicScore
from pgmpy.base import DAG
import networkx as nx

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve

### function - get body part location for each pair of cameras

In [None]:
from ana_functions.body_part_locs_eachpair import body_part_locs_eachpair
from ana_functions.body_part_locs_singlecam import body_part_locs_singlecam

### function - align the two cameras

In [None]:
from ana_functions.camera_align import camera_align       

### function - merge the two pairs of cameras

In [None]:
from ana_functions.camera_merge import camera_merge

### function - find social gaze time point

In [None]:
from ana_functions.find_socialgaze_timepoint import find_socialgaze_timepoint
from ana_functions.find_socialgaze_timepoint_singlecam import find_socialgaze_timepoint_singlecam
from ana_functions.find_socialgaze_timepoint_singlecam_wholebody import find_socialgaze_timepoint_singlecam_wholebody

### function - define time point of behavioral events

In [None]:
from ana_functions.bhv_events_timepoint import bhv_events_timepoint
from ana_functions.bhv_events_timepoint_singlecam import bhv_events_timepoint_singlecam

### function - plot behavioral events

In [None]:
from ana_functions.plot_bhv_events import plot_bhv_events
from ana_functions.plot_bhv_events_levertube import plot_bhv_events_levertube
from ana_functions.draw_self_loop import draw_self_loop
import matplotlib.patches as mpatches 
from matplotlib.collections import PatchCollection

### function - make demo videos with skeleton and inportant vectors

In [None]:
from ana_functions.tracking_video_singlecam_demo import tracking_video_singlecam_demo
from ana_functions.tracking_video_singlecam_wholebody_demo import tracking_video_singlecam_wholebody_demo

### function - interval between all behavioral events

In [None]:
from ana_functions.bhv_events_interval import bhv_events_interval
from ana_functions.bhv_events_interval import bhv_events_interval_certainEdges

### function - train the dynamic bayesian network - multi time lag (3 lags)

In [None]:
from ana_functions.train_DBN_multiLag import train_DBN_multiLag
from ana_functions.train_DBN_multiLag import train_DBN_multiLag_create_df_only
from ana_functions.train_DBN_multiLag import train_DBN_multiLag_training_only
from ana_functions.train_DBN_multiLag import graph_to_matrix
from ana_functions.train_DBN_multiLag import get_weighted_dags
from ana_functions.train_DBN_multiLag import get_significant_edges
from ana_functions.train_DBN_multiLag import threshold_edges
from ana_functions.train_DBN_multiLag import Modulation_Index
from ana_functions.EfficientTimeShuffling import EfficientShuffle
from ana_functions.AicScore import AicScore

## Analyze each session

### prepare the basic behavioral data (especially the time stamps for each bhv events)

In [None]:
# instead of using gaze angle threshold, use the target rectagon to deside gaze info
# ...need to update
sqr_thres_tubelever = 75 # draw the square around tube and lever
sqr_thres_face = 1.15 # a ratio for defining face boundary
sqr_thres_body = 4 # how many times to enlongate the face box boundry to the body


# get the fps of the analyzed video
fps = 30

# frame number of the demo video
# nframes = 0.5*30 # second*30fps
nframes = 2*30 # second*30fps

# re-analyze the video or not
reanalyze_video = 0
redo_anystep = 0

# session list options
do_bestsession = 1 # only analyze the best (five) sessions for each conditions during the training phase
do_trainedMCs = 1 # the list that only consider trained (1s) MC, together with SR and NV as controls
if do_bestsession:
    if not do_trainedMCs:
        savefile_sufix = '_bestsessions'
    elif do_trainedMCs:
        savefile_sufix = '_trainedMCsessions'
else:
    savefile_sufix = ''

# which camera to analyzed
cameraID = 'camera-2'
cameraID_short = 'cam2'

# where to save the summarizing data
data_saved_folder = '/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/3d_recontruction_analysis_self_and_coop_task_data_saved/'


    

### load the DBN related data for each dyad and run the prediction
### For each condition, only use the hypothetical dependencies
### For each dyad align animal1 as the subordinate animal and animal2 as the donimant animal
### for each dyad, for each iternation, train on 80% of data but test on 20% of other dyad

In [None]:
# Suppress all warnings
warnings.simplefilter('ignore')

redoFitting = 0

do_succfull = 0

niters = 100

# PLOT multiple pairs in one plot, so need to load data seperately
moreSampSize = 0
mergetempRos = 0 # 1: merge different time bins

#
animal1_fixedorders = ['eddie','dodson','ginger','dannon','koala']
animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
# animal1_fixedorders = ['eddie',]
# animal2_fixedorders = ['sparkle',]
nanimalpairs = np.shape(animal1_fixedorders)[0]

# donimant animal name; since animal1 and 2 are already aligned, no need to check it; but keep it here for reference
dom_animal_names = ['sparkle','scorch','kanga_withG','kanga_withD','koala']

temp_resolu = 1

# ONLY FOR PLOT!! 
# define DBN related summarizing variables
# DBN_group_typenames = ['self','coop(2s)','coop(1.5s)','coop(1s)','no-vision']
# DBN_group_typeIDs  =  [1,3,3,3,5]
# DBN_group_coopthres = [0,2,1.5,1,0]
DBN_group_typenames = ['coop(2s)']
DBN_group_typeIDs  =  [3,]
DBN_group_coopthres = [2,]
if do_trainedMCs:
    DBN_group_typenames = ['coop(1s)']
    DBN_group_typeIDs  =  [3]
    DBN_group_coopthres = [1]
nDBN_groups = np.shape(DBN_group_typenames)[0]

# DBN input data - to and from Nodes
toNodes = ['pull1_t3','pull2_t3','owgaze1_t3','owgaze2_t3']
fromNodes = ['pull1_t2','pull2_t2','owgaze1_t2','owgaze2_t2']
eventnames = ["M1pull","M2pull","M1gaze","M2gaze"]
nevents = np.shape(eventnames)[0]

timelagtype = '' # '' means 1secondlag, otherwise will be specificed
#
if timelagtype == '2secondlag':
    fromNodes = ['pull1_t1','pull2_t1','owgaze1_t1','owgaze2_t1']
if timelagtype == '3secondlag':
    fromNodes = ['pull1_t0','pull2_t0','owgaze1_t0','owgaze2_t0']

    
# hypothetical graph structure that reflect the strategies
# hypothetical graph structure that reflect the strategies
# strategynames = ['threeMains','sync_pulls','gaze_lead_pull','social_attention','other_dependencies','other_noself_dependcies']
strategynames = ['threeMains',]
# strategynames = ['sync_pulls'] # ['all_threes','sync_pulls','gaze_lead_pull','social_attention']
bina_graphs_specific_strategy = {
    'threeMains': np.array([[0,1,0,1],[1,0,1,0],[1,0,0,0],[0,1,0,0]]),
    'sync_pulls': np.array([[0,1,0,0],[1,0,0,0],[0,0,0,0],[0,0,0,0]]),
    'gaze_lead_pull':np.array([[0,0,0,0],[0,0,0,0],[1,0,0,0],[0,1,0,0]]),
    'social_attention':np.array([[0,0,0,1],[0,0,1,0],[0,0,0,0],[0,0,0,0]]),
    'other_dependencies': np.array([[1,0,1,0],[0,1,0,1],[0,1,1,1],[1,0,1,1]]),
    'other_noself_dependcies': np.array([[0,0,1,0],[0,0,0,1],[0,1,0,1],[1,0,1,0]]),
}
nstrategies_forplot = np.shape(strategynames)[0]


for istrg in np.arange(0,nstrategies_forplot,1):
    
    strategyname = strategynames[istrg]

    #
    bina_graph_mean_strg = bina_graphs_specific_strategy[strategyname]
    
    # translate the binary DAGs to edge
    nrows,ncols = np.shape(bina_graph_mean_strg)
    edgenames = []
    for irow in np.arange(0,nrows,1):
        for icol in np.arange(0,ncols,1):
            if bina_graph_mean_strg[irow,icol] > 0:
                edgenames.append((fromNodes[irow],toNodes[icol]))

    # define the DBN predicting model
    bn = BayesianNetwork()
    bn.add_nodes_from(fromNodes)
    bn.add_nodes_from(toNodes)
    bn.add_edges_from(edgenames)
    
    effect_slice = toNodes
    
    # load ROC_summary_all data
    try:
        if redoFitting:
            dumpy
        
        print('load all ROC data for hypothetical dependencies, and only plot the summary figure')
        
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebodylabels_combinesessions_basicEvents_DBNpredictions_cross_dyad_validation/'+\
                               savefile_sufix+'/'+cameraID+'/'

        with open(data_saved_subfolder+'/ROC_summary_all_dependencies_'+strategyname+timelagtype+'.pkl', 'rb') as f:
            ROC_summary_all = pickle.load(f)
   
    except:  
    
        # initialize a summary dataframe for plotting the summary figure across animals 
        ROC_summary_all = pd.DataFrame(columns=['train_animal','test_animal','action','testCondition','predROC'])

        #
        # session type to analyze
        for igroup in np.arange(0,nDBN_groups,1):
            DBN_group_typename = DBN_group_typenames[igroup]
               
            #    
            # load the dyad for training
            for ianimalpair_train in np.arange(0,nanimalpairs,1):

                #
                # load the DBN input data
                animal1_train = animal1_fixedorders[ianimalpair_train]
                animal2_train = animal2_fixedorders[ianimalpair_train]
                #
                # only for kanga
                if animal2_train == 'kanga':
                    if animal1_train == 'ginger':
                        animal2_train_nooverlap = 'kanga_withG'
                    elif animal1_train == 'dannon':
                        animal2_train_nooverlap = 'kanga_withD'
                    else:
                        animal2_train_nooverlap = animal2_train
                else:
                    animal2_train_nooverlap = animal2_train
                #
                if not do_succfull:
                    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_train+animal2_train+'/'
                    if not mergetempRos:
                        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                            DBN_input_data_alltypes = pickle.load(f)
                    else:
                        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_mergeTempsReSo.pkl', 'rb') as f:
                            DBN_input_data_alltypes = pickle.load(f)
                    #
                    DBN_input_data_train = DBN_input_data_alltypes[DBN_group_typename]
                #    
                elif do_succfull:
                    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_SuccAndFailedPull_newDefinition'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_train+animal2_train+'/'
                    if not mergetempRos:
                        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                            DBN_input_data_alltypes = pickle.load(f)
                    else:
                        with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_mergeTempsReSo.pkl', 'rb') as f:
                            DBN_input_data_alltypes = pickle.load(f)
                    #
                    DBN_input_data_train = DBN_input_data_alltypes['succpull'][DBN_group_typename]
            
                #    
                # load the dyad for testing
                for ianimalpair_test in np.arange(0,nanimalpairs,1):

                    # load the DBN input data
                    animal1_test = animal1_fixedorders[ianimalpair_test]
                    animal2_test = animal2_fixedorders[ianimalpair_test]
                    #
                    # only for kanga
                    if animal2_test == 'kanga':
                        if animal1_test == 'ginger':
                            animal2_test_nooverlap = 'kanga_withG'
                        elif animal1_test == 'dannon':
                            animal2_test_nooverlap = 'kanga_withD'
                        else:
                            animal2_test_nooverlap = animal2_test
                    else:
                        animal2_test_nooverlap = animal2_test
                    #
                    if not do_succfull:
                        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_test+animal2_test+'/'
                        if not mergetempRos:
                            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                                DBN_input_data_alltypes = pickle.load(f)
                        else:
                            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_mergeTempsReSo.pkl', 'rb') as f:
                                DBN_input_data_alltypes = pickle.load(f)
                        #
                        DBN_input_data_test = DBN_input_data_alltypes[DBN_group_typename]
                    #    
                    elif do_succfull:
                        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody_SuccAndFailedPull_newDefinition'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_test+animal2_test+'/'
                        if not mergetempRos:
                            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                                DBN_input_data_alltypes = pickle.load(f)
                        else:
                            with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_mergeTempsReSo.pkl', 'rb') as f:
                                DBN_input_data_alltypes = pickle.load(f)
                        #
                        DBN_input_data_test = DBN_input_data_alltypes['succpull'][DBN_group_typename]
           
        
                    #
                    # run niters iterations for each condition
                    for iiter in np.arange(0,niters,1):


                        # Split data into training and testing sets
                        train_data, _ = train_test_split(DBN_input_data_train, test_size=0.2)
                        _,  test_data = train_test_split(DBN_input_data_test, test_size=0.5)

                        # Perform parameter learning for each time slice
                        bn.fit(train_data, estimator=MaximumLikelihoodEstimator)

                        # Perform inference
                        infer = VariableElimination(bn)

                        # Prediction for each behavioral events
                        # With aligned animals across dyad - animal1:sub, animal2:dom
                        for ievent in np.arange(0,nevents,1):

                            var = effect_slice[ievent]
                            Pbehavior = [] # Initialize log-likelihood

                            for index, row in test_data.iterrows():
                                evidence = {fromNodes[0]: row[fromNodes[0]], 
                                            fromNodes[1]: row[fromNodes[1]], 
                                            fromNodes[2]: row[fromNodes[2]], 
                                            fromNodes[3]: row[fromNodes[3]], }

                                # Query the probability distribution for Pulls given evidence
                                aucPpredBehavior = infer.query(variables=[var], evidence=evidence) 

                                # Extract the probability of outcome = 1
                                prob = aucPpredBehavior.values[1]
                                Pbehavior = np.append(Pbehavior, prob)

                            # Calculate the AUC score
                            trueBeh = test_data[var].values
                            try:
                                auc = roc_auc_score(trueBeh, Pbehavior)
                            except:
                                auc = np.nan
                            print(f"AUC Score: {auc:.4f}")

                            # put data in the summarizing data frame
                            if (ievent == 0) | (ievent == 2): # for animal1
                                ROC_summary_all = ROC_summary_all.append({'train_animal':animal1_train,
                                                                          'test_animal':animal1_test,
                                                                          'train_dyadID':ianimalpair_train,
                                                                          'test_dyadID':ianimalpair_test,
                                                                          'action':eventnames[ievent][2:],
                                                                          'testCondition':DBN_group_typename,
                                                                          'predROC':auc,
                                                                          'iters': iiter,
                                                                         }, ignore_index=True)
                            else:
                                ROC_summary_all = ROC_summary_all.append({'train_animal':animal2_train_nooverlap,
                                                                          'test_animal':animal2_test_nooverlap,
                                                                          'train_dyadID':ianimalpair_train,
                                                                          'test_dyadID':ianimalpair_test,
                                                                          'action':eventnames[ievent][2:],
                                                                          'testCondition':DBN_group_typename,
                                                                          'predROC':auc,
                                                                          'iters': iiter,
                                                                     }, ignore_index=True)
                                
                        #         
                        # Prediction for each behavioral events with swapped animal1 and animal2
                        # With training set and testing set has swapped animal type:
                        # training set: animal1 - sub; animal2 - dom
                        # testing set: animal1 - dom; animal2 - sub
                        test_data_swap = test_data.copy()
                        # Create a column mapping
                        new_columns = {}
                        for col in test_data_swap.columns:
                            if 'pull1' in col:
                                new_columns[col] = col.replace('pull1', 'pull2')
                            elif 'pull2' in col:
                                new_columns[col] = col.replace('pull2', 'pull1')
                            elif 'owgaze1' in col:
                                new_columns[col] = col.replace('owgaze1', 'owgaze2')
                            elif 'owgaze2' in col:
                                new_columns[col] = col.replace('owgaze2', 'owgaze1')
                        # Rename the columns using the mapping
                        test_data_swap = test_data_swap.rename(columns=new_columns)
                        
                        for ievent in np.arange(0,nevents,1):

                            var = effect_slice[ievent]
                            Pbehavior = [] # Initialize log-likelihood

                            for index, row in test_data_swap.iterrows():
                                evidence = {fromNodes[0]: row[fromNodes[0]], 
                                            fromNodes[1]: row[fromNodes[1]], 
                                            fromNodes[2]: row[fromNodes[2]], 
                                            fromNodes[3]: row[fromNodes[3]], }

                                # Query the probability distribution for Pulls given evidence
                                aucPpredBehavior = infer.query(variables=[var], evidence=evidence) 

                                # Extract the probability of outcome = 1
                                prob = aucPpredBehavior.values[1]
                                Pbehavior = np.append(Pbehavior, prob)

                            # Calculate the AUC score
                            trueBeh = test_data_swap[var].values
                            try:
                                auc = roc_auc_score(trueBeh, Pbehavior)
                            except:
                                auc = np.nan
                            print(f"AUC Score: {auc:.4f}")

                            # put data in the summarizing data frame
                            if (ievent == 0) | (ievent == 2): # for animal1
                                ROC_summary_all = ROC_summary_all.append({'train_animal':animal1_train,
                                                                          'test_animal':animal2_test_nooverlap,
                                                                          'train_dyadID':ianimalpair_train,
                                                                          'test_dyadID':ianimalpair_test,
                                                                          'action':eventnames[ievent][2:],
                                                                          'testCondition':DBN_group_typename,
                                                                          'predROC':auc,
                                                                          'iters': iiter,
                                                                         }, ignore_index=True)
                            else:
                                ROC_summary_all = ROC_summary_all.append({'train_animal':animal2_train_nooverlap,
                                                                          'test_animal':animal1_test,
                                                                          'train_dyadID':ianimalpair_train,
                                                                          'test_dyadID':ianimalpair_test,
                                                                          'action':eventnames[ievent][2:],
                                                                          'testCondition':DBN_group_typename,
                                                                          'predROC':auc,
                                                                          'iters': iiter,
                                                                     }, ignore_index=True)

                            
                    
        
        
        # save the summarizing data ROC_summary_all
        savedata = 0
        if savedata:
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebodylabels_combinesessions_basicEvents_DBNpredictions_cross_dyad_validation/'+\
                                   savefile_sufix+'/'+cameraID+'/'
            if not os.path.exists(data_saved_subfolder):
                os.makedirs(data_saved_subfolder)

            with open(data_saved_subfolder+'/ROC_summary_all_dependencies_'+strategyname+timelagtype+'.pkl', 'wb') as f:
                pickle.dump(ROC_summary_all, f)

            
            
    

In [None]:
if 0:
    ind = (ROC_summary_all['action']=='pull') & (ROC_summary_all['testCondition']=='coop(1s)')
    ROC_summary_all_tgt = ROC_summary_all[ind]

    print(ROC_summary_all_tgt.keys())

    import seaborn as sns

    # Pivot the DataFrame to have train_animal as rows and test_animal as columns
    heatmap_data = ROC_summary_all_tgt.pivot(index='train_animal', columns='test_animal', values='predROC')

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_data, annot=True, cmap='viridis', vmin=0.5, vmax=1.0)
    plt.title('ROC AUC Heatmap')
    plt.xlabel('Test Animal')
    plt.ylabel('Train Animal')
    plt.tight_layout()
    plt.show()

In [None]:
# further analysis
###
# step 1: add dom/sub information
###

# Function to assign rank
def get_rank(name):
    return 'dom' if name in dom_animal_names else 'sub'

# Apply to create new columns
ROC_summary_all['train_rank'] = ROC_summary_all['train_animal'].apply(get_rank)
ROC_summary_all['test_rank'] = ROC_summary_all['test_animal'].apply(get_rank)

###
# Step 2: Group by unique combination (excluding iteration) and get mean/std of predROC
###
# Group and compute mean and std
ROC_summary_grouped = (
    ROC_summary_all
    .groupby(['train_animal', 'test_animal', 'action', 'testCondition'])
    .agg(predROC_mean=('predROC', 'mean'),
         predROC_std=('predROC', 'std'))
    .reset_index()
)

# Drop duplicates to preserve unique train/test animal rank combinations
rank_info = ROC_summary_all[['train_animal', 'test_animal', 'train_rank', 'test_rank']].drop_duplicates()

# Merge into grouped summary
ROC_summary_grouped = ROC_summary_grouped.merge(rank_info, on=['train_animal', 'test_animal'], how='left')

###
# add statistic column
###
from scipy.stats import ttest_1samp
from statsmodels.stats.multitest import multipletests

# Step 1: collect all t-tests
p_values = []
group_keys = []
for name, group in ROC_summary_all.groupby(['train_animal', 'test_animal', 'action', 'testCondition']):
    t_stat, p = ttest_1samp(group['predROC'], 0.5)
    p_values.append(p)
    group_keys.append(name)

# Step 2: correct p-values
rejected, pvals_corrected, _, _ = multipletests(p_values, method='fdr_bh')  # or method='bonferroni'

# Step 3: build a DataFrame
significance_df = pd.DataFrame(group_keys, columns=['train_animal', 'test_animal', 'action', 'testCondition'])
significance_df['p_value'] = p_values
significance_df['p_value_corrected'] = pvals_corrected
significance_df['significant_vs_0.5'] = rejected

# Step 4: merge into ROC_summary_grouped
ROC_summary_grouped = ROC_summary_grouped.merge(significance_df, 
                        on=['train_animal', 'test_animal', 'action', 'testCondition'], how='left')



In [None]:
###
# Step 3: plot - heatmap plot
###
if 1:
    # Combine train and test animals with their dyad and rank info
    # We'll use the original ROC_summary_all (before grouping) or ROC_summary_all_tgt if it has this info
    # Here I assume 'train_animal', 'train_dyadID', 'train_rank' are available

    # Extract unique animals with their dyad and rank info
    train_info = ROC_summary_all[['train_animal', 'train_dyadID', 'train_rank']].drop_duplicates()
    test_info = ROC_summary_all[['test_animal', 'test_dyadID', 'test_rank']].drop_duplicates()

    # Rename columns to unify
    train_info = train_info.rename(columns={'train_animal':'animal', 'train_dyadID':'dyadID', 'train_rank':'rank'})
    test_info = test_info.rename(columns={'test_animal':'animal', 'test_dyadID':'dyadID', 'test_rank':'rank'})

    # Combine and drop duplicates (some animals appear in both)
    animal_info = pd.concat([train_info, test_info]).drop_duplicates().reset_index(drop=True)

    # Sort by dyadID ascending, then rank (put 'dom' before 'sub')
    animal_info['rank_order'] = animal_info['rank'].map({'dom': 0, 'sub': 1})
    animal_info = animal_info.sort_values(by=['dyadID', 'rank_order'])

    # Create the ordered list
    ordered_animals = animal_info['animal'].tolist()

    # choose the target entries
    ind = (ROC_summary_grouped['action']=='pull') & (ROC_summary_grouped['testCondition']=='coop(1s)')
    ROC_summary_all_tgt = ROC_summary_grouped[ind]

    print(ROC_summary_all_tgt.keys())

    import seaborn as sns

    # Pivot the DataFrame to have train_animal as rows and test_animal as columns
    heatmap_data = ROC_summary_all_tgt.pivot(index='train_animal', columns='test_animal', values='predROC_mean')

    heatmap_data = heatmap_data.reindex(index=ordered_animals, columns=ordered_animals)

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_data, annot=True, cmap='viridis', vmin=0.5, vmax=1.0)
    plt.title('ROC AUC Heatmap')
    plt.xlabel('Test Animal')
    plt.ylabel('Train Animal')
    plt.tight_layout()
    plt.show()

In [None]:
###
# Step 3: plot - violin/bar plot
###
if 1:
    import seaborn as sns
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from statsmodels.stats.multicomp import pairwise_tukeyhsd

    # --- Step 1: Define group column ---
    ROC_summary_all_tgt['train_test_rank'] = (
        'train_' + ROC_summary_all_tgt['train_rank'] + '_test_' + ROC_summary_all_tgt['test_rank']
    )

    group_order = ['train_dom_test_dom', 'train_dom_test_sub', 'train_sub_test_dom', 'train_sub_test_sub']

    # --- Step 2: Run Tukey HSD post hoc test ---
    tukey = pairwise_tukeyhsd(
        endog=ROC_summary_all_tgt['predROC_mean'],
        groups=ROC_summary_all_tgt['train_test_rank'],
        alpha=0.05
    )

    # Convert to DataFrame
    tukey_results = pd.DataFrame(data=tukey.summary().data[1:], columns=tukey.summary().data[0])

    # Keep only significant results (p < 0.05)
    sig_results = tukey_results[tukey_results['p-adj'].astype(float) < 0.05]

    # --- Step 3: Create violin plot ---
    plt.figure(figsize=(8, 6))
    sns.violinplot(data=ROC_summary_all_tgt, x='train_test_rank', y='predROC_mean',
                   order=group_order, palette='muted')

    sns.stripplot(data=ROC_summary_all_tgt, x='train_test_rank', y='predROC_mean',
                  order=group_order, color='black', size=2, alpha=0.5, jitter=True)

    plt.axhline(0.5, color='gray', linestyle='--', linewidth=1)
    plt.ylabel('Mean ROC AUC')
    plt.xlabel('Train-Test Rank Group')
    plt.title('Distribution of ROC AUC by Train/Test Rank Combination')

    # --- Step 4: Annotate significant pairs ---
    group_pos = {name: i for i, name in enumerate(group_order)}
    ymax = ROC_summary_all_tgt['predROC_mean'].max()
    line_offset = 0.01
    line_height = 0.02
    current_offset = 0

    for _, row in sig_results.iterrows():
        g1, g2 = row['group1'], row['group2']
        p_val = float(row['p-adj'])

        if g1 not in group_pos or g2 not in group_pos:
            continue

        x1, x2 = group_pos[g1], group_pos[g2]
        y = ymax + line_offset + current_offset
        plt.plot([x1, x1, x2, x2], [y, y + line_height, y + line_height, y], color='black')
        plt.text((x1 + x2) / 2, y + line_height + 0.005, f'p = {p_val:.3f}',
                 ha='center', va='bottom', fontsize=11)
        current_offset += 0.04  # Stack annotations

    plt.tight_layout()
    plt.show()

    # === Plot 2: 2-group violin with t-test ===
    def classify_same_vs_cross(row):
        return 'same_rank' if row['train_rank'] == row['test_rank'] else 'cross_rank'

    ROC_summary_all_tgt['rank_pair_type'] = ROC_summary_all_tgt.apply(classify_same_vs_cross, axis=1)

    plt.figure(figsize=(6, 6))
    sns.violinplot(data=ROC_summary_all_tgt, x='rank_pair_type', y='predROC_mean', palette='pastel')
    sns.stripplot(data=ROC_summary_all_tgt, x='rank_pair_type', y='predROC_mean',
                  color='black', size=2, alpha=0.5, jitter=True)

    # Stats: t-test
    same_vals = ROC_summary_all_tgt[ROC_summary_all_tgt['rank_pair_type'] == 'same_rank']['predROC_mean']
    cross_vals = ROC_summary_all_tgt[ROC_summary_all_tgt['rank_pair_type'] == 'cross_rank']['predROC_mean']
    t_stat, p_ttest = ttest_ind(same_vals, cross_vals)

    # Annotate t-test result
    if p_ttest < 0.5:
        stars = '***' if p_ttest < 0.001 else '**' if p_ttest < 0.01 else '*'
        ymax = ROC_summary_all_tgt['predROC_mean'].max()
        plt.text(0.5, ymax + 0.01, f'p = {p_ttest:.3e} {stars}', ha='center', fontsize=12)

    plt.axhline(0.5, color='gray', linestyle='--', linewidth=1)
    plt.ylabel('Mean ROC AUC')
    plt.xlabel('Train-Test Rank Type')
    plt.title('ROC AUC: Same vs Cross Rank')
    plt.tight_layout()
    plt.show()

    

In [None]:
np.sum(ROC_summary_all_tgt['test_rank']=='sub')

In [None]:
np.sum(ROC_summary_all_tgt['train_test_rank']=='train_sub_test_sub')

In [None]:
np.sum(ROC_summary_grouped['significant_vs_0.5'])

### load the DBN related data for each dyad and run the prediction
### training set and testing set are from the same conditions
### use the DBN learned structure, only the 0 1 DAG, do not consider the weights
### do not consider the self dependencies (dependencies to variables themselves, no diagonal dependencies)

In [None]:
redoFitting = 0

niters = 100

# PLOT multiple pairs in one plot, so need to load data seperately
moreSampSize = 0
mergetempRos = 0 # 1: merge different time bins

#
animal1_fixedorders = ['eddie','dodson','ginger','dannon','koala']
animal2_fixedorders = ['sparkle','scorch','kanga','kanga','vermelho']
# animal1_fixedorders = ['eddie',]
# animal2_fixedorders = ['sparkle',]
nanimalpairs = np.shape(animal1_fixedorders)[0]

# donimant animal name; since animal1 and 2 are already aligned, no need to check it; but keep it here for reference
# dom_animal_names = ['sparkle','scorch','kanga_withG','kanga_withD','vermelho']
dom_animal_names = ['sparkle','scorch','kanga_withG','kanga_withD','koala']

temp_resolu = 1

# ONLY FOR PLOT!! 
# define DBN related summarizing variables
DBN_group_typenames = ['self','coop(2s)','coop(1.5s)','coop(1s)','no-vision']
DBN_group_typeIDs  =  [1,3,3,3,5]
DBN_group_coopthres = [0,2,1.5,1,0]
if do_trainedMCs:
    DBN_group_typenames = ['coop(1s)']
    DBN_group_typeIDs  =  [3]
    DBN_group_coopthres = [1]
nDBN_groups = np.shape(DBN_group_typenames)[0]

# DBN model
toNodes = ['pull1_t3','pull2_t3','owgaze1_t3','owgaze2_t3']
fromNodes = ['pull1_t2','pull2_t2','owgaze1_t2','owgaze2_t2']
eventnames = ["M1pull","M2pull","M1gaze","M2gaze"]
nevents = np.shape(eventnames)[0]

timelagtype = 'allthreelags'
time_lags = ['t_-3','t_-2','t_-1']
fromRowIDs =[[0,1,2,3], [4,5,6,7], [8,9,10,11]]
#
# timelagtype = '1and2secondlag'
# time_lags = ['t_-2','t_-1']
# fromRowIDs =[[4,5,6,7], [8,9,10,11]]
#
# timelagtype = '1secondlag'
# time_lags = ['t_-1']
# fromRowIDs =[[8,9,10,11]]
#
# timelagtype = '2secondlag'
# time_lags = ['t_-2']
# fromRowIDs =[[4,5,6,7]]
#
# timelagtype = '3secondlag'
# time_lags = ['t_-3']
# fromRowIDs =[[0,1,2,3]]
#
nlags = np.shape(fromRowIDs)[0]
#
if timelagtype == '2secondlag':
    fromNodes = ['pull1_t1','pull2_t1','owgaze1_t1','owgaze2_t1']
if timelagtype == '3secondlag':
    fromNodes = ['pull1_t0','pull2_t0','owgaze1_t0','owgaze2_t0']


# load ROC_summary_all data
try:
    if redoFitting:
        dumpy
    print('load all ROC data for within task condition (only binary dependencies without self dependencies), and only plot the summary figure')
        
    data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebodylabels_combinesessions_basicEvents_DBNpredictions_cross_dyad_validation/'+\
                               savefile_sufix+'/'+cameraID+'/'
    
    with open(data_saved_subfolder+'/ROC_summary_all_dependencies_DBNdependenciesAfterMI_binary_noself_'+timelagtype+'.pkl', 'rb') as f:
        ROC_summary_all = pickle.load(f)

except:

    # initialize a summary dataframe for plotting the summary figure across animals 
    ROC_summary_all = pd.DataFrame(columns=['train_animal','test_animal','action','testCondition','predROC'])
    
    for igroup in np.arange(0,nDBN_groups,1):
        DBN_group_typename = DBN_group_typenames[igroup]

        
        #    
        # load the dyad for training
        for ianimalpair_train in np.arange(0,nanimalpairs,1):

            #
            # load the DBN input data
            animal1_train = animal1_fixedorders[ianimalpair_train]
            animal2_train = animal2_fixedorders[ianimalpair_train]
            #
            # only for kanga
            if animal2_train == 'kanga':
                if animal1_train == 'ginger':
                    animal2_train_nooverlap = 'kanga_withG'
                elif animal1_train == 'dannon':
                    animal2_train_nooverlap = 'kanga_withD'
                else:
                    animal2_train_nooverlap = animal2_train
            else:
                animal2_train_nooverlap = animal2_train
            #
            data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_train+animal2_train+'/'
            if not mergetempRos:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_train+animal2_train+'_mergeTempsReSo.pkl', 'rb') as f:
                    DBN_input_data_alltypes = pickle.load(f)

            # load the DBN training outcome
            if moreSampSize:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'_moreSampSize.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'_moreSampSize.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)
            else:
                with open(data_saved_subfolder+'/weighted_graphs_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'.pkl', 'rb') as f:
                    weighted_graphs_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/weighted_graphs_shuffled_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'.pkl', 'rb') as f:
                    weighted_graphs_shuffled_diffTempRo_diffSampSize = pickle.load(f)
                with open(data_saved_subfolder+'/sig_edges_diffTempRo_diffSampSize_'+animal1_train+animal2_train+'.pkl', 'rb') as f:
                    sig_edges_diffTempRo_diffSampSize = pickle.load(f)

            # make sure these variables are the same as in the previous steps
            # temp_resolus = [0.5,1,1.5,2] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            temp_resolus = [1] # temporal resolution in the DBN model, eg: 0.5 means 500ms
            ntemp_reses = np.shape(temp_resolus)[0]
            #
            if moreSampSize:
                # different data (down/re)sampling numbers
                # samplingsizes = np.arange(1100,3000,100)
                samplingsizes = [1100]
                # samplingsizes = [100,500,1000,1500,2000,2500,3000]        
                # samplingsizes = [100,500]
                # samplingsizes_name = ['100','500','1000','1500','2000','2500','3000']
                samplingsizes_name = list(map(str, samplingsizes))
            else:
                samplingsizes_name = ['min_row_number']   
            nsamplings = np.shape(samplingsizes_name)[0]

            #
            temp_resolu = temp_resolus[0]
            j_sampsize_name = samplingsizes_name[0]    

            #
            DBN_input_data_train = DBN_input_data_alltypes[DBN_group_typename]
            
            weighted_graphs_tgt = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][DBN_group_typename]
            weighted_graphs_shuffled_tgt = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][DBN_group_typename]
            # sig_edges_tgt = sig_edges_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)][DBN_group_typename]
            sig_edges_tgt = get_significant_edges(weighted_graphs_tgt,weighted_graphs_shuffled_tgt)
            
            # self reward as the baseline to compare with
            weighted_graphs_self = weighted_graphs_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)]['self']
            weighted_graphs_shuffled_self = weighted_graphs_shuffled_diffTempRo_diffSampSize[(str(temp_resolu),j_sampsize_name)]['self']
            sig_edges_self = get_significant_edges(weighted_graphs_self,weighted_graphs_shuffled_self)

            # calculate the modulation index
            MI_coop_self_all,sig_edges_coop_self = Modulation_Index(weighted_graphs_self, weighted_graphs_tgt,
                                          sig_edges_self, sig_edges_tgt, 150)
            # only consider the edges that has significant MI and enhanced
            nfromNodes = np.shape(MI_coop_self_all)[1]
            ntoNodes = np.shape(MI_coop_self_all)[2]
            sig_edges_MI = np.zeros((np.shape(sig_edges_coop_self)))
            #
            for ifromNode in np.arange(0,nfromNodes,1):
                for itoNode in np.arange(0,ntoNodes,1):
                    _,pp = st.ttest_1samp(MI_coop_self_all[:,ifromNode,itoNode],0)
                    
                    if (pp<0.01) & (np.nanmean(MI_coop_self_all[:,ifromNode,itoNode])>0):
                        sig_edges_MI[ifromNode,itoNode] = 1
                        
            bina_graphs_mean_tgt = sig_edges_MI*sig_edges_coop_self
            
            #
            # consider the time lags
            if nlags == 1:
                bina_graphs_mean_tgt = bina_graphs_mean_tgt[fromRowIDs[0],:] 
            elif nlags == 2:
                bina_graphs_mean_tgt = bina_graphs_mean_tgt[fromRowIDs[0],:]+bina_graphs_mean_tgt[fromRowIDs[1],:]
            elif nlags == 3:
                bina_graphs_mean_tgt = bina_graphs_mean_tgt[fromRowIDs[0],:]+bina_graphs_mean_tgt[fromRowIDs[1],:]+bina_graphs_mean_tgt[fromRowIDs[2],:]

            #
            # translate the binary DAGs to edge
            nrows,ncols = np.shape(bina_graphs_mean_tgt)
            edgenames = []
            for irow in np.arange(0,nrows,1):
                for icol in np.arange(0,ncols,1):
                    
                    # remove the self dependencies
                    if irow == icol:
                        bina_graphs_mean_tgt[irow,icol] = 0
                    
                    if bina_graphs_mean_tgt[irow,icol] > 0:
                        edgenames.append((fromNodes[irow],toNodes[icol]))

            # define the DBN predicting model
            bn = BayesianNetwork()
            bn.add_nodes_from(fromNodes)
            bn.add_nodes_from(toNodes)
            bn.add_edges_from(edgenames)

            effect_slice = toNodes        
            
            #    
            # load the dyad for testing
            for ianimalpair_test in np.arange(0,nanimalpairs,1):

                # load the DBN input data
                animal1_test = animal1_fixedorders[ianimalpair_test]
                animal2_test = animal2_fixedorders[ianimalpair_test]
                #
                # only for kanga
                if animal2_test == 'kanga':
                    if animal1_test == 'ginger':
                        animal2_test_nooverlap = 'kanga_withG'
                    elif animal1_test == 'dannon':
                        animal2_test_nooverlap = 'kanga_withD'
                    else:
                        animal2_test_nooverlap = animal2_test
                else:
                    animal2_test_nooverlap = animal2_test
                #
                data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebody'+savefile_sufix+'_3lags/'+cameraID+'/'+animal1_test+animal2_test+'/'
                if not mergetempRos:
                    with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_'+str(temp_resolu)+'sReSo.pkl', 'rb') as f:
                        DBN_input_data_alltypes = pickle.load(f)
                else:
                    with open(data_saved_subfolder+'/DBN_input_data_alltypes_'+animal1_test+animal2_test+'_mergeTempsReSo.pkl', 'rb') as f:
                        DBN_input_data_alltypes = pickle.load(f)
                #
                DBN_input_data_test = DBN_input_data_alltypes[DBN_group_typename]
               
            
                #
                # run niters iterations for each condition
                for iiter in np.arange(0,niters,1):


                    # Split data into training and testing sets
                    train_data, _ = train_test_split(DBN_input_data_train, test_size=0.2)
                    _,  test_data = train_test_split(DBN_input_data_test, test_size=0.5)

                    # Perform parameter learning for each time slice
                    bn.fit(train_data, estimator=MaximumLikelihoodEstimator)

                    # Perform inference
                    infer = VariableElimination(bn)

                    # Prediction for each behavioral events
                    # With aligned animals across dyad - animal1:sub, animal2:dom
                    for ievent in np.arange(0,nevents,1):

                        var = effect_slice[ievent]
                        Pbehavior = [] # Initialize log-likelihood

                        for index, row in test_data.iterrows():
                            evidence = {fromNodes[0]: row[fromNodes[0]], 
                                        fromNodes[1]: row[fromNodes[1]], 
                                        fromNodes[2]: row[fromNodes[2]], 
                                        fromNodes[3]: row[fromNodes[3]], }

                            # Query the probability distribution for Pulls given evidence
                            aucPpredBehavior = infer.query(variables=[var], evidence=evidence) 

                            # Extract the probability of outcome = 1
                            prob = aucPpredBehavior.values[1]
                            Pbehavior = np.append(Pbehavior, prob)

                        # Calculate the AUC score
                        trueBeh = test_data[var].values
                        try:
                            auc = roc_auc_score(trueBeh, Pbehavior)
                        except:
                            auc = np.nan
                        print(f"AUC Score: {auc:.4f}")

                        # put data in the summarizing data frame
                        if (ievent == 0) | (ievent == 2): # for animal1
                            ROC_summary_all = ROC_summary_all.append({'train_animal':animal1_train,
                                                                      'test_animal':animal1_test,
                                                                      'train_dyadID':ianimalpair_train,
                                                                      'test_dyadID':ianimalpair_test,
                                                                      'action':eventnames[ievent][2:],
                                                                      'testCondition':DBN_group_typename,
                                                                      'predROC':auc,
                                                                      'iters': iiter,
                                                                     }, ignore_index=True)
                        else:
                            ROC_summary_all = ROC_summary_all.append({'train_animal':animal2_train_nooverlap,
                                                                      'test_animal':animal2_test_nooverlap,
                                                                      'train_dyadID':ianimalpair_train,
                                                                      'test_dyadID':ianimalpair_test,
                                                                      'action':eventnames[ievent][2:],
                                                                      'testCondition':DBN_group_typename,
                                                                      'predROC':auc,
                                                                      'iters': iiter,
                                                                 }, ignore_index=True)

                    #         
                    # Prediction for each behavioral events with swapped animal1 and animal2
                    # With training set and testing set has swapped animal type:
                    # training set: animal1 - sub; animal2 - dom
                    # testing set: animal1 - dom; animal2 - sub
                    test_data_swap = test_data.copy()
                    # Create a column mapping
                    new_columns = {}
                    for col in test_data_swap.columns:
                        if 'pull1' in col:
                            new_columns[col] = col.replace('pull1', 'pull2')
                        elif 'pull2' in col:
                            new_columns[col] = col.replace('pull2', 'pull1')
                        elif 'owgaze1' in col:
                            new_columns[col] = col.replace('owgaze1', 'owgaze2')
                        elif 'owgaze2' in col:
                            new_columns[col] = col.replace('owgaze2', 'owgaze1')
                    # Rename the columns using the mapping
                    test_data_swap = test_data_swap.rename(columns=new_columns)

                    for ievent in np.arange(0,nevents,1):

                        var = effect_slice[ievent]
                        Pbehavior = [] # Initialize log-likelihood

                        for index, row in test_data_swap.iterrows():
                            evidence = {fromNodes[0]: row[fromNodes[0]], 
                                        fromNodes[1]: row[fromNodes[1]], 
                                        fromNodes[2]: row[fromNodes[2]], 
                                        fromNodes[3]: row[fromNodes[3]], }

                            # Query the probability distribution for Pulls given evidence
                            aucPpredBehavior = infer.query(variables=[var], evidence=evidence) 

                            # Extract the probability of outcome = 1
                            prob = aucPpredBehavior.values[1]
                            Pbehavior = np.append(Pbehavior, prob)

                        # Calculate the AUC score
                        trueBeh = test_data_swap[var].values
                        try:
                            auc = roc_auc_score(trueBeh, Pbehavior)
                        except:
                            auc = np.nan
                        print(f"AUC Score: {auc:.4f}")

                        # put data in the summarizing data frame
                        if (ievent == 0) | (ievent == 2): # for animal1
                            ROC_summary_all = ROC_summary_all.append({'train_animal':animal1_train,
                                                                      'test_animal':animal2_test_nooverlap,
                                                                      'train_dyadID':ianimalpair_train,
                                                                      'test_dyadID':ianimalpair_test,
                                                                      'action':eventnames[ievent][2:],
                                                                      'testCondition':DBN_group_typename,
                                                                      'predROC':auc,
                                                                      'iters': iiter,
                                                                     }, ignore_index=True)
                        else:
                            ROC_summary_all = ROC_summary_all.append({'train_animal':animal2_train_nooverlap,
                                                                      'test_animal':animal1_test,
                                                                      'train_dyadID':ianimalpair_train,
                                                                      'test_dyadID':ianimalpair_test,
                                                                      'action':eventnames[ievent][2:],
                                                                      'testCondition':DBN_group_typename,
                                                                      'predROC':auc,
                                                                      'iters': iiter,
                                                                 }, ignore_index=True)

                

   
    # save the summarizing data ROC_summary_all
    savedata = 0
    if savedata:
        data_saved_subfolder = data_saved_folder+'data_saved_singlecam_wholebodylabels_combinesessions_basicEvents_DBNpredictions_cross_dyad_validation/'+\
                                       savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(data_saved_subfolder):
            os.makedirs(data_saved_subfolder)

        with open(data_saved_subfolder+'/ROC_summary_all_dependencies_DBNdependenciesAfterMI_binary_noself_'+timelagtype+'.pkl', 'wb') as f:
            pickle.dump(ROC_summary_all, f)
        



In [None]:
if 0:
    ind = (ROC_summary_all['action']=='pull') & (ROC_summary_all['testCondition']=='coop(1s)')
    ROC_summary_all_tgt = ROC_summary_all[ind]

    print(ROC_summary_all_tgt.keys())

    import seaborn as sns

    # Pivot the DataFrame to have train_animal as rows and test_animal as columns
    heatmap_data = ROC_summary_all_tgt.pivot(index='train_animal', columns='test_animal', values='predROC')

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_data, annot=True, cmap='viridis', vmin=0.5, vmax=1.0)
    plt.title('ROC AUC Heatmap')
    plt.xlabel('Test Animal')
    plt.ylabel('Train Animal')
    plt.tight_layout()
    plt.show()

In [None]:
# further analysis
###
# step 1: add dom/sub information
###

# Function to assign rank
def get_rank(name):
    return 'dom' if name in dom_animal_names else 'sub'

# Apply to create new columns
ROC_summary_all['train_rank'] = ROC_summary_all['train_animal'].apply(get_rank)
ROC_summary_all['test_rank'] = ROC_summary_all['test_animal'].apply(get_rank)

###
# Step 2: Group by unique combination (excluding iteration) and get mean/std of predROC
###
# Group and compute mean and std
ROC_summary_grouped = (
    ROC_summary_all
    .groupby(['train_animal', 'test_animal', 'action', 'testCondition'])
    .agg(predROC_mean=('predROC', 'mean'),
         predROC_std=('predROC', 'std'))
    .reset_index()
)

# Drop duplicates to preserve unique train/test animal rank combinations
rank_info = ROC_summary_all[['train_animal', 'test_animal', 'train_rank', 'test_rank']].drop_duplicates()

# Merge into grouped summary
ROC_summary_grouped = ROC_summary_grouped.merge(rank_info, on=['train_animal', 'test_animal'], how='left')

###
# add statistic column
###
from scipy.stats import ttest_1samp
from statsmodels.stats.multitest import multipletests

# Step 1: collect all t-tests
p_values = []
group_keys = []
for name, group in ROC_summary_all.groupby(['train_animal', 'test_animal', 'action', 'testCondition']):
    t_stat, p = ttest_1samp(group['predROC'], 0.5)
    p_values.append(p)
    group_keys.append(name)

# Step 2: correct p-values
rejected, pvals_corrected, _, _ = multipletests(p_values, method='fdr_bh')  # or method='bonferroni'

# Step 3: build a DataFrame
significance_df = pd.DataFrame(group_keys, columns=['train_animal', 'test_animal', 'action', 'testCondition'])
significance_df['p_value'] = p_values
significance_df['p_value_corrected'] = pvals_corrected
significance_df['significant_vs_0.5'] = rejected

# Step 4: merge into ROC_summary_grouped
ROC_summary_grouped = ROC_summary_grouped.merge(significance_df, 
                        on=['train_animal', 'test_animal', 'action', 'testCondition'], how='left')



In [None]:
###
# Step 3: plot - heatmap plot
###
if 1:
    # Combine train and test animals with their dyad and rank info
    # We'll use the original ROC_summary_all (before grouping) or ROC_summary_all_tgt if it has this info
    # Here I assume 'train_animal', 'train_dyadID', 'train_rank' are available

    # Extract unique animals with their dyad and rank info
    train_info = ROC_summary_all[['train_animal', 'train_dyadID', 'train_rank']].drop_duplicates()
    test_info = ROC_summary_all[['test_animal', 'test_dyadID', 'test_rank']].drop_duplicates()

    # Rename columns to unify
    train_info = train_info.rename(columns={'train_animal':'animal', 'train_dyadID':'dyadID', 'train_rank':'rank'})
    test_info = test_info.rename(columns={'test_animal':'animal', 'test_dyadID':'dyadID', 'test_rank':'rank'})

    # Combine and drop duplicates (some animals appear in both)
    animal_info = pd.concat([train_info, test_info]).drop_duplicates().reset_index(drop=True)

    # Sort by dyadID ascending, then rank (put 'dom' before 'sub')
    animal_info['rank_order'] = animal_info['rank'].map({'dom': 0, 'sub': 1})
    animal_info = animal_info.sort_values(by=['dyadID', 'rank_order'])

    # Create the ordered list
    ordered_animals = animal_info['animal'].tolist()

    # choose the target entries
    ind = (ROC_summary_grouped['action']=='pull') & (ROC_summary_grouped['testCondition']=='coop(1s)')
    ROC_summary_all_tgt = ROC_summary_grouped[ind]

    print(ROC_summary_all_tgt.keys())

    import seaborn as sns

    # Pivot the DataFrame to have train_animal as rows and test_animal as columns
    heatmap_data = ROC_summary_all_tgt.pivot(index='train_animal', columns='test_animal', values='predROC_mean')

    heatmap_data = heatmap_data.reindex(index=ordered_animals, columns=ordered_animals)

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_data, annot=True, cmap='viridis', vmin=0.5, vmax=1.0)
    plt.title('ROC AUC Heatmap')
    plt.xlabel('Test Animal')
    plt.ylabel('Train Animal')
    plt.tight_layout()
    plt.show()

In [None]:
###
# Step 3: plot - violin/bar plot
###
if 1:
    import seaborn as sns
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from statsmodels.stats.multicomp import pairwise_tukeyhsd
    from scipy.stats import f_oneway, ttest_ind


    # --- Step 1: Define group column ---
    ind_same = ROC_summary_all_tgt['train_animal']==ROC_summary_all_tgt['test_animal']
    ROC_summary_all_tgt = ROC_summary_all_tgt[~ind_same]
    
    ROC_summary_all_tgt['train_test_rank'] = (
        'train_' + ROC_summary_all_tgt['train_rank'] + '_test_' + ROC_summary_all_tgt['test_rank']
    )

    group_order = ['train_dom_test_dom', 'train_dom_test_sub', 'train_sub_test_dom', 'train_sub_test_sub']

    # --- Step 2: Run Tukey HSD post hoc test ---
    tukey = pairwise_tukeyhsd(
        endog=ROC_summary_all_tgt['predROC_mean'],
        groups=ROC_summary_all_tgt['train_test_rank'],
        alpha=0.05
    )

    # Convert to DataFrame
    tukey_results = pd.DataFrame(data=tukey.summary().data[1:], columns=tukey.summary().data[0])

    # Keep only significant results (p < 0.05)
    sig_results = tukey_results[tukey_results['p-adj'].astype(float) < 0.05]

    # --- Step 3: Create violin plot ---
    fig1 = plt.figure(figsize=(8, 6))
    sns.violinplot(data=ROC_summary_all_tgt, x='train_test_rank', y='predROC_mean',
                   order=group_order, palette='muted')

    sns.stripplot(data=ROC_summary_all_tgt, x='train_test_rank', y='predROC_mean',
                  order=group_order, color='black', size=2, alpha=0.5, jitter=True)

    plt.axhline(0.5, color='gray', linestyle='--', linewidth=1)
    plt.ylabel('Mean ROC AUC')
    plt.xlabel('Train-Test Rank Group')
    plt.title('Distribution of ROC AUC by Train/Test Rank Combination')

    # --- Step 4: Annotate significant pairs ---
    group_pos = {name: i for i, name in enumerate(group_order)}
    ymax = ROC_summary_all_tgt['predROC_mean'].max()
    line_offset = 0.01
    line_height = 0.02
    current_offset = 0

    for _, row in sig_results.iterrows():
        g1, g2 = row['group1'], row['group2']
        p_val = float(row['p-adj'])

        if g1 not in group_pos or g2 not in group_pos:
            continue

        x1, x2 = group_pos[g1], group_pos[g2]
        y = ymax + line_offset + current_offset
        plt.plot([x1, x1, x2, x2], [y, y + line_height, y + line_height, y], color='black')
        plt.text((x1 + x2) / 2, y + line_height + 0.005, f'p = {p_val:.3f}',
                 ha='center', va='bottom', fontsize=11)
        current_offset += 0.04  # Stack annotations

    plt.tight_layout()
    plt.show()

    # === Plot 2: 2-group violin with t-test ===
    def classify_same_vs_cross(row):
        return 'same_rank' if row['train_rank'] == row['test_rank'] else 'cross_rank'

    ROC_summary_all_tgt['rank_pair_type'] = ROC_summary_all_tgt.apply(classify_same_vs_cross, axis=1)

    fig2 = plt.figure(figsize=(6, 6))
    sns.violinplot(data=ROC_summary_all_tgt, x='rank_pair_type', y='predROC_mean', palette='pastel')
    sns.stripplot(data=ROC_summary_all_tgt, x='rank_pair_type', y='predROC_mean',
                  color='black', size=2, alpha=0.5, jitter=True)

    # Stats: t-test
    same_vals = ROC_summary_all_tgt[ROC_summary_all_tgt['rank_pair_type'] == 'same_rank']['predROC_mean']
    cross_vals = ROC_summary_all_tgt[ROC_summary_all_tgt['rank_pair_type'] == 'cross_rank']['predROC_mean']
    t_stat, p_ttest = ttest_ind(same_vals, cross_vals)

    # Annotate t-test result
    if p_ttest < 0.5:
        stars = '***' if p_ttest < 0.001 else '**' if p_ttest < 0.01 else '*'
        ymax = ROC_summary_all_tgt['predROC_mean'].max()
        plt.text(0.5, ymax + 0.01, f'p = {p_ttest:.3e} {stars}', ha='center', fontsize=12)

    plt.axhline(0.5, color='gray', linestyle='--', linewidth=1)
    plt.ylabel('Mean ROC AUC')
    plt.xlabel('Train-Test Rank Type')
    plt.title('ROC AUC: Same vs Cross Rank')
    plt.tight_layout()
    plt.show()

    
    savefig = 1
    if savefig:
        figsavefolder = data_saved_folder+'figs_for_3LagDBN_and_bhv_singlecam_wholebodylabels_combinesessions_basicEvents_DBNpredictions_cross_dyad_validation/'+savefile_sufix+'/'+cameraID+'/'
        if not os.path.exists(figsavefolder):
            os.makedirs(figsavefolder)
        fig2.savefig(figsavefolder+'withinCondition_Acrossdyads_SameOrAcrossRanks_DBNdependenciesAfterMI_binary_noself_summarizingplot.pdf')
        
     
    

In [None]:

ROC_summary_all_tgt