In [None]:
import os
from tqdm import tqdm
from utils import pltManager
import matplotlib.pyplot as plt
from args_settings import select_args_specifications
from load_and_test_models import visualize_args_latent_space

In [None]:
translation = {
    "yahoo":"Yahoo", "yelp":"Yelp", "snli":"SNLI", "short_yelp":"Short-Yelp",
}

In [None]:
def visualize(model_names, args_specifications, save_name=None):
    # given models * four datasets * two modes * three axis tuples
    save_path = f"{save_name if save_name else model_names[0]}.png"
    lock_path = f"lock-{save_path}.txt"
    if os.path.exists(save_path):
        return
    if os.path.exists(lock_path):
        return
    with open(lock_path, "w") as f:
        f.write("")
    
    assert len(model_names)*4 == len(args_specifications)
    model_names = model_names*4
    
    select_axis = [[0, 1], [6, 7], [12, 13], [18, 19], [24, 25], [30, 31]]
    mode = ["aggregated", "center"]
    columns = len(select_axis) * len(mode)
    lines = len(args_specifications)
    pltm = pltManager(plt, columns, lines)
    
    processed = 0
    for args_specification,model_name in tqdm(zip(args_specifications,model_names), total=len(model_names)):
        with open(lock_path, "w") as f:
            f.write(f"{processed}/{len(model_names)}")
        visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0,
                                    model_name=model_name, dataset=translation[args_specification[1]])
        processed += 1
    
    pltm.plt.savefig(save_path, dpi=300)
    os.remove(lock_path)

In [None]:
#    VAE (default)   & -330.7 & -330.7 & 0.0 & 0.0 & 0 & 32 \\
model_names = ["VAE (default)"]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "GaussianLSTMEncoder"
    ]
]
visualize(model_names, args_specifications, "VAE (default)")

In [None]:
#    cyclic-VAE      & -329.8 & -328.9 & 1.1 & 1.0 & 2 & 31 \\
model_names = ["cyclic-VAE"]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--cycle", "20",
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "GaussianLSTMEncoder"
    ]
]
visualize(model_names, args_specifications, "cyclic-VAE")

In [None]:
#    bow-VAE         & -330.5 & -330.5 & 0.0 & 0.0 & 0 & 32 \\
model_names = ["bow-VAE"]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--add_bow",
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "GaussianLSTMEncoder"
    ]
]
visualize(model_names, args_specifications, "bow-VAE")

In [None]:
#    skip-VAE        & -330.1 & -325.2 & 5.0 & 4.3 & 8 & 31 \\
model_names = ["skip-VAE"]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--add_skip",
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "GaussianLSTMEncoder"
    ]
]
visualize(model_names, args_specifications, "skip-VAE")

In [None]:
#    $\delta$-VAE(0.15)  & -330.5 & -330.6 & 4.8 & 0.0 & 0 & 0 \\
model_names = [r"$\delta$-VAE(0.15)"]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "DeltaGaussianLSTMEncoder"
    ]
]
visualize(model_names, args_specifications, "delta-VAE")

In [None]:
#    BN-VAE(0.6)     & -327.6 & -321.1 & 6.6 & 5.9 & 32 & 32 \\
#    %BN-VAE(0.7)     & -326.8 & -318.4 & 9.1 & 7.5 & 32 & 32 \\
#    BN-VAE(0.9)     & -327.0 & -313.8 & 15.5 & 9.0 & 32 & 32 \\
#    %BN-VAE(1.2)     & -330.9 & -310.1 & 26.2 & 9.2 & 32 & 0 \\
#    %BN-VAE(1.5)     & -337.8 & -310.2 & 37.6 & 9.2 & 32 & 0 \\
#    BN-VAE(1.8)     & -343.5 & -308.6 & 51.3 & 9.2 & 32 & 0 \\
model_names = [f"BN-VAE({gamma:.1f})" for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--gamma", str(gamma),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "BNGaussianLSTMEncoder"
    ] for gamma in [
        0.6, 0.7, 0.9, 1.2, 1.5, 1.8
    ]
]
visualize(model_names, args_specifications, "BN-VAEs")

In [None]:
#    FB-VAE(4)       & -329.8 & -328.4 & 3.9 & 1.8 & 32 & 32 \\
#    %FB-VAE(9)       & -327.8 & -326.3 & 8.8 & 4.1 & 32 & 12 \\
#    FB-VAE(16)      & -325.7 & -320.8 & 16.1 & 8.5 & 32 & 8 \\
#    %FB-VAE(25)      & -333.4 & -316.2 & 25.8 & 9.2 & 32 & 0 \\
#    %FB-VAE(36)      & -341.3 & -307.0 & 36.9 & 9.2 & 32 & 0 \\
#    FB-VAE(49)      & -344.6 & -296.1 & 50.0 & 9.2 & 32 & 0 \\
model_names = [f"FB-VAE({target_kl:d})" for target_kl in [4, 9, 16, 25, 36, 49]]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--target_kl", str(target_kl),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "FineFBGaussianLSTMEncoder"
    ] for target_kl in [
        4, 9, 16, 25, 36, 49
    ]
]
visualize(model_names, args_specifications, "FB-VAEs")

In [None]:
#    %$\beta$-VAE(0.8)    & -330.1 & -328.5 & 2.0 & 1.9 & 2 & 30 \\
#    $\beta$-VAE(0.4)    & -330.8 & -324.8 & 7.0 & 6.7 & 3 & 31 \\
#    $\beta$-VAE(0.2)    & -338.6 & -310.3 & 30.1 & 9.2 & 22 & 25 \\
#    $\beta$-VAE(0.1)    & -369.9 & -289.6 & 83.7 & 9.2 & 32 & 0 \\
#    %$\beta$-VAE(0.0)    & -445.2 & -280.3 & 178.8 & 9.2 & 32 & 0 \\
model_names = [r"$\beta$-VAE({})".format(beta) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--kl_beta", str(kl_beta),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "GaussianLSTMEncoder"
    ] for kl_beta in [
        1.0, 0.8, 0.4, 0.2, 0.1, 0.0
    ]
]
visualize(model_names, args_specifications, "Beta-VAEs")

In [None]:
#    DG-VAE ($|b|=1$)   & -330.7 & -330.7 & 0.0 & 0.0 & 0 & 32 \\
#    %DG-VAE ($|b|=2$)   & -330.1 & -326.4 & 4.1 & 4.0 & 4 & 32 \\
#    DG-VAE ($|b|=4$)   & -330.4 & -318.3 & 14.3 & 9.1 & 11 & 32 \\
#    %DG-VAE ($|b|=8$)   & -338.2 & -308.2 & 32.1 & 9.1 & 30 & 32 \\
#    %DG-VAE ($|b|=16$)  & -349.5 & -295.0 & 57.6 & 9.1 & 32 & 32 \\
#    DG-VAE ($|b|=32$)  & -355.4 & -294.1 & 65.2 & 9.1 & 32 & 32 \\
model_names = [r"DG-VAE ($|b|={}$)".format(agg_size) for agg_size in [1,2,4,8,16,32]]
args_specifications = [
    [
        "--dataset", str(dataset),
        "--encoder_class", str(encoder_class),
        "--agg_size", str(agg_size),
    ] for dataset in [
        "yahoo","yelp","snli","short_yelp"
    ] for encoder_class in [
        "DGGaussianLSTMEncoder"
    ] for agg_size in [
        1, 2, 4, 8, 16, 32
    ]
]
visualize(model_names, args_specifications, "DG-VAEs")

In [None]:
if __name__ == "__main__":
    save_path = "collapse_and_hole_5.png"
    if not os.path.exists(f"{save_path}") and not os.path.exists(f"lock-{save_path}"):
        with open(f"lock-{save_path}", "w") as f:
            f.write("")
        
        args_specifications = [
            [
                "--dataset", str(dataset),
                "--encoder_class", str(encoder_class),
                str(specific_1), str(specific_2)
            ] for dataset in [
                "yahoo",
            ] for encoder_class, specific_1, specific_2 in [
                ("GaussianLSTMEncoder", "", ""),
                ("BNGaussianLSTMEncoder", "--gamma", "1.2"),
                ("GaussianLSTMEncoder", "--kl_beta", "0.4"),
                ("FineFBGaussianLSTMEncoder", "--target_kl", "36"),
                ("DGGaussianLSTMEncoder", "", "")
            ]
        ]
        model_names = ["VAE", "BN-VAE(1.2)", r"$\beta$-VAE(0.4)",  "FB-VAE(36)","DG-VAE"]

        select_datasets = range(1)
        select_models = range(len(model_names))
        select_axis = ["max_var"]
        mode = ["aggregated", "center"]
        columns, lines = len(select_axis)*len(select_models), len(select_datasets)*len(mode)
        pltm = pltManager(plt, columns, lines)

        num_models = len(select_models)
        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]
        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):
            visualize_args_latent_space(args_specification, pltm, select_axis, ["aggregated"], verbose=0, model_name=model_name)
        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):
            visualize_args_latent_space(args_specification, pltm, select_axis, ["center"], verbose=0, model_name=model_name)
        #pltm.plt.suptitle("title", fontsize=30, ha="center")

        pltm.plt.savefig(f"{save_path}", dpi=300)
        os.remove(f"lock-{save_path}")

In [None]:
if __name__ == "__main__":
    save_path = "gaussian_models_max.png"
    if not os.path.exists(f"{save_path}") and not os.path.exists(f"lock-{save_path}"):
        with open(f"lock-{save_path}", "w") as f:
            f.write("")
        
        args_specifications = [
            [
                "--dataset", str(dataset),
                "--encoder_class", str(encoder_class),
                str(specific_1), str(specific_2)
            ] for dataset in [
                "yahoo",
            ] for encoder_class, specific_1, specific_2 in ([
                ("BNGaussianLSTMEncoder", "--gamma", str(gamma)) for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]
            ] + [
                ("FineFBGaussianLSTMEncoder", "--target_kl", str(target_kl)) for target_kl in [4, 9, 16, 25, 36, 49]
            ] + [
                ("GaussianLSTMEncoder", "--kl_beta", str(beta)) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]
            ] + [
                ("DGGaussianLSTMEncoder", "--agg_size", str(agg_size)) for agg_size in [1,2,4,8,16,32]
            ])
        ]
        model_names = [
            f"BN-VAE({gamma:.1f})" for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]
        ] + [
            f"FB-VAE({target_kl:d})" for target_kl in [4, 9, 16, 25, 36, 49]
        ] + [
            r"$\beta$-VAE({})".format(beta) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]
        ] + [
            r"DG-VAE ($|b|={}$)".format(agg_size) for agg_size in [1,2,4,8,16,32]
        ]
        print(model_names)

        select_datasets = range(1)
        select_models = range(24)
        select_axis = ["max_var"]
        mode = ["aggregated", "center"]
        columns, lines = len(select_axis)*len(mode)*6, len(select_datasets)*4
        pltm = pltManager(plt, columns, lines)

        num_models = len(select_models)
        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]
        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)):
            visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0, model_name=model_name)
        #pltm.plt.suptitle("title", fontsize=30, ha="center")

        pltm.plt.savefig(f"{save_path}", dpi=300)
        os.remove(f"lock-{save_path}")

In [None]:
if __name__ == "__main__":
    save_path = "collapse_and_hole.png"
    if not os.path.exists(f"{save_path}") and not os.path.exists(f"lock-{save_path}"):
        with open(f"lock-{save_path}", "w") as f:
            f.write("")
        
        args_specifications = [
            [
                "--dataset", str(dataset),
                "--encoder_class", str(encoder_class),
                str(specific_1), str(specific_2)
            ] for dataset in [
                "yahoo","yelp"
            ] for encoder_class, specific_1, specific_2 in [
                ("GaussianLSTMEncoder", "", ""),
                ("GaussianLSTMEncoder", "--kl_beta", "0.0"),
                ("DGGaussianLSTMEncoder", "", "")
            ]
        ]
        model_names = ["VAE", "AE", "DG-VAE"]

        select_datasets = range(1)
        select_models = range(3)
        select_axis = ["mid_var"]
        mode = ["aggregated", "center"]
        columns, lines = len(select_axis)*len(mode)*len(select_models), len(select_datasets)
        pltm = pltManager(plt, columns, lines)

        num_models = len(select_models)
        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]
        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):
            visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0, model_name=model_name)
        #pltm.plt.suptitle("title", fontsize=30, ha="center")

        pltm.plt.savefig(f"{save_path}", dpi=300)
        os.remove(f"lock-{save_path}")