# Analyze & plot results from wandb

In [None]:
groups = {'allsh': "Statistical",
 'aum': "LB:Output",
 'cleanlab': "Statistical",
 'conf_agree': "Statistical",
 'dataiq':"LB:Output",
 'datamaps': "LB:Output",
 'detector': "LB:Stats",
 'el2n': "LB:Stats",
 'forgetting': "LB:Other",
 'grand': "LB:Grad",
 'loss': "LB:Other",
 "protypicality": "Dist",
 'vog': "LB:Grad"}

In [None]:
import wandb
import pandas as pd
import os
import yaml
import itertools
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scienceplots 
plt.style.reload_library()
plt.style.use(["science", "ieee", "no-latex", "notebook", "grid", "vibrant"])

# Load the WANDB YAML file
with open('./wandb.yaml') as file:
    wandb_data = yaml.load(file, Loader=yaml.FullLoader)

os.environ["WANDB_API_KEY"] = wandb_data['wandb_key'] 
wandb_entity = wandb_data['wandb_entity'] 

color_mapping = {
    'allsh': 'red',
    'aum': 'blue',
    'cleanlab': 'green',
    'conf_agree': 'orange',
    'dataiq': 'purple',
    'datamaps': 'brown',
    'detector': 'pink',
    'el2n': 'gray',
    'forgetting': 'olive',
    'grand': 'cyan',
    'loss': 'magenta',
    'prototypicality': 'lime',
    'vog': 'teal'
}

colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'magenta', 'lime', 'teal']
methods = ['allsh', 'aum', 'cleanlab', 'conf_agree', 'dataiq', 'datamaps', 'detector', 'el2n', 'forgetting', 'grand', 'loss', 'prototypicality', 'vog']


In [None]:
def compute_mean_dict(dict_list):
    from collections import defaultdict

    # Initialize dictionaries to store the total and count of each metric
    totals = defaultdict(lambda: defaultdict(int))
    counts = defaultdict(lambda: defaultdict(int))

    # Iterate over the list of dictionaries
    for d in dict_list:
        # Iterate over each metric in the dictionary
        for key, metrics in d.items():
            for metric, value in metrics.items():
                # Add the value to the total and increment the count
                totals[key][metric] += value
                counts[key][metric] += 1

    # Compute the means
    means = {key: {metric: total / counts[key][metric]
                for metric, total in metrics.items()}
            for key, metrics in totals.items()}
    
    return means

In [None]:
def data_and_perf_dicts(performance_dict, prop_list=[0.1,0.2,0.3,0.4,0.5]):
    
    data = performance_dict
    perf_dict= {}

    for idx, p in enumerate(prop_list):
        perf_dict[p] = {}
        for key in data:
            perf_dict[p][key] = data[key][idx]

    # Create a new dictionary to store the ranks
    rank_dict = {}
    d= perf_dict
    # Iterate over the keys of the original dictionary
    for key in d.keys():
        
        # Get the values of the current key
        values = d[key]
        
        # Sort the values in descending order and get the keys
        sorted_keys = sorted(values, key=lambda x: values[x], reverse=True)
        
        # Create a dictionary to store the ranks
        rank_values = {}
        
        # Assign ranks to the keys
        for i, k in enumerate(sorted_keys):
            rank_values[k] = i+1
        
        # Add the rank values to the rank dictionary
        rank_dict[key] = rank_values

    data = rank_dict

    return data, perf_dict

In [None]:
failed = []

# set parameters you wish to plot: for example
models = ['LeNet', 'ResNet']
hardness_methods = ['uniform', 'asymmetric', 'instance', 'adjacent', 'domain_shift', 'ood_covariate', 'far_ood', 'crop_shift', 'zoom_shift']
datasets = ['mnist', 'cifar']
# datasets = ['cover', "eye", "diabetes"]#, 'cifar']

exp_combos =  list(itertools.product(models, datasets, hardness_methods))


perf_prc_dict = {method: [] for method in hardness_methods}
auc_prc_dict = {method: [] for method in hardness_methods}
data_prc_dict = {method: [] for method in hardness_methods}

perf_roc_dict = {method: [] for method in hardness_methods}
auc_roc_dict = {method: [] for method in hardness_methods}
data_roc_dict = {method: [] for method in hardness_methods}


failed = []

# if folder doesn't exist create it
if not os.path.exists("results_folder"):
    os.makedirs("results_folder")

# set the base folder
base_folder = "results_folder"


# Replace 'your-username' with your wandb username
username = wandb_entity

for exp_combo in tqdm(exp_combos):

    try:
        model = exp_combo[0]
        dataset_name = exp_combo[1]
        hardness = exp_combo[2]

        folder = f"{base_folder}/{hardness}_{dataset_name}_{model}"

        myproject = f"{hardness}_{dataset_name}_{model}"

        # check if folder does not exist in create folder then create and sub folders: cd_plots, heatmaps, violin, mat_compare
        if not os.path.exists(folder):
            os.makedirs(folder)

        prop_list = [0.1,0.2,0.3,0.4,0.5]
        if hardness == 'crop_shift' or hardness == 'zoom_shift' or hardness == 'far_ood':
            prop_list = [0.05,0.1,0.15,0.2,0.25]
        api = wandb.Api()
        runs = api.runs(f"{username}/{myproject}")
        
        # Collect the metrics from all runs into a list
        all_metrics = []
        for run in runs:
            metrics = run.history()
            metrics["run_id"] = run.id
            all_metrics.append(metrics)

        # Combine the metrics into a single DataFrame
        df = pd.concat(all_metrics, ignore_index=True)

        # Group the DataFrame by 'run_id'
        grouped = df.groupby('run_id')

        # Merge rows with the same 'run_id' by taking the mean of each group
        grouped_df = grouped.first().reset_index()

        merged_df = grouped_df[grouped_df['hardness'] == hardness]

        # Extract the columns with 'accuracy' in their name
        merged_df = merged_df.sort_values('p')

        # Define the columns for each subplot
        accuracy_cols = np.sort([col for col in merged_df.columns if 'accuracy' in col])
        auprc_cols = np.sort([col for col in merged_df.columns if 'auc_prc' in col and 'accuracy' not in col])
        auroc_cols = np.sort([col for col in merged_df.columns if 'auc_roc' in col and 'accuracy' not in col])

        # Plot the recall columns on the second subplot
        auprc_dict = {}
        for col in auprc_cols:
            matching_methods = [method for method in methods if method in col]
            color = [color_mapping[method] for method in matching_methods][0]
            p_list = merged_df['p'].unique()
            auprc_list = []
            for p in p_list:
                score = np.mean(merged_df[merged_df['p']==p][col])
                auprc_list.append(score)
            auprc_dict[col] = auprc_list
   

        # Plot the precision columns on the third subplot
        auroc_dict = {}
        for col in auroc_cols:
            matching_methods = [method for method in methods if method in col]
            color = [color_mapping[method] for method in matching_methods][0]
            p_list = merged_df['p'].unique()
            auroc_list = []
            for p in p_list:
                score = np.mean(merged_df[merged_df['p']==p][col])
                if score<0.5:
                    score=1-score
                auroc_list.append(score)
            auroc_dict[col] = auroc_list
  

        # Call plot_rank for each subplot
        data_prc, perf_prc = data_and_perf_dicts(auprc_dict,  prop_list=prop_list)
        data_roc, perf_roc = data_and_perf_dicts(auroc_dict, prop_list=prop_list)

        
        myproject = f"update_{hardness}"



        ####################################################
        ####################################################
        #
        # HEATMAP plot (PRC)
        #
        ####################################################
        ####################################################

        import matplotlib.pyplot as plt
        dictionary = perf_prc

        # Convert the dictionary to a 2D NumPy array
        data = np.array([[value for value in level.values()] for level in dictionary.values()])

        # Create a list of keys and levels for labeling the heatmap
        keys = list(dictionary.values())[0].keys()
        keys = [key.split(".")[0] for key in keys]
        levels = list(dictionary.keys())

        # Create a list of tuples (group, method) for each key
        grouped_keys = [(groups.get(key, "Dist"), key) for key in keys]

        # Sort the keys and data by group
        sorted_grouped_keys = sorted(grouped_keys, key=lambda x: (x[0], keys.index(x[1])))

        # Extract the sorted keys and group labels
        sorted_keys = [key[1] for key in sorted_grouped_keys]
        group_labels = [key[0] for key in sorted_grouped_keys]

        # Create an ordered dictionary that maps the original keys to their corresponding data columns
        data_dict = collections.OrderedDict(zip(keys, data.T))

        # Create a new sorted dictionary and a new sorted data array
        sorted_data_dict = collections.OrderedDict(sorted(data_dict.items(), key=lambda x: sorted_grouped_keys.index((groups.get(x[0], "Dist"), x[0]))))

        # Get the indices of the sorted keys
        indices = [keys.index(key) for key in sorted_keys]

        # Reorder the data array based on the sorted keys
        sorted_data = data[:, indices]


        fig, ax = plt.subplots(figsize=(4, 2))
        heatmap = ax.imshow(sorted_data, cmap='bwr_r', vmin=np.round(np.min(sorted_data),1), vmax=np.round(np.max(sorted_data),1))

        # Remove grid
        ax.grid(False)

        # Add custom gridlines
        rows, cols = sorted_data.shape
        for i in range(rows):
            for j in range(cols):
                if i != rows - 1:
                    ax.axhline(i + 0.5, color='white', linewidth=0.5)
                if j != cols - 1:
                    ax.axvline(j + 0.5, color='white', linewidth=0.5)

        # Remove grid
        ax.grid(False)

        fs=8

        # Set the ticks and labels for the x and y axes
        ax.set_xticks(np.arange(len(keys)))
        ax.set_yticks(np.arange(len(levels)))
        ax.set_yticklabels(prop_list, fontsize=fs) #old levels

        plt.minorticks_off()
        plt.tick_params(axis='both', which='both', bottom=True, top=False, labelbottom=True, left=True, right=False, labelleft=True, direction='out')
        plt.xticks(fontsize=fs)
        plt.yticks(fontsize=fs)

        rawlabels = list(sorted_data_dict.keys())

        labels = []
        for label in rawlabels:

            if label =='prototypicality':
                labels.append('proto')
            elif label =='forgetting':
                labels.append('forget')
            elif label =='conf_agree':
                labels.append('conf_agr')
            else:
                labels.append(label)

        # Rotate the method names on the x-axis for better readability
        ax.set_xticklabels(labels, rotation=45, ha="right", rotation_mode="anchor", fontsize=fs)
        ax.tick_params(axis='x', pad=-1)

        group_positions = {}
        for i, key in enumerate(sorted_data_dict.keys()):
            group = groups.get(key, "Dist")
            if group not in group_positions:
                group_positions[group] = [i]
            else:
                group_positions[group].append(i)

        # Now, for each group, find the midpoint and add the label there
        for group, positions in group_positions.items():
            midpoint = np.mean(positions)
            ax.text(midpoint, -0.55, group, ha='center', va='top', transform=ax.get_xaxis_transform(), fontsize=6)


            # # Draw vertical lines at the boundaries of each group
            if len(positions) > 1:
                ax.axvline(x=positions[0]-0.5,  color='k', linestyle='--',)
                ax.axvline(x=positions[-1]+0.5, color='k', linestyle='--')


        # Add colorbar
        cbar = ax.figure.colorbar(heatmap, shrink=0.55)
        cbar.ax.tick_params(labelsize=fs) 

        # Set the title and labels

        ax.set_ylabel("Proportion perturbed", fontsize=fs)
        title = f"{hardness}_{model}_{dataset_name}"
        #plt.title(title, fontsize=fs)

        # Show the plot
        plt.tight_layout()

        plt.subplots_adjust(bottom=0)

        plt.savefig(f'{folder}/heatmaps_{myproject}_prc.pdf', bbox_inches='tight', pad_inches=0.1)
        plt.close()


    
 

        ####################################################
        ####################################################
        #
        #  Rank violin (PRC)
        #
        ####################################################
        ####################################################

        import pandas as pd
        import seaborn as sns

        rank_scores = data_prc

        # Convert the dictionary to a DataFrame for easier manipulation
        df = pd.DataFrame.from_dict(rank_scores, orient='index')
        labels = []
        rawlabels = [name.split(".")[0] for name in list(df.columns)]
        for label in rawlabels:

            if label =='prototypicality':
                labels.append('proto')
            elif label =='forgetting':
                labels.append('forget')
            elif label =='conf_agree':
                labels.append('conf_agr')
            else:
                labels.append(label)
        df.columns = labels
        method_order = sorted(df.columns)
        df = df.reindex(sorted(df.columns), axis=1)
        # Set the figure size
        plt.figure(figsize=(6, 2))

        # Create the violin plot without scaling the width
        sns.violinplot(data=df, inner="stick", scale="width",cut=0, order=method_order)

        # Add median values to the violin plot
        medians = df.median()
        vertical_offset = df.values.max() * 0.01  # adjust this value for better alignment
        for xtick in range(df.shape[1]):
            plt.text(xtick, medians[xtick] + vertical_offset, round(medians[xtick],2),
                    horizontalalignment='center', color='black')

        # Set labels and title
        #plt.xlabel('HCM')
        plt.ylabel('Rank Scores', fontsize=14)

        # 
        # Rotate x-axis labels if needed
        plt.xticks(rotation=45)
        plt.tick_params(axis='x', pad=-5)

        # Flip the y-axis
        plt.gca().invert_yaxis()

        plt.minorticks_off()
        plt.tick_params(axis='both', which='both', bottom=True, top=False, labelbottom=True, left=True, right=False, labelleft=True, direction='out')
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=16)
        plt.yticks(range(1, int(df.values.max()) + 1,2))

        #plt.title(f"{hardness}_prc")
        plt.savefig(f'{folder}/violin_{myproject}_prc.pdf', bbox_inches='tight', pad_inches=0.1)

        plt.close()


        # Initialize an empty matrix to store the counts
        rank_scores = data_prc

        # Convert the dictionary to a DataFrame for easier manipulation
        df = pd.DataFrame.from_dict(rank_scores, orient='index')
        labels = []
        rawlabels = [name.split(".")[0] for name in list(df.columns)]
        for label in rawlabels:

            if label =='prototypicality':
                labels.append('proto')
            elif label =='forgetting':
                labels.append('forget')
            elif label =='conf_agree':
                labels.append('conf_agr')
            else:
                labels.append(label)
        df.columns = labels
        data=df
        data = data.sort_index(axis=0).sort_index(axis=1)
        matrix = pd.DataFrame(0, index=data.columns, columns=data.columns)

        # Iterate over the columns to calculate the counts
        for col1 in data.columns:
            for col2 in data.columns:
                matrix.loc[col1, col2] = sum(data[col1] < data[col2])  # Count how many times col1 < col2


        # Normalize the matrix between 0 and 1
        matrix_normalized = matrix / matrix.max().max()

        # Replace diagonal elements with 'X'
        np.fill_diagonal(matrix_normalized.values, 0.5)


        # Plot the normalized heatmap matrix with gridlines
        plt.figure(figsize=(10, 8))
        sns.heatmap(matrix_normalized, cmap='bwr_r', fmt='.1f', linewidths=0.5, linecolor='black')
        #plt.title('Method Beats Matrix (Normalized)')
        plt.xlabel('HCM 2')
        plt.ylabel('HCM 1')
        plt.savefig(f'{folder}/mat_compare_{myproject}_prc.pdf', bbox_inches='tight', pad_inches=0.1)

        plt.close()

        

        ####################################################
        ####################################################
        #
        # CD PLOT (PRC)
        #
        ####################################################
        ####################################################

        data = data_prc

        pdf = pd.DataFrame(data)
        pdf.index = [name.split(".")[0] for name in pdf.index]

        # reset index to treat method names as a column
        pdf = pdf.reset_index()

        # melt the DataFrame to long format
        df_melted = pd.melt(pdf, id_vars='index', var_name='score_key', value_name='rank')

        # # rename 'index' column to 'method'
        df_melted.rename(columns={'index': 'method', "rank": "score"}, inplace=True)
        avg_rank = df_melted.groupby('score_key').score.rank(pct=False, ascending=True).groupby(df_melted.method).mean()

        import scikit_posthocs as sp
        from src.plots import *

        test_results = sp.posthoc_siegel_friedman(
        df_melted,
        melted=True,
        block_col='score_key',
        group_col='method',
        y_col='score',
        )

        plt.figure(figsize=(3.5, 2), dpi=1000)
        #plt.title(f"{hardness}_prc")
        critical_difference_diagram(avg_rank, test_results, label_props={'color': 'black', 'fontsize': 12},text_h_margin=0.1,)
        plt.savefig(f'{folder}/cd_{myproject}_prc1.pdf', bbox_inches='tight', pad_inches=0.1)
        plt.close()




    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print(f'Failed - {myproject}')
        failed.append(myproject)
        continue