In [None]:
import pandas as pd
import json
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def clean(rawdata):
    return rawdata.replace("'","\"").replace("False","0").replace("True","1").replace("None","\"None\"")

In [None]:
def safe_check_required_key_value(dictionary, key, value):
    return key in dictionary.keys() and dictionary[key] == value

In [None]:
def match_dict_values(curdict, goaldict, ignore_keys_list):
    for key in goaldict.keys():
        if key not in ignore_keys_list:
            if not safe_check_required_key_value(curdict, key, goaldict[key]):
                return False
    return True

In [None]:
all_run_dicts = []
rawdata = pd.read_csv("../project.csv")
rawdata = rawdata.reset_index()  # make sure indexes pair with number of rows
for index, row in rawdata.iterrows():
    configdict = json.loads(clean(row.config))
    if not any([match_dict_values(configdict, prd,["logname","model_save_dir","latent_dim"]) for prd in all_run_dicts]):
        all_run_dicts.append(configdict)

In [None]:
def create_dataset_combine(sweep_dictionaries,ignore_keys_list = ["logname","model_save_dir","latent_dim"]):
    all_data = []
    for sweep_dictionary in sweep_dictionaries:
        data = create_dataset(sweep_dictionary,ignore_keys_list)
        all_data.append(data)
    all_data = np.concatenate(all_data, axis=0)
    return all_data

In [None]:
def create_dataset(sweep_dictionary,ignore_keys_list = ["logname","model_save_dir","latent_dim"], minstamp = None):
    data = []
    rawdata = pd.read_csv("../project.csv")
    rawdata = rawdata.reset_index()  # make sure indexes pair with number of rows

    for index, row in rawdata.iterrows():
        configdict = json.loads(clean(row.config))
        if match_dict_values(configdict,sweep_dictionary,ignore_keys_list):
            rowdata = json.loads(clean(row.summary))
            if "train_alignedRMSE" in rowdata.keys() and rowdata["_step"] >= 99:
                if minstamp is None or configdict["logname"][21:43] >= minstamp:
                    data.append((configdict["latent_dim"],
                                 rowdata["train_alignedRMSE"],
                                 rowdata["test_alignedRMSE"],
                                 rowdata["train_noiselessRMSE"],
                                 rowdata["test_noiselessRMSE"],
                                 rowdata["train_KLD"]))
    data = np.array(data)       
    return data

In [None]:
datasets = [create_dataset(rd) for rd in all_run_dicts]

In [None]:
[print(d.shape) for d in datasets]

In [None]:
datasets

In [None]:
def plot_with_average_variance_line(ax,data,color,label,markerstyle):
    #ax.scatter(data[:,0],data[:,1],c=color,alpha=1,marker=markerstyle)
    xvals = []
    yvals = []
    stdyvals = []
    for x,y in data:
        if x not in xvals:
            xvals.append(x)
            yvals.append(np.mean(data[data[:,0]==x,1]))
            stdyvals.append(np.std(data[data[:,0]==x,1]))
    xvals = np.array(xvals)
    yvals = np.array(yvals)
    stdyvals = np.array(stdyvals)
    order = np.argsort(xvals)

    ax.plot(xvals[order],yvals[order],c=color,label=label,zorder=2)
    ax.scatter(xvals[order],yvals[order],c=color,alpha=1,marker=markerstyle,s=100,edgecolors= "k",linewidth=1,zorder=3)
    ax.fill_between(xvals[order], yvals[order]-stdyvals[order], yvals[order]+stdyvals[order], facecolor=color, alpha=0.5,zorder=1)

In [None]:

pcaresults = np.load("../paper_calculations/pca_results.npy")
pcaresults = pcaresults[pcaresults[:,0]>0]
pcaresults = pcaresults[pcaresults[:,0]<17]
datasets.append(pcaresults)

In [None]:
pcaresults

In [None]:
import matplotlib.lines as mlines 
#https://stackoverflow.com/questions/47391702/how-to-make-a-colored-markers-legend-from-scratch

In [None]:
def save_plots(namebase,rateylim,toplot = None,labels=None,colors=None,shapes=None,labeltextcolor="white"):

    fsize=(8,6)
    odds = [i for i in range(1,17) if i % 5 == 1]
    
    
    if toplot is None:
        toplot = range(len(datasets))
    if labels is None:
        labels = toplot
    if colors is None:
        colors = [f"C{i}" for i in range(len(labels))]
    if shapes is None:
        shapes = ["o"] * len(datasets)
        
    f = plt.figure(figsize=fsize)
    ax = f.add_axes([0.1, 0.1, 0.8, 0.8])  # add the left Axes
    for i,label,color,shape in zip(toplot,labels,colors,shapes):
        dataval = datasets[i]
        if dataval.shape[1] > 5:
            plotdat = np.copy(dataval[:,(0,5)])
            plotdat[:,1] = np.log2(plotdat[:,1])
            plot_with_average_variance_line(ax, plotdat, color,label=f"{label}",markerstyle=shape)
    ax.set_ylim(2,np.log2(rateylim))
    ax.set_xlim(0.5,16.5)
    ax.set_xticks(odds,odds)
    ax.set_xticks(range(1,17),minor=True)
    plt.savefig(f"{namebase}LatentSweepRate.pdf",bbox_inches="tight")
    
    f = plt.figure(figsize=fsize)
    ax = f.add_axes([0.1, 0.1, 0.8, 0.8])  # add the left Axes
    for i,label,color,shape in zip(toplot,labels,colors,shapes):
        dataval = datasets[i]
        plot_with_average_variance_line(ax, dataval[:,(0,1)],color,label=f"{label}",markerstyle=shape)
    ax.set_ylim(0,1.2)
    ax.set_xlim(0.5,16.5)
    ax.set_xticks(odds,odds)
    ax.set_xticks(range(1,17),minor=True)
    plt.savefig(f"{namebase}LatentSweepTrain.pdf",bbox_inches="tight")
    
    f = plt.figure(figsize=fsize)
    ax = f.add_axes([0.1, 0.1, 0.8, 0.8])  # add the left Axes
    legend_lines = []
    for i,label,color,shape in zip(toplot,labels,colors,shapes):
        dataval = datasets[i]
        plot_with_average_variance_line(ax, dataval[:,(0,2)],color,label=f"{label}",markerstyle=shape)
        legend_lines.append(
            mlines.Line2D([], [], color=color, marker=shape, label=f"{label}",markersize=10,markeredgecolor= "k",markeredgewidth=1))
    ax.set_ylim(0,1.2)
    ax.set_xlim(0.5,16.5)
    ax.set_xticks(odds,odds)
    ax.set_xticks(range(1,17),minor=True)
    plt.legend(handles=legend_lines,labelcolor=labeltextcolor, frameon=False)
    plt.savefig(f"{namebase}LatentSweepTest.pdf",bbox_inches="tight")
    plt.show()

In [None]:
def is_base_config(testdict,must_not_match=[]):
    checks = [("beta",0.001),
                ("scalar_timewarper_name","modeled_scalar_timewarper"),
                ("decoder_name","functional_decoder_complicated"),
                ("dec_complicated_function_hidden_dims",[200]),
                ("training_data_added_timing_noise",0.1),
                ("vector_timewarper_name","identity_vector_timewarper"),
                ("encoder_name", "convolutional_encoder")
             ]
    for check in checks:
        #print(check[0])
        if check[0] in must_not_match:
            if safe_check_required_key_value(testdict,check[0],check[1]):
                #print("The problem was NOTMATCH",testdict,check[0],check[1])
                return False
        elif not safe_check_required_key_value(testdict,check[0],check[1]):
            #print("The problem was",testdict,check[0],check[1])
            return False
    return True

In [None]:
all_run_dicts[1]

In [None]:
[d for d in all_run_dicts if is_base_config(d,["vector_timewarper_name","scalar_timewarper_name"])]

In [None]:
base_config_index = np.where(np.array([is_base_config(d) for d in all_run_dicts]))[0].item()
noise_off_index = np.where(np.array([is_base_config(d,"training_data_added_timing_noise") for d in all_run_dicts]))[0].item()
timewarp_PCA_index = np.where(np.array([is_base_config(d,"dec_complicated_function_hidden_dims") for d in all_run_dicts]))[0].item()
notimewarp_index = np.where(np.array([is_base_config(d,"scalar_timewarper_name") for d in all_run_dicts]))[0].item()

conv_index = np.where(np.array([is_base_config(d,["decoder_name","dec_complicated_function_hidden_dims","scalar_timewarper_name"]) for d in all_run_dicts]))[0].item()
no_noise_timewarp_PCA_index = np.where(np.array([is_base_config(d,["training_data_added_timing_noise","dec_complicated_function_hidden_dims"]) for d in all_run_dicts]))[0].item()
no_timewarp_PCA_index = np.where(np.array([is_base_config(d,["scalar_timewarper_name","dec_complicated_function_hidden_dims"]) for d in all_run_dicts]))[0].item()

beta_inds = list(np.where(np.array([is_base_config(d,["beta"]) for d in all_run_dicts]))[0])
no_tw_beta_inds = list(np.where(np.array([is_base_config(d,["beta","scalar_timewarper_name"]) for d in all_run_dicts]))[0])
conv_beta_inds = list(np.where(np.array([is_base_config(d,["beta","decoder_name","dec_complicated_function_hidden_dims","scalar_timewarper_name"]) for d in all_run_dicts]))[0])

dtw_index = np.where(np.array([is_base_config(d,["vector_timewarper_name", "scalar_timewarper_name"]) for d in all_run_dicts]))[0].item()
trans_index = np.where(np.array([is_base_config(d,["encoder_name","scalar_timewarper_name"]) for d in all_run_dicts]))[0].item()

# 4 + 3 + 2 + 2 + 2
print(len(all_run_dicts))

In [None]:
import matplotlib
font = {        'size'   : 22}

matplotlib.rc('font', **font)

In [None]:
save_plots("ablation",rateylim=115,toplot=[base_config_index,noise_off_index,timewarp_PCA_index,notimewarp_index],
           labels=["TimewarpVAE","ndaug","nnonlin","NoTimewarp"],
          shapes=["o","P","X","."])

In [None]:
save_plots("conv",rateylim=500,toplot=[base_config_index,no_timewarp_PCA_index,no_noise_timewarp_PCA_index],
           labels=["TimewarpVAE","NoTWNoNonlinear","NoNoiseNoNonlinear"])

In [None]:
save_plots("beta",rateylim=100,toplot=[base_config_index] + beta_inds)
save_plots("BetaNoTw",rateylim=100,toplot=[base_config_index] + no_tw_beta_inds + [notimewarp_index])

In [None]:
def rate_distortion(namebase,rateylim,latent_dim,toplot,colors=None,shapes=None):
    fsize=(8,6)
    if colors is None:
        colors = [f"C{i}" for i in range(2)]
    if shapes is None:
        shapes = ["o"] * len(datasets)
        
    f = plt.figure(figsize=fsize)
    ax = f.add_axes([0.1, 0.1, 0.8, 0.8])  # add the left Axes
    for i in toplot:
        dataval = datasets[i]
        dataval = dataval[dataval[:,0]==latent_dim]
        ax.scatter(dataval[:,5],dataval[:,1],c=colors[0],label=f"Training" if i == 0 else None,marker=shapes[0])
        ax.scatter(dataval[:,5],dataval[:,2],c=colors[1],label=f"Test" if i == 0 else None,marker=shapes[1])
    ax.set_xlim(0,rateylim)
    ax.set_ylim(0,1)
    plt.legend()
    plt.savefig(f"{namebase}RateDistortion.pdf",bbox_inches="tight")
    

In [None]:
rate_distortion("base",100,16,[base_config_index] + beta_inds,colors=None,shapes=None)

In [None]:
# These base distortion results make sense compared to bottom right of figure 3 in http://proceedings.mlr.press/v130/bozkurt21a/bozkurt21a.pdf
# but they aren't very "excitign"

In [None]:
all_conv_inds = conv_beta_inds + [conv_index]
conv_betas = [all_run_dicts[ix]["beta"] for ix in all_conv_inds]
sortorder = np.argsort(conv_betas)
sorted_conv_inds = [all_conv_inds[s] for s in sortorder]
all_timewarpvae_indices =  [base_config_index] + beta_inds
timewarpvae_betas = [all_run_dicts[ix]["beta"] for ix in all_timewarpvae_indices]
twvae_sortorder = np.argsort(timewarpvae_betas)
sorted_twvae_inds = [all_timewarpvae_indices[s] for s in twvae_sortorder]


# just plot the first 2 betas
labels = [f"TimewarpVAEbt{timewarpvae_betas[s]}" for s in twvae_sortorder[:2]] + [f"Conv{conv_betas[s]}" for s in sortorder[:2]]
save_plots("BetaConvTwo",rateylim=500,
           toplot=sorted_twvae_inds[:2] + sorted_conv_inds[:2] + [-1], 
           labels=labels + ["PCA"],
           colors=["C0"]*2 + ["C1"]*2 +["C2"],
          shapes=["o","X"]*2 + ["."],
          labeltextcolor="white")

In [None]:
all_conv_inds = conv_beta_inds + [conv_index]
conv_betas = [all_run_dicts[ix]["beta"] for ix in all_conv_inds]
sortorder = np.argsort(conv_betas)
sorted_conv_inds = [all_conv_inds[s] for s in sortorder]
all_timewarpvae_indices =  [base_config_index] + beta_inds
timewarpvae_betas = [all_run_dicts[ix]["beta"] for ix in all_timewarpvae_indices]
twvae_sortorder = np.argsort(timewarpvae_betas)
sorted_twvae_inds = [all_timewarpvae_indices[s] for s in twvae_sortorder]

labels = [f"TimewarpVAEbt{timewarpvae_betas[s]}" for s in twvae_sortorder] + [f"Conv{conv_betas[s]}" for s in sortorder]
save_plots("BetaConv",rateylim=500,toplot=sorted_twvae_inds + sorted_conv_inds, 
           labels=labels,
           colors=["C0"]*3 + ["C1"]*3,
          shapes=["o","P","X"]*2,
          labeltextcolor="white")

In [None]:
dtw_index
trans_index
base_config_index

save_plots("TimewarpVAEDTW",rateylim=100,toplot=[base_config_index,dtw_index], 
           labels=["TimewarpVAE", "beta-VAE DTW---sp"],
           colors=["C0","C1"],
          shapes=["o","P"],
          labeltextcolor="white")