In [None]:
%cd ..

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from trade import BoltzmannGeneratorHParams, BoltzmannGenerator
import torch
import numpy as np
from math import ceil, floor
from yaml import safe_load
import os
from functools import partial
import mdtraj as md
from matplotlib.colors import LogNorm
import matplotlib
from trade.data import get_loader
import pandas as pd
from tqdm.auto import trange


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Plots and metrics

In [None]:
def compute_kl_divergence(data_model, data_ground_truth, bins=50, range=None):
    """
    Compute the KL divergence between two 2D histograms of model and ground truth data.

    Parameters:
        data_model (array-like): Nx2 array of model data points.
        data_ground_truth (array-like): Nx2 array of ground truth data points.
        bins (int or tuple of ints): Number of bins along each dimension.
        range (array-like, optional): Range of the bins [(xmin, xmax), (ymin, ymax)].
                                      If None, inferred from the data.

    Returns:
        float: KL divergence between the two histograms.
    """
    # Compute 2D histograms
    hist_model, xedges, yedges = np.histogram2d(
        data_model[:, 0], data_model[:, 1], bins=bins, range=range
    )
    hist_ground_truth, _, _ = np.histogram2d(
        data_ground_truth[:, 0], data_ground_truth[:, 1], bins=bins, range=range
    )

    # Normalize histograms to get probability distributions
    P = hist_model / np.sum(hist_model)
    Q = hist_ground_truth / np.sum(hist_ground_truth)

    # Avoid log(0) or division by zero: set invalid entries to 0
    mask = (P > 0) & (Q > 0)
    P = P[mask]
    Q = Q[mask]

    # Compute KL divergence
    kl_divergence = np.sum(P * np.log(P / Q))

    return kl_divergence

def get_angles(samples, system):
    try:
        samples = samples.cpu()
    except:
        pass
    trajectory = md.Trajectory(
        xyz=samples.reshape(-1, 22, 3), 
        topology=system.mdtraj_topology
    )
    return np.stack(system.compute_phi_psi(trajectory), axis=-1)

In [None]:
@torch.no_grad()
def plot_ala2(samples, system, target_energy, reference_data=None):
    has_reference = reference_data is not None
    fig, ax = plt.subplots(1, 3 + has_reference, figsize=(5*(3 + has_reference), 5))
    plot_energies(ax[2+has_reference], samples, target_energy, reference_data)
    ax[-1].set_title("Energy distribution")

    samples = samples.cpu().detach().numpy()
    if has_reference:
        reference_data = reference_data.cpu().detach().numpy()
        vmin, vmax = plot_phi_psi(ax[1], reference_data, system)
        ax[1].set_title("Ramachandran plot (MD)")
    else:
        vmin, vmax = None, None
    plot_phi_psi(ax[0], samples, system, vmin=vmin, vmax=vmax)
    ax[0].set_title("Ramachandran plot (BG)")

    plot_phi(ax[1+has_reference], samples, system, reference_data)
    ax[1+has_reference].set_title(f"Density of $\\phi$")

    plt.tight_layout()
    return fig, ax


@torch.no_grad()
def plot_ala2_together(*samples, system, target_energy, reference_data=None, names=None):
    if names is None:
        names = [f"model {i}" for i in range(len(samples))]
    has_reference = reference_data is not None
    fig, ax = plt.subplots(1, 2 + has_reference + len(samples), figsize=(5*(2 + has_reference + len(samples)), 5))
    plot_energies(ax[-1], *samples, target_energy=target_energy, test_data=reference_data, names=names)
    ax[-1].set_title("Energy distribution")



    plot_phi(ax[-2], *samples, system=system, reference_data=reference_data, names=names)
    ax[-2].set_title(f"Density of $\\phi$")

    if has_reference:
        reference_data = reference_data.cpu().detach().numpy()
        vmin, vmax = plot_phi_psi(ax[-3], reference_data, system)
        ax[2].set_title("Ramachandran plot (MD)")
    else:
        vmin, vmax = None, None

    for i, (name, sample) in enumerate(zip(names, samples)):
        sample = sample.cpu().detach().numpy()
        plot_phi_psi(ax[i], sample, system, vmin=vmin, vmax=vmax)
        ax[i].set_title(f"Ramachandran plot ({name})")



    plt.tight_layout()
    return fig, ax

def plot_phi_psi(ax, trajectory, system, vmin=None, vmax=None):
    try:
        trajectory = trajectory.cpu()
    except:
        pass
    if not isinstance(trajectory, md.Trajectory):
        trajectory = md.Trajectory(
            xyz=trajectory.reshape(-1, 22, 3), 
            topology=system.mdtraj_topology
        )
    phi, psi = system.compute_phi_psi(trajectory)
    hist = ax.hist2d(phi, psi, 50, norm=LogNorm(vmin=vmin, vmax=vmax), density=True)
    ax.set_xlim(-np.pi, np.pi)
    ax.set_ylim(-np.pi, np.pi)
    ax.set_xlabel("$\phi$")
    ax.set_ylabel("$\psi$")
    
    return hist[-1].get_clim()

def plot_energies(ax, *samples, target_energy=None, test_data=None, names=None):

    samples_energy = []
    min_energy = np.inf
    max_energy = - np.inf
    cut = -np.inf
    for sample in samples:
        samples_energy.append(target_energy(sample).cpu().detach().numpy())
        min_s, max_s = np.nanmin(samples_energy[-1]), np.nanmax(samples_energy[-1])
        min_energy, max_energy = min(min_s, min_energy), max(max_s, max_energy)
        cut = max(np.nanpercentile(samples_energy[-1], 80), cut)
    if test_data is not None:
        md_energies = target_energy(test_data[:len(samples[0])]).cpu().detach().numpy()
        min_s, max_s = np.nanmin(md_energies), np.nanmax(md_energies)
        min_energy, max_energy = min(min_s, min_energy), max(max_s, max_energy)
        cut = max(np.nanmax(md_energies), cut)
    else:
        md_energies = sample_energies
    full_range = min_energy, max_energy
    plot_range = (full_range[0] - 0.1*(cut - full_range[0]), cut)

    ax.set_xlabel("Energy   [$k_B T$]")
    # y-axis on the right
    ax2 = plt.twinx(ax)
    ax.get_yaxis().set_visible(False)

    
    if test_data is not None:

        count_md_energies = np.sum(np.logical_and(md_energies < plot_range[1], md_energies > plot_range[0]))
        
        for name, sample_energy in zip(names, samples_energy):
            # This adjusts the counts to be comparable, even if different number of samples fall into the range
            count_sample_energy = np.sum(np.logical_and(sample_energy < plot_range[1], sample_energy > plot_range[0]))
            weights_sample = np.ones_like(sample_energy)*count_md_energies/count_sample_energy

            ax2.hist(sample_energy, range=plot_range, bins=40, weights=weights_sample, density=False, label=f"{name}", alpha=0.6)
    else:
        for name, sample_energy in zip(names, samples_energy):
            ax2.hist(sample_energy, range=plot_range, bins=40, density=False, label=f"{name}", alpha=0.6)

    if test_data is not None:
        ax2.hist(md_energies, range=plot_range, bins=40, density=False, label="MD", alpha=0.3)
    ax2.set_ylabel(f"Count   [#Samples / {len(samples)}]")
    ax2.legend()

def plot_phi(ax, *trajectories, system=None, reference_data=None, names=None):
    for name, trajectory in zip(names, trajectories):
        try:
            trajectory = trajectory.cpu()
        except:
            pass
        if not isinstance(trajectory, md.Trajectory):
            trajectory = md.Trajectory(
                xyz=trajectory.reshape(-1, 22, 3), 
                topology=system.mdtraj_topology
            )

        phi, _ = system.compute_phi_psi(trajectory)
        p_phi = np.histogram(phi, bins=100, range=(-np.pi, np.pi), density=True)
        ax.plot(p_phi[1][:-1], p_phi[0], label=f"{name}")

    if reference_data is not None:
        reference_data = md.Trajectory(
            xyz=reference_data.reshape(-1, 22, 3), 
            topology=system.mdtraj_topology
        )
        phi_ref, _ = system.compute_phi_psi(reference_data)
        p_phi_ref = np.histogram(phi_ref, bins=100, range=(-np.pi, np.pi), density=True)
        ax.plot(p_phi_ref[1][:-1], p_phi_ref[0], label="MD")

    ax.set_xlim(-np.pi, np.pi)
    ax.set_xlabel("$\phi$")
    ax.set_ylabel("$p(\phi)$")
    ax.legend()

In [None]:
@torch.no_grad()
def ESS_from_log_weights(log_omega, clip_weights=False):
    log_omega = log_omega - torch.logsumexp(log_omega, dim=0)
    log_a = 2 * torch.logsumexp(log_omega,0)
    log_b = torch.logsumexp(2 * log_omega,0)

    ESS_r = torch.exp(log_a - log_b - np.log(len(log_omega)))
    return ESS_r

## TRADE 

In [None]:
model_folder = "lightning_logs/version_" # Path to your trained model

hparams_path = os.path.join(model_folder, "hparams.yaml")
checkpoint_path = os.path.join(model_folder, "checkpoints/last.ckpt")

ckpt = torch.load(checkpoint_path)
hparams_trade = dict(ckpt["hyper_parameters"])
del hparams_trade["n_steps"]
del hparams_trade["epoch_len"]
model_trade = BoltzmannGenerator(hparams_trade)
model_trade.load_state_dict(ckpt["state_dict"])
model_trade.eval()
pass

In [None]:
system = model_trade.datasets[0].system
batch_size = 10000
model_trade = model_trade.to(device)

In [None]:
reference_data_trade = model_trade.val_data[:][0]

In [None]:
T = 1.0
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_trade)/batch_size)):

        x_gt, c_gt = reference_data_trade.split(batch_size)[i], []
        nll.append(model_trade.flow.energy(x_gt.to(device), c_gt, parameter=T))
        c = model_trade.sample_condition(batch_size=batch_size)
        x = model_trade.flow.sample(batch_size, c=c, parameter=T)
        lw = torch.zeros_like(x[:,0])
        samples.append(x)
        log_weights.append(lw)
    
    nll_trade = torch.cat(nll, dim=0).mean().item()
    samples_trade = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_trade = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_trade, system)
    gt_angles = get_angles(reference_data_trade, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_trade*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_trade:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

In [None]:
dataset = get_loader("ala2")(temperature=300)

ind = torch.randperm(len(dataset.coordinates))[:len(reference_data_trade)]
reference_data_low = torch.from_numpy(dataset.coordinates.reshape(-1, dataset.dim))[ind]

In [None]:
T = 0.5
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_low)/batch_size)):

        x_gt, c_gt = reference_data_low.split(batch_size)[i], []
        nll.append(model_trade.flow.energy(x_gt.to(device), c_gt, parameter=T))
        c = model_trade.sample_condition(batch_size=batch_size)
        x = model_trade.flow.sample(batch_size, c=c, parameter=T)
        lw = torch.zeros_like(x[:,0])
        samples.append(x)
        log_weights.append(lw)
    
    nll_trade = torch.cat(nll, dim=0).mean().item()
    samples_trade_lowT = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_trade = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_trade_lowT, system)
    gt_angles = get_angles(reference_data_low, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_trade*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_trade:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

## Temperature Steerable model

In [None]:
model_folder = "lightning_logs/version_" # Path to your trained model

hparams_path = os.path.join(model_folder, "hparams.yaml")
checkpoint_path = os.path.join(model_folder, "checkpoints/last.ckpt")

ckpt = torch.load(checkpoint_path)
hparams_vp = dict(ckpt["hyper_parameters"])
del hparams_vp["n_steps"]
del hparams_vp["epoch_len"]
model_vp = BoltzmannGenerator(hparams_vp)
model_vp.load_state_dict(ckpt["state_dict"])
model_vp.eval()
pass

In [None]:
system = model_vp.datasets[0].system
batch_size = 10000
model_trade = model_vp.to(device)

In [None]:
reference_data_vp = model_vp.val_data[:][0]

In [None]:
T = 1.0
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_vp)/batch_size)):

        x_gt, c_gt = reference_data_vp.split(batch_size)[i], []
        nll.append(model_vp.flow.energy(x_gt.to(device), c_gt, parameter=T))

        c = model_vp.sample_condition(batch_size=batch_size)
        x = model_vp.flow.sample(batch_size, c=c, parameter=T)
        samples.append(x)
        lw = torch.zeros_like(x[:,0])
        log_weights.append(lw)
    
    nll_vp = torch.cat(nll, dim=0).mean().item()
    samples_vp = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_vp = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_vp, system)
    gt_angles = get_angles(reference_data_vp, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_vp*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_vp:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

In [None]:
dataset = get_loader("ala2")(temperature=300)

ind = torch.randperm(len(dataset.coordinates))[:len(reference_data_vp)]
reference_data_low = torch.from_numpy(dataset.coordinates.reshape(-1, dataset.dim))[ind]

In [None]:
T = 0.5
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_low)/batch_size)):

        x_gt, c_gt = reference_data_low.split(batch_size)[i], []
        nll.append(model_vp.flow.energy(x_gt.to(device), c_gt, parameter=T))
        

        c = model_vp.sample_condition(batch_size=batch_size)
        # x = fix_sampling_bg_vp(model_vp.flow.bgflow_bg, batch_size, context=c, temperature=T)
        x = model_vp.flow.sample(batch_size, c=c, parameter=T)
        samples.append(x)
        lw = torch.zeros_like(x[:,0])
        log_weights.append(lw)
    
    nll_vp = torch.cat(nll, dim=0).mean().item()
    samples_vp_lowT = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_vp = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_vp_lowT, system)
    gt_angles = get_angles(reference_data_low, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_vp*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_vp:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

## Reverse KL model

In [None]:
model_folder = "lightning_logs/version_" # Path to your trained model

hparams_path = os.path.join(model_folder, "hparams.yaml")
checkpoint_path = os.path.join(model_folder, "checkpoints/last.ckpt")

ckpt = torch.load(checkpoint_path)
hparams_rev_kl = dict(ckpt["hyper_parameters"])
del hparams_rev_kl["n_steps"]
del hparams_rev_kl["epoch_len"]
model_rev_kl = BoltzmannGenerator(hparams_rev_kl)
model_rev_kl.load_state_dict(ckpt["state_dict"])
model_rev_kl.eval()
pass

In [None]:
system = model_rev_kl.datasets[0].system
batch_size = 10000
model_trade = model_rev_kl.to(device)

In [None]:
reference_data_rev_kl = model_rev_kl.val_data[:][0]

In [None]:
T = 1.0
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_rev_kl)/batch_size)):

        x_gt, c_gt = reference_data_rev_kl.split(batch_size)[i], []
        nll.append(model_rev_kl.flow.energy(x_gt.to(device), c_gt, parameter=T))

        c = model_rev_kl.sample_condition(batch_size=batch_size)
        x = model_rev_kl.flow.sample(batch_size, c=c, parameter=T)
        samples.append(x)
        lw = torch.zeros_like(x[:,0])
        log_weights.append(lw)
    
    nll_rev_kl = torch.cat(nll, dim=0).mean().item()
    samples_rev_kl = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_rev_kl = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_rev_kl, system)
    gt_angles = get_angles(reference_data_rev_kl, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_rev_kl*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_rev_kl:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

In [None]:
dataset = get_loader("ala2")(temperature=300)

ind = torch.randperm(len(dataset.coordinates))[:len(reference_data_rev_kl)]
reference_data_low = torch.from_numpy(dataset.coordinates.reshape(-1, dataset.dim))[ind]

In [None]:
T = 0.5
with torch.no_grad():
    samples = []
    log_weights = []
    nll = []
    for i in trange(ceil(len(reference_data_low)/batch_size)):

        x_gt, c_gt = reference_data_low.split(batch_size)[i], []
        nll.append(model_rev_kl.flow.energy(x_gt.to(device), c_gt, parameter=T))

        c = model_rev_kl.sample_condition(batch_size=batch_size)
        x = model_rev_kl.flow.sample(batch_size, c=c, parameter=T)
        samples.append(x)
        lw = torch.zeros_like(x[:,0])
        log_weights.append(lw)
    
    nll_rev_kl = torch.cat(nll, dim=0).mean().item()
    samples_rev_kl_lowT = torch.cat(samples, dim=0)
    log_weights = torch.cat(log_weights, dim=0)
    ESS_rev_kl = ESS_from_log_weights(log_weights).item()
    model_angles = get_angles(samples_rev_kl_lowT, system)
    gt_angles = get_angles(reference_data_low, system)
    kl_divergence_angles = compute_kl_divergence(model_angles, gt_angles, bins=30) 
    print(f"Effective Sample Size at T={T:.1f}: {ESS_rev_kl*100:.2f}%")
    print(f"Negative Log Likelihood at T={T:.1f}: {nll_rev_kl:.2f}")
    print(f"KL diveregence $\phi, \psi$ histogram at T={T:.1f}: {kl_divergence_angles:.4f}")

## Combined Plots

In [None]:
plot_ala2_together(samples_vp, 
                   samples_rev_kl, 
                   samples_trade, 
                   reference_data=reference_data_trade, 
                   system=system, 
                   target_energy=partial(model_trade.flow.get_energy_model().energy, 
                                         temperature=600),
                   names=["TSF", "Rev KL", "TRADE"])

In [None]:
plot_ala2_together(samples_vp_lowT, 
                   samples_rev_kl_lowT, 
                   samples_trade_lowT, 
                   reference_data=reference_data_low, 
                   system=system, 
                   target_energy=partial(model_trade.flow.get_energy_model().energy, 
                                         temperature=300),
                   names=["TSF", "Rev KL", "TRADE"])