In [1]:

"""
idp_rl Analysis Example
=============================
View the notebook in Google Colab: https://drive.google.com/drive/folders/1WAnTv4SGwEQHHqyMcbrExzUob_mOfTcM?usp=sharing

This notebook gives examples of how the analysis functions in idp_rl can be used.
The example data (example_data1.pickle, example_data2.pickle and example_data3.pickle) were all
data taken from different evaluation episodes of an actual run.

The full API reference for the analysis module can be found at: 
"""
import sys
sys.path.append("..")

import copy
import os
import random
import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="darkgrid")
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from idp_rl import utils
from idp_rl.agents import PPORecurrentExternalCurriculumAgent
from idp_rl.config import Config
from idp_rl.environments import Task
from idp_rl.models import RTGNRecurrent
from idp_rl.environments.environment_components.forcefield_mixins import CharMMMixin

from idp_rl.molecule_generation.generate_chignolin import generate_chignolin
from idp_rl.molecule_generation.generate_molecule_config import config_from_rdkit
from idp_rl.utils.misc_utils import to_np

import logging
logging.basicConfig(level=logging.DEBUG)

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fc685eab0f0>

In [3]:
import os
os.environ["MP_RANK"] = "0"
os.getenv("MP_RANK")

'0'

In [4]:
filename = "../GYDPETGTWG.pkl"
with open(filename, 'rb') as file:
    mol_config = pickle.load(file)

network = RTGNRecurrent(6, 128, edge_dim=6, node_dim=5).to("cuda:0")

env = Task('GibbsScorePruningEnvCharmm-v0', num_envs=1, mol_config=mol_config)
state = env.reset()
done = False
rstates = None
info = None

DEBUG:root:initializing conformer environment
[13:35:22] Molecule does not have explicit Hs. Consider calling AddHs()
DEBUG:root:reset called
DEBUG:root:reset called


In [14]:
while not done:
    with torch.no_grad():
        prediction, rstates = network(state, rstates)
        action, rstates = prediction['a'], rstates
    state, reward, done, info = env.step(to_np(action))

INFO:root:step 130 reward 0.0
INFO:root:step 131 reward 0.00017485828695318383
INFO:root:step 132 reward 0.0
INFO:root:step 133 reward 0.0
INFO:root:step 134 reward 0.0
INFO:root:step 135 reward 0.0
INFO:root:step 136 reward 0.0
INFO:root:step 137 reward 0.0
INFO:root:step 138 reward 0.0
INFO:root:step 139 reward 0.0
INFO:root:step 140 reward 0.0
INFO:root:step 141 reward 0.0
INFO:root:step 142 reward 0.0
INFO:root:step 143 reward 0.0
INFO:root:step 144 reward 0.0
INFO:root:step 145 reward 0.0
INFO:root:step 146 reward 0.0
INFO:root:step 147 reward 0.0
INFO:root:step 148 reward 0.0
INFO:root:step 149 reward 0.0
INFO:root:step 150 reward 0.0
INFO:root:step 151 reward 0.0
INFO:root:step 152 reward 0.0
INFO:root:step 153 reward 0.0
INFO:root:step 154 reward 0.0
INFO:root:step 155 reward 0.0
INFO:root:step 156 reward 0.0
INFO:root:step 157 reward 0.0
INFO:root:step 158 reward 0.0
INFO:root:step 159 reward 0.0
INFO:root:step 160 reward 0.0
INFO:root:step 161 reward 0.0
INFO:root:step 162 re

In [None]:
from rdkit.Chem import TorsionFingerprints
from rdkit.Chem import rdMolTransforms

def get_gly7_dihedrals(version_mol):
    nonring, ring = TorsionFingerprints.CalculateTorsionLists(version_mol)
    torsions = [nr[0] for nr in nonring]

    full_dihedrals = []
    for conf_id in range(version_mol.GetNumConformers()):
        conf = version_mol.GetConformer(conf_id)
        phi = rdMolTransforms.GetDihedralDeg(conf, *torsions[24][0])
        psi = rdMolTransforms.GetDihedralDeg(conf, *torsions[25][0])
        full_dihedrals.append([phi, psi])
    full_dihedrals = np.array(full_dihedrals)

    return full_dihedrals

In [None]:
baseline = env.envs[0].episode_info["mol"]
baseline_dihedrals = get_gly7_dihedrals(baseline)
plt.scatter(baseline_dihedrals[:,0], baseline_dihedrals[:,1])