In [None]:
%cd ..

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from trade import BoltzmannGeneratorHParams, BoltzmannGenerator
import torch
import numpy as np
from math import ceil, floor
from yaml import safe_load
import os
from functools import partial
import mdtraj as md
from matplotlib.colors import LogNorm
import matplotlib
from trade.data import get_loader
import pandas as pd
from tqdm.auto import trange

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
@torch.no_grad()
def ESS_from_log_weights(log_omega, clip_weights=False):
    log_a = 2 * torch.logsumexp(log_omega,0)
    log_b = torch.logsumexp(2 * log_omega,0)

    ESS_r = torch.exp(log_a - log_b) / len(log_omega)
    return ESS_r

In [None]:
runs_folder = "lightning_logs/"
full_names =  [os.path.join(runs_folder, f"version_{i}") for i in range(,)] # Paths to your models

In [None]:
data_low_T = torch.from_numpy(np.load(f"data/multi_well_5d_0.5.npz")["coordinates"]).to(device)
n_samples_ESS = 30000

In [None]:
df = pd.DataFrame(columns=["version", 
                           "loss_mode", 
                           "sample_origin", 
                           "energy_origin", 
                           "reweight_samples", 
                           "temperature_weighted_loss", 
                           "backward_kl", 
                           "no_causality_weights",
                           "NLL T=0.5",
                           "NLL T=1.0",
                           "ESS T=0.5",
                           "ESS T=1.0"])


with torch.no_grad():
    for i, model_folder in enumerate(full_names):
        hparams_path = os.path.join(model_folder, "hparams.yaml")
        checkpoint_path = os.path.join(model_folder, "checkpoints/last.ckpt")
        
        ckpt = torch.load(checkpoint_path)
        hparams = ckpt["hyper_parameters"]
        del hparams["n_steps"]
        del hparams["epoch_len"]
        if hparams["parameter_pinf_loss"] is not None:
            if not "n_points_param_grid" in hparams["parameter_pinf_loss"].additional_kwargs.keys():
                hparams["parameter_pinf_loss"].additional_kwargs["mode"] = "continuous"
            else:
                hparams["parameter_pinf_loss"].additional_kwargs["mode"] = "grid"

        model = BoltzmannGenerator(hparams)
        model.load_state_dict(ckpt["state_dict"])
        model.eval().to(device)

        low_T = 0.5 
        try:
            model_nll_low_T = model.flow.energy(data_low_T, c=[], parameter=low_T).mean().item()
            model_nll_high_T = model.flow.energy(model.val_data[:][0].to(device), c=[], parameter=1.0).mean().item()
        
            log_weights_low_T = - model.flow.energy_ratio_from_latent(model.flow.prior.sample([n_samples_ESS], parameter=low_T), c=[], parameter=low_T)
            log_weights_high_T = - model.flow.energy_ratio_from_latent(model.flow.prior.sample([n_samples_ESS], parameter=1.0), c=[], parameter=1.0)
        except (RuntimeError, TypeError):
            continue
        ESS_low_T = ESS_from_log_weights(log_weights_low_T).item()
        ESS_high_T = ESS_from_log_weights(log_weights_high_T).item()

        if hparams["parameter_pinf_loss"] is not None:
            loss_mode = "continuous" if not "n_points_param_grid" in hparams["parameter_pinf_loss"].additional_kwargs.keys() else "grid"
            sample_origin = hparams["parameter_pinf_loss"].additional_kwargs["take_samples_from"]
            energy_origin = hparams["parameter_pinf_loss"].additional_kwargs["check_consistency_with"]
            reweight_samples = "reference_parameter" in dict(hparams["parameter_pinf_loss"].additional_kwargs).keys()
            reweight = False
            backward_kl = False
            no_causality_weights = dict(hparams["parameter_pinf_loss"].additional_kwargs).get("use_target_proposals")
        
        else:
            backward_kl = hparams["kl_loss"] is not None
            reweight = hparams["temperature_weighted_loss"] is not None
            loss_mode = None
            sample_origin = None
            energy_origin = None
            reweight_samples = None
            no_causality_weights = None
        
        row = [model_folder.split("/")[-1].split("_")[-1],
               loss_mode,
               sample_origin,
               energy_origin,
               reweight_samples,
               reweight,
               backward_kl,
               no_causality_weights,
               model_nll_low_T,
               model_nll_high_T,
               ESS_low_T,
               ESS_high_T]
        df.loc[i] = row   

In [None]:
# Assuming df is your DataFrame
settings_columns = [
    'loss_mode', 'sample_origin', 'energy_origin', 'reweight_samples',
    'temperature_weighted_loss', 'backward_kl', 'no_causality_weights'
]

results_columns = [
    'NLL T=0.5', 'NLL T=1.0', 'ESS T=0.5', 'ESS T=1.0'
]

# Group by settings columns and aggregate the results columns
aggregated_df = df.fillna("None").groupby(settings_columns)[results_columns].agg(['mean', 'std']).reset_index()

aggregated_df