In [10]:
import torch
import dgl
import espaloma as esp
import math

Using backend: pytorch


In [11]:
# layer
layer = esp.nn.layers.dgl_legacy.gn("SAGEConv")

# representation
representation = esp.nn.Sequential(layer, config=[128, "relu", 128, "relu", 128, "relu"])
janossy_config = [128, "relu", 128, "relu", 128, "relu", 128, "relu"]
readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128, config=janossy_config,
    out_features={
            2: {'log_coefficients': 2},
            3: {'log_coefficients': 2},
            4: {'k': 6},
    },
)

readout_improper = esp.nn.readout.janossy.JanossyPoolingImproper(
    in_features=128, config=janossy_config
)

class ExpCoeff(torch.nn.Module):
    def forward(self, g):
        g.nodes['n2'].data['coefficients'] = g.nodes['n2'].data['log_coefficients'].exp()
        g.nodes['n3'].data['coefficients'] = g.nodes['n3'].data['log_coefficients'].exp()
        g.nodes['n2'].data['k'], g.nodes['n2'].data['eq'] = esp.mm.functional.linear_mixture_to_original(
            g.nodes['n2'].data['coefficients'][:, 0][:, None],
            g.nodes['n2'].data['coefficients'][:, 1][:, None],
            1.5, 6.0,
        )

        g.nodes['n3'].data['k'], g.nodes['n3'].data['eq'] = esp.mm.functional.linear_mixture_to_original(
            g.nodes['n3'].data['coefficients'][:, 0][:, None],
            g.nodes['n3'].data['coefficients'][:, 1][:, None],
            0.0, math.pi
        )
        
        return g

class GetLoss(torch.nn.Module):
    def forward(self, g):
        return torch.nn.MSELoss()(
            g.nodes['g'].data['u'] - g.nodes['g'].data['u'].mean(),
            g.nodes['g'].data['u_ref'] - g.nodes['g'].data['u_ref'].mean(),
        )

net = torch.nn.Sequential(
        representation,
        readout,
        readout_improper,
        ExpCoeff(),
        # esp.mm.geometry.GeometryInGraph(),
        # esp.mm.energy.EnergyInGraph(terms=["n2", "n3", "n4", "n4_improper"]),
)

In [12]:
net.load_state_dict(
    torch.load(
        "/Users/wangy1/Downloads/net2810.th",
        map_location="cpu",
    )
)

<All keys matched successfully>

In [24]:
from rdkit import Chem
from openff.toolkit.topology import Molecule
ligand = esp.Graph(Molecule.from_rdkit(next(iter(Chem.SDMolSupplier("ligand.sdf")))))
protein = esp.Graph(Molecule.from_rdkit(Chem.MolFromPDBFile("protein.pdb"), allow_undefined_stereo=True))

 - Atom C (index 0)
 - Atom C (index 282)
 - Atom C (index 285)
 - Atom C (index 289)



KeyboardInterrupt: 