# Hyperparameter search analysis

In [76]:
import seaborn as sns
import pandas as pd
from pathlib import Path
from chunked_writer import TidyReader
%matplotlib inline

## The following cell contains functions for retrieving dataframes via tidywriters
## We also add the hyperparameters to the dataframes here

In [96]:
def load_df_and_params(posixpath, tag, columns):
    """ 
    Args:
        posixpath: Posixpath. Path to one specific epxeriment
        tag: String. The name of the metric which we want to retrieve
        columns: List[String]. The column headers of the resulting dataframe
    Returns:
        df: DataFrame
    """
    reader = TidyReader(os.path.join(posixpath, "data"))
    df = reader.read(tag=tag, columns=columns)
    params = stem_to_params(posixpath.name)
    return df, params


def eval_d(string: str):
    """
    Helper function which parses a string of the form 'x:y' into
    a dictionary {"x": "y"}. This is used for inferring the hyperparameters
    from the folder names.
    Args:
        string: String. The input string to evaluate.
    """
    k, v = string.split(":")
    return {k: v}

def stem_to_params(stem) -> dict:
    """
    Helper function which takes in a path stem of the form "param:value-param2:value2..."
    and returns a dictionary of parameters, e.g. "{"param":value,...}
    Args:
        stem: String. The name of the folder.
    """
    params = {k: v for d in map(eval_d, stem.split("-")[:-1]) for k, v in d.items()}
    return params

def series_to_mean(df):
    groups = df.groupby(["Rank", "Metric", "Type", "Agent",])
    return groups.apply(lambda x: x[x["Step"] >= 4000].mean())

In [117]:
paths = Path("results/jeanzay/results/sweeps/shared_ref_mnist/2021-04-12/23-42-59").glob("*")

dfs = []
for i, path in enumerate(paths):
    df, params = load_df_and_params(path, "pred_from_latent",
    ["Rank", "Step", "Value", "Metric", "Type", "Agent"])
    df = series_to_mean(df)
    for param, value in params.items():
        df[param] = value
    dfs.append(df)
    if i == 2:
        break


df = pd.concat(dfs)
print(df)
#print(df.reset_index(col_level=2))
# filter df
df = df[(df["Metric"] == "Accuracy") & (df["Type"] == "Reconstruction")]

groups = df.groupby(["Rank", "Metric", "Type", "Agent", "sigma", "eta_lsa", "eta_ae"])
#for name, group in groups:
#    print("NAME", name)
#    print(group)

df.to_csv("test_before.csv")
df = groups.apply(lambda x: x[x["Step"] >= 4000].mean())
df.to_csv("test.csv")
print(df)
sns.pointplot(data=df, x="eta_ae", y="sigma", hue="Value", style="Rank", join=False)
# sns.relplot(data=df[(df["Metric"] == "Accuracy") & (df["Type"] == "Reconstruction")],
#             x="Step",
#             y="Value",
#             col="Agent",
#             hue="eta_lsa"
#             )

KeyError: 'Metric'