In [None]:
import os
import re
import copy
import json
import random
import itertools
import numpy as np
from tqdm import tqdm
import pandas as pd


import scipy.stats as statsc
import matplotlib.pyplot as plt
import seaborn as sns
from pingouin import partial_corr

import warnings; warnings.filterwarnings("ignore")

import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout, pygraphviz_layout
from networkx.drawing.nx_pydot import pydot_layout
from networkx.algorithms.community import louvain_communities

from tigramite import data_processing as pp
from tigramite.jpcmciplus import JPCMCIplus
from tigramite.independence_tests.gpdc import GPDC

from utils import get_dicts

In [None]:
out_df = pd.read_csv('./data/qna_output.csv')
label_dict, act_max_dict, act_labels, abbv_dict = get_dicts()

def df2obs(out_df, itv_type, itv_thought, sample_id, step, act_labels, act_max_dict, label_dict=label_dict, abbv_dict=abbv_dict, step_agg='max'):

    out_df = out_df.drop(['output_text', 'query'], axis=1)
    
    # convert column 'sae_preds' into multiple columns
    new_cols = list(label_dict.keys())
    sea_preds = out_df['sae_preds']
    sea_preds = np.array(sea_preds.apply(lambda x: json.loads(x)).tolist()) # convert string to list to np.array
    out_df = pd.concat([out_df, pd.DataFrame(sea_preds, columns=new_cols)], axis=1)
    out_df = out_df.drop(['sae_preds'], axis=1)
    out_df['itv_thought'] = out_df['itv_thought'].map(abbv_dict).fillna('none')
    out_df = out_df.rename(columns=abbv_dict)
    
    
    # aggregate for each sample and step
    obs = out_df.copy()
    if step_agg=='max':
        obs = obs.groupby(['itv_type', 'itv_thought', 'step', 'sample_id']).max()
    elif step_agg=='mean':
        obs = obs.groupby(['itv_type', 'itv_thought', 'step', 'sample_id']).mean()
    elif step_agg=='sum':
        obs = obs.groupby(['itv_type', 'itv_thought', 'step', 'sample_id']).sum()
    
    # filter for specific itv_type, itv_thought, sample_id, and step
    obs = obs[obs.index.get_level_values('itv_type').isin(itv_type)]
    obs = obs[obs.index.get_level_values('itv_thought').isin(itv_thought)]
    obs = obs[obs.index.get_level_values('step').isin(step)]
    obs = obs[obs.index.get_level_values('sample_id').isin(sample_id)]
    
    # scale the values
    obs = (obs / act_max_dict.values())[act_labels]
    
    # convert column 'itv_thought' into multiple columns
    itv = pd.DataFrame(columns=act_labels + ['none'], data=np.zeros((len(obs), len(act_labels)+1)))
    itv_vals = obs.reset_index()['itv_thought']
    itv_types = obs.reset_index()['itv_type']
    for i in range(len(itv_vals)): 
        itv.loc[i, itv_vals[i]] = 0 if itv_types[i] == 'phase_4' else 1

    itv = itv.reset_index(drop=True)
    itv = itv.loc[:, (itv != 0).any(axis=0)].fillna(0)
    itv.columns = itv.columns + '_itv'
    if 'none_itv' in itv.columns:
        itv = itv.drop('none_itv', axis=1)
            
    obs_itv = pd.concat([obs.reset_index(), itv], axis=1)
    obs_itv.set_index(['itv_type', 'sample_id', 'itv_thought', 'step'], inplace=True)

    return obs, itv, obs_itv

obs, itvn, obs_itv = df2obs(
    out_df,
    itv_type=['phase_3', 'phase_4'], 
    itv_thought=act_labels + ['none'], 
    sample_id=list(range(1, 11)),
    step=range(1, 101), 
    act_labels=act_labels,
    act_max_dict=act_max_dict,
    abbv_dict=abbv_dict,
    step_agg='max'
)

# Infer Causal Structure

In [None]:
"""
SETUP FOR CAUSAL NETWORK DISCOVERY
"""

from sklearn.gaussian_process.kernels import RBF, WhiteKernel as W

# set variables
tau_min, tau_max = 0,1
time_dummy = False; space_dummy = False; time_context = False
sys_vars = act_labels 
cxt_vars = ['phase'] + [col + '_itv' for col in sys_vars] if time_context else [col + '_itv' for col in sys_vars]
full_vars = sys_vars + cxt_vars

# set hyperparameters
alpha = 0.05
bootstrap_sample_size = 200
itv_per_bootstrap_sample = 2
link_removal_threshold = int(bootstrap_sample_size * 0.75)
kernel_length = 2
noise_level = 0.1

# set CI test
ci_test = GPDC(significance='analytic', gp_params={'alpha': 0.0, 'kernel': W(noise_level, 'fixed') + RBF(kernel_length, 'fixed')})

# get data
_obs_itv = copy.deepcopy(obs_itv)
_obs_itv = _obs_itv[_obs_itv.index.get_level_values('itv_type').isin(['phase_3'])]
_obs_itv = _obs_itv[_obs_itv.index.get_level_values('itv_thought').isin(sys_vars)]
_obs_itv = _obs_itv[_obs_itv.index.get_level_values('sample_id').isin(range(1, 11))]
_obs_itv = _obs_itv[sys_vars + cxt_vars]


In [None]:
"""
RUN CAUSAL DISCOVERY W/ BOOTSTRAPPED SAMPLES
"""

def get_link_assumptions(jpcmciplus, tau_min, tau_max, num_system_vars, assumptions=[1,2]):
    
    observed_context_nodes = jpcmciplus.time_context_nodes + jpcmciplus.space_context_nodes
    link_assumptions = jpcmciplus._set_link_assumptions(None, tau_min, tau_max, remove_contemp=False)
    link_assumptions = jpcmciplus.assume_exogenous_context(link_assumptions, observed_context_nodes)
    link_assumptions = jpcmciplus.clean_link_assumptions(link_assumptions, tau_max)
    for node_src in jpcmciplus.system_nodes:
        for node_dst in jpcmciplus.system_nodes:
            if node_src != node_dst:
                # assume there is NO lag-0 links between endogenous -> endogenous
                if 1 in assumptions: link_assumptions[node_src].pop((node_dst, 0), None)

        for node_dst in jpcmciplus.space_context_nodes:
            if node_src != (node_dst-num_system_vars):
                # assume there is NO lag-0 links between NON-corresponding exogenous -> endogenous
                if 2 in assumptions: link_assumptions[node_src].pop((node_dst, 0), None)
            else:
                # assume lag-0 links between corresponding exogenous -> endogenous
                if 2 in assumptions: link_assumptions[node_src][(node_dst, 0)] = '-->'
    # print("[WARNING] No assumptions for context -> system links")

    for node_src in jpcmciplus.space_context_nodes:
        for node_dst in jpcmciplus.space_context_nodes:
            # if node_src != (node_dst+num_system_vars):
            if node_src != node_dst:
                # assume there is NO lag-0 links between exogenous -> exogenous (implied by assigning exogenous nodes as space_context nodes)
                link_assumptions[node_src].pop((node_dst, 0), None)

    return link_assumptions

def causal_inference(_obs_itv, ci_test=ci_test, alpha=alpha, tau_max=tau_max, tau_min=tau_min, time_dummy=time_dummy, space_dummy=space_dummy, time_context=time_context):
    # basic stats
    num_time = _obs_itv.index.get_level_values('step').nunique() 
    num_system_vars = len([col for col in _obs_itv.columns if not re.search(r'_itv$', col)]) # num_system_vars corresponds to the number of endogenous variables (i.e., the dysfunctional representational states)
    num_context_vars = len([col for col in _obs_itv.columns if re.search(r'_itv$', col)]) # num_context_vars corresponds to the number of exogenous variables (i.e., the intervention variables)
    num_domains = (_obs_itv.index.get_level_values('domain').nunique()) # num_domain corresponds to the number of distinct time series datasets
    var_names = list(_obs_itv.columns)

    if time_context: num_context_vars += 1; num_system_vars -= 1
    if time_dummy:  var_names += ['t_dummy']
    if space_dummy: var_names += ['s_dummy']


    # prepare data for jpcmci by placing each dataset into a dictionary
    data_dict = {}
    dummy_data_time = np.identity(num_time)
    for i in range(num_domains):
        df = _obs_itv[_obs_itv.index.get_level_values('domain') == i]
        dummy_data_space = np.zeros((num_time, num_domains))
        dummy_data_space[:, i] = 1

        if len(df) > 0: 
            data = df.to_numpy()
            if time_dummy:  data = np.hstack((data, dummy_data_time))
            if space_dummy: data = np.hstack((data, dummy_data_space))
            data_dict.update({i: data})  
    

    # specify node types: system (endogenous), time_context (time-lagged exogenous), space_context (non-lagged exogenous)
    node_classification = dict(zip(
        _obs_itv.columns,
        ["space_context" if "_itv" in col else "system" for col in _obs_itv.columns],
    ))
    if time_context: node_classification['phase'] = "time_context"

    observed_indices_time = [i for i, col in enumerate(_obs_itv.columns) if node_classification[col] == "time_context"]
    t_context_nodes = list(range(
        num_system_vars, 
        num_system_vars + len(observed_indices_time)
    ))

    observed_indices_space = [i for i, col in enumerate(_obs_itv.columns) if node_classification[col] == "space_context"]
    s_context_nodes = list(range(
        num_system_vars + len(observed_indices_time), 
        num_system_vars + len(observed_indices_time) + len(observed_indices_space)
    ))

    system_indices = [i for i, col in enumerate(_obs_itv.columns) if node_classification[col] == "system"]
    observed_indices = system_indices + observed_indices_time + observed_indices_space

    node_classification_jpcmci = dict(zip(observed_indices, node_classification.values()))
    vector_vars = {i: [(i, 0)] for i in system_indices + t_context_nodes + s_context_nodes}

    new_idx = (num_system_vars + num_context_vars - 1)
    if time_dummy:  
        new_idx += 1
        t_dummy = list(range(new_idx, new_idx + num_time))
        node_classification_jpcmci.update({new_idx : "time_dummy"})
        vector_vars[new_idx] = [(i, 0) for i in t_dummy]
        
    if space_dummy: 
        new_idx += 1
        s_dummy = list(range((data).shape[1] - num_context_vars, (data).shape[1]))
        node_classification_jpcmci.update({new_idx : "space_dummy"})
        vector_vars[new_idx] = [(i, 0) for i in s_dummy]


    # specify the data_types: 0 is continuous; 1 discrete data
    data_type1 =  np.zeros((num_domains, num_time, num_system_vars), dtype='int')
    data_type2 =  np.ones((num_domains, num_time, num_context_vars), dtype='int')
    # if time_context: data_type2[:, :, 0] = 0
    if space_dummy: data_type2 = np.concatenate([data_type2, np.ones((num_domains, num_time, num_domains), dtype='int')], axis=2)
    data_type = np.concatenate([data_type1, data_type2], axis=2)


    # run jpcmciplus
    dataframe = pp.DataFrame(
        data=data_dict,
        analysis_mode='multiple',
        var_names=var_names,
        data_type=data_type,
        vector_vars=vector_vars,
    )
    jpcmciplus = JPCMCIplus(
        dataframe=dataframe, 
        cond_ind_test=ci_test,
        node_classification=node_classification_jpcmci,
        verbosity=0,
    )
    # link assumption 1 = no lag-0 links between endogenous -> endogenous; 
    # link assumption 2 = no lag-0 links between non-corresponding exogenous -> endogenous; 
    results = jpcmciplus.run_jpcmciplus(
        tau_min=tau_min, 
        tau_max=tau_max, 
        pc_alpha=alpha, 
        reset_lagged_links=True,
        link_assumptions=get_link_assumptions(jpcmciplus, tau_min, tau_max, num_system_vars, assumptions=[1,2]), 
    )

    return results, jpcmciplus, dataframe, var_names

def dag2adj(dag, num_vars):

    adj_lag0 = np.zeros((num_vars, num_vars))
    adj_lag1 = np.zeros((num_vars, num_vars))          
    for i in range(num_vars):
        for j in range(num_vars):
            if dag[i, j, 0] == '': adj_lag0[i, j] = 0
            if dag[i, j, 0] == '-->': adj_lag0[i, j] = 1
            if dag[i, j, 0] == '<--': adj_lag0[j, i] = 1
            if dag[i, j, 1] == '': adj_lag1[i, j] = 0
            if dag[i, j, 1] == '-->': adj_lag1[i, j] = 1
            if dag[i, j, 1] == '<--': adj_lag1[j, i] = 1

    adj_cat = adj_lag0 + adj_lag1
    adj_cat[adj_cat > 1] = 1
    
    return adj_cat, adj_lag0, adj_lag1

def bootstrap_causal_inference(obs_itv, num_samples, k):
    
    sample_ids = obs_itv.index.get_level_values('sample_id').unique()
    itv_ids = obs_itv.index.get_level_values('itv_thought').unique()
    type_ids = obs_itv.index.get_level_values('itv_type').unique()
    
    dags, models = [], []
    
    for _ in tqdm(range(num_samples)):
        
        _obs_itv_list = []
        for i, (itv, typ) in enumerate(itertools.product(itv_ids, type_ids)):
            
            s_ids = random.sample(list(sample_ids), k=k)
            for j, s_id in enumerate(s_ids):
                _obs_itv = obs_itv.copy()
                _obs_itv = _obs_itv[_obs_itv.index.get_level_values('itv_thought') == itv]
                _obs_itv = _obs_itv[_obs_itv.index.get_level_values('itv_type') == typ]
                _obs_itv = _obs_itv[_obs_itv.index.get_level_values('sample_id') == s_id]
                
                # add level for domain
                _obs_itv = _obs_itv.reset_index()
                _obs_itv['domain'] = i * k + j # a domain corresponds to a distinct time series dataset
                _obs_itv = _obs_itv.set_index(['itv_type', 'sample_id', 'itv_thought', 'step', 'domain'])
                
                _obs_itv_list.append(_obs_itv)

        _obs_itv = pd.concat(_obs_itv_list)
        
        results, jpcmciplus, dataframe, var_names = causal_inference(_obs_itv)
        dag = jpcmciplus._get_dag_from_cpdag(cpdag_graph=results['graph'], variable_order=range(len(var_names)))
        
        models.append(jpcmciplus); dags.append(dag)
        
    return models, dags

jpcmciplus_list, dags = bootstrap_causal_inference(_obs_itv, num_samples=bootstrap_sample_size, k=itv_per_bootstrap_sample)


In [None]:
"""
GET THE FINAL CAUSAL STRUCTURE BY REMOVING LESS FREQUENCT LINKS
"""

def get_adj_matrix(dag, num_vars):
    adj0 = np.zeros(((num_vars), (num_vars)))
    adj1 = np.zeros(((num_vars), (num_vars)))          
    for i in range((num_vars)):
        for j in range((num_vars)):
            if dag[i, j, 0] == '': adj0[i, j] = 0
            if dag[i, j, 0] == '-->': adj0[i, j] = 1
            if dag[i, j, 0] == '<--': adj0[j, i] = 1
            if dag[i, j, 1] == '': adj1[i, j] = 0
            if dag[i, j, 1] == '-->': adj1[i, j] = 1
            if dag[i, j, 1] == '<--': adj1[j, i] = 1    
    adj = adj0 + adj1
    adj[adj > 1] = 1
    
    return adj, adj0, adj1

def bootstrap_outcomes(dags, link_removal_threshold, num_full_vars):
        
    adj1_list = []
    adj0_list = []
    
    for i in range(len(dags)):
        adj, adj_lag0, adj_lag1 = get_adj_matrix(dags[i], num_full_vars)
        adj1_list.append(adj_lag1)
        adj0_list.append(adj_lag0)
        
    adj_lag1 = np.stack(adj1_list)
    adj_lag1 = adj_lag1.sum(axis=0)
    adj_lag1[adj_lag1 < link_removal_threshold] = 0
    adj_lag1[adj_lag1 >= link_removal_threshold] = 1
    
    adj_lag0 = np.stack(adj0_list)
    adj_lag0 = adj_lag0.sum(axis=0)
    adj_lag0[adj_lag0 < link_removal_threshold] = 0
    adj_lag0[adj_lag0 >= link_removal_threshold] = 1
    
    adj_cat = adj_lag0 + adj_lag1
    adj_cat[adj_cat > 1] = 1
    
    return adj_cat, adj_lag0, adj_lag1

adj_cat, adj_lag0, adj_lag1 = bootstrap_outcomes(dags, link_removal_threshold, len(full_vars))

In [None]:
"""
ANALYZE THE CAUSAL STRUCTURE
"""

def get_nx_graphs(adj, adj0, adj1, var_names):
    G0 = nx.DiGraph(adj0[:len(var_names), :len(var_names)])
    G1 = nx.DiGraph(adj1[:len(var_names), :len(var_names)])
    G = nx.DiGraph(adj[:len(var_names), :len(var_names)])

    G.remove_edges_from(nx.selfloop_edges(G))
    G1.remove_edges_from(nx.selfloop_edges(G1))

    key_idx_dict = {}
    key_idx_dict.update({idx: val for idx, val in enumerate(var_names)})
    edges = G.edges()

    assert nx.is_directed_acyclic_graph(G0)
    
    return G, G0, G1

def graph_stats(G, var_names, print_output=False):

    # analyze G
    _G = G.copy()
    indegree = nx.in_degree_centrality(_G)
    outdegree = nx.out_degree_centrality(_G)
    pagerank = nx.pagerank(_G)
    betweenness = nx.betweenness_centrality(_G)
    closeness = nx.closeness_centrality(_G)
    c_louvain = louvain_communities(_G)
    c_louvain_dict = {}
    for i, community in enumerate(c_louvain):
        for node in community: c_louvain_dict[node] = i
        
    df = pd.DataFrame({
        'in-d': indegree,
        'out-d': outdegree,
        'pagerank': pagerank,
        'btw': betweenness,
        'close': closeness,
        'comm_lv': c_louvain_dict,
    })
    df.index = var_names
    if print_output:
        print(df.applymap(lambda x: round(x, 2)).sort_values('comm_lv', ascending=False))
    
    return df

def draw_graph(G1, var_names):
    plt.figure(figsize=(5, 5))
    pos = pydot_layout(G1) # graphviz_layout, pygraphviz_layout, pydot_layout
    # pos = nx.circular_layout(G1)
    
    edge_scores1 = [0.15] * len(G1.edges)
    
    node_size = 2000
    plt.figure(figsize=(10, 10))
    nx.draw(G1, pos, arrowsize=0) 
    nx.draw_networkx_nodes(G1, pos, 
        node_size=node_size, 
        node_color='seashell', 
    )
    nx.draw_networkx_labels(G1, pos, 
        labels=dict(zip(range(len(var_names)), var_names)),    
        font_size=10,
        font_weight='bold',
    )    
    nx.draw_networkx_edges(G1, pos, 
        node_size=node_size,
        width=4, 
        arrowsize=15, 
        connectionstyle='arc3,rad=0.2', 
        edge_color=edge_scores1,
        edge_cmap=plt.cm.coolwarm,
        edge_vmin=-0.25,
        edge_vmax=0.25,
        label='Edge Weights',
    )
    plt.show()
    
G, G0, G1 = get_nx_graphs(adj_cat, adj_lag0, adj_lag1, sys_vars)
df_stat = graph_stats(G1, sys_vars, print_output=True)
draw_graph(G1, sys_vars)

# Eval Causal Structure

In [None]:
"""
GET UNIT TIME-LAGGED CORR. FROM LLM
"""

itv_vars = [col + '_itv' for col in act_labels]
_obs = copy.deepcopy(obs_itv)
_obs = _obs[_obs.index.get_level_values('itv_thought').isin(act_labels + ['none'])]
_obs = _obs[_obs.index.get_level_values('itv_type').isin(['phase_3','phase_4'])]

lag = 1
lagged_pcorr_s = pd.DataFrame(index=act_labels, columns=act_labels, dtype=float)
lagged_corr_s = pd.DataFrame(index=act_labels, columns=act_labels, dtype=float)

dfs = []
sample_ids = list(_obs.index.get_level_values('sample_id').unique())
itv_thoughts = list(_obs.index.get_level_values('itv_thought').unique())
for sample_id, itv_thought in itertools.product(sample_ids, itv_thoughts):
    df = _obs[(_obs.index.get_level_values('itv_thought') == itv_thought) & (_obs.index.get_level_values('sample_id') == sample_id)]
    if len(df) > 0: dfs.append(df)
print("Number of datasets - LLM: ", len(dfs))

for var1, var2 in itertools.product(act_labels, act_labels):

    _obs_lag = []

    if (var1 != var2):
            
        for df in (dfs):
            df_copy1 = copy.deepcopy(df); df_copy2 = copy.deepcopy(df)
            df_copy1 = df_copy1.iloc[lag:].reset_index(drop=True)[var1]
            df_copy2 = df_copy2.iloc[:-lag].reset_index(drop=True).drop([var1], axis=1)
            _obs_lag.append(pd.concat([df_copy1, df_copy2], axis=1))

        _obs_lag = pd.concat(_obs_lag, axis=0).reset_index(drop=True).dropna()
        lagged_pcorr_s.loc[var1, var2] = partial_corr(_obs_lag, x=var1, y=var2, covar=list(_obs_lag.columns.difference([var1, var2])), method='spearman')['r'].values[0] # ['p-val'].values[0] # ['r'].values[0]
        lagged_corr_s.loc[var1, var2] = _obs_lag[[var1, var2]].corr(method='spearman').iloc[0, 1]
        
    else:
        for df in (dfs):
            df_copy1 = copy.deepcopy(df); df_copy2 = copy.deepcopy(df)
            df_copy1 = df_copy1.iloc[lag:].reset_index(drop=True)[var1]
            df_copy2 = df_copy2.iloc[:-lag].reset_index(drop=True)
            df_copy2.columns = [col + '_' for col in df_copy2.columns]
            _obs_lag.append(pd.concat([df_copy1, df_copy2], axis=1))
            
        _obs_lag = pd.concat(_obs_lag, axis=0).reset_index(drop=True).dropna()
        lagged_pcorr_s.loc[var1, var2] = partial_corr(_obs_lag, x=var1, y=var2+'_', covar=list(_obs_lag.columns.difference([var1, var2 + '_'])), method='spearman')['r'].values[0] # ['p-val'].values[0] # ['r'].values[0]
        lagged_corr_s.loc[var1, var2] = _obs_lag[[var1, var2 + '_']].corr(method='spearman').iloc[0, 1]

In [None]:
"""
RELATIONSHIP BETWEEN LAGGED CORR. AND SHORTEST PATH DISTANCE
"""

distances = dict(nx.all_pairs_shortest_path_length(G))
n = len(distances)
dist_matrix = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        if distances[i].get(j) is not None:
            dist_matrix[i, j] = distances[i][j]
        else:
            dist_matrix[i, j] = np.inf

distance_list = dist_matrix.flatten()
lagged_corr_s_list = lagged_corr_s.values.flatten()
lagged_pcorr_s_list = lagged_pcorr_s.values.flatten()
df_dist = pd.DataFrame({'distance': distance_list, 'lagged_corr_s': lagged_corr_s_list, 'lagged_pcorr_s': lagged_pcorr_s_list})
df_dist = df_dist[df_dist['distance'] != np.inf]


var1, var2 = 'distance', 'lagged_pcorr_s'
plt.figure(figsize=(5, 4.5))
sns.lineplot(data=df_dist, x=var1, y=var2, color='black', linewidth=5, errorbar=None, err_kws={'alpha': 0.2, 'linewidth': 0}, alpha=0.7)
sns.stripplot(data=df_dist, x=var1, y=var2, alpha=.3, legend=False, color='gray', size=8, jitter=False, dodge=False)
plt.xlabel(f'Shortest Path Distance'); plt.ylabel(f'Lag-1 Correlation')
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.show()

In [None]:
"""
RELATIONSHIP BETWEEN CENTRALITY AND NETWORK ACTIVATION
"""
_df_stat = df_stat.copy()

_obs = copy.deepcopy(obs)
_obs = _obs[act_labels]
net_act = _obs.groupby(['itv_thought']).mean().mean(1)
_df_stat['net_act'] = net_act
_df_stat = _df_stat.reset_index()

x_vars = ['out-d', 'btw', 'close']; var2 = 'net_act'

x_scores = []; y_scores = []
for x_var in x_vars:
    x_scores.append(_df_stat[x_var].values)
    y_scores.append(_df_stat[var2].values)


_df_stat['rank-btw'] = _df_stat['btw'].rank(ascending=False)
_df_stat['rank-out-d'] = _df_stat['out-d'].rank(ascending=False)
_df_stat['rank-close'] = _df_stat['close'].rank(ascending=False)


_df_stats = []
for i, (x_score, y_score) in enumerate(zip(x_scores, y_scores)):
    x_score = _df_stat[x_vars[i]].rank(ascending=False)
    _df_stats.append(pd.DataFrame({'score': x_score, var2: y_score, 'index': _df_stat['index'], 'x_var': x_vars[i]}))

_df_stats = pd.concat(_df_stats, axis=0)


plt.figure(figsize=(4, 5))
sns.regplot(data=_df_stat, x='rank-btw', y=var2, scatter=False, color='green', line_kws={'alpha': 0.5, 'linewidth': 5}, )
sns.regplot(data=_df_stat, x='rank-out-d', y=var2, scatter=False, color='red', line_kws={'alpha': 0.5, 'linewidth': 5}, )
sns.regplot(data=_df_stat, x='rank-close', y=var2, scatter=False, color='blue', line_kws={'alpha': 0.5, 'linewidth': 5}, )
sns.scatterplot(data=_df_stats, x='score', y=var2, hue='x_var', s=100, palette='hls', style='index', zorder=2, legend=False, )

plt.xlim(0, 15)
plt.xlabel(f'Centrality Score (Ranked)'); plt.ylabel(f'Network Activation')
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.show()

# Infer SCM

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel as W, ConstantKernel as C
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.neighbors import KernelDensity
from sklearn.metrics import r2_score
from tabulate import tabulate

def fit_structural_equations(_obs_itv, adj_lag0, adj_lag1, sys_vars=sys_vars, full_vars=full_vars, model_name='gpr'):
    
    perf_df = pd.DataFrame(columns=['var', 'r2'])
    
    # fit structural equation for each variable
    models = {}
    for i, var in enumerate(sys_vars):
        
        # init feature
        parents_0 = [full_vars[j] for j in range(len(full_vars)) if adj_lag0[j, i] == 1]
        parents_1 = [full_vars[j] for j in range(len(full_vars)) if adj_lag1[j, i] == 1]
        y, x0, x1 = [], [], []
        for _obs in _obs_itv:
            y.extend(_obs[var].values[1:])
            x0.extend(_obs[parents_0].values[1:])
            x1.extend(_obs[parents_1].values[:-1])
        y = np.array(y); x0 = np.array(x0); x1 = np.array(x1)
        x = np.concatenate([x0, x1], axis=1)
        
        # init and fit model
        if model_name == 'linear':  model = LinearRegression(fit_intercept=False)
        if model_name == 'poly':    model = make_pipeline(PolynomialFeatures(degree=2, interaction_only=False, include_bias=False), LinearRegression(fit_intercept=False)) 
        if model_name == 'gpr':     model = make_pipeline(GaussianProcessRegressor(kernel=W() + RBF(), alpha=0.0))
        if model_name == 'mlp':     model = MLPRegressor(hidden_layer_sizes=(64,), max_iter=1000, learning_rate_init=0.05)

        model.fit(x, y)
        y_hat = model.predict(x)

        # gather results
        models[var] = model
        df = pd.DataFrame({'y': y,'y_hat': y_hat,})
        r2 = r2_score(y, y_hat)
        _perf_df = pd.DataFrame({'var': var, 'r2': round(r2,2)}, index=[0])
        perf_df = pd.concat([perf_df, _perf_df], axis=0).reset_index(drop=True)
    
    print("\nTrain Performance:")
    print(tabulate(perf_df.T, headers='keys', tablefmt='pretty')) 
    
    return models

def test_structural_equations(_obs_itv, adj_lag0, adj_lag1, models, sys_vars=sys_vars, full_vars=full_vars, model_name='gpr'):

    perf_df = pd.DataFrame(columns=['var', 'r2'])    
    
    # test structural equation for each variable
    for i, var in enumerate(sys_vars):
        
        # init feature
        parents_0 = [full_vars[j] for j in range(len(full_vars)) if adj_lag0[j, i] == 1]
        parents_1 = [full_vars[j] for j in range(len(full_vars)) if adj_lag1[j, i] == 1]
        y, x0, x1 = [], [], []
        for _obs in _obs_itv:
            y.extend(_obs[var].values[1:])
            x0.extend(_obs[parents_0].values[1:])
            x1.extend(_obs[parents_1].values[:-1])
        y = np.array(y); x0 = np.array(x0); x1 = np.array(x1)
        x = np.concatenate([x0, x1], axis=1)
        
        # predict with model
        y_hat = models[var].predict(x)
            
        # gather results
        df = pd.DataFrame({'y': y,'y_hat': y_hat,})
        r2 = r2_score(y, y_hat)
        _perf_df = pd.DataFrame({'var': var, 'r2': round(r2,2)}, index=[0])
        perf_df = pd.concat([perf_df, _perf_df], axis=0).reset_index(drop=True)

    print("\nTest Performance")
    print(tabulate(perf_df.T, headers='keys', tablefmt='pretty'))
    
    return perf_df

def visualize_activation(_obs):

    n_cols, n_rows = 5,3
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(22, 10))

    for i, itv in enumerate(act_labels + ['none']):
        
        # gather activation scores    
        node_list, step_list, score_list = [], [], []
        itv_node_idx = (_obs.index.get_level_values('itv_thought') == itv)
        _obs_itv = _obs[itv_node_idx]
        if len(_obs_itv) < 20: continue
        
        # gather mean activation scores 
        none_itv_cols = [c for c in _obs.columns if itv not in c]
        obs_mean = _obs_itv[none_itv_cols].mean(axis=1)
        obs_mean = pd.DataFrame(obs_mean, columns=['score'])
        obs_mean = obs_mean.groupby(['sample_id', 'step']).mean().reset_index()
        
        # gather mean activation scores when no intervened thought is present
        none_node_idx = (_obs.index.get_level_values('itv_thought') == 'none')
        obs_base = _obs[none_node_idx].mean(axis=1)
        obs_base = pd.DataFrame(obs_base, columns=['score'])
        obs_base = obs_base.groupby(['sample_id', 'step']).mean().reset_index()
        
        steps = (_obs_itv.index.get_level_values('step').unique())
        for s in (steps): 
            node_obs = _obs_itv[_obs_itv.index.get_level_values('step') == s]
            node_obs = node_obs.mean() # avg pool activation scores over all samples
            for idx in node_obs.index:
                node_list.append(idx)
                step_list.append(s)
                score_list.append(node_obs[idx])
                
        out_node = pd.DataFrame({'node': node_list, 'step': step_list, 'score': score_list, 'itv': [(n in [itv]) for n in node_list]})
        out_node['node'] = pd.Categorical(out_node['node'], act_labels)
        
        # moving average smoothing
        num_steps = out_node['step'].nunique()
        window_size = int(num_steps // 10)
        out_node = out_node[(out_node['step'] == 1) | (out_node['step'] % int(window_size/2) == 0)]
        
        # plot
        row = i // n_cols; col = i % n_cols
        ax = axes[row, col]
        legend = False
        sns.lineplot(data=obs_mean, x='step', y='score', ax=ax, linestyle='-', linewidth=2.5, legend=False, color='black', errorbar='sd', err_kws={'alpha': 0.15, 'linewidth': 0}, zorder=1)
        sns.lineplot(data=obs_base, x='step', y='score', ax=ax, linestyle='-', linewidth=2.5, legend=False, color='orangered', errorbar='sd', err_kws={'alpha': 0.15, 'linewidth': 0}, zorder=1)
        sns.scatterplot(data=out_node[out_node['itv']==False], x='step', y='score', ax=ax, style="node", markers=True, s=40, alpha=.7, legend=legend, style_order=list(abbv_dict.values()), zorder=0, color='gray')
        sns.scatterplot(data=out_node[out_node['itv']==True], x='step', y='score', ax=ax, style="node", markers=True, s=60, alpha=1, legend=legend, style_order=list(abbv_dict.values()), zorder=2, color='darkblue')
        
        ax.set_xlabel(''); ax.set_ylabel(''); ax.set_title(f"[{itv}-itv]", fontweight='bold', fontsize=14)
        if row == n_rows-1: ax.set_xlabel('Q&A Step', fontsize=12, fontweight='bold')
        if col == 0: ax.set_ylabel('Thought Activation', fontsize=12, fontweight='bold')
        
        ax.set_yticks(np.arange(0, 1.1, 0.2))
        ax.tick_params(axis='both', which='major', labelsize=10)

        if num_steps > 50:
            ax.axvline(x=50, color='black', linestyle='dotted', linewidth=2.5, zorder=0)
        if legend:
            ax.legend(loc='center', bbox_to_anchor=(-0.7, 0.5), ncol=2, fontsize=12)
        
        axes[row, col].set_xlabel(''); axes[row, col].set_ylabel(''); axes[row, col].set_title("")#; axes[row, col].set_xticks([]); axes[row, col].set_yticks([])
                
    plt.subplots_adjust(wspace=0.15, hspace=0.15)
    plt.show()
    
def simulate_SCM(N, T, adj_lag0, adj_lag1, models, resid_dist_dict, act_labels=act_labels, sys_vars=sys_vars, full_vars=full_vars, model_name='gpr', plot=True):

    X_list = []
    for n in range(N):
        for _, itv in enumerate(act_labels + ['none']):

            X = pd.DataFrame(columns=full_vars, data=np.zeros((T+1, len(full_vars))))
            X[itv+'_itv'] = [1 if step < 51 else 0 for step in range(T+1)]
            
            for step in range(1, T+1):
                for i, var in enumerate(sys_vars):
                    parents_0 = [full_vars[j] for j in range(len(full_vars)) if adj_lag0[j, i] == 1]
                    parents_1 = [full_vars[j] for j in range(len(full_vars)) if adj_lag1[j, i] == 1]
                    
                    x0 = X.loc[(step)][parents_0].values
                    x1 = X.loc[(step-1)][parents_1].values
                    x = np.concatenate([x0, x1], axis=0).reshape(1, -1)
                    y = models[var].predict(x)


                    if resid_dist_dict is not None:
                        resid_kde = resid_dist_dict[var]
                        y += resid_kde.sample(1).squeeze()
                    
                    y = np.clip(y, 0, 1)
                    X.loc[(step), var] = y
            
            X['step'] = range(T+1)
            X['itv_thought'] = itv
            X['sample_id'] = n
            X = X.set_index(['itv_thought', 'step', 'sample_id'])
            X_list.append(X[full_vars])
        
    obs = pd.concat(X_list, axis=0)
    visualize_activation(obs[sys_vars])

    return obs

def estimate_residuals(_obs_itv, models, bandwidth, sys_vars=sys_vars, full_vars=full_vars, plot=True):
    
    if plot: fig, axes = plt.subplots(2, 7, figsize=(25,5))

    # fit noise of structural equation for each variable
    resid_dist_dict = {}
    for i, var in enumerate(sys_vars):

        # init feature
        parents_0 = [full_vars[j] for j in range(len(full_vars)) if adj_lag0[j, i] == 1]
        parents_1 = [full_vars[j] for j in range(len(full_vars)) if adj_lag1[j, i] == 1]
        y, x0, x1 = [], [], []
        for _obs in _obs_itv:
            y.extend(_obs[var].values[1:])
            x0.extend(_obs[parents_0].values[1:])
            x1.extend(_obs[parents_1].values[:-1])
        y = np.array(y); x0 = np.array(x0); x1 = np.array(x1)
        x = np.concatenate([x0, x1], axis=1)
        
        # gather residuals
        y_hat = models[var].predict(x)
        residual = y - y_hat
        
        # fit kde probability density estimation to the residuals
        kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(residual.reshape(-1, 1))
        resid_dist_dict[var] = kde
        
        if plot:
            row_i = i // 7; col_i = i % 7
            ax = axes[row_i, col_i]
            sns.histplot(residual, bins=50, kde=True, color='darkblue', ax=ax, kde_kws={'bw_method': bandwidth}, stat='density')
            ax.set_title(f"Residual Distribution [{var}]")
            ax.set_xlabel(''); ax.set_ylabel('')
            plt.subplots_adjust(wspace=0.35, hspace=0.35)

    if plot: plt.show()
        
    return resid_dist_dict

def get_list_of_data(data, act_labels):
    data_list = []
    
    for i, itv in enumerate(act_labels + ['none']):
        data_itv = data[data.index.get_level_values('itv_thought') == itv]
        sample_ids = data_itv.index.get_level_values('sample_id').unique()
        
        for sample_id in sample_ids:
            data_sample = data_itv[data_itv.index.get_level_values('sample_id') == sample_id]
            data_list.append(data_sample)
            
    return data_list



In [None]:
"""
FIT & EVALUATE THE STRUCTURAL EQUATIONS
"""

def eval_structural_equation(obs_itv, adj_lag0, adj_lag1, model_name='poly', act_labels=act_labels, sys_vars=sys_vars, full_vars=full_vars):
    
    sample_ids = obs_itv.index.get_level_values('sample_id').unique()
    perf_dfs = []
    
    for sample_id in sample_ids:
        
        test_id = sample_id
        train_ids = [i for i in sample_ids if i != test_id]
        
        _obs_train = copy.deepcopy(obs_itv)
        _obs_train = _obs_train[_obs_train.index.get_level_values('itv_type').isin(['phase_3'])]
        _obs_train = _obs_train[_obs_train.index.get_level_values('sample_id').isin(train_ids)]
        _obs_train = _obs_train[_obs_train.index.get_level_values('itv_thought').isin(act_labels)]
        _obs_train = _obs_train[sys_vars + cxt_vars]


        _obs_test = copy.deepcopy(obs_itv)
        _obs_test = _obs_test[_obs_test.index.get_level_values('itv_type').isin(['phase_3'])]
        _obs_test = _obs_test[_obs_test.index.get_level_values('sample_id').isin([test_id])]
        _obs_test = _obs_test[_obs_test.index.get_level_values('itv_thought').isin(act_labels)]
        _obs_test = _obs_test[sys_vars + cxt_vars]

        
        _obs_train_list = get_list_of_data(_obs_train, act_labels)
        _obs_test_list = get_list_of_data(_obs_test, act_labels)
                

        models = fit_structural_equations(_obs_train_list, adj_lag0, adj_lag1, sys_vars=sys_vars, full_vars=full_vars, model_name=model_name)
        perf_df = test_structural_equations(_obs_test_list, adj_lag0, adj_lag1, models=models, sys_vars=sys_vars, full_vars=full_vars, model_name=model_name)
        
        perf_dfs.append(perf_df)
        
    perf_df = pd.concat(perf_dfs, axis=0)
    print("\nFinal Test Performance")
    print(tabulate(perf_df.groupby('var').mean().round(2).T, headers='keys', tablefmt='pretty'))    

    return models, perf_df

model_name = 'poly'
models, perf_df = eval_structural_equation(obs_itv, adj_lag0, adj_lag1, model_name=model_name)

In [None]:
"""
INFER SCM AND GENERATE DATA DISTRIBUTIONS
"""

# setup parameters
N, T, std = 1, 100, 0.0
model_name = 'poly'


# setup data
_obs_itv = copy.deepcopy(obs_itv)
_obs_itv = _obs_itv[_obs_itv.index.get_level_values('itv_type').isin(['phase_3'])]
_obs_itv = _obs_itv[sys_vars + cxt_vars]


# fit structural equations
_obs_itv_list = get_list_of_data(_obs_itv, act_labels)
models = fit_structural_equations(_obs_itv_list, adj_lag0, adj_lag1, sys_vars=sys_vars, full_vars=full_vars, model_name=model_name)
resid_dist_dict = None
# resid_dist_dict = estimate_residuals(_obs_itv_list, models, bandwidth=0.2)


# generate data distributions
X = simulate_SCM(
    N, T, adj_lag0, adj_lag1, models, 
    resid_dist_dict=resid_dist_dict, 
    act_labels=act_labels, 
    sys_vars=sys_vars,
    full_vars=full_vars,
    model_name=model_name,
    plot=True
)

In [None]:
"""
MEASURE CORR. BETWEEN UNIT ACTIVATION DISTRIBUTION FROM LLM AND SCM
"""

_obs_scm = copy.deepcopy(X)
_obs_scm = _obs_scm[_obs_scm.index.get_level_values('itv_thought').isin(act_labels)]
_obs_scm = _obs_scm[_obs_scm.index.get_level_values('step').isin(range(1, 101))]
_obs_scm = _obs_scm.groupby(['itv_thought', 'step', 'sample_id']).mean().reset_index()

itv_vars = [col + '_itv' for col in act_labels]
_obs_llm = copy.deepcopy(obs_itv)
_obs_llm = _obs_llm[_obs_llm.index.get_level_values('itv_thought').isin(act_labels)]
_obs_llm = _obs_llm[_obs_llm.index.get_level_values('itv_type').isin(['phase_3', 'phase_4'])]
_obs_llm = _obs_llm.groupby(['itv_thought', 'itv_type','step']).mean().reset_index()

itv_thoughts = _obs_llm['itv_thought'].unique()
for itv_t in act_labels:
    corrs = []
    _llm_samples, _scm_samples = [], []
    for sys_var in sys_vars:
        _llm_sample = _obs_llm[_obs_llm['itv_thought'] == itv_t]#[sys_var]
        _scm_sample = _obs_scm[_obs_scm['itv_thought'] == itv_t]#[sys_var]
        
        # drop column itv_thought and sample_id
        _scm_sample = _scm_sample.drop(columns=['itv_thought', 'sample_id'])
        _scm_sample = _scm_sample.groupby('step').mean().reset_index()
        
        _llm_samples.extend(_llm_sample.sort_values('step')[sys_var].values)
        _scm_samples.extend(_scm_sample.sort_values('step')[sys_var].values)
        
    spearman_corr = statsc.spearmanr(_llm_samples, _scm_samples)[0]
    print(f"[{itv_t}] Corr {round(np.mean(spearman_corr),2)}")
