# Analyze & plot from wandb

In [None]:
import itertools
import wandb
import yaml
import pickle
import tempfile
from tqdm import tqdm
import os

from scipy import stats
import numpy as np
import itertools

# set parameters you wish to plot: for example
models = ['LeNet']
hardness_methods = ['uniform'] # place in a list 'asymmetric', "adjacent", 'instance', 'domain_shift', 'ood_covariate', "crop_shift", "zoom_shift", "far_ood"]
datasets = ['cifar']

exp_combos = list(itertools.product(models, datasets, hardness_methods))

In [None]:
failed = []
sensitivity_dict = {method: [] for method in hardness_methods}
corr_dict = {method: [] for method in hardness_methods}
for exp_combo in tqdm(exp_combos):

    try:
        model = exp_combo[0]
        dataset_name = exp_combo[1]
        hardness = exp_combo[2]

        title = f"{hardness}_{model}_{dataset_name}"

        with open('./wandb.yaml') as file:
            wandb_data = yaml.load(file, Loader=yaml.FullLoader)

        os.environ["WANDB_API_KEY"] = wandb_data['wandb_key'] 
        wandb_entity = str(wandb_data['wandb_entity'])

        seed=0
        runid=0
        metainfo = f"{dataset_name}_{hardness}_{0.1}_{seed}_{runid}" 
       
        # check file exists for one run
        project_name = f'{hardness}_{dataset_name}_{model}'
        api = wandb.Api()
        artifact = api.artifact(f'{wandb_entity}/{project_name}/scores_dict_{metainfo}:latest', type='pickle')
        temp_dir = tempfile.TemporaryDirectory()
        artifact.download(root=temp_dir.name)
        temp_dir.cleanup()


        props = [0.1,0.2,0.3,0.4,0.5]
        runids = [0,1,2]

        overall_results = {}

        for prop in props:

            prop_list = []

            for runid in runids:
                

                try:
                    temp_dir = tempfile.TemporaryDirectory()

                    metainfo = f"{dataset_name}_{hardness}_{prop}_{seed}_{runid}" 
                    api = wandb.Api()
                    artifact = api.artifact(f'{wandb_entity}/{project_name}/scores_dict_{metainfo}:latest', type='pickle')
                    artifact_dir = artifact.download(root=temp_dir.name)

                    with open(artifact_dir + f"/scores_dict_{metainfo}.pkl", "rb") as f:
                        data = pickle.load(f)
                    temp_dir.cleanup()

                    prop_list.append(data)

                except:
                    print(f'FAILED {metainfo} - {model}')
                    continue

            overall_results[prop] = prop_list



        combos = list(itertools.combinations(range(len(overall_results[0.1])), 2))

        sensitivity = {}

    
        for prop in props:

            corr_dict = {}

            for model in list(overall_results[0.1][0]['scores'].keys()):
                corr = []
                for combo in combos:
                    try:
                        c1, c2 = combo[0], combo[1]
                        rvs1 = overall_results[prop][c1]['scores'][model]
                        rvs2 = overall_results[prop][c2]['scores'][model]
                        corr.append(stats.spearmanr(rvs1, rvs2)[0])
                    except:
                        continue

                corr_dict[model] = np.mean(corr)

            sensitivity[prop] = corr_dict

        sensitivity_dict[hardness].append(sensitivity)

    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print(f'failed {title}')
        failed.append(title)

In [None]:
def compute_mean_dict(dict_list):
    from collections import defaultdict

    # Initialize dictionaries to store the total and count of each metric
    totals = defaultdict(lambda: defaultdict(int))
    counts = defaultdict(lambda: defaultdict(int))

    # Iterate over the list of dictionaries
    for d in dict_list:
        # Iterate over each metric in the dictionary
        for key, metrics in d.items():
            for metric, value in metrics.items():
                # Add the value to the total and increment the count
                totals[key][metric] += value
                counts[key][metric] += 1

    # Compute the means
    means = {key: {metric: total / counts[key][metric]
                for metric, total in metrics.items()}
            for key, metrics in totals.items()}
    
    return means

In [None]:
groups = {'allsh': "Statistical",
 'aum': "LB:Output",
 'cleanlab': "Statistical",
 'conf_agree': "Statistical",
 'dataiq':"LB:Output",
 'datamaps': "LB:Output",
 'detector': "LB:Stats",
 'el2n': "LB:Stats",
 'forgetting': "LB:Other",
 'grand': "LB:Grad",
 'loss': "LB:Other",
 "protypicality": "Dist",
 'vog': "LB:Grad"}


def get_sorted(dictionary, title, save='png'):
        import matplotlib.pyplot as plt
        import collections

        # Convert the dictionary to a 2D NumPy array
        data = np.array([[value for value in level.values()] for level in dictionary.values()])

        # Create a list of keys and levels for labeling the heatmap
        keys = list(dictionary.values())[0].keys()
        keys = [key.split(".")[0] for key in keys]
        levels = list(dictionary.keys())

        # Create a list of tuples (group, method) for each key
        grouped_keys = [(groups.get(key, "Dist"), key) for key in keys]

        # Sort the keys and data by group
        sorted_grouped_keys = sorted(grouped_keys, key=lambda x: (x[0], keys.index(x[1])))

        # Extract the sorted keys and group labels
        sorted_keys = [key[1] for key in sorted_grouped_keys]
        group_labels = [key[0] for key in sorted_grouped_keys]

        # Create an ordered dictionary that maps the original keys to their corresponding data columns
        data_dict = collections.OrderedDict(zip(keys, data.T))

        # Create a new sorted dictionary and a new sorted data array
        sorted_data_dict = collections.OrderedDict(sorted(data_dict.items(), key=lambda x: sorted_grouped_keys.index((groups.get(x[0], "Dist"), x[0]))))

        # Get the indices of the sorted keys
        indices = [keys.index(key) for key in sorted_keys]

        # Reorder the data array based on the sorted keys
        sorted_data = data[:, indices]

        return sorted_data, sorted_keys, levels, group_labels, sorted_data_dict

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_bar_plot(sorted_data, sorted_keys, levels, group_labels, sorted_data_dict, title):
    s_keys = []
    for key in sorted_keys:
        if key =='prototypicality':
            s_keys.append('proto')
        elif key =='forgetting':
            s_keys.append('forget')
        elif key=='conf_agree':
            s_keys.append('conf_agr')
        else:
            s_keys.append(key)

    sorted_keys = s_keys
    # create pandas dataframe from sorted_data, where rows are levels and columns are sorted_keys
    df = pd.DataFrame(sorted_data, index=levels, columns=sorted_keys)

    # compute mean and standard deviation across levels
    mean = df.mean(axis=0)
    std = df.std(axis=0)

    min_val = mean.min()

    fs=18
    # plot bar plot
    plt.figure(figsize=(7, 3))
    ax = sns.barplot(x=mean.index, y=mean.values, yerr=std.values)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=fs-3)
    
    # set fontsize of yticks
    ax.tick_params(axis='y', which='major', labelsize=fs)

    #set ylim
    ax.set_ylim(min_val-0.1, 1.0)
    ax.set_ylabel("Spearman correlation", fontsize=fs)

    # add a vertical grid to the plot, but make it very light in color
    # so we can use it for reading data values but not be distracting
    ax.grid(True, which='major', axis='y', linestyle='--', color='grey', alpha=0.33)

    # add extra xticks below based on groups
    group_positions = {}
    for i, key in enumerate(sorted_data_dict.keys()):
        group = groups.get(key, "Dist")
        if group not in group_positions:
            group_positions[group] = [i]
        else:
            group_positions[group].append(i)

    # Now, for each group, find the midpoint and add the label there
    for group, positions in group_positions.items():
        midpoint = np.mean(positions)
        ax.text(midpoint, -0.425, group, ha='center', va='top', transform=ax.get_xaxis_transform(), fontsize=fs-5)

    # if results_folder doesn't exist create it
    if not os.path.exists('results_folder'):
        os.makedirs('results_folder')
    
    if not os.path.exists('results_folder/model_compare'):
        os.makedirs('results_folder/model_compare')

    plt.savefig(f"results_folder/model_compare/{title}.pdf", bbox_inches='tight', pad_inches=0.1)

    plt.close()

In [None]:
for hardness in hardness_methods:
    try:
        sensitivity_indiv = compute_mean_dict(sensitivity_dict[hardness])
        title = f"{hardness}"
        sorted_data, sorted_keys, levels, group_labels, sorted_data_dict = get_sorted(dictionary=sensitivity_indiv, title=title, save='png')
        plot_bar_plot(sorted_data, sorted_keys, levels, group_labels, sorted_data_dict, title=title)
        
    except Exception as e:
        print(f'failed {hardness}')
        print(e)
        continue