In [19]:
import os, sys, json
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

import matplotlib.pyplot as plt
from collections import defaultdict

In [20]:
distance_metrics = ["KL-Divergence", "Jensen-Shannon", "Wasserstein-Distance", "Energy-Distance"]

In [21]:
models = {
    "wgan": 'w_gan',
    "wgpgan": 'w_gp_gan',
    "nsgan": 'ns_gan',
    "lsgan": 'ls_gan',
    "mmgan": 'mm_gan',
    "dragan": 'dra_gan',
    "began": 'be_gan',
    "ragan": 'ra_gan',
    "infogan": 'info_gan',
    "fishergan": 'fisher_gan',
    "fgan_forward_kl": 'forkl_gan',
    "fgan_reverse_kl": 'revkl_gan',
    "fgan_jensen_shannon": 'js_gan',
    "fgan_total_var": 'tv_gan',
    "fgan_hellinger": 'hellinger_gan',
    "fgan_pearson": 'pearson_gan',
    "vae": 'vae',
    "autoencoder": 'ae',
}

In [55]:
def get_best_performance_multivariate(data_type, start_time, trial):
    # Get path, files in path
    mypath = "/Users/sob/github/gans6883/hypertuning/{0}/{1}/trial_{2}".format(data_type, start_time, trial)
    files = [f for f in os.listdir(mypath) if os.path.isfile(os.path.join(mypath, f))]
    results = []
    
    # Read in the files
    for file in files:
        with open("{}/{}".format(mypath, file)) as f:
            data = json.load(f)
        results.append(data)
    
    # Initialize best dictionary
    optimal = nested_pickle_dict()
    
    # Go through all models, distributionss, metrics, and record the best
    for result in results:
        for model, distributions in result.items():
            for distribution, metrics in distributions.items():
                for metric, values in metrics.items():
                    if metric not in ["LR", "HDIM", "BSIZE"]:
                        
                        # If metric is seen for the first time, it is the best
                        if metric not in optimal[model][distribution]:
                            optimal[model][distribution][metric]["value"] = values
                            optimal[model][distribution][metric]["parameters"] = [metrics["LR"], metrics["HDIM"], metrics["BSIZE"]]
                        
                        # Otherwise, compare it the presently considered value
                        elif optimal[model][distribution][metric]["value"][-1] > values[-1]:
                            optimal[model][distribution][metric]["value"] = values
                            optimal[model][distribution][metric]["parameters"] = [metrics["LR"], metrics["HDIM"], metrics["BSIZE"]]

    return optimal

In [56]:
def get_best_performance_mixture(data_type, start_time, trial):
    """ For a trial, get the best performance for a mixture model """
    # Get path, files in path
    mypath = "/Users/sob/github/gans6883/hypertuning/{0}/{1}/trial_{2}".format(data_type, start_time, trial)
    files = [f for f in os.listdir(mypath) if os.path.isfile(os.path.join(mypath, f))]
    results = []
    
    # Read in the files
    for file in files:
        with open("{}/{}".format(mypath, file)) as f:
            data = json.load(f)
        results.append(data)
    
    # Initialize best dictionary
    optimal = nested_pickle_dict()
    for result in results:
        for model, mixtures in result.items():
            for mixture, distributions in mixtures.items():
                for distribution, metrics in distributions.items():
                    for metric, values in metrics.items():
                        if metric not in ["LR", "HDIM", "BSIZE"]:
                            
                            # If metric is seen for the first time, it is the best
                            if metric not in optimal[model][mixture][distribution]:
                                optimal[model][mixture][distribution][metric]["value"] = values
                                optimal[model][mixture][distribution][metric]["parameters"] = [metrics["LR"], metrics["HDIM"], metrics["BSIZE"]]

                            # Otherwise, compare it the presently considered value
                            elif optimal[model][mixture][distribution][metric]["value"][-1] > values[-1]:
                                optimal[model][mixture][distribution][metric]["value"] = values
                                optimal[model][mixture][distribution][metric]["parameters"] = [metrics["LR"], metrics["HDIM"], metrics["BSIZE"]]

    return optimal

In [None]:
def get_confidence_intervals_multivariate(data_type, start_time, trial):
    """ Compute 95% confidence intervals for multivariate """
    mypath = "best/{}".format(data_type)
    files = [f for f in os.listdir(mypath) if os.path.isfile(os.path.join(mypath, f))]
    results = []
    for file in files:
        with open("{}/{}".format(mypath, file)) as f:
            data = json.load(f)
        results.append(data)

    optimal = {}
    for result in results:
        for gan, distributions in result.items():
            if gan not in optimal:
                optimal[gan] = {}
            for distribution, metrics in distributions.items():
                if distribution not in optimal[gan]:
                    optimal[gan][distribution] = {}
                for metric, values in metrics.items():
                    if metric not in optimal[gan][distribution]:
                        optimal[gan][distribution][metric] = {"original": []}
                    optimal[gan][distribution][metric]["original"].append(values['value'])

    for result in results:
        for gan, distributions in result.items():
            for distribution, metrics in distributions.items():
                for metric, values in metrics.items():
                    data = np.array(optimal[gan][distribution][metric]["original"])
                    optimal[gan][distribution][metric]['5'] = list(np.percentile(data, 5, axis=0))
                    optimal[gan][distribution][metric]['95'] = list(np.percentile(data, 95, axis=0))

    return optimal

In [58]:
optimal = get_best_performance_multivariate('multivariate', '2018-09-21-1537557860', '1')

In [60]:
optim = get_best_performance_mixture('mixture', '2018-09-21-1537559182', '1')

In [64]:
optim.keys()

dict_keys(['wgan', 'wgpgan'])

In [16]:
def nested_pickle_dict():
    """ Picklable defaultdict nested dictionaries """
    return defaultdict(nested_pickle_dict)

In [146]:
# # def get_confidence_intervals_multivariate(data_type):
# mypath = "/Users/sob/Desktop/mnist_best/"
# files = [f for f in os.listdir(mypath) if os.path.isfile(os.path.join(mypath, f))]
# results = []
# for file in files:
#     print(file)
#     with open("{}/{}".format(mypath, file)) as f:
#         data = json.load(f)
#     results.append(data)
#     data2 = data

# # optimal = nested_pickle_dict()
# # for result in results:
# #     for model, distributions in result.items():
# #         for dist, metrics in distributions.items():
# #             for metric, values in metrics.items():
# #                 if metric not in optimal[model][distribution]:
# #                     optimal[model][dist][metric] = {"original": []}
# #                 optimal[model][dist][metric]["original"].append(values['value'])

# # for result in results:
# #     for model, distributions in result.items():
# #         for distribution, metrics in distributions.items():
# #             for metric, values in metrics.items():
# #                 data = np.array(optimal[gan][distribution][metric]["original"])
# #                 optimal[gan][distribution][metric]['5'] = list(np.percentile(data, 5, axis=0))
# #                 optimal[gan][distribution][metric]['95'] = list(np.percentile(data, 95, axis=0))

# # #     return optimal

results_0_2018-09-10.json
False
results_10_2018-09-15.json
True
results_11_2018-09-16.json
True
results_12_2018-09-16.json
True
results_13_2018-09-16.json
True
results_14_2018-09-16.json
True
results_15_2018-09-17.json
True
results_16_2018-09-17.json
True
results_17_2018-09-17.json
True
results_18_2018-09-18.json
True
results_19_2018-09-18.json
True
results_1_2018-09-11.json
True
results_2_2018-09-11.json
True
results_3_2018-09-12.json
True
results_4_2018-09-12.json
True
results_5_2018-09-13.json
True
results_6_2018-09-13.json
True
results_7_2018-09-13.json
True
results_8_2018-09-14.json
True
results_9_2018-09-14.json
True
