In [17]:
import numpy as np
import mdtraj as md
import bgflow as bg
import nglview as nv

from openmm import *
from openmm.app import *
from openmm.unit import *

import torch

from tqdm.auto import tqdm

In [2]:
data_dir = "../data"
IMAGE_SIZE = "400px"

# State view

In [3]:
molecule = "alanine"
state = "c5"
pdb_file = f"{data_dir}/{molecule}/{state}.pdb"

In [4]:
view = nv.show_structure_file(pdb_file, width=IMAGE_SIZE, height=IMAGE_SIZE)
view

NGLWidget()

In [94]:
traj = md.load(pdb_file)
print(traj.xyz)

[[[1.9185 0.3272 2.0062]
  [1.8852 0.3056 2.1077]
  [1.9711 0.2965 2.1742]
  [1.8267 0.2137 2.1094]
  [1.7969 0.4173 2.158 ]
  [1.6799 0.394  2.1826]
  [1.8542 0.5357 2.1737]
  [1.9537 0.5432 2.1581]
  [1.7848 0.6546 2.2226]
  [1.6778 0.6439 2.2048]
  [1.8056 0.6616 2.3749]
  [1.7618 0.7523 2.4166]
  [1.9126 0.6649 2.3954]
  [1.7636 0.5733 2.423 ]
  [1.833  0.7819 2.1506]
  [1.9476 0.789  2.1079]
  [1.7462 0.8834 2.143 ]
  [1.6555 0.8713 2.1856]
  [1.778  1.0114 2.0802]
  [1.69   1.0757 2.0804]
  [1.8606 1.0599 2.1323]
  [1.8047 0.9939 1.976 ]]]


# Internal coordinates

In [93]:
z_matrix = np.array([
    [ 0,  1,  4,  6],
    [ 1,  4,  6,  8],
    [ 2,  1,  4,  0],
    [ 3,  1,  4,  0],
    [ 4,  6,  8, 14],
    [ 5,  4,  6,  8],
    [ 7,  6,  8,  4],
    [11, 10,  8,  6],
    [12, 10,  8, 11],
    [13, 10,  8, 11],
    [15, 14,  8, 16],
    [16, 14,  8,  6],
    [17, 16, 14, 15],
    [18, 16, 14,  8],
    [19, 18, 16, 14],
    [20, 18, 16, 19],
    [21, 18, 16, 19]
])
rigid_block = np.array([ 6,  8,  9, 10, 14])

coordinate_transform = bg.RelativeInternalCoordinateTransformation(
    z_matrix=z_matrix,
    fixed_atoms=rigid_block,
    normalize_angles = True,
    eps = 1e-7,
)

In [96]:
example_data = torch.tensor(traj.xyz)
example_data = example_data.reshape(1, -1)

In [97]:
bonds, angles, torsions, z_fixed, dlogp = coordinate_transform.forward(example_data)
example_data_recovered = coordinate_transform._inverse(bonds, angles, torsions, z_fixed)
example_data_recovered[0][0] - example_data[0]

tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  2.9802e-08,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -5.9605e-08,
         0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00,  0.0000e+00,
        -5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00,
         0.0000e+00])

In [98]:
bonds, angles, torsions, z_fixed

(tensor([[0.1090, 0.1510, 0.1090, 0.1090, 0.1325, 0.1218, 0.1010, 0.1090, 0.1090,
          0.1090, 0.1225, 0.1338, 0.1009, 0.1461, 0.1090, 0.1090, 0.1090]]),
 tensor([[0.6112, 0.6477, 0.6035, 0.5986, 0.6898, 0.6902, 0.6519, 0.6180, 0.6038,
          0.6131, 0.6680, 0.6488, 0.6551, 0.6846, 0.6106, 0.6136, 0.6049]]),
 tensor([[0.3188, 0.0018, 0.8350, 0.1661, 0.1035, 0.4999, 0.0141, 0.9891, 0.1725,
          0.8418, 0.9924, 0.9256, 0.0101, 0.9966, 0.9899, 0.8390, 0.1749]]),
 tensor([[1.8542, 0.5357, 2.1737, 1.7848, 0.6546, 2.2226, 1.6778, 0.6439, 2.2048,
          1.8056, 0.6616, 2.3749, 1.8330, 0.7819, 2.1506]]))

In [88]:
asdf = torch.cat(coordinate_transform.forward(example_data)[:-1], dim=1)
print(coordinate_transform.forward(example_data))
print(asdf.shape)

(tensor([[0.1090, 0.1510, 0.1090, 0.1090, 0.1325, 0.1218, 0.1010, 0.1090, 0.1090,
         0.1090, 0.1225, 0.1338, 0.1009, 0.1461, 0.1090, 0.1090, 0.1090]]), tensor([[0.6112, 0.6477, 0.6035, 0.5986, 0.6898, 0.6902, 0.6519, 0.6180, 0.6038,
         0.6131, 0.6680, 0.6488, 0.6551, 0.6846, 0.6106, 0.6136, 0.6049]]), tensor([[0.3188, 0.0018, 0.8350, 0.1661, 0.1035, 0.4999, 0.0141, 0.9891, 0.1725,
         0.8418, 0.9924, 0.9256, 0.0101, 0.9966, 0.9899, 0.8390, 0.1749]]), tensor([[1.8542, 0.5357, 2.1737, 1.7848, 0.6546, 2.2226, 1.6778, 0.6439, 2.2048,
         1.8056, 0.6616, 2.3749, 1.8330, 0.7819, 2.1506]]), tensor([[24.1758]]))
torch.Size([1, 66])


In [92]:
asdf[:, :17], asdf[:, 17:34], asdf[:, 34:51], asdf[:, 51:67]

(tensor([[0.1090, 0.1510, 0.1090, 0.1090, 0.1325, 0.1218, 0.1010, 0.1090, 0.1090,
          0.1090, 0.1225, 0.1338, 0.1009, 0.1461, 0.1090, 0.1090, 0.1090]]),
 tensor([[0.6112, 0.6477, 0.6035, 0.5986, 0.6898, 0.6902, 0.6519, 0.6180, 0.6038,
          0.6131, 0.6680, 0.6488, 0.6551, 0.6846, 0.6106, 0.6136, 0.6049]]),
 tensor([[0.3188, 0.0018, 0.8350, 0.1661, 0.1035, 0.4999, 0.0141, 0.9891, 0.1725,
          0.8418, 0.9924, 0.9256, 0.0101, 0.9966, 0.9899, 0.8390, 0.1749]]),
 tensor([[1.8542, 0.5357, 2.1737, 1.7848, 0.6546, 2.2226, 1.6778, 0.6439, 2.2048,
          1.8056, 0.6616, 2.3749, 1.8330, 0.7819, 2.1506]]))