In [None]:
import pandas as pd
import glob

import matplotlib.pyplot as plt
import seaborn as sns
import warnings 

plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 600

warnings.filterwarnings('ignore')

In [None]:
list_data = ["ML-1M", "JobRec", "LFM-1B"]
pat_p = "-P| P"
pat_ndcg = "NDCG"

In [None]:
cand = glob.glob("table/big_table*")
cand = list(reversed(sorted(cand)))[0]
cand

In [None]:
df = pd.read_csv(cand, index_col=[0,1,2])
df = df.droplevel(level=1)
df

# Individual vs all Grp

In [None]:
from evaluate import Utils

utils = Utils()
fig, ax = plt.subplots(nrows=len(list_data), ncols=1,
                       figsize=(6,6)
                       )


cbar_ax = fig.add_axes([0.9, 0.15, .015, .7])
cbar_ax.tick_params(labelsize=8)

hm_kws = dict(annot=True, square=True, vmin=-1, vmax=1, 
                   cmap="coolwarm_r", annot_kws={"size": 10},  
                cbar_kws={"orientation": "vertical"})

for i, (ax_id, data) in enumerate(zip(ax, list_data)):
    this_data = df.loc[data]

    this_data_T = this_data.T
    down_cols = this_data_T.columns[this_data_T.columns.str.contains("down")]
    this_data_T[down_cols] = this_data_T[down_cols].applymap(lambda x: -x)

    this_data_T = this_data_T[this_data_T.columns[~this_data_T.columns.str.contains("Dec")]]

    only_NDCG_columns = this_data_T.columns[this_data_T.columns.str.contains(pat_ndcg)]
    
    for_corr_ndcg = this_data_T[only_NDCG_columns]
    for_corr_ndcg.columns = for_corr_ndcg.columns.str.replace("\$.*\$ ","", regex=True)
    for_corr_ndcg.columns = for_corr_ndcg.columns.str.replace("\-NDCG","", regex=True)

    corr_tab_ndcg = for_corr_ndcg.corr("kendall").round(2)
    corr_tab_ndcg = corr_tab_ndcg.loc["SD-Ind":"Atk-Ind", "Min":"GCE"]
    
    hm_NDCG = sns.heatmap(corr_tab_ndcg, ax=ax_id,   
                          cbar=i == 0,
                        cbar_ax=None if i else cbar_ax,**hm_kws)

    new_labels = []
    for x in ax_id.get_yticklabels():
        the_text = x.get_text()
        x.set_text(the_text.replace("-Ind",""))
        new_labels.append(x)

    ax_id.set_yticklabels(new_labels, rotation=0)
    ax_id.set_ylabel("$Individual\ fairness$", rotation=90)
    ax_id.set_xlabel("$Group\ fairness$", rotation=0)
    ax_id.set_title(f"{data}", fontsize=14)

    
plt.tight_layout(w_pad=0.1)
timenow = utils.timenow()
plt.savefig(f'corr/temp_corr_ind-group_LLM_{timenow}.pdf', bbox_inches='tight')
plt.show()
