In [None]:
import pandas as pd
import json
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def clean(rawdata):
    return rawdata.replace("'","\"").replace("False","0").replace("True","1").replace("None","\"None\"")

In [None]:
def safe_check_required_key_value(dictionary, key, value):
    return key in dictionary.keys() and dictionary[key] == value

In [None]:
def match_dict_values(curdict, goaldict, ignore_keys_list):
    for key in goaldict.keys():
        if key not in ignore_keys_list:
            if not safe_check_required_key_value(curdict, key, goaldict[key]):
                return False
    return True

In [None]:
all_run_dicts = []
rawdata = pd.read_csv("../fork_project.csv")
rawdata = rawdata.reset_index()  # make sure indexes pair with number of rows
for index, row in rawdata.iterrows():
    configdict = json.loads(clean(row.config))
    if not any([match_dict_values(configdict, prd,["logname","model_save_dir","latent_dim"]) for prd in all_run_dicts]):
        all_run_dicts.append(configdict)

In [None]:
                ("decoder_name","functional_decoder_complicated"),
                ("dec_complicated_function_hidden_dims",[200]),
                ("training_data_added_timing_noise",0.1)

In [None]:
def create_dataset_combine(sweep_dictionaries,ignore_keys_list = ["logname","model_save_dir","latent_dim"]):
    all_data = []
    for sweep_dictionary in sweep_dictionaries:
        data = create_dataset(sweep_dictionary,ignore_keys_list)
        all_data.append(data)
    all_data = np.concatenate(all_data, axis=0)
    return all_data

In [None]:
def create_dataset(sweep_dictionary,ignore_keys_list = ["logname","model_save_dir","latent_dim"], minstamp = None):
    data = []
    rawdata = pd.read_csv("../fork_project.csv")
    rawdata = rawdata.reset_index()  # make sure indexes pair with number of rows

    for index, row in rawdata.iterrows():
        configdict = json.loads(clean(row.config))
        if match_dict_values(configdict,sweep_dictionary,ignore_keys_list):
            rowdata = json.loads(clean(row.summary))
            if "train_alignedRMSE" in rowdata.keys() and rowdata["_step"] >= 99:
                if minstamp is None or configdict["logname"][21:43] >= minstamp:
                    data.append((configdict["latent_dim"],
                                 rowdata["train_alignedRMSE"],
                                 rowdata["test_alignedRMSE"],
                                 rowdata["train_noiselessRMSE"],
                                 rowdata["test_noiselessRMSE"],
                                 np.log2(rowdata["train_KLD"])))
    data = np.array(data)       
    return data

In [None]:
datasets = [create_dataset(rd) for rd in all_run_dicts]

In [None]:
[print(d.shape[0]) for d in datasets]

In [None]:
valid_indices = [i for i,d in enumerate(datasets) if d.shape[0]>3]

In [None]:
valid_indices

In [None]:
datasets = [datasets[vi] for vi in valid_indices]
all_run_dicts = [all_run_dicts[vi] for vi in valid_indices]

In [None]:
def table_with_average_variance_line(data):
    #ax.scatter(data[:,0],data[:,1],c=color,alpha=1,marker=markerstyle)
    xvals = []
    yvals = []
    stdyvals = []
    for x,y in data:
        if x not in xvals:
            xvals.append(x)
            yvals.append(np.mean(data[data[:,0]==x,1]))
            stdyvals.append(np.std(data[data[:,0]==x,1]))
    xvals = np.array(xvals)
    yvals = np.array(yvals)
    stdyvals = np.array(stdyvals)
    order = np.argsort(xvals)

    return(xvals, yvals, stdyvals)

In [None]:

pcaresults = np.load("../fork_pca_results.npy")
pcaresults = pcaresults[pcaresults[:,0]>0]
pcaresults = pcaresults[pcaresults[:,0]<17]
datasets.append(pcaresults)

In [None]:
pcaresults

In [None]:
import matplotlib.lines as mlines 
#https://stackoverflow.com/questions/47391702/how-to-make-a-colored-markers-legend-from-scratch

In [None]:
def save_table(namebase,toplot = None):

    if toplot is None:
        toplot = range(len(datasets))
    ratedat = []
    for i in toplot:
        dataval = datasets[i]
        if dataval.shape[1] > 5:
            returnval = table_with_average_variance_line(dataval[:,(0,5)])
            ratedat.append(returnval)
            
    traindat = []
    for i in toplot:
        dataval = datasets[i]
        if dataval.shape[1] > 5:
            returnval = table_with_average_variance_line(dataval[:,(0,1)])
            traindat.append(returnval)
            
    testdat = []
    for i in toplot:
        dataval = datasets[i]
        if dataval.shape[1] > 5:
            returnval = table_with_average_variance_line(dataval[:,(0,2)])
            testdat.append(returnval)
            
    
    return (ratedat, traindat, testdat)

In [None]:
def is_base_config(testdict,must_not_match=[]):
    checks = [("beta",0.1),
                ("scalar_timewarper_name","modeled_scalar_timewarper"),
                ("decoder_name","functional_decoder_complicated"),
                ("dec_complicated_function_hidden_dims",[200]),
                ("training_data_added_timing_noise",0.1)
             ]
    for check in checks:
        if check[0] in must_not_match:
            if safe_check_required_key_value(testdict,check[0],check[1]):
                return False
        elif not safe_check_required_key_value(testdict,check[0],check[1]):
            return False
    return True

In [None]:
np.array([is_base_config(d) for d in all_run_dicts])

In [None]:
np.sum(np.array([is_base_config(d) for d in all_run_dicts]))

In [None]:
base_config_index = np.where(np.array([is_base_config(d) for d in all_run_dicts]))[0].item()
notimewarp_index = np.where(np.array([is_base_config(d,"scalar_timewarper_name") for d in all_run_dicts]))[0].item()
conv_index = np.where(np.array([is_base_config(d,["decoder_name","scalar_timewarper_name","dec_complicated_function_hidden_dims"]) for d in all_run_dicts]))[0].item()

beta_inds = list(np.where(np.array([is_base_config(d,["beta"]) for d in all_run_dicts]))[0])
no_tw_beta_inds = list(np.where(np.array([is_base_config(d,["beta","scalar_timewarper_name"]) for d in all_run_dicts]))[0])
conv_beta_inds = list(np.where(np.array([is_base_config(d,["beta","decoder_name","scalar_timewarper_name","dec_complicated_function_hidden_dims"]) for d in all_run_dicts]))[0])

print(len(all_run_dicts))

In [None]:
toplots=[base_config_index] + beta_inds+ no_tw_beta_inds + [notimewarp_index] + conv_beta_inds + [conv_index]

In [None]:
(rates, trains, tests) = save_table("beta",toplots)

In [None]:
trains

In [None]:
names = ["TimewarpVAE" if all_run_dicts[i]["scalar_timewarper_name"] == "modeled_scalar_timewarper" else 
         "beta-VAE" if all_run_dicts[i]["decoder_name"] == "convolutional_decoder_upsampling" else 
         "NoTimewarp" for i in toplots]
betas = np.array([all_run_dicts[i]["beta"] for i in toplots])

In [None]:
outstring = ""
prevname = None
almost_sorted = np.flip(np.lexsort((-betas,names)))
sorted_indices = [almost_sorted[j] for j in [3,4,5,6,7,8,0,1,2]]
for i in sorted_indices:
    if betas[i] == 0.0001:
        continue
    (n,b,r,trai,test) = (names[i], betas[i], rates[i], trains[i],tests[i])
    if prevname is not None and n == prevname:
        thisname = ""
    else:
        thisname = n
    openbrace="{"
    closebrace="}"
    if n == "TimewarpVAE" and b == 0.1:
        outstring += (f"{thisname} & {b} & {r[1].item():.3f} & "
                    +f"\\textbf{openbrace}{trai[1].item():.3f} $\\pm$ {3*trai[2].item():.3f}{closebrace} & "
                    +f"{test[1].item():.3f} $\\pm$ {3*test[2].item():.3f} \\\\\n")
    elif n == "TimewarpVAE" and b == 0.01:
        outstring += (f"{thisname} & {b} & {r[1].item():.3f} & "
                    +f"{trai[1].item():.3f} $\\pm$ {3*trai[2].item():.3f} & "
                    +f"\\textbf{openbrace}{test[1].item():.3f} $\\pm$ {3*test[2].item():.3f}{closebrace} \\\\\n")
    else:
        outstring += (f"{thisname} & {b} & {r[1].item():.3f} & "
                    +f"{trai[1].item():.3f} $\\pm$ {3*trai[2].item():.3f} & "
                    +f"{test[1].item():.3f} $\\pm$ {3*test[2].item():.3f}\\\\\n")
    prevname = n
print(outstring)

In [None]:
with open("forkResultsTable.tex", "w") as text_file:
    text_file.write(outstring[:-2]) # the last line needs its return to be hardcoded into the main tex file