In [None]:
%cd ..

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from trade import BoltzmannGeneratorHParams, BoltzmannGenerator
import torch
import numpy as np
from math import ceil, floor
from yaml import safe_load
import os
from functools import partial
import mdtraj as md
from matplotlib.colors import LogNorm
import matplotlib
from beta_scaling.data import get_loader
import pandas as pd
from tqdm.auto import trange

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Plotting functions

In [None]:
font = {'size'   : 12}

matplotlib.rc('font', **font)

In [None]:
def plot_two_moons(flow, reference_data, reference_conditions, beta):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))  # Larger figure size for better clarity
    target_observation = torch.zeros(len(reference_data), 2, device=list(flow.parameters())[0].device)
    samples = flow.sample(len(reference_data), c=[target_observation], parameter=beta)
    ind_close = torch.where(torch.norm(reference_conditions - target_observation[:1].to(reference_conditions.device), dim=1) < 0.001)[0]
    data_close = reference_data[ind_close].cpu().detach().numpy()
    samples = samples.cpu().detach().numpy()

    # Plot for Two Moons (NF)
    ax[0].scatter(samples[:, 0], samples[:, 1], s=50, alpha=0.7, c='blue', edgecolor='black', label='NF samples')
    ax[0].set_title(f"TRADE at $\\beta$={beta:.1f}", fontsize=14)
    ax[0].set_xlabel("$\psi_1$", fontsize=12)
    ax[0].set_ylabel("$\psi_2$", fontsize=12)
    ax[0].grid(True, linestyle='--', alpha=0.6)

    # Plot for Two Moons (Approximate GT)
    ax[1].scatter(data_close[:, 0], data_close[:, 1], s=50, alpha=0.8, c='red', edgecolor='black', label='GT Reference Points')
    ax[1].set_title(f"Approximate ground truth at $\\beta$={beta:.1f}", fontsize=14)
    ax[1].set_xlabel("$\psi_1$", fontsize=12)
    ax[1].set_ylabel("$\psi_2$", fontsize=12)
    ax[1].grid(True, linestyle='--', alpha=0.6)

    ax[0].set_xlim([-0.35, 0.35])
    ax[1].set_xlim([-0.35, 0.35])
    ax[0].set_ylim([-0.35, 0.35])
    ax[1].set_ylim([-0.35, 0.35])
    plt.tight_layout()
    return fig, ax

In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_calibration_curve(flow, flow_T1, reference_data, reference_conditions, beta, 
                           param_names=None, n_intervals=10, n_bootstrap=1000, ci_alpha=0.95, plot_T1=True, n_repeats=300):
    """
    Plots calibration curves with 95% confidence intervals for a Bayesian inference model.

    Args:
        flow: The model used to generate posterior samples.
        reference_data (array-like): True parameter values (shape: n_samples x 2).
        reference_conditions (array-like): Conditional inputs for the flow model.
        beta (float): beta parameter for the flow model.
        param_names (list): Names of the parameters (default: None).
        n_intervals (int): Number of credible intervals to evaluate.
        n_bootstrap (int): Number of bootstrap samples for estimating confidence intervals.
        ci_alpha (float): Confidence interval significance level.
    """

    posterior_samples_T1 = flow_T1.sample(len(reference_data)*n_repeats, c=[reference_conditions.repeat_interleave(n_repeats, 0).to(device)], parameter=1.0)
    posterior_samples = flow.sample(len(reference_data)*n_repeats, c=[reference_conditions.repeat_interleave(n_repeats, 0).to(device)], parameter=beta)
    posterior_samples_T1 = posterior_samples_T1.reshape(-1, n_repeats, 2).cpu().detach().numpy()
    posterior_samples = posterior_samples.reshape(-1, n_repeats, 2).cpu().detach().numpy()
    # n_params = true_params.shape[1]

    true_params = reference_data.cpu().detach().numpy()
    true_params_dep = np.stack([-np.abs(true_params[..., 0] + true_params[..., 1]) / np.sqrt(2.0), (-true_params[..., 0] + true_params[..., 1]) / np.sqrt(2.0)], axis=-1)
    true_r = reference_conditions.cpu().detach().numpy() - true_params_dep
    true_r[:, 0] -= 0.25
    true_r = np.sqrt(np.sum(true_r**2 , axis=-1, keepdims=True))



    
    model_params_dep = np.stack([-np.abs(posterior_samples[:, :, 0] + posterior_samples[:, :, 1]) / np.sqrt(2.0), (-posterior_samples[:, :, 0] + posterior_samples[:, :, 1]) / np.sqrt(2.0)], axis=-1)
    # Ensure proper shape for reference_conditions
    repeated_conditions = reference_conditions.repeat_interleave(n_repeats, 0).reshape(-1, n_repeats, 2)
    
    model_r = reference_conditions.repeat_interleave(n_repeats, 0).reshape(-1, n_repeats, 2) - model_params_dep
    model_r[..., 0] -= 0.25
    model_r = model_r.cpu().detach().numpy()
    model_r = np.sqrt(np.sum(model_r**2, axis=-1, keepdims=True))
    
    model_params_dep_T1 = np.stack([-np.abs(posterior_samples_T1[..., 0] + posterior_samples_T1[..., 1]) / np.sqrt(2.0), (-posterior_samples_T1[..., 0] + posterior_samples_T1[..., 1]) / np.sqrt(2.0)], axis=-1)
    model_r_T1 = reference_conditions.repeat_interleave(n_repeats, 0).reshape(-1, n_repeats, 2).cpu().detach().numpy() - model_params_dep_T1
    model_r_T1[..., 0] -= 0.25
    model_r_T1 = np.sqrt(np.sum(model_r_T1**2 , axis=-1, keepdims=True))

    n_params = true_r.shape[1]

    if param_names is None:
        param_names = [f"Param {i+1}" for i in range(n_params)]
    
    fig, axes = plt.subplots(1, n_params, figsize=(6 * n_params, 5))
    
    for i in range(n_params):
        # true_param_i = true_params[:, i]
        # posterior_samples_i = posterior_samples[:, :, i]
        # posterior_samples_T1_i = posterior_samples_T1[:, :, i]

        true_param_i = true_r[:, i]
        posterior_samples_i = model_r[:, :, i]
        posterior_samples_T1_i = model_r_T1[:, :, i]
        # print(posterior_samples_T1_i[0, 0], posterior_samples_i[0, 0])
        
        
        interval_coverage = []
        predicted_coverage = []
        lower_ci = []
        upper_ci = []

        interval_coverage_T1 = []
        lower_ci_T1 = []
        upper_ci_T1 = []
        for alpha in np.linspace(0, 1, n_intervals):
            lower_bound = np.percentile(posterior_samples_i, (1 - alpha) * 100 / 2, axis=1)
            upper_bound = np.percentile(posterior_samples_i, (1 + alpha) * 100 / 2, axis=1)

            # print(upper_bound_T1[0], upper_bound[0])
            
            in_interval = (true_param_i >= lower_bound) & (true_param_i <= upper_bound)
            empirical_coverage = np.mean(in_interval)

            
            lower_bound_T1 = np.percentile(posterior_samples_T1_i, (1 - alpha) * 100 / 2, axis=1)
            upper_bound_T1 = np.percentile(posterior_samples_T1_i, (1 + alpha) * 100 / 2, axis=1)
            
            in_interval_T1 = (true_param_i >= lower_bound_T1) & (true_param_i <= upper_bound_T1)
            empirical_coverage_T1 = np.mean(in_interval_T1)

            # print(empirical_coverage, empirical_coverage_T1)
            # Bootstrap to estimate 95% confidence intervals
            bootstrap_coverage = []
            bootstrap_coverage_T1 = []
            for _ in range(n_bootstrap):
                indices = np.random.choice(len(true_param_i), len(true_param_i), replace=True)
                bootstrap_in_interval = in_interval[indices]
                bootstrap_coverage.append(np.mean(bootstrap_in_interval))
                
                bootstrap_in_interval_T1 = in_interval_T1[indices]
                bootstrap_coverage_T1.append(np.mean(bootstrap_in_interval_T1))
            
            ci_lower = np.percentile(bootstrap_coverage, (1 - ci_alpha) / 2 * 100)
            ci_upper = np.percentile(bootstrap_coverage, (1 + ci_alpha) / 2 * 100)

            ci_lower_T1 = np.percentile(bootstrap_coverage_T1, (1 - ci_alpha) / 2 * 100)
            ci_upper_T1 = np.percentile(bootstrap_coverage_T1, (1 + ci_alpha) / 2 * 100)
            
            interval_coverage.append(empirical_coverage)
            predicted_coverage.append(alpha)
            lower_ci.append(ci_lower)
            upper_ci.append(ci_upper)
            
            interval_coverage_T1.append(empirical_coverage_T1)
            lower_ci_T1.append(ci_lower_T1)
            upper_ci_T1.append(ci_upper_T1)
        # Plot calibration curve with confidence intervals
        ax = axes[i] if n_params > 1 else axes
        ax.plot(predicted_coverage, interval_coverage, label=f"TRADE at $\\beta={beta:.1f}$", marker="o", color="C0")
        ax.fill_between(predicted_coverage, lower_ci, upper_ci, color="C0", alpha=0.2, label=f"{int(ci_alpha * 100)}% CI")

        if plot_T1:
            ax.plot(predicted_coverage, interval_coverage_T1, label=r"Baseline model ($\beta=1.0$)", marker="o", color="C1")
            ax.fill_between(predicted_coverage, lower_ci_T1, upper_ci_T1, color="C1", alpha=0.2, label=f"{int(ci_alpha * 100)}% CI")
        
        ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect calibration")
        ax.set_title(f"Calibration Curve for {param_names[i]} at $\\beta={beta:.1f}$")
        ax.set_xlabel("Predicted Coverage")
        ax.set_ylabel("Empirical Coverage")
        ax.legend(loc="upper left")
    
    plt.tight_layout()


## TRADE

In [None]:
model_folder = "lightning_logs/version_" # Path to your trained model

hparams_path = os.path.join(model_folder, "hparams.yaml")
checkpoint_path = os.path.join(model_folder, "checkpoints/last.ckpt")

ckpt = torch.load(checkpoint_path)
hparams = ckpt["hyper_parameters"]
del hparams["n_steps"]
del hparams["epoch_len"]
hparams["parameter_pinf_loss"].additional_kwargs["mode"] = "continuous"
model = BoltzmannGenerator(hparams)
model.load_state_dict(ckpt["state_dict"])
model.eval()
pass

In [None]:
beta = 1.0
plot_data = np.load(f"data/two_moons_target_obs_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"])
reference_conditions = torch.from_numpy(plot_data["conditions"])
plot_two_moons_new(model.flow.cpu(), reference_data, reference_conditions, beta)

In [None]:
beta = 0.5
plot_data = np.load(f"data/two_moons_target_obs_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"])
reference_conditions = torch.from_numpy(plot_data["conditions"])
plot_two_moons_new(model.flow.cpu(), reference_data, reference_conditions, beta)

In [None]:
beta = 2.0
plot_data = np.load(f"data/two_moons_target_obs_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"])
reference_conditions = torch.from_numpy(plot_data["conditions"])
plot_two_moons_new(model.flow.cpu(), reference_data, reference_conditions, beta)

## Baseline Model

In [None]:
model_folder2 = "lightning_logs/version_" # Path to your trained model

hparams_path2 = os.path.join(model_folder2, "hparams.yaml")
checkpoint_path2 = os.path.join(model_folder2, "checkpoints/last.ckpt")

ckpt2 = torch.load(checkpoint_path2)
hparams2 = ckpt2["hyper_parameters"]
del hparams2["n_steps"]
del hparams2["epoch_len"]
model2 = BoltzmannGenerator(hparams2)
model2.load_state_dict(ckpt2["state_dict"])
model2.eval()
pass

In [None]:
beta = 1.0
plot_data = np.load(f"data/two_moons_target_obs_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"])
reference_conditions = torch.from_numpy(plot_data["conditions"])
plot_two_moons_new(model2.flow.cpu(), reference_data, reference_conditions, beta)

## Combined Plots

In [None]:
beta = 1.0
plot_data = np.load(f"data/two_moons_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"]).float()[:1000]
reference_conditions = torch.from_numpy(plot_data["conditions"]).float()[:1000]
plot_calibration_curve(model.flow.to(device), model2.flow.to(device), reference_data, reference_conditions, beta, plot_T1=True, n_repeats=300, param_names=[r"$r$"])

In [None]:
beta = 0.5
plot_data = np.load(f"data/two_moons_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"]).float()[:1000]
reference_conditions = torch.from_numpy(plot_data["conditions"]).float()[:1000]
plot_calibration_curve(model.flow, model2.flow, reference_data, reference_conditions, beta, plot_T1=True, param_names=[r"$r$"])

In [None]:
beta = 2.0
plot_data = np.load(f"data/two_moons_{beta:.1f}.npz")
reference_data = torch.from_numpy(plot_data["coordinates"]).float()[:1000]
reference_conditions = torch.from_numpy(plot_data["conditions"]).float()[:1000]
plot_calibration_curve(model.flow, model2.flow, reference_data, reference_conditions, beta, plot_T1=True, param_names=[r"$r$"])