In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from evaluate import Utils, Fairness, compute_group_scores
import pandas as pd

from glob import glob

from collections import OrderedDict

import numpy as np
import re

import matplotlib.pyplot as plt

In [3]:
plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 600

list_data = ["ml-1m", "jobrec", "lfm-1b"]

map_columns = {
                #datasets
                "ml-1m":"ML-1M",
               "jobrec":"JobRec",
               "lfm-1b":"LFM-1B",

                #models
               'meta-llama_neutral':"Llama-3.1-8B (NS)",
                'meta-llama_sensitive':"Llama-3.1-8B (S)",
                'mistralai_neutral':"Ministral-8B (NS)",
                'mistralai_sensitive':"Ministral-8B (S)",
                'Qwen_neutral': "Qwen2.5-7B (NS)",
                'Qwen_sensitive': "Qwen2.5-7B (S)",
                'THUDM_neutral': "GLM-4-9B (NS)",
                'THUDM_sensitive': "GLM-4-9B (S)", 
               }


utils = Utils()
fairness = Fairness()

# Base eval

In [None]:
k=10

In [None]:
results = {}

for data in list_data:
    df_user = utils.load_df_user(data)
    sensitive_cols = df_user.columns[1:].to_list()

    df_user = df_user.sort_values("user_id")
    df_user.reset_index(drop=True, inplace=True)

    per_user_results = glob(f"../results_llm/{data}*-result*")

    for res in per_user_results:
        splitted_name = res.split("\\")[1]
        _, model, prompt_type, _ = splitted_name.split("_")
        model = f"{model}_{prompt_type}"

        print(f"Doing {data}, {model}")
        res = pd.read_pickle(res)

        df_res = pd.DataFrame(res)
        df_res.drop(columns=["most_sim_items"], inplace=True)

        df_user = pd.merge(df_user, df_res, left_index=True, right_index=True)


        df_user = df_user.rename(columns={
                                        "hit@k":"HR",
                                          "precision@k":"P",
                                          "ndcg@k":"NDCG"
                                          })


        df_user["MRR"] = df_user.rel.apply(lambda x: 1 / (x.index(1)+1) if 1 in x else 0)

        df_user.drop(columns="rel", inplace=True)

        # commented out to prefer overwriting the file
        # df_user.to_csv(f"per_user_score/{data}_{model}.csv", index=False)


        if data not in results:
            results[data] = OrderedDict()
            results[data][model] = OrderedDict()

        if model not in results[data]:
            results[data][model] = OrderedDict()

        
        results[data][model]["HR"] = df_user.HR.mean()
        results[data][model]["MRR"] = df_user.MRR.mean()
        results[data][model]["P"] = df_user.P.mean()
        results[data][model]["NDCG"] = df_user.NDCG.mean()

        for base_score in ["P", "NDCG"]:

            # === GROUP FAIRNESS ===

            #1 attrib at a time
            for col in sensitive_cols: 
                selected_cols = [col, base_score]

                cleaned_col = utils.clean_col(col)

                compute_group_scores(results, data, model, df_user, selected_cols, [col], base_score, agg_type=cleaned_col+"-", fairness=fairness)

            for i, col1 in enumerate(sensitive_cols[:-1]):
                for j, col2 in enumerate(sensitive_cols[i+1:]):
                    if col1==col2:
                        continue
                    selected_cols = [col1, col2, base_score]
                    selected_sens = [col1, col2]

                    cleaned_cols = [utils.clean_col(col) for col in selected_sens]
                    agg_type = "-".join(cleaned_cols)
                    compute_group_scores(results, data, model, df_user, selected_cols, selected_sens, base_score, agg_type=agg_type+"-", fairness=fairness)


            selected_cols = sensitive_cols+[base_score]
            per_group_score, per_group_count, atk_b, atk_within = compute_group_scores(results, data, model, df_user, selected_cols, sensitive_cols, base_score, fairness=fairness)


            # === INDIVIDUAL FAIRNESS ===
            results[data][model][f"SD-Ind-{base_score}"] = fairness.score_std(df_user[base_score])
            results[data][model][f"Gini-Ind-{base_score}"] = fairness.gini(df_user[base_score])
            results[data][model][f"Atk-Ind-{base_score}"]= fairness.atk(df_user[base_score])
            results[data][model][f"DecAtk-Ind-{base_score}"]= atk_within + atk_b - atk_within*atk_b

        df_user.drop(columns=["HR", "MRR", "NDCG", "P"], inplace=True)

Doing ml-1m, meta-llama_neutral
Doing ml-1m, meta-llama_sensitive
Doing ml-1m, mistralai_neutral
Doing ml-1m, mistralai_sensitive
Doing ml-1m, Qwen_neutral
Doing ml-1m, Qwen_sensitive
Doing ml-1m, THUDM_neutral
Doing ml-1m, THUDM_sensitive
Doing jobrec, meta-llama_neutral
Doing jobrec, meta-llama_sensitive
Doing jobrec, mistralai_neutral
Doing jobrec, mistralai_sensitive
Doing jobrec, Qwen_neutral
Doing jobrec, Qwen_sensitive
Doing jobrec, THUDM_neutral
Doing jobrec, THUDM_sensitive
Doing lfm-1b, meta-llama_neutral
Doing lfm-1b, meta-llama_sensitive
Doing lfm-1b, mistralai_neutral
Doing lfm-1b, mistralai_sensitive
Doing lfm-1b, Qwen_neutral
Doing lfm-1b, Qwen_sensitive
Doing lfm-1b, THUDM_neutral
Doing lfm-1b, THUDM_sensitive


# Big table LLM model

In [6]:
df_results = pd.DataFrame(pd.Series(utils.flatten_dict(results)))

In [7]:
measure_order = ['HR', 'MRR', 'P', 'NDCG', 
                 'Min-P', 'Min-NDCG',
                 'Range-P',  'Range-NDCG',
                 'SD-P',   'SD-NDCG', 
                 'MAD-P', 'MAD-NDCG',
                 'Gini-P', 'Gini-NDCG', 
                'Atk-P', 'Atk-NDCG',
                 'CV-P',  'CV-NDCG',
                'FStat-P',  'FStat-NDCG',
                'KL-P', 'KL-NDCG',
                'GCE-P', 'GCE-NDCG',
                'SD-Ind-P', 'SD-Ind-NDCG',
                'Gini-Ind-P', 'Gini-Ind-NDCG',
                'Atk-Ind-P', 'Atk-Ind-NDCG',
                'DecAtk-Ind-P', 'DecAtk-Ind-NDCG',
  ]

In [8]:
def rotate_index(indices):
    indices = "\\rotatebox[origin=c]{90}{" + indices + "}"
    return indices

def add_arrows_and_sort(big_table):

    lower_is_better = "Range|SD|Gini|CV|FStat|Atk|MAD|KL|GCE"
    #higher_better = hr, p, ndcg, worst

    measure_name = big_table.index.get_level_values(2)
    mask_lower = measure_name.str.contains(lower_is_better)
    true_order = pd.concat([
                        pd.Series(np.where(~mask_lower)[0]),
                        pd.Series(np.where(mask_lower)[0]),
                        ])
    measure_with_arrow = pd.concat([
                    pd.Series("$\\uparrow$ " + measure_name[~mask_lower]),
                    pd.Series("$\downarrow$ " + measure_name[mask_lower]),
                    ])
    helper = pd.DataFrame([true_order.reset_index(drop=True),measure_with_arrow.reset_index(drop=True)], index=["true_order", "measure"]).T.sort_values("true_order")
    big_table = big_table.set_index([big_table.index.get_level_values(0),
                                     big_table.index.get_level_values(1),
                                     helper.measure.reset_index(drop=True)])
    big_table.columns.name = None

    return big_table


def get_measure_type(index):
    pattern = ["Ind", " HR| P| NDCG| MRR"]
    index[(~index.str.contains(pattern[0], regex=True)) & (~index.str.contains(pattern[1], regex=True))] = "\\textsc{Fair (Grp.)}"
    index[index.str.contains(pattern[0], regex=True)] = "\\textsc{Fair (Ind.)}"
    index[index.str.contains(pattern[1], regex=True)] = "\\textsc{Eff}"
    
    return index.values

def add_measure_type_sort_col(big_table):

    index_series = big_table.index\
                                .get_level_values(2)\
                                .to_series()
                            
    big_table["measure_type"] = get_measure_type(index_series)
    big_table["measure_type"] = big_table["measure_type"].apply(rotate_index)
    big_table = big_table\
                        .set_index(["measure_type"], append=True)\
                        .reorder_levels([1, 0, 3, 2])
    big_table.index.names = [None, None, None, None]

    return big_table

In [None]:
# consists of different ways of grouping (only 1 attribute, 2 attributes, and all 3 attributes)
big_table_full = add_arrows_and_sort(df_results)
big_table_full = big_table_full\
                        .unstack([0])\
                        .droplevel(0, axis=1)

#save here
timenow = utils.timenow()
big_table_full.to_csv(f"multiple_groups/multiple_groups_LLM_{timenow}.csv")

In [None]:
sorted_df_results = df_results.reindex(measure_order, level=2)
df_result_arrow = add_arrows_and_sort(sorted_df_results)
df_result_arrow

In [11]:
df_result_measure_type = add_measure_type_sort_col(df_result_arrow)
df_result_measure_type

Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,0
ml-1m,meta-llama_neutral,\rotatebox[origin=c]{90}{\textsc{Eff}},$\uparrow$ HR,0.259677
ml-1m,meta-llama_neutral,\rotatebox[origin=c]{90}{\textsc{Eff}},$\uparrow$ MRR,0.101489
ml-1m,meta-llama_neutral,\rotatebox[origin=c]{90}{\textsc{Eff}},$\uparrow$ P,0.045968
ml-1m,meta-llama_neutral,\rotatebox[origin=c]{90}{\textsc{Eff}},$\uparrow$ NDCG,0.139983
ml-1m,meta-llama_neutral,\rotatebox[origin=c]{90}{\textsc{Fair (Grp.)}},$\uparrow$ Min-P,0.030357
...,...,...,...,...
lfm-1b,THUDM_sensitive,\rotatebox[origin=c]{90}{\textsc{Fair (Ind.)}},$\downarrow$ Gini-Ind-NDCG,0.460593
lfm-1b,THUDM_sensitive,\rotatebox[origin=c]{90}{\textsc{Fair (Ind.)}},$\downarrow$ Atk-Ind-P,0.411271
lfm-1b,THUDM_sensitive,\rotatebox[origin=c]{90}{\textsc{Fair (Ind.)}},$\downarrow$ Atk-Ind-NDCG,0.357888
lfm-1b,THUDM_sensitive,\rotatebox[origin=c]{90}{\textsc{Fair (Ind.)}},$\downarrow$ DecAtk-Ind-P,0.411271


In [None]:
order = df_result_arrow.index.get_level_values(2).unique()

In [13]:
df_result_measure_type.index.get_level_values(1).unique().to_list()

['meta-llama_neutral',
 'meta-llama_sensitive',
 'mistralai_neutral',
 'mistralai_sensitive',
 'Qwen_neutral',
 'Qwen_sensitive',
 'THUDM_neutral',
 'THUDM_sensitive']

In [14]:
#save results here, so we can do corr analysis in another notebook

big_table_for_save =  df_result_measure_type\
                .unstack([1])\
                .droplevel(0, axis=1)\
                .loc[["ml-1m", "jobrec", "lfm-1b"]]\
                .rename(columns=map_columns, index=map_columns)\
                .reindex(order, level=2)

# timenow = utils.timenow()
# big_table_for_save.to_csv(f"table/big_table_LLM_{timenow}.csv")

In [15]:

big_table = df_result_measure_type\
                .unstack([0,1])\
                .droplevel(0, axis=1)\
                .loc[:,["ml-1m", "jobrec", "lfm-1b"]]\
                .rename(columns=map_columns)\
                .reindex(order, level=1)\
                .round(3)

In [16]:
big_table.index = big_table.index.set_levels(big_table.index.levels[1].str.replace("Ind-","\-"), level=1)

Separate into P and NDCG table

In [17]:
to_separate = big_table.reset_index(level=1)

In [18]:
big_table_p = to_separate.loc[to_separate["level_1"].str.contains("-P| HR| MRR| P")]
big_table_ndcg = to_separate.loc[to_separate["level_1"].str.contains("-NDCG| HR| MRR| NDCG")]

In [19]:
big_table_p = big_table_p.set_index("level_1", append=True)
big_table_ndcg = big_table_ndcg.set_index("level_1", append=True)

In [20]:
big_table_p.index.names = [None, None]
big_table_ndcg.index.names = [None, None]

In [21]:
order_p = df_result_arrow.index.get_level_values(2)[df_result_arrow.index.get_level_values(2).str.contains(" HR| MRR| P|-P")].str.replace("Ind-","\-").unique()
order_ndcg = df_result_arrow.index.get_level_values(2)[df_result_arrow.index.get_level_values(2).str.contains(" HR| MRR| NDCG|-NDCG")].str.replace("Ind-","\-").unique()

In [22]:
# model type and prompt type
def separate_model_and_prompt(df, order):
    df = df.stack(1)


    new_index = np.asarray(
                        df.index\
                            .get_level_values(2)\
                            .str.replace("[()]","", regex=True)\
                            .str.split(" ")\
                            .tolist())\
                            .T

    df = df.set_index(pd.MultiIndex.from_arrays(new_index), append=True)
    df = df.droplevel(2)
    df = df.unstack([2,3])
    df = df.reindex(order, level=1)
    return df

In [23]:
big_table_p = separate_model_and_prompt(big_table_p, order_p)
big_table_ndcg = separate_model_and_prompt(big_table_ndcg, order_ndcg)

In [37]:
import seaborn as sns

list_data_proper = ["ML-1M", "JobRec", "LFM-1B"]

col = sns.light_palette("seagreen").as_hex()[4]

def highlight_max(x):
    return np.where(x == np.nanmax(x.to_numpy()), f"font-weight: bold;", None)

def highlight_min(x):
    return np.where(x == np.nanmin(x.to_numpy()), f"font-weight: bold;", None)

def add_cdashline(measure, latex_code, df):
    end_index = []
    for el in re.finditer(f"{measure}.*\\n",latex_code):
        end_index.append(el.end())
    j = df[["ML-1M"]].shape[1] + 2
    for idx in reversed(end_index):
        latex_code= latex_code[:idx] + "\\cdashline{2-" + str(j) + "}\n" + latex_code[idx:]
    return latex_code


def nicetable(df, color=False):

    the_index = df.index.levels[1]

    row_with_up = the_index[the_index.str.contains("uparrow")]
    row_with_down = the_index[the_index.str.contains("downarrow")]

    idx = pd.IndexSlice

    for data in df.columns.levels[0]:

        styler = df[[data]].style
        styler.format(formatter="{:.3f}")

        slice_max = idx[idx[:, row_with_up], :]
        slice_min = idx[idx[:,row_with_down],:]

        styler\
            .apply(highlight_max, axis=1, subset=slice_max)\
            .apply(highlight_min, axis=1, subset=slice_min)

        if color:
            cm = sns.light_palette(col, as_cmap=True)
            cm_r = sns.light_palette(col, reverse=True, as_cmap=True)
            styler\
                .background_gradient(cmap=cm, axis=1,  subset=slice_max)\
                .background_gradient(cmap=cm_r, axis=1, subset=slice_min)\

        latex_code = styler.to_latex(
            hrules=True, 
            clines="skip-last;data",
            convert_css=True, 
            column_format = "ll*{2}{r}|*{2}{r}|*{2}{r}|*{2}{r}",
            multicol_align = "c|"
            )

        # erase last cline
        last_cline_starts = latex_code.find("\\cline", -100,-1)
        last_cline_ends = latex_code.find("\\bottomrule")
        latex_code = latex_code[:last_cline_starts] + latex_code[last_cline_ends:]

        latex_code = add_cdashline("Atk\-(P|N)", latex_code, df)
        
        latex_code = latex_code.replace("\\begin{tabular}","\\resizebox{0.98\columnwidth}{!}{\n\\begin{tabular}") #add resize box
        latex_code = latex_code.replace("{c|}{"+data+"} \\\\", "{c}{"+data+"} \\\\ \n\midrule") #add midrule after k and get rid of last |
        latex_code = latex_code.replace("{c|}{Qwen2.5-7B} \\\\", "{c}{Qwen2.5-7B} \\\\ \n\midrule") #add midrule after k and get rid of last |
        latex_code = latex_code.replace("\end{tabular}","\end{tabular}}") #add } as part of resize box
        latex_code = latex_code.replace("\t","\\t")

        latex_code = latex_code.replace("-P","")
        latex_code = latex_code.replace("-NDCG","")
        latex_code = latex_code.replace("-\\","")
        latex_code = latex_code.replace("\color[HTML]{F1F1F1}","")
        latex_code = latex_code.replace("\color[HTML]{000000}","")

        print(latex_code)

In [None]:
nicetable(big_table_p[list_data_proper], color=True)
nicetable(big_table_ndcg[list_data_proper], color=True)

# Multiple ways of grouping table

In [63]:
from evaluate import measure_type_multiple_group, fill_attr, print_group_table

In [None]:
import pandas as pd
list_data_proper = ["ML-1M", "JobRec", "LFM-1B"]
latest_file = sorted(glob("multiple_groups/multiple_groups_LLM_*.csv"))[-1]
# Load the latest file
print(f"Loading latest file: {latest_file}")
big_table_full = pd.read_csv(latest_file, index_col=[0,1])
big_table_full

In [65]:
selected_model = "THUDM_neutral"
selected_agg = "SD|Gini|Atk"
selected_base = "-NDCG"
excluded = "Dec" #decomposed Atk

df_group = big_table_full[selected_model].reset_index().rename(columns={"level_0":"dataset"})
df_group = df_group[df_group.measure.str.contains(selected_agg)]
df_group = df_group[df_group.measure.str.contains(selected_base)]
df_group = df_group[~df_group.measure.str.contains(excluded)]
df_group

Unnamed: 0,dataset,measure,THUDM_neutral
0,jobrec,$\downarrow$ Atk-Degree-Experience-NDCG,0.163747
2,jobrec,$\downarrow$ Atk-Degree-Major-NDCG,0.253214
4,jobrec,$\downarrow$ Atk-Degree-NDCG,0.007456
6,jobrec,$\downarrow$ Atk-Experience-Major-NDCG,0.412942
8,jobrec,$\downarrow$ Atk-Experience-NDCG,0.026185
...,...,...,...
554,ml-1m,$\downarrow$ SD-Within-Gender-Age-NDCG,0.329302
556,ml-1m,$\downarrow$ SD-Within-Gender-NDCG,0.329577
557,ml-1m,$\downarrow$ SD-Within-Gender-Occupation-NDCG,0.329850
560,ml-1m,$\downarrow$ SD-Within-NDCG,0.329207


In [None]:
df_group["measure_type"] = df_group.measure.apply(measure_type_multiple_group)
df_group

In [67]:
df_group["group_context"] = df_group["measure"].copy()

In [68]:
df_group["measure"] = df_group["measure"]\
                            .str.replace("-Ind|-NDCG","", regex=True)\
                            .str.replace("\-.*", "", regex=True)

In [69]:
df_group["group_context"] = df_group["group_context"]\
                                .str.replace("-Ind|-NDCG","", regex=True)\
                                .str.split("-")\
                                .apply(lambda x: [x for x in x if "arrow" not in x])\
                                .apply(lambda x: "-".join(x))

In [70]:
group_table = df_group\
                .set_index(["dataset", "measure_type","group_context","measure"])\
                .reindex([ "Grp (1)", "Grp (2)","Grp (3)", "Ind",], level=1)\
                .unstack(3)\
                .loc[list_data]

group_table.index.names = ["data", "Fair", "Attr."]
group_table = group_table.droplevel(level=0, axis=1)


In [None]:
group_table = group_table.reset_index()

In [None]:
group_table_no_within = group_table[~group_table["Attr."].str.contains("Within")]
group_table_within = group_table[(group_table["Attr."].str.contains("Within"))|(group_table["Fair"].str.contains("Ind"))]
group_table_no_within = fill_attr(group_table_no_within)
group_table_within = fill_attr(group_table_within)
printed_between = print_group_table(group_table_no_within)

In [None]:
group_table_within["Attr."] = group_table_within["Attr."].str.replace("Within-","")
printed_within = print_group_table(group_table_within)

## Plot multiple groups vs Individual fairness

In [79]:
from evaluate import prep_df_for_lineplot, plot_line

In [80]:
base_score = "NDCG"

In [81]:
def get_min_max_group_count():
    curr_min = 999999999999
    curr_max = 0
    
    for col in sensitive_cols: 
        selected_cols = [col, base_score]

        per_group_count = df_user[selected_cols]\
                                .groupby(col)\
                                .count()[base_score]
        
        per_group_count = per_group_count[per_group_count>0]

        num_group = per_group_count.shape[0]

        if num_group < curr_min:
            curr_min = num_group
        if num_group > curr_max:
            curr_max = num_group
    return curr_min, curr_max

def get_min_max_group_count_nested():
    curr_min = 999999999999
    curr_max = 0

    for i, col1 in enumerate(sensitive_cols[:-1]):
        for j, col2 in enumerate(sensitive_cols[i+1:]):
            if col1==col2:
                continue
            selected_cols = [col1, col2, base_score]
            selected_sens = [col1, col2]


            per_group_count = df_user[selected_cols]\
                                .groupby(selected_sens)\
                                .count()[base_score]
        
            per_group_count = per_group_count[per_group_count>0]

            num_group = per_group_count.shape[0]

            if num_group < curr_min:
                curr_min = num_group
            if num_group > curr_max:
                curr_max = num_group
    return curr_min, curr_max

def get_min_max_group_count_all():

    selected_cols = sensitive_cols+[base_score]


    per_group_count = df_user[selected_cols]\
                        .groupby(sensitive_cols)\
                        .count()[base_score]

    per_group_count = per_group_count[per_group_count>0]

    num_group = per_group_count.shape[0]

    return num_group

In [114]:
min_max_group = dict()

for data in list_data:
    print(data)
    min_max_group[data] = {}

    df_user = pd.read_csv(f"per_user_score/{data}_Qwen_neutral.csv")
    sensitive_cols = df_user.columns[1:4].to_list()

    #1 attrib at a time
    curr_min, curr_max = get_min_max_group_count()
    min_max_group[data]["Grp (1)"] = f"{curr_min}-{curr_max}"

    #2 attribs at a time
    curr_min, curr_max = get_min_max_group_count_nested()
    min_max_group[data]["Grp (2)"] = f"{curr_min}-{curr_max}"

    #3 attribs at a time
    grp_count = get_min_max_group_count_all()
    min_max_group[data]["Grp (3)"] = str(grp_count)

    min_max_group[data]["Ind"] = str(df_user.shape[0])


ml-1m
jobrec
lfm-1b


In [116]:
import pickle

In [None]:
# commented out to prefer overwriting the file
# with open(f"group_count_dict.pickle","wb") as f:
#     pickle.dump(min_max_group, f, pickle.HIGHEST_PROTOCOL)

In [118]:
min_max_group = pd.read_pickle("group_count_dict.pickle")

In [None]:
def map_to_group_count(df_group):

    df_group.loc[df_group.data=="ML-1M", "Fairness Type"] = df_group.loc[df_group.data=="ML-1M", "Fairness Type"].apply(lambda x: x+"\n"+min_max_group["ml-1m"][x])
    df_group.loc[df_group.data=="JobRec", "Fairness Type"] = df_group.loc[df_group.data=="JobRec", "Fairness Type"].apply(lambda x: x+"\n"+min_max_group["jobrec"][x])
    df_group.loc[df_group.data=="LFM-1B", "Fairness Type"] = df_group.loc[df_group.data=="LFM-1B", "Fairness Type"].apply(lambda x: x+"\n"+min_max_group["lfm-1b"][x])

    return df_group

In [None]:
save = False
# save = True

btw_group = prep_df_for_lineplot(printed_between)
wth_group = prep_df_for_lineplot(printed_within)

btw_group = map_to_group_count(btw_group)

btw_group = btw_group.rename(columns={"Fairness Type":"#groups"})
wth_group = wth_group.rename(columns={"Fairness Type":"#groups"})

plot_line(btw_group, exp_type="between_LLM", save=save)
plot_line(wth_group, exp_type="within_LLM")

In [None]:
base_score = "NDCG"
def get_group_count():
    res = dict()   
    for col in sensitive_cols: 
        selected_cols = [col, base_score]

        per_group_count = df_user[selected_cols]\
                                .groupby(col)\
                                .count()[base_score]
        
        per_group_count = per_group_count[per_group_count>0]

        num_group = per_group_count.shape[0]
        res[utils.clean_col(col)] = num_group

    return res

def get_group_count_nested():
    res = dict()   
    for i, col1 in enumerate(sensitive_cols[:-1]):
        for j, col2 in enumerate(sensitive_cols[i+1:]):
            if col1==col2:
                continue
            selected_cols = [col1, col2, base_score]
            selected_sens = [col1, col2]


            per_group_count = df_user[selected_cols]\
                                .groupby(selected_sens)\
                                .count()[base_score]
        
            per_group_count = per_group_count[per_group_count>0]

            num_group = per_group_count.shape[0]
            cleaned_col1 = utils.clean_col(col1)
            cleaned_col2 = utils.clean_col(col2)
            res[f"{cleaned_col1}-{cleaned_col2}"] = num_group

    return res

def get_group_count_all():

    selected_cols = sensitive_cols+[base_score]


    per_group_count = df_user[selected_cols]\
                        .groupby(sensitive_cols)\
                        .count()[base_score]

    per_group_count = per_group_count[per_group_count>0]

    num_group = per_group_count.shape[0]

    return num_group

group_count = dict()

for data in list_data:
    print(data)
    group_count[data] = {}

    # doesn't really matter which model we use here, as we only need the user sensitive attributes
    df_user = pd.read_csv(f"per_user_score/{data}_Qwen_neutral.csv")
    sensitive_cols = df_user.columns[1:4].to_list()

    #1 attrib at a time
    res_dict = get_group_count()
    group_count[data] = res_dict

    #2 attribs at a time
    res_dict = get_group_count_nested()
    group_count[data].update(res_dict)

    #3 attribs at a time
    grp_count = get_group_count_all()

    cleaned_sens_col = [utils.clean_col(col) for col in sensitive_cols]

    group_count[data]["-".join(cleaned_sens_col)] = grp_count

    group_count[data]["Ind"] = df_user.shape[0]


In [140]:
def map_to_group_count_fine(df_group):

    df_group["num"] = pd.Series()
    df_group.loc[df_group.data=="ML-1M", "num"] = df_group.loc[df_group.data=="ML-1M", "Attr."].apply(lambda x: group_count["ml-1m"][x] if x!="-" else group_count["ml-1m"]["Ind"])
    df_group.loc[df_group.data=="JobRec", "num"] = df_group.loc[df_group.data=="JobRec", "Attr."].apply(lambda x: group_count["jobrec"][x] if x!="-" else group_count["jobrec"]["Ind"])
    df_group.loc[df_group.data=="LFM-1B", "num"] = df_group.loc[df_group.data=="LFM-1B", "Attr."].apply(lambda x: group_count["lfm-1b"][x] if x!="-" else group_count["lfm-1b"]["Ind"])

    return df_group

In [141]:
import seaborn.objects as so
btw_group["component"] = "between"
wth_group["component"] =  "within"

btw_group = map_to_group_count_fine(btw_group)
wth_group = map_to_group_count_fine(wth_group)

In [142]:
def sort_and_add_name(df):
    map_data = {"ML-1M":1, "JobRec":2, "LFM-1B":3}
    df["data_idx"] = df["data"].map(map_data)

    df.sort_values(["data_idx","num"], inplace=True, kind="stable")
    df["Attr."] = df["Attr."] + "\n(" +df["num"].astype(str) + ")"
    df["Attr."] = df["Attr."].str.replace("-", "\n")

    return df

In [143]:
btw_group = sort_and_add_name(btw_group)
wth_group = sort_and_add_name(wth_group)

In [None]:
import seaborn as sns

In [146]:
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.ticker as ticker

In [None]:
for data in list_data_proper:
    selected_btw = btw_group.query("data==@data")
    selected_wth = wth_group.query("data==@data")
    btw_no_ind = selected_btw[~selected_btw["#groups"].str.contains("Ind")]
    wth_no_ind = selected_wth[~selected_wth["#groups"].str.contains("Ind")]

    btw_no_ind["Attr."] = btw_no_ind["Attr."].str.replace("Continent","Country")
    wth_no_ind["Attr."] = wth_no_ind["Attr."].str.replace("Continent","Country")


    blue_patch = mpatches.Patch(color=sns.color_palette("colorblind")[0], label='The blue data')
    orange_patch = mpatches.Patch(color=sns.color_palette("colorblind")[1], label='The orange data')

    #handle for individual fairness
    dotted_line = mlines.Line2D([], [], color='black', linestyle="--", label="dotted_line")

    #placeholder handle for unfairness
    fake = mlines.Line2D([], [], color='white', linestyle="--", label="dotted_line")

    f = plt.figure(figsize=(7,3.75))
    f.set_layout_engine("constrained")
    plt.axis("off")


    legend = plt.legend(labels=["Unfairness", "between-group", "within-group", "individual"], handles=[fake, blue_patch, orange_patch, dotted_line], 
            loc="upper center",
            bbox_to_anchor=(0.55, 1.585),
            ncols=4,
            )
    p = so.Plot(pd.concat(
                        [btw_no_ind,
                        wth_no_ind
                        ]), 
                
                        x="Attr.", y=f"$\downarrow$Unfairness", 
                        color="component"
                        )\
                        .facet(row="measure")\
                        .add(so.Bar(), so.Agg("mean"), so.Dodge(), legend=False)\
                        .layout(size=(9, 9))\
                        .share(x=True, y=False)\
                        .scale(color="colorblind")\
                        .on(f).plot()


    list_measures = ["SD", "Gini", "Atk"]
    for i in range(1, 4):
        ax = f.axes[i]
        measure = list_measures[i-1]
        y_val = selected_btw[selected_btw["#groups"].str.contains("Ind")]\
                                                                        .query("measure==@measure")["$\downarrow$Unfairness"]\
                                                                        .values[0]

        #plot Ind as line instead
        ax.axhline(y=y_val, color='black', linestyle='--', alpha=0.75, linewidth=0.75)

    # change x axis name 
    ax.set_xlabel("Group \n(#groups)")


    for ax in f.axes:
        ax.set_ylabel("$\\downarrow$"+ax.get_title())
        ax.yaxis.label.set_size(fontsize=10)
        ax.xaxis.label.set_size(fontsize=10)
        ax.set_title("")
        ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=4))
        ax.xaxis.set_label_coords(-0.05, -0.19)

    f.axes[1].set_title(data)
    time = utils.timenow()

    plt.savefig(f'multiple_groups/temp_{time}_decomposability_all_{data}.pdf', bbox_inches='tight')
    p.show()


# Agreement of different ways of grouping and individual fairness

In [151]:
selected_agg = "SD|Gini|Atk"
selected_base = "-NDCG"
excluded = "Dec|Within" #decomposed Atk


df_group_all = big_table_full.reset_index().rename(columns={"level_0":"dataset"})
df_group_all = df_group_all[df_group_all.measure.str.contains(selected_agg)]
df_group_all = df_group_all[df_group_all.measure.str.contains(selected_base)]
df_group_all = df_group_all[~df_group_all.measure.str.contains(excluded)]

In [152]:
df_group_all["measure_type"] = df_group_all.measure.apply(measure_type_multiple_group)
df_group_all["group_context"] = df_group_all["measure"].copy()
df_group_all["measure"] = df_group_all["measure"]\
                            .str.replace("-Ind|-NDCG","", regex=True)\
                            .str.replace("\-.*", "", regex=True)
df_group_all["group_context"] = df_group_all["group_context"]\
                                .str.replace("-Ind|-NDCG","", regex=True)\
                                .str.split("-")\
                                .apply(lambda x: [x for x in x if "arrow" not in x])\
                                .apply(lambda x: "-".join(x))

group_table = df_group_all\
                .set_index(["dataset", "measure_type","group_context","measure"])\
                .reindex([ "Grp (1)", "Grp (2)","Grp (3)", "Ind",], level=1)\
                .unstack(3)\
                .loc[list_data]


In [None]:
from matplotlib.lines import Line2D

In [None]:
list_measure = ["$\downarrow$ SD", "$\downarrow$ Gini", "$\downarrow$ Atk"]

hm_kws = dict(annot=True, #square=True, 
              vmin=-1, vmax=1, 
                   cmap="coolwarm_r", annot_kws={"size": 11},  
                cbar_kws={"orientation": "horizontal"})

fig, ax = plt.subplots(ncols=len(list_data),
                       figsize=(8,3.75)
                       )

cbar_ax = fig.add_axes([.25, -.03, .7, .025])
cbar_ax.tick_params(labelsize=9)

for i, (ax_id, data) in enumerate(zip(ax, list_data)):
    df_agreement = pd.DataFrame()
    for meas in list_measure:
        df_this_meas = df_group_all\
                            .set_index(["dataset", "measure_type","group_context","measure"])\
                            .reindex([ "Grp (1)", "Grp (2)","Grp (3)", "Ind",], level=1)\
                            .swaplevel(1,3, axis=0)\
                            .swaplevel(2,3, axis=0)\
                            .loc[data].T.corr("kendall")\
                            .round(2)\
                            .loc[meas].loc["Ind"].loc[:, meas]
        
        df_this_meas["measure"] = meas.split(" ")[-1]

        if len(df_agreement) == 0:
            df_agreement = df_this_meas
        else:
            df_agreement = pd.concat([df_agreement, df_this_meas])

    df_agreement = df_agreement.set_index("measure") 
    df_agreement = df_agreement.droplevel(level=0, axis=1)
    
    # remove individual fairness corr with itself
    df_agreement = df_agreement.iloc[:,:-1].T

    # change last index to "All"
    idx_name =  df_agreement.index.to_list() 
    idx_name[-1] = "All"
    df_agreement.index = idx_name
    df_agreement.index = df_agreement.index.str.replace("-", "\n")

    index_order = pd.Series(group_count[data]).sort_values()[:-1].index
    index_order = index_order.str.replace("-", "\n").to_list()
    index_order[-1] = "All"

    df_agreement = df_agreement.reindex(index_order)
    df_agreement.index = df_agreement.index.str.replace("Continent","Country")
                
    hm = sns.heatmap(df_agreement, ax=ax_id, cbar=i==0, 
                        cbar_ax=None if i else cbar_ax,
                     **hm_kws)

    ax_id.set_title(map_columns[data])
    if data != "lfm-1b":
        hm.set_ylabel(None)
    else:
        hm.set_ylabel("$Group\ fairness$", rotation=270,loc="center",labelpad=10, fontsize=11)
        ax_id.yaxis.set_label_position("right")

    hm.set_xlabel("$Individual\ fairness$", rotation=0, fontsize=11)
    ax_id.set_yticklabels(ax_id.get_yticklabels(), rotation=0, fontsize=11)
    ax_id.set_xticklabels(ax_id.get_xticklabels(), rotation=0, fontsize=11)

horiz_start_end = (0.1, 0.968)
height = 0.58
line_kw = dict(xdata=horiz_start_end,
                ls="--", color="#1A224C",
                               transform=fig.transFigure)

line = Line2D( ydata=(height, height),**line_kw)
height = 0.26
line2 = Line2D( ydata=(height, height),**line_kw)

height = 0.9
line0 = Line2D( ydata=(height, height),**line_kw)

height = 0.16
line3 = Line2D( ydata=(height, height),**line_kw)

fig.lines = line,line2,line0, line3

args = {"clip_on":False,"size":10}

ax[0].text(-2.5, 0, '#Attributes',
    horizontalalignment='left',
    verticalalignment='bottom',
    weight="bold",
    color="#1A224C",
    **args)

left = -2.75
ax[0].text(left+0.005, 1.55, '1',
    horizontalalignment='left',
    verticalalignment='bottom',
    weight="bold",
    color="#1A224C",
    **args)

ax[0].text(left, 4.62, '2',
    horizontalalignment='left',
    verticalalignment='bottom',
    weight="bold",
    color="#1A224C",
    **args)

ax[0].text(left, 6.62, '3',
    horizontalalignment='left',
    verticalalignment='bottom',
    weight="bold",
    color="#1A224C",
    **args)


plt.tight_layout(w_pad=0.11)
time = utils.timenow()
plt.savefig(f'multiple_groups/temp_{time}_corr_different_ways_of_grouping_LLM.pdf', bbox_inches='tight')
plt.show()