In [20]:
import sys
import os
sys.path.append("/home/kortkamp/helix/code/lightning_bg/src/lightning_bg")
from evaluate import Evaluator, ShowTraj
from architectures import get_network_by_name, BaseHParams
from utils import dataset_setter, Alignment
import bgmol
from bgmol.systems.peptide import peptide
import yaml
import mdtraj

In [21]:
data_path = "./data"
param_path = "./params"
molecule = "/Dialanine" # change this for different molecule 
experiment_name = "RNVPrvkl_latent.yaml" # change this for different experiment
version = 1 # change this for different version

In [22]:
import os
os.getcwd()

'/home/kortkamp/helix/code/lightning_bg'

In [23]:
molecule_path = os.path.join(data_path, "Molecules", molecule.lstrip("/"))
molecule_path

'./data/Molecules/Dialanine'

In [24]:
if "Dialanine" in molecule:
    is_data_here = os.path.exists(molecule_path + "/Ala2TSF300.npy")
    ala_data = bgmol.datasets.Ala2TSF300(download=not is_data_here, read=True, root=molecule_path)
    system = ala_data.system
    n_atoms = 22
else:
    with open(molecule_path.rstrip("/") + "/top.pdb", 'r') as file:
        lines = file.readlines()
        lastline = lines[-3]
        n_atoms = int(lastline[4:11].strip())
        n_res = int(lastline[22:26].strip())
        print(n_atoms, n_res)

    # define system & energy model
    system = peptide(short=False, n_atoms=n_atoms, n_res=n_res, filepath=molecule_path)
system.reinitialize_energy_model(n_workers=1)

Using downloaded and verified file: /tmp/alanine-dipeptide-nowater.pdb


In [25]:
import ipdb

In [26]:
experiment_path = os.path.join(param_path, molecule.lstrip("/"), experiment_name)
with open(experiment_path) as f:  # TODO: this is stupid bc there is a hparams file in the version folder. use it instead!!
    params = yaml.load(f, yaml.FullLoader)  # load the parameters
ModelClass = get_network_by_name(params['network_name'])
ParamClass = ModelClass.hparams_type
params['network_params']['n_dims'] = n_atoms * 3
hparams = ParamClass(**params['network_params'])
checkpoint = os.path.join(data_path, "lightning_logs", molecule.lstrip("/"), experiment_name[:-5], f"version_{version}", "checkpoints/last.ckpt")
# checkpoint = "/home/kortkamp/helix/code/lightning_bg/data/lightning_logs/OppA/Peptides/1b4z/params/OppA/1b4z/RNVPfwkl20000/version_0/checkpoints/last.ckpt"

if "Dialanine" in molecule:
    coordinates = ala_data.coordinates
else:
    traj = mdtraj.load_hdf5(molecule_path + "/traj.h5")
    coordinates = traj.xyz
train_split = params['training_params']['train_split']    
train_data, val_data, test_data = dataset_setter(coordinates, system, val_split=(.8 - train_split), test_split=.2, seed=42)

model = ModelClass.load_from_checkpoint(checkpoint, hparams=hparams, train_data=train_data, val_data=val_data, energy_function=system.energy_model, alignment_penalty=Alignment(system, train_data.reference_molecule).penalty)

In [27]:
x = model.sample((100,))
w = ShowTraj(x, system)

ValueError: coordinate "-180170300.0" could not be represnted in a width-8 field

In [None]:
w

In [30]:
E = Evaluator(model, system)
E.energy_plot(rg=[-2000, 2000])

KeyboardInterrupt: 

In [31]:
system.energy_model.energy(x)

tensor([[1.2727e+09],
        [1.3841e+07],
        [9.6622e+08],
        [7.3186e+08],
        [5.0869e+09],
        [2.5424e+09],
        [1.5464e+09],
        [1.3841e+07],
        [2.3681e+08],
        [2.0037e+08],
        [2.2021e+08],
        [2.7355e+08],
        [1.0900e+09],
        [4.3428e+06],
        [1.3553e+10],
        [6.3093e+08],
        [8.6662e+08],
        [2.4480e+08],
        [4.0838e+07],
        [7.2236e+08],
        [9.3676e+09],
        [2.0408e+09],
        [2.8140e+09],
        [2.2422e+07],
        [4.4661e+08],
        [3.8149e+06],
        [1.0776e+08],
        [3.3178e+07],
        [2.0173e+10],
        [3.2151e+08],
        [8.1632e+06],
        [2.7661e+08],
        [1.0834e+10],
        [1.2778e+09],
        [3.0674e+10],
        [1.9075e+09],
        [7.0283e+09],
        [6.7176e+06],
        [7.5090e+08],
        [2.1132e+09],
        [4.9134e+09],
        [3.3949e+10],
        [7.4719e+08],
        [1.7846e+07],
        [6.3202e+08],
        [2