In [2]:
%reload_ext autoreload
%autoreload 2
import os
import shapley_value
import pandas as pd
import numpy as np
from getting_data import read_conf
from xai_process import get_smp_shap_data
from s2search_score_pdp import pdp_based_importance

ds_list = [
    'cslg', 'cscv', 'csai', 'csit', 'cscl',
    'cscr', 'csds', 'cssy', 'csni', 'csro',
    'csdc', 'cssi', 'cslo', 'csna', 'cscy',
    'csdm', 'csir', 'csne', 'csse', 'cscc',
    'cshc', 'csgt', 'cssd', 'csdb', 'cscg',
    'cspl', 'csma', 'csce', 'csfl', 'csdl',
    'csmm', 'csgr', 'cspf', 'csar', 'cset', 
    'csoh', 'csms', 'cssc', 'csos', 'csgl'
]

In [6]:
from datetime import datetime
from xai_process import get_smp_shap_paper_count_from_log, get_smp_shap_time_from_log
pd.set_option('display.expand_frame_repr', False)

log_metrics_arr = []

for sample_name in ds_list:
    exp_name = f'exp-{sample_name}'
    exp_path = os.path.join('.', f'pipelining/{exp_name}')
    
    log_path = os.path.join(exp_path, 'log')
    
    log_files_name = os.listdir(log_path)
    
    # get shap time
    shap_time_sec = get_smp_shap_time_from_log(exp_name, [sample_name])[sample_name]
    # get shap paper count
    shap_paper_count = get_smp_shap_paper_count_from_log(exp_name, [sample_name])[sample_name]

    # get masking paper count and time
    masking_calls_log_files = [lf for lf in log_files_name if 'ranker_calls_masking' in lf]
    
    masking_paper_count = 0
    masking_time_sec = 0
    for masking_calls_file in masking_calls_log_files[-1:]:
        with open(os.path.join(log_path, masking_calls_file))as f:
            lines = [l.strip() for l in f.readlines() if l.strip() != '']
            start_line = None
            end_line = None
            for i in range(len(lines)):
                curr_line = lines[i].strip()
                if i > 0:
                    pre_line = lines[i - 1].strip()
                    if f'{exp_name}_{sample_name}' in pre_line and '=== end' in curr_line:
                        masking_paper_count += int(curr_line.replace('=', '').replace('end', '').replace(' ', ''))
                if f'{exp_name}_{sample_name}' in curr_line:
                    if start_line == None:
                        start_line = curr_line
                    end_line = curr_line
            
            # print(log_path, masking_calls_file)
            start_time = datetime.strptime(start_line[1:start_line.index(']')], '%m/%d/%Y, %H:%M:%S')
            end_time = datetime.strptime(end_line[1:end_line.index(']')], '%m/%d/%Y, %H:%M:%S')
            
            masking_time_sec = (end_time - start_time).total_seconds()
    
    log_metrics_arr.append([
        sample_name,
        masking_paper_count, round(masking_time_sec),
        shap_paper_count, round(shap_time_sec),
    ])
    # break

ms = ['sv paper', 'sv time', 'shap paper', 'shap time']

log_metrics_pd = pd.DataFrame(columns=['dataset', *ms], data=log_metrics_arr, index=range(1, len(ds_list) + 1))

log_metrics_pd.to_csv('shap_sv_time.csv')

log_metrics_pd
# print(log_metrics_pd[['dataset', *[s for s in ms if 'paper' in s]]])

Unnamed: 0,dataset,sv paper,sv time,shap paper,shap time
1,cslg,5948032,2849,6226847,2858
2,cscv,4129984,2052,4323578,2074
3,csai,2518016,1146,2636049,1197
4,csit,2299904,1139,2407713,1176
5,cscl,1968000,772,2060251,920
6,cscr,1239936,456,1298059,549
7,csds,1217088,429,1274140,545
8,cssy,1139968,607,1193405,584
9,csni,1051008,467,1100275,506
10,csro,1029632,366,1077897,462


In [22]:
xai_metrics_arr = []

data_len_of = []

mp = {
    'sv': {},
    'shap': {},
}

def aggregation(arr):
    # return np.std(arr, ddof=1)
    # return np.mean(np.abs(arr))
    return pdp_based_importance(arr)

def aggregation2(arr):
    # return np.std(arr, ddof=1)
    return np.mean(np.abs(arr))
    # return pdp_based_importance(arr)

sv_df_data = []
shap_df_data = []

for sample_name in ds_list:
    # print(sample_name)
    exp_name = f'exp-{sample_name}'
    exp_path = os.path.join('.', f'pipelining/{exp_name}')
    sv = shapley_value.compute_shapley_value(exp_name, sample_name)
    
    data_len = sv.shape[0]
    
    data_len_of.append(data_len)

    sv_fi = [
        aggregation2(sv[col])
        for col in sv.columns
    ]
    
    # print(sv_fi)
    
    mp['sv'][sample_name] = {
        'title': sv_fi[0],
        'abstract': sv_fi[1],
        'venue': sv_fi[2],
        'authors': sv_fi[3],
        'year': sv_fi[4],
        'cita': sv_fi[5],
    }
    
    sv_df_data.append([sample_name, *sv_fi])
    
    shap_fi = []
    shap_sv = get_smp_shap_data(exp_name)[sample_name]['shap_sv']
                
    shap_sv = np.flipud(np.rot90(np.array(shap_sv)))
    for feature_sv in shap_sv:
        shap_fi.append(aggregation2(feature_sv))
           
    mp['shap'][sample_name] = {
        'title': shap_fi[0],
        'abstract': shap_fi[1],
        'venue': shap_fi[2],
        'authors': shap_fi[3],
        'year': shap_fi[4],
        'cita': shap_fi[5],
    }
    shap_df_data.append([sample_name, *shap_fi])
    

In [24]:
sv_pd = pd.DataFrame(columns=['dataset', 'title', 'abstract', 'venue', 'authors', 'year', 'n_citations'], data=sv_df_data)

sv_pd.to_csv('sv.csv')

sv_pd

Unnamed: 0,dataset,title,abstract,venue,authors,year,n_citations
0,cslg,3.194204,7.441923,0.186365,0.090231,1.407588,0.16609
1,cscv,0.238091,2.572076,3.164207,0.101029,1.348643,0.218324
2,csai,2.318256,5.449311,0.091314,0.092025,1.26285,0.173482
3,csit,0.745772,3.831369,4.642609,0.114755,0.793315,0.256647
4,cscl,1.515217,6.904959,0.138453,0.100864,1.299759,0.205863
5,cscr,1.688565,7.540762,1.067041,0.093934,1.096295,0.192501
6,csds,2.051777,8.100917,0.37524,0.093795,0.856648,0.206577
7,cssy,3.678797,7.450833,2.562602,0.088735,1.260261,0.139943
8,csni,3.298869,6.737578,0.639195,0.096662,0.85909,0.183971
9,csro,0.408608,2.221882,5.071916,0.101068,1.368311,0.165135


In [25]:
shap_pd = pd.DataFrame(columns=['dataset', 'title', 'abstract', 'venue', 'authors', 'year', 'n_citations'], data=shap)

shap_pd.to_csv('shap.csv')

shap_pd

Unnamed: 0,dataset,title,abstract,venue,authors,year,n_citations
0,cslg,2.511312,6.143623,0.154023,0.077296,1.20664,0.148921
1,cscv,0.188319,2.163166,2.776231,0.080743,1.09987,0.166331
2,csai,1.818835,4.496852,0.074846,0.079123,1.082479,0.156463
3,csit,0.566912,3.160956,3.958713,0.096927,0.655974,0.189663
4,cscl,1.184304,5.796074,0.112086,0.088482,1.114594,0.183957
5,cscr,1.303017,6.293317,0.864929,0.096472,0.929483,0.15934
6,csds,1.63177,6.808143,0.294983,0.095651,0.726389,0.173355
7,cssy,2.836428,6.037388,2.078557,0.089216,1.049239,0.115
8,csni,2.628898,5.567021,0.518906,0.084557,0.73876,0.164374
9,csro,0.311762,1.835492,4.295107,0.101237,1.163559,0.136861
