In [1]:
from generate_alkanes import generate_branched_alkane
from torsion_utils import get_torsion_tuples
from energy import get_conformer_energy

import numpy as np
import rdkit.Chem as Chem
import py3Dmol

In [2]:
mol = generate_branched_alkane(14)
Chem.AllChem.EmbedMultipleConfs(mol, numConfs=1)
Chem.rdForceFieldHelpers.MMFFOptimizeMoleculeConfs(mol, nonBondedThresh=10., )

[(1, 20.185011342132288)]

In [3]:
torsion_angles, _ = get_torsion_tuples(mol)

In [4]:
conf = mol.GetConformers()[0]

In [5]:
sample_degs = [30 for _ in range(len(torsion_angles))]

In [6]:
for idx, tors in enumerate(torsion_angles):
    Chem.rdMolTransforms.SetDihedralDeg(conf, *tors, float(sample_degs[idx]))

In [7]:
def drawConformer(mol, confIds=[-1], size=(300, 300), style="stick"):
    """Displays interactive 3-dimensional representation of specified conformer.

    Parameters
    ----------
    mol : RDKit Mol object
        The molecule containing the conformer to be displayed.
    confId : int
        The ID of the conformer to be displayed.
    size : Tuple[int, int]
        The size of the display (width, height).
    style: str
        The drawing style for displaying the molecule. Can be sphere, stick, line, cross, cartoon, and surface.
    """
    view = py3Dmol.view(width=size[0], height=size[1])
    for confId in confIds:
        block = Chem.MolToMolBlock(mol, confId=confId)
        view.addModel(block, 'mol')
    view.setStyle({style : {}})
    view.zoomTo()
    return view

In [8]:
drawConformer(mol, confIds=list(range(1)))

<py3Dmol.view at 0x7f8138dde910>

In [10]:
loss = get_conformer_energy(mol)

In [12]:
# optimizer.zero_grad()
loss.backward()
# nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip)
# self.optimizer.step()

AttributeError: 'float' object has no attribute 'backward'

In [14]:
from RTGN import RTGNGat

In [17]:
model = RTGNGat(6, 128, node_dim=5)

In [18]:
model

RTGNGat(
  (gat): GAT(
    (fc): Linear(in_features=5, out_features=128, bias=True)
    (conv_layers): ModuleList(
      (0): GATConv(128, 128, heads=2)
      (1): GATConv(256, 128, heads=2)
      (2): GATConv(256, 128, heads=2)
      (3): GATConv(256, 128, heads=2)
      (4): GATConv(256, 128, heads=2)
      (5): GATConv(256, 128, heads=1)
    )
  )
  (set2set): Set2Set(128, 256)
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)