# Corrplot

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import warnings 
import matplotlib.patches as patches
from datetime import datetime
plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 600

warnings.filterwarnings('ignore')

In [2]:
newest_file = #insert newest file name

big_table = pd.read_csv(f"../experiments/combined_base/{newest_file}.csv", index_col=0)
big_table

Unnamed: 0,dataset,measures,reranking,BPR,ItemKNN,MultiVAE,NCL
0,Lastfm,HR,-,0.772876,0.764706,0.778322,0.793028
1,Lastfm,HR,CM,0.587146,0.580610,0.523420,0.571351
2,Lastfm,MRR,-,0.491556,0.483544,0.476367,0.502564
3,Lastfm,MRR,CM,0.280363,0.270167,0.231551,0.259868
4,Lastfm,P,-,0.178159,0.172059,0.175763,0.184041
...,...,...,...,...,...,...,...
203,ML-10M,II-F_ori,CM,0.000357,0.000355,0.000356,0.000357
204,ML-10M,II-F_our,-,0.836064,0.846363,0.877612,0.827934
205,ML-10M,II-F_our,CM,0.908378,0.885336,0.898438,0.899725
206,ML-10M,AI-F_ori,-,0.000008,0.000023,0.000072,0.000013


In [3]:
big_table.measures = big_table.measures\
                                            .str.replace("IFD_mul$", "IFD_mul_ori", regex=True)\
                                            .str.replace("IFD_div$", "IFD_div_ori", regex=True)
                                            

In [4]:
original_order = big_table.measures.unique()
original_order

array(['HR', 'MRR', 'P', 'MAP', 'R', 'NDCG', 'Jain_our', 'QF_our',
       'Ent_our', 'FSat_our', 'Gini_our', 'IBO_ori', 'IBO_our', 'IWO_ori',
       'IWO_our', 'IAA_true_ori', 'IAA_our', 'IFD_div_ori', 'IFD_div_our',
       'IFD_mul_ori', 'IFD_mul_our', 'HD', 'MME_ori', 'II-F_ori',
       'II-F_our', 'AI-F_ori'], dtype=object)

In [5]:
temp = big_table\
            .set_index(["dataset","measures","reranking"])\
            .unstack(2)\
            .reindex(original_order, level=1)

In [6]:
temp.columns = temp.columns.to_flat_index()

In [7]:
temp.columns = [f"{model}-{rerank}" if rerank != "-" else model for (model, rerank) in temp.columns]

## Full corr

In [8]:
def clean_name(col):
    col = col.str.strip("@10")
    col = col\
            .str.replace("Gini_our","Gini", regex=False)\
            .str.replace("Jain_our","Jain", regex=False)\
            .str.replace("FSat_our","FSat", regex=False)\
            .str.replace("QF_our","QF", regex=False)\
            .str.replace("Ent_our","Ent", regex=False)\
            .str.replace("HD","HD_ori", regex=False)\
            .str.replace("_div$","_div$", regex=False)\
            .str.replace("_div","$_{\div}$", regex=False)\
            .str.replace("_mul","$_{\\times}$", regex=False)\
            .str.replace("_our","$_{our}$", regex=False)\
            .str.replace("_ori","$_{ori}$", regex=False)\
            .str.replace("_true","")


    return col

In [9]:
pattern = "Gini|IAA|II|AI|IFD|MME|IWO|HD"

### Grid layout 2x2

In [10]:
from scipy import stats
from statsmodels.stats.multitest import multipletests

def significance_test(forcorr, mask,alpha=0.05):
    df_mask = pd.DataFrame(mask)
    # df_mask = pd.DataFrame(np.triu(np.ones_like(forcorr.corr("kendall").round(2), dtype=np.bool)))

    corr_sig = pd.DataFrame(index=forcorr.index, columns=forcorr.columns)

    for i, col1 in enumerate(forcorr.index):
        for j, col2 in enumerate(forcorr.columns):
            corr, pval = stats.kendalltau(forcorr[col1], forcorr[col2])
            corr_sig.loc[col1,col2] = corr, pval
    

    assert all(forcorr == corr_sig.applymap(lambda x: x[0]).round(2))

    for i, col1 in enumerate(forcorr.index):
        for j, col2 in enumerate(forcorr.columns):        
            if df_mask.iloc[i,j]:
                corr_sig.loc[col1,col2] = np.nan
            else:
                corr_sig.loc[col1,col2] = corr_sig.loc[col1,col2][1] 


    uncorrected_p_values = corr_sig.copy()

    flattened_p_val = uncorrected_p_values.values.flatten().astype("float64")
    filtered_p_val = flattened_p_val[~np.isnan(flattened_p_val)]
    res_bonf = multipletests(filtered_p_val, alpha=alpha, method='bonferroni')
    res_holm = multipletests(filtered_p_val, alpha=alpha, method='holm')
    res_fdr = multipletests(filtered_p_val, alpha=alpha, method='fdr_bh')

    return res_bonf, res_holm, res_fdr

In [None]:
list_dataset = big_table.dataset.unique()
fig, ax = plt.subplots(nrows=2,ncols=2,
                       figsize=(20,14)
                       )
cbar_ax = fig.add_axes([.175, -.03, .7, .025])
cbar_ax.tick_params(labelsize=10)
for i, ax_id, data in zip(range(len(list_dataset)),ax.flatten(), list_dataset):

    forcorr = temp.loc[data].dropna().T
    forcorr.columns = clean_name(forcorr.columns)

    forcorr.loc[:,forcorr.columns.str.contains(pattern,regex=True)] = forcorr.loc[:,forcorr.columns.str.contains(pattern,regex=True)].apply(lambda x: -x)

    corr_tab = forcorr.corr("kendall").round(2)

    corr_tab = corr_tab.loc["IBO$_{our}$":,:]
    unjoint_mask = np.zeros_like(corr_tab.loc[:,:"Gini"], dtype=np.bool)
    joint_mask = np.triu(np.ones_like(corr_tab.loc[:,"IBO$_{our}$":], dtype=np.bool))
    mask = np.concatenate((unjoint_mask,joint_mask), axis=1)

    hm = sns.heatmap(corr_tab, annot=True, square=True, vmin=-1, vmax=1, cmap="coolwarm_r", annot_kws={"size": 9.5}, ax=ax_id, mask=mask, cbar=i == 0,
                cbar_ax=None if i else cbar_ax, cbar_kws={"orientation": "horizontal"})

    print(data)

    bon, holm, hoch = significance_test(corr_tab, mask=mask)

    for res, name in zip([bon], ["Bonferroni"]):
        sig_matrix = res[0]
        p_val = res[1]

        if any(sig_matrix):
            print(f"[{name}] There is some significant result(s) for {data}!")

            if name == "Bonferroni":
                indicator = ["$^*$" if indicator else "" for indicator in sig_matrix.tolist() ]
            
            for t, ind in zip(hm.texts, indicator): t.set_text(t.get_text() + ind)

        else:
            print(f"[{name}] No significant result for {data}")

    ax_id.set_yticklabels(ax_id.get_yticklabels(), rotation=0)

    ax_id.set_title(f"{data}")

    ax_id.tick_params(axis='both', which='major', labelsize=14)
    
    args = {"clip_on":False,"size":10}

    far_left = 0
    far_down = 9.75 + 6

    text_down = far_down + 0.5

    #REL-JOINT
    x = [far_left,6,6,far_left]
    y = [0,0,far_down,far_down]

    dict_args = dict(fill=False, linewidth=0.9, clip_on=False)

    ax_id.add_patch(patches.Polygon(xy=list(zip(x,y)),**dict_args))

    x = [6,11,11,6]
    ax_id.add_patch(patches.Polygon(xy=list(zip(x,y)),**dict_args))
    
    x = [11,24,24,11]
    ax_id.add_patch(patches.Polygon(xy=list(zip(x,y)),**dict_args))


    ax_id.text(0.5*(far_left+6), text_down, 'EFF',
        horizontalalignment='center',
        verticalalignment='bottom',
        **args)
    
    ax_id.text(0.5*(6+11), text_down, 'FAIR',
        horizontalalignment='center',
        verticalalignment='bottom',
        **args)
    
    ax_id.text(0.5*(11+24), text_down, 'JOINT',
        horizontalalignment='center',
        verticalalignment='bottom',
        **args)

    ax_id.set_xticklabels(ax_id.get_xticklabels(), rotation=90)
    ax_id.set_xlabel(None)
    ax_id.set_ylabel("JOINT", fontsize=10)

plt.tight_layout()
now = datetime.now()
time = str(now.strftime("%Y-%m-%d_%H%M%S"))
plt.savefig(f'corr/temp/temp_{time}_selected_rectangular_corr_plot_grid.pdf', bbox_inches='tight')

plt.show()