In [3]:
import torch
import numpy as np

import matplotlib.pyplot as plt

from bgflow.utils import as_numpy
from bgflow import DiffEqFlow, BoltzmannGeneratorCV, MeanFreeNormalDistribution
from bgflow import BlackBoxDynamics, BruteForceEstimator
from tbg.models2 import EGNN_dynamics_AD2_cat
from tbg.modelwithcv import EGNN_AD2_CV
from bgflow import BlackBoxDynamics, BruteForceEstimator



****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



In [7]:
n_particles = 22
n_dimensions = 3
dim = n_particles * n_dimensions

atom_types = np.arange(22)
atom_types[[0, 2, 3]] = 0
atom_types[1] = 2
atom_types[[19, 20, 21]] = 20
atom_types[[11, 12, 13]] = 12
h_initial = torch.nn.functional.one_hot(torch.tensor(atom_types))
prior = MeanFreeNormalDistribution(dim, n_particles, two_event_dims=False).cuda()


## Original

In [8]:
brute_force_estimator = BruteForceEstimator()
# net_dynamics = EGNN_AD2_CV(
net_dynamics = EGNN_dynamics_AD2_cat(
    n_particles=n_particles,
    device="cuda",
    n_dimension=dim // n_particles,
    h_initial=h_initial,
    hidden_nf=64,
    act_fn=torch.nn.SiLU(),
    n_layers=5,
    recurrent=True,
    tanh=True,
    attention=True,
    condition_time=True,
    mode="egnn_dynamics",
    agg="sum",
)
bb_dynamics = BlackBoxDynamics(
    dynamics_function=net_dynamics, divergence_estimator=brute_force_estimator
)
flow = DiffEqFlow(dynamics=bb_dynamics)
bg = BoltzmannGeneratorCV(prior, flow, prior).cuda()

class BruteForceEstimatorFast(torch.nn.Module):
    """
    Exact bruteforce estimation of the divergence of a dynamics function.
    """

    def __init__(self):
        super().__init__()

    def forward(self, dynamics, t, xs, cv_condition = None):

        with torch.set_grad_enabled(True):
            xs.requires_grad_(True)
            x = [xs[:, [i]] for i in range(xs.size(1))]

            dxs = dynamics(t, torch.cat(x, dim=1))

            assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
            divergence = 0
            for i in range(xs.size(1)):
                divergence += torch.autograd.grad(
                    dxs[:, [i]], x[i], torch.ones_like(dxs[:, [i]]), retain_graph=True
                )[0]

        return dxs, -divergence.view(-1, 1)


brute_force_estimator_fast = BruteForceEstimatorFast()
bb_dynamics._divergence_estimator = brute_force_estimator_fast
bg.flow._integrator_atol = 1e-4
bg.flow._integrator_rtol = 1e-4
flow._use_checkpoints = False
flow._kwargs = {}


filename = "FM-AD2-train-repro-custom-data"
PATH_last = f"models/{filename}"
checkpoint = torch.load(PATH_last)
flow.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [9]:
flow

DiffEqFlow(
  (_dynamics): DensityDynamics(
    (_dynamics): BlackBoxDynamics(
      (_dynamics_function): EGNN_dynamics_AD2_cat(
        (egnn): EGNN(
          (embedding): Linear(in_features=22, out_features=64, bias=True)
          (embedding_out): Linear(in_features=64, out_features=22, bias=True)
          (gcl_0): E_GCL(
            (edge_mlp): Sequential(
              (0): Linear(in_features=130, out_features=64, bias=True)
              (1): SiLU()
              (2): Linear(in_features=64, out_features=64, bias=True)
              (3): SiLU()
            )
            (node_mlp): Sequential(
              (0): Linear(in_features=128, out_features=64, bias=True)
              (1): SiLU()
              (2): Linear(in_features=64, out_features=64, bias=True)
            )
            (coord_mlp): Sequential(
              (0): Linear(in_features=64, out_features=64, bias=True)
              (1): SiLU()
              (2): Linear(in_features=64, out_features=1, bias=False)
       

## CV modified

In [14]:
brute_force_estimator = BruteForceEstimator()
net_dynamics = EGNN_AD2_CV(
    n_particles=n_particles,
    device="cuda",
    n_dimension=dim // n_particles,
    h_initial=h_initial,
    hidden_nf=64,
    act_fn=torch.nn.SiLU(),
    n_layers=5,
    recurrent=True,
    tanh=True,
    attention=True,
    condition_time=True,
    mode="egnn_dynamics",
    agg="sum",
)

# Set up the dynamics
bb_dynamics = BlackBoxDynamics(
    dynamics_function=net_dynamics, divergence_estimator=brute_force_estimator
)
flow = DiffEqFlow(dynamics=bb_dynamics)
bg = BoltzmannGeneratorCV(prior, flow, prior).cuda()

class BruteForceEstimatorFast(torch.nn.Module):
    """
    Exact bruteforce estimation of the divergence of a dynamics function.
    """

    def __init__(self):
        super().__init__()

    def forward(self, dynamics, t, xs, cv_condition = None):

        with torch.set_grad_enabled(True):
            xs.requires_grad_(True)
            x = [xs[:, [i]] for i in range(xs.size(1))]

            dxs = dynamics(t, torch.cat(x, dim=1), cv_condition=cv_condition)

            assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
            divergence = 0
            for i in range(xs.size(1)):
                divergence += torch.autograd.grad(
                    dxs[:, [i]], x[i], torch.ones_like(dxs[:, [i]]), retain_graph=True
                )[0]

        return dxs, -divergence.view(-1, 1)


print(">> Loading force estimator")
brute_force_estimator_fast = BruteForceEstimatorFast()
bb_dynamics._divergence_estimator = brute_force_estimator_fast
bg.flow._integrator_atol = 1e-4
bg.flow._integrator_rtol = 1e-4
flow._use_checkpoints = False
flow._kwargs = {}


filename = "tbg-fixed6"
PATH_last = f"models/{filename}.pt"
checkpoint = torch.load(PATH_last)
flow.load_state_dict(checkpoint["model_state_dict"])

TBGCV(
  (norm_in): Normalization(in_features=45, out_features=45, mode=mean_std)
  (encoder): FeedForward(
    (nn): Sequential(
      (0): Linear(in_features=45, out_features=30, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=30, out_features=30, bias=True)
      (3): ReLU(inplace=True)
      (4): Linear(in_features=30, out_features=2, bias=True)
    )
  )
)
>> Loading force estimator


<All keys matched successfully>

In [15]:
flow

DiffEqFlow(
  (_dynamics): DensityDynamics(
    (_dynamics): BlackBoxDynamics(
      (_dynamics_function): EGNN_AD2_CV(
        (cv): TBGCV(
          (norm_in): Normalization(in_features=45, out_features=45, mode=mean_std)
          (encoder): FeedForward(
            (nn): Sequential(
              (0): Linear(in_features=45, out_features=30, bias=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=30, out_features=30, bias=True)
              (3): ReLU(inplace=True)
              (4): Linear(in_features=30, out_features=2, bias=True)
            )
          )
        )
        (egnn): EGNN(
          (embedding): Linear(in_features=24, out_features=64, bias=True)
          (embedding_out): Linear(in_features=64, out_features=24, bias=True)
          (gcl_0): E_GCL(
            (edge_mlp): Sequential(
              (0): Linear(in_features=130, out_features=64, bias=True)
              (1): SiLU()
              (2): Linear(in_features=64, out_features=64,