In [None]:
import os
import torch
import sys
import yaml
from torch.utils.data import DataLoader
from pathlib import Path
import pandas as pd
import json
import numpy as np
from trade.datasets import GMM,DataSet2DGMM
from trade.models import set_up_sequence_INN_DoubleWell
import matplotlib.pyplot as plt
from functools import partial
from trade.plots import eval_pdf_on_grid_2D
device = "cuda:0" if torch.cuda.is_available() else "cpu"

Define some helper functions:

---

In [None]:
def load_INN(base_path:str,use_last:bool = False,device:str = "cuda:0"):

    config_i = yaml.safe_load(Path(base_path + "/hparams.yaml").read_text())
    state_dict_folder_i = base_path + f"/checkpoints/"

    files = os.listdir(state_dict_folder_i)

    
    for f in files:
        print(f)
        #Use the last recorded state dict
        if use_last:

            if f == "last.ckpt":
                state_dict_path_i = os.path.join(state_dict_folder_i,f)
                break

        #Use the best performing state dict
        else:
            if f.startswith("checkpoint_epoch"):
                state_dict_path_i = os.path.join(state_dict_folder_i,f)
                break

    config_i["device"] = device

    INN_i = set_up_sequence_INN_DoubleWell(config=config_i)
    INN_i.load_state_dict(state_dict_path_i)
    INN_i.train(False)

    print(state_dict_path_i)

    return INN_i,config_i

In [None]:
def p_beta(x,beta,gmm,Z = None):
    q_beta = gmm(x).pow(beta)

    if Z is None:
        return q_beta
    
    else:
        return q_beta / Z

In [None]:
def get_validation_loader_dict_2D_GMM(T_list_eval,n_samples):
    validation_data_loader_dict = {}

    for i,T_i in enumerate(T_list_eval):

        T_i = round(T_i,5)
        print(f"Loading validation data for T = {T_i}")

        DS_i = DataSet2DGMM(
            d = 2,
            mode = "validation",
            temperature_list=[T_i],
            base_path="../data/2D_GMM/",
            n_samples=n_samples
            )

        DL_i = DataLoader(
            DS_i,
            batch_size = 1000,
            shuffle = False,
            num_workers = 4
        )

        validation_data_loader_dict[f"{T_i}"] = DL_i

    return validation_data_loader_dict

Load the validation data

---

In [None]:
T_list = torch.linspace(np.log(0.1),np.log(10),20).exp()
T_list = torch.cat((T_list,torch.tensor([1.0])))
T_list = [round(T_list.sort().values[i].item(),5) for i in range(len(T_list))]

a = 7
T_list_eval = T_list[10 - a:-(10 - a)]

validation_data_loader_dict = get_validation_loader_dict_2D_GMM(T_list_eval = T_list_eval,n_samples = 80000)


Set the paths to the trained models

---

In [None]:
home_dir = os.path.expanduser("~")

base_path_TRADE_grid =          "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_TRADE_no_grid =       "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_reverse_KL =          "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_reverse_KL_nll =      "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_nll_only =            "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_reweighting =         "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"
base_path_volume_preserving =   "../results/runs_2D_GMM/Path to your trained model/lightning_logs/version_0"

base_paths_dict = {
    "TRADE_grid":base_path_TRADE_grid,
    "TRADE_no_grid":base_path_TRADE_no_grid,
    "nll_only":base_path_nll_only,
    "reverse_KL":base_path_reverse_KL,
    "reverse_KL_nll":base_path_reverse_KL_nll,
    "reweighting":base_path_reweighting,
    "volume preserving":base_path_volume_preserving
}

Load the models

---

In [None]:
INN_dict = {}
config_dict = {}
for key in base_paths_dict:
    INN_k,config_k = load_INN(base_path = base_paths_dict[key],device=device,use_last=False)
    INN_dict[key] = INN_k
    config_dict[key] = config_k

Evaluate the validation nll

---

Initialize the target distribution

In [None]:
means = torch.tensor([
            [-1.0,2.0],
            [3.0,7.0],
            [-4.0,2.0],
            [-2.0,-4.0],
            [0.0,4.0],
            [5.0,-2.0]
        ])

#Covariance matrices
S = torch.tensor([
        [[ 0.2778,  0.4797],
        [ 0.4797,  0.8615]],

        [[ 0.8958, -0.0249],
        [-0.0249,  0.1001]],

        [[ 1.3074,  0.9223],
        [ 0.9223,  0.7744]],

        [[ 0.0305,  0.0142],
        [ 0.0142,  0.4409]],

        [[ 0.0463,  0.0294],
        [ 0.0294,  0.3441]],
        
        [[ 0.15,  0.0294],
        [ 0.0294,  1.5]]])

gmm = GMM(means = means,covs=S,device=device)

# Load the approximated partition functions for the power-scaled target distribution
with open("../data/2D_GMM/Z_T.json","r") as f:
    Z_T_dict = json.load(f)
f.close()

In [None]:
val_KLD_dicts = {}
error_val_KLD_dicts = {}

n_bootstrap = 20

for T_i in T_list_eval:
    val_KLD_dicts[f"{T_i}"] = {}
    error_val_KLD_dicts[f"{T_i}"] = {}

with torch.no_grad():
    for k in INN_dict:
        print("evaluate ",k)

        for T_i in T_list_eval:
            T_i = round(T_i,5)
            
            DL_i = validation_data_loader_dict[f"{T_i}"]

            log_p_theta_val = torch.zeros([0])
            log_p_gt_val = torch.zeros([0])

            for j,(beta_batch,x_batch) in enumerate(DL_i):
                
                #Model log likelihood
                log_p_theta_val_i = INN_dict[k].log_prob(x_batch.to(device),beta_tensor=beta_batch.to(device))
                log_p_theta_val = torch.cat((log_p_theta_val,log_p_theta_val_i.detach().cpu()),0)

                #Ground truth log likelihood
                log_p_gt_val_i = p_beta(x_batch.to(device),beta = 1 / T_i,gmm = gmm,Z = Z_T_dict[f"{T_i}"]).log()
                log_p_gt_val = torch.cat((log_p_gt_val,log_p_gt_val_i.detach().cpu()),0)

            assert(log_p_gt_val.shape == log_p_theta_val.shape)
            
            #Apply bootstrapping to estimate the deviation of the evaluation nlls
            samples = np.zeros(n_bootstrap)

            for i in range(n_bootstrap):
                indices = np.random.randint(0,len(log_p_theta_val),len(log_p_theta_val))
            
                samples[i] = (log_p_gt_val[indices] - log_p_theta_val[indices]).mean()

            mean_samples = samples.mean()
            error_i = np.sqrt(np.square(samples - mean_samples).sum() / (n_bootstrap - 1))
            error_val_KLD_dicts[f"{T_i}"][k] = error_i    

            #Get the log likelihood of the validation set
            val_KLD_i = (log_p_gt_val - log_p_theta_val).mean().item()
            val_KLD_dicts[f"{T_i}"][k] = val_KLD_i

In [None]:
df = pd.DataFrame(val_KLD_dicts)

# Function to highlight the minimum value in each column
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold;' if v else '' for v in is_min]

# Apply the style
df.style.apply(highlight_min, subset=pd.IndexSlice[:, :])

In [None]:
T_print_list = [[0.20691,0.33598,0.54556],[0.88587,1.0,1.12884],[1.83298,2.97635,4.83293]]

row_name_dict = {
    "nll_only":"NLL + lat. TS",
    "TRADE_grid":"TRADE (grid)",
    "TRADE_no_grid":"TRADE (no grid)",
    "reverse_KL":"Reverse KLD",
    "reverse_KL_nll":"NLL + Reverse KLD",
    "reweighting":"Reweighting",
    "volume preserving":"Volume Preserving",
    "gt":"Ground Truth"
}

highlight_color = "lightgray"
rows_to_highlight = ["TRADE_grid","TRADE_no_grid"]

#Get the best value in each colume:
is_best_dict = {}
for c,T_print in enumerate(T_print_list):
    for T_i in T_print:

        is_best_dict[f"{T_i}"] = {}

        min_key = None

        for k in INN_dict:
            is_best_dict[f"{T_i}"][k] = False

            if (min_key is None) or (val_KLD_dicts[f'{T_i}'][k] < val_KLD_dicts[f'{T_i}'][min_key]):
                min_key = k

        is_best_dict[f"{T_i}"][min_key] = True

    table_str = "\\begin{tabularx}{\\textwidth}{|c|"

    for i in range(len(T_print)):
        table_str = table_str + ">{\centering\\arraybackslash}X|"
    table_str = table_str+ "}\n\hline\n"

    #Column names
    for T_i in T_print:
        table_str += f"&KLD $T = {T_i}\downarrow$"
    table_str += "\\\\\n\hline\n"

    for k in base_paths_dict.keys():

        if k in rows_to_highlight:
            table_str += "\\rowcolor{" + highlight_color + "}"

        table_str += f"{row_name_dict[k]}"

        for T_i in T_print:

            magnitude = np.floor(np.log10(abs( error_val_KLD_dicts[f"{T_i}"][k]))) 
            magnitude = abs(int(magnitude - 2))

            if is_best_dict[f'{T_i}'][k]:
                table_str += "&\\textbf{"+ f"{round(val_KLD_dicts[f'{T_i}'][k],magnitude)}$\pm${round(error_val_KLD_dicts[f'{T_i}'][k],magnitude)}"+"}"
            else:
                table_str += f"&{round(val_KLD_dicts[f'{T_i}'][k],magnitude)}$\pm${round(error_val_KLD_dicts[f'{T_i}'][k],magnitude)}"

        table_str += "\\\\\n"
    table_str += "\hline\n"

    table_str = table_str +"\end{tabularx}"
    print(f"%subtable {c+1}")
    print(table_str)

    print("")

Plot the densities

---

Best performing models

In [None]:
cmap = "jet"
lim_list_grid = [[-9,9],[-9,9]]
res_list_grid = [500,500]
fs = 35
T_list_plotting = [0.20691,0.54556,1.0,1.83298,4.83293]

In [None]:
INN_dict_last_cp = {}

for key in base_paths_dict:

    INN_last_i,_ = load_INN(base_path = base_paths_dict[key],use_last = False)

    INN_dict_last_cp[key] = INN_last_i

In [None]:
fig,axes = plt.subplots(len(base_paths_dict.keys())+1,len(T_list_plotting),figsize = (len(T_list_plotting) * 5,(1 +len(base_paths_dict.keys())) * 5))


with torch.no_grad():
    for i,T_i in enumerate(T_list_plotting):

        #Ground truth distribution
        p = partial(p_beta,gmm = gmm,beta = 1 / T_i, Z = Z_T_dict[f"{T_i}"])

        pdf_grid,x_grid,y_grid = eval_pdf_on_grid_2D(
            pdf=p,
            x_lims = lim_list_grid[0],
            y_lims = lim_list_grid[1],
            x_res = res_list_grid[0],
            y_res = res_list_grid[1],
            device = device
        )

        grid_dict_i = {"gt":pdf_grid.detach().cpu()}
        min_val = pdf_grid.min()
        max_val = pdf_grid.max()

        for k in base_paths_dict:

            p_k = partial(INN_dict_last_cp[k].log_prob,beta_tensor = 1 / T_i)
            pdf_grid_k,x_grid,y_grid = eval_pdf_on_grid_2D(
                pdf=p_k,
                x_lims = lim_list_grid[0],
                y_lims = lim_list_grid[1],
                x_res = res_list_grid[0],
                y_res = res_list_grid[1],
                device = device
            )

            pdf_grid_k = pdf_grid_k.detach().cpu().exp()

            grid_dict_i[k] = pdf_grid_k

            if min_val > pdf_grid_k.min():
                min_val = pdf_grid_k.min()

            if max_val < pdf_grid_k.max():
                max_val = pdf_grid_k.max()

        axes[0][i].set_title(f"c = {round(1 / T_i,4)}",fontsize = fs)
        for j,k in enumerate(grid_dict_i.keys()):

            axes[j][i].imshow(
                grid_dict_i[k],
                extent = [x_grid.detach().cpu().min(),
                x_grid.detach().cpu().max(),
                y_grid.detach().cpu().min(),
                y_grid.detach().cpu().max()],
                origin = 'lower',
                #vmin = min_val,
                #vmax = max_val,
                cmap = cmap
                )

            axes[j][i].set(yticklabels=[])  # remove the tick labels
            axes[j][i].tick_params(left=False)

            axes[j][i].set(xticklabels=[])  # remove the tick labels
            axes[j][i].tick_params(bottom=False)

            if i == 0:
                axes[j][0].set_ylabel(row_name_dict[k],fontsize = fs)

            #Label
            axes[j][i].text(-8.5,7.0, f"{chr(ord('A') + j)}{i+1}", fontsize = fs,c = "w")

            plt.tight_layout()

    plt.savefig("densities_2D_GMM_best.pdf")
    plt.close(fig)

Model at the end of the training

In [None]:
INN_dict_last_cp = {}

for key in base_paths_dict:

    INN_last_i,_ = load_INN(base_path = base_paths_dict[key],use_last = True)

    INN_dict_last_cp[key] = INN_last_i

In [None]:
fig,axes = plt.subplots(len(base_paths_dict.keys())+1,len(T_list_plotting),figsize = (len(T_list_plotting) * 5,(1 +len(base_paths_dict.keys())) * 5))


with torch.no_grad():
    for i,T_i in enumerate(T_list_plotting):

        #Ground truth distribution
        p = partial(p_beta,gmm = gmm,beta = 1 / T_i, Z = Z_T_dict[f"{T_i}"])

        pdf_grid,x_grid,y_grid = eval_pdf_on_grid_2D(
            pdf=p,
            x_lims = lim_list_grid[0],
            y_lims = lim_list_grid[1],
            x_res = res_list_grid[0],
            y_res = res_list_grid[1],
            device = device
        )

        grid_dict_i = {"gt":pdf_grid.detach().cpu()}
        min_val = pdf_grid.min()
        max_val = pdf_grid.max()

        for k in base_paths_dict:

            p_k = partial(INN_dict_last_cp[k].log_prob,beta_tensor = 1 / T_i)
            pdf_grid_k,x_grid,y_grid = eval_pdf_on_grid_2D(
                pdf=p_k,
                x_lims = lim_list_grid[0],
                y_lims = lim_list_grid[1],
                x_res = res_list_grid[0],
                y_res = res_list_grid[1],
                device = device
            )

            pdf_grid_k = pdf_grid_k.detach().cpu().exp()

            grid_dict_i[k] = pdf_grid_k

            if min_val > pdf_grid_k.min():
                min_val = pdf_grid_k.min()

            if max_val < pdf_grid_k.max():
                max_val = pdf_grid_k.max()

        axes[0][i].set_title(f"c = {round(1 / T_i,4)}",fontsize = fs)
        for j,k in enumerate(grid_dict_i.keys()):

            axes[j][i].imshow(
                grid_dict_i[k],
                extent = [x_grid.detach().cpu().min(),
                x_grid.detach().cpu().max(),
                y_grid.detach().cpu().min(),
                y_grid.detach().cpu().max()],
                origin = 'lower',
                #vmin = min_val,
                #vmax = max_val,
                cmap = cmap
                )

            axes[j][i].set(yticklabels=[])  # remove the tick labels
            axes[j][i].tick_params(left=False)

            axes[j][i].set(xticklabels=[])  # remove the tick labels
            axes[j][i].tick_params(bottom=False)

            if i == 0:
                axes[j][0].set_ylabel(row_name_dict[k],fontsize = fs)

            #Label
            axes[j][i].text(-8.5,7.0, f"{chr(ord('A') + j)}{i+1}", fontsize = fs,c = "w")

            plt.tight_layout()

    plt.savefig("densities_2D_GMM_final.pdf")
    plt.close(fig)