In [None]:
import sys
sys.path.append('../')
import ast
import math
import statistics 

import numpy as np
from mlflow.tracking import MlflowClient
from matplotlib import pyplot as plt
import torch
%matplotlib inline


from src.deepal.database.mlflow_logger import MLFlowLogger, export_from_mlflow

In [None]:
colormapper = {
    "KCenterGreedy": ("tab:red", "x", "latent", "solid"),
    "KCenterGreedy_pca": ("tab:red", "x", "pca", "dotted"),
    "KCenterGreedy_output": ("tab:red", "x", "prob", "dashed"),
    "KMeansSampling": ("tab:orange", "*", "latent", "solid"),
    "KMeansSampling_output": ("tab:orange", "*", "prob", "dashed"),    
    "KMeansSampling_pca": ("tab:orange", "*", "pca", "dotted"),
    "KMeansPP": ("tab:green", "|", "latent", "solid"),
    "KMeansPP_output": ("tab:green", "|", "prob", "dashed"),
    "KMeansPP_pca": ("tab:green", "|", "pca", "dotted"),
}

namemapper = {"KMeansSampling": "KMeansCenter", "KCenterGreedy": "KCenterGreedy", "KMeansPP": "KMeans++"}
mlflow_uri = ''  # TODO: PLEASE SET MLFLOW URI

In [None]:
def get_avg_accs(run_id, seeds, colname="avg_acc"):
    accs = None
    if not (math.isnan(row[colname])):
        accs = np.array([m.value for m in tracking.get_metric_history(run_id, colname)])        
    
    accs_dic = {}
    for seed in seeds:
        for it, m in enumerate(tracking.get_metric_history(run_id, f"acc_{seed}")):
            if m.step not in accs_dic.keys():
                accs_dic[m.step] = [m.value]
            else:
                accs_dic[m.step].append(m.value)                
    std_accs = np.array([statistics.stdev(value) if len(value) > 1 else 0 for key,value in accs_dic.items()])
    if accs is None:
        accs = np.array([statistics.mean(value) for key,value in accs_dic.items()]) 
    lower_bound = accs - std_accs
    upper_bound = accs + std_accs
    return accs, lower_bound, upper_bound
    
def get_nlabels(run_id, prepend_zero=True):
    labels = []
    if not (math.isnan(row["avg_acc"])):
        labels = np.array([m.step for m in tracking.get_metric_history(run_id, "avg_acc")])
    if len(labels) == 0:
        for it, m in enumerate(tracking.get_metric_history(run_id, f"acc_{seeds[0]}")):
            labels.append(m.step)
    if prepend_zero:
        nlabels = np.insert(labels, 0, 0, axis=0)
    return nlabels

def get_qtimes(run_id, seeds, shape, prepend_zero=True):
    seconds = sum([1 if len(tracking.get_metric_history(run_id, f"sec_{se}")) > 0 else 0 for se in seeds])
    if seconds > 0:
        time_tensor = torch.ones((seconds-1,shape))
        for se_i in range(seconds-1):
            se = seeds[se_i]
            for tr_j,m in enumerate(tracking.get_metric_history(run_id, f"sec_{se}")):
                time_tensor[se_i][tr_j] = m.value                                

        avg_cum_times = torch.cumsum(torch.mean(time_tensor, dim=0), dim=0)
        avg_seconds = avg_cum_times.numpy()
    
    if prepend_zero:
        avg_seconds = np.insert(avg_seconds, 0, 0, axis=0)
    return avg_seconds



In [None]:
competitor_strategies = [
    "KMeansSampling",
    "KCenterGreedy",
    "KMeansPP",
]

experiment_name = "low-dim-div-sampling"

all_runs_df = export_from_mlflow(mlflow_uri=mlflow_uri, mlflow_experiment_name=experiment_name, metrics=['avg_acc'])
query_sizes = all_runs_df['n_queries'].unique()
query_sizes.sort()


fig, ax = plt.subplots(nrows=len(competitor_strategies), ncols=3, figsize=(40,35))

for strat_i, strat in enumerate(competitor_strategies):
    filtered_runs_df = all_runs_df[all_runs_df.strategy.isin([strat])]

    filtered_runs_df = filtered_runs_df.sort_values(by=['strategy', 'strategy_parameters.emb', 'model_params.linear_hidden_dims'], ascending=False)
    for run_id, row in filtered_runs_df.iterrows():
        seeds = ast.literal_eval(row["seed"])
        db = MLFlowLogger(tracking_uri=mlflow_uri, experiment_name=row["experiment_name"])
        tracking = MlflowClient()

        avg_accs, lower_bound, upper_bound = get_avg_accs(run_id, seeds)
        nlabels = get_nlabels(run_id, seeds)
        nseconds = get_qtimes(run_id, seeds, avg_accs.shape[0]-1)
        lim_min = 0

        title_prepared = row["strategy"]
        if row["strategy_parameters.emb"] == "output":
            title_prepared = title_prepared + "_output"
        elif row["strategy_parameters.emb"] == "pca":
            title_prepared = title_prepared + "_pca"

        ax[strat_i][0].plot(nlabels, np.insert(avg_accs, 0, 0, axis=0), label=colormapper[title_prepared][2], ms=12, marker=colormapper[title_prepared][1])
        ax[strat_i][1].plot(nseconds, avg_accs, label=colormapper[title_prepared][2], ms=12, marker=colormapper[title_prepared][1])
        ax[strat_i][2].plot(nlabels[2:], nseconds[1:], linewidth=3, label=colormapper[title_prepared][2], ms=12, marker=colormapper[title_prepared][1])
        ax[strat_i][0].fill_between(nlabels, np.insert(lower_bound, 0, 0, axis=0), np.insert(upper_bound, 0, 0, axis=0), alpha=0.15)
        ax[strat_i][1].fill_between(nseconds, lower_bound, upper_bound, alpha=0.15)

    ax[strat_i][1].set_title(f"{namemapper[row['strategy']]}", fontsize=60)
    ax[strat_i][0].set_ylim(0.7,1)
    ax[strat_i][1].set_ylim(0.7,1)

    fs = 45
    ax[strat_i][0].legend(fontsize=fs, loc="lower right", fancybox=True)
    ax[strat_i][1].legend(fontsize=fs, loc="lower right", fancybox=True)
    ax[strat_i][2].legend(fontsize=fs, loc="upper left", fancybox=True)

    fs = 50
    ax[strat_i][0].set_ylabel('Accuracy', fontsize=fs)
    ax[strat_i][1].set_ylabel('Accuracy', fontsize=fs)
    ax[strat_i][2].set_ylabel('Query Time (s)', fontsize=fs)

    ax[strat_i][0].set_xlabel('# Labels', fontsize=fs)
    ax[strat_i][1].set_xlabel('Query Time (s)', fontsize=fs)
    ax[strat_i][2].set_xlabel('# Labels', fontsize=fs)

    ax[strat_i][0].grid(visible=True, color='lightgrey', linestyle='--', linewidth=1)
    ax[strat_i][1].grid(visible=True, color='lightgrey', linestyle='--', linewidth=1)
    ax[strat_i][2].grid(visible=True, color='lightgrey', linestyle='--', linewidth=1)

    ax[strat_i][0].set_aspect('auto')
    ax[strat_i][1].set_aspect('auto')
    ax[strat_i][2].set_aspect('auto')

fig.tight_layout()
plt.show()