In [1]:
import bgflow as bg
import bgmol
import torch
import FrEIA.modules as Fm
import FrEIA.framework as Ff
from typing import Tuple, Iterable, List

import torch.nn as nn
from torch import Tensor
from bgmol.systems import peptide
import mdtraj

In [2]:
molecule_path = "../data/Molecules/Dialanine"
ala_data = bgmol.datasets.Ala2TSF300(download=False, read=True, root=molecule_path)
system = ala_data.system

Using downloaded and verified file: /tmp/alanine-dipeptide-nowater.pdb


In [3]:
zfactory = bgmol.zmatrix.ZMatrixFactory(system.mdtraj_topology)
zmatrix, fixed_atoms = zfactory.build_naive()

In [4]:
ic_layer = bg.GlobalInternalCoordinateTransformation(zmatrix)

In [5]:
xyz = torch.Tensor(ala_data.coordinates.reshape(-1, ala_data.dim))[:100]
xyz.shape

torch.Size([100, 66])

In [6]:
ic_layer(xyz)[0].shape

torch.Size([100, 21])

In [7]:
actnorm = Fm.ActNorm([(33,)])

In [8]:
isinstance(ic_layer, torch.nn.Module)

True

In [9]:
class ExperimentLayer(Fm.InvertibleModule):
    def __init__(self, dims_in):
        super().__init__(dims_in)
        print(dims_in)
        
    def output_dims(self, input_dims):
        return input_dims

In [10]:
inn = Ff.SequenceINN(66)
inn.append(ExperimentLayer)

[(66,)]


In [11]:
# noinspection PyShadowingNames
class ICTransform(Fm.InvertibleModule):
    def __init__(self, dims_in, dims_c=None, system=None):
        super().__init__(dims_in, dims_c)
        zfactory = bgmol.zmatrix.ZMatrixFactory(system.mdtraj_topology)
        zmatrix, fixed_atoms = zfactory.build_naive()
        self.bg_layer = bg.GlobalInternalCoordinateTransformation(zmatrix)
        
    def output_dims(self, input_dims):
        return input_dims
    
    def forward(
            self, x_or_z: Iterable[Tensor], 
            c: Iterable[Tensor] = None,
            rev: bool = False, 
            jac: bool = True
        ) -> Tuple[Tuple[Tensor], Tensor]:
        x_or_z = x_or_z[0]
        if not rev:
            bonds, angles, torsions, loc, rot, jac_det = self.bg_layer._forward(x_or_z)
            origin = torch.zeros([x_or_z.shape[0], 6], device=x_or_z.device)
            out = torch.cat([bonds, angles, torsions, origin], dim=1)
        else:
            bonds = x_or_z[:, :self.bg_layer.dim_bonds]
            angles = x_or_z[:, self.bg_layer.dim_bonds:self.bg_layer.dim_bonds + self.bg_layer.dim_angles]
            torsions = x_or_z[:, self.bg_layer.dim_bonds + self.bg_layer.dim_angles:self.bg_layer.dim_bonds + self.bg_layer.dim_angles + self.bg_layer.dim_torsions]
            loc = torch.zeros([x_or_z.shape[0], 1, 3], device=x_or_z.device)
            rot = torch.zeros([x_or_z.shape[0], 3], device=x_or_z.device)
            out, jac_det = self.bg_layer._inverse(bonds, angles, torsions, loc, rot)
        return (out,), jac_det

In [12]:
inn = Ff.SequenceINN(66)
inn.append(ICTransform, system=system)

In [13]:
ic = inn(torch.Tensor(xyz))[0]

In [14]:
reconstructed = inn(ic, rev=True)[0]

In [15]:
from src.lightning_bg.evaluate import ShowTraj
w = ShowTraj(reconstructed, system)
w



ShowTraj(children=(NGLWidget(max_frame=99), BoundedFloatText(value=-26.55, description='Energy:', max=1e+50, m…

In [16]:
molecule_path = "../data/Molecules/OppA/Peptides/1b4z"
with open(molecule_path.rstrip("/") + "/top.pdb", 'r') as file:
    lines = file.readlines()
    lastline = lines[-3]
    n_atoms = int(lastline[4:11].strip())
    n_res = int(lastline[22:26].strip())
    print(f"Number of atoms: {n_atoms}, residues: {n_res}")

# define system & energy model
system = peptide(short=False, n_atoms=n_atoms, n_res=n_res, filepath=molecule_path)
system.reinitialize_energy_model(temperature=300., n_workers=1)
energy_model = system.energy_model

traj = mdtraj.load_hdf5(molecule_path + "/traj.h5")
coordinates = traj.xyz[:10]


Number of atoms: 59, residues: 3


In [17]:
inn = Ff.SequenceINN(59*3)
inn.append(ICTransform, system=system)



In [18]:
ic = inn(torch.Tensor(coordinates))[0]

In [19]:
reconstructed = inn(ic, rev=True)[0]

In [20]:
from src.lightning_bg.evaluate import ShowTraj
w = ShowTraj(reconstructed, system)
w

ShowTraj(children=(NGLWidget(max_frame=9), BoundedFloatText(value=-402.54, description='Energy:', max=1e+50, m…