Skip to content

Commit

Permalink
Created Coulomb prior (#159)
Browse files Browse the repository at this point in the history
* Created Coulomb prior

* Bug fix

* Fixed exception when a sample has no interactions within the cutoff

* Fixed error resuming training

* Fixed torchscript errors

* Fixed hardcoded neighborlist size
  • Loading branch information
peastman committed Jun 7, 2023
1 parent e20876f commit 237b4fe
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 9 deletions.
28 changes: 27 additions & 1 deletion tests/test_priors.py
Expand Up @@ -5,7 +5,7 @@
from torchmdnet import models
from torchmdnet.models.model import create_model, create_prior_models
from torchmdnet.module import LNNP
from torchmdnet.priors import Atomref, D2, ZBL
from torchmdnet.priors import Atomref, D2, ZBL, Coulomb
from torch_scatter import scatter
from utils import load_example_args, create_example_batch, DummyDataset
from os.path import dirname, join
Expand Down Expand Up @@ -63,6 +63,32 @@ def compute_interaction(pos1, pos2, z1, z2):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
torch.testing.assert_allclose(expected, energy)

def test_coulomb():
pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=torch.float32) # Atom positions in nm
charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=torch.float32) # Partial charges
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
distance_scale = 1e-9 # Convert nm to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
alpha = 1.8

# Use the Coulomb class to compute the energy.

coulomb = Coulomb(alpha, 5, distance_scale=distance_scale, energy_scale=energy_scale)
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), {'partial_charges':charge})[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
return torch.erf(alpha*r)*138.935*z1*z2/r

expected = 0
for i in range(len(pos)):
for j in range(i):
expected += compute_interaction(pos[i], pos[j], charge[i], charge[j])
torch.testing.assert_allclose(expected, energy)

def test_multiple_priors():
# Create a model from a config file.

Expand Down
3 changes: 2 additions & 1 deletion torchmdnet/priors/__init__.py
@@ -1,5 +1,6 @@
from torchmdnet.priors.atomref import Atomref
from torchmdnet.priors.d2 import D2
from torchmdnet.priors.zbl import ZBL
from torchmdnet.priors.coulomb import Coulomb

__all__ = ["Atomref", "D2", "ZBL"]
__all__ = ["Atomref", "D2", "ZBL", "Coulomb"]
7 changes: 4 additions & 3 deletions torchmdnet/priors/base.py
@@ -1,4 +1,5 @@
from torch import nn
from torch import nn, Tensor
from typing import Optional, Dict


class BasePrior(nn.Module):
Expand All @@ -18,7 +19,7 @@ def get_init_args(self):
"""
return {}

def pre_reduce(self, x, z, pos, batch, extra_args):
def pre_reduce(self, x, z, pos, batch, extra_args: Optional[Dict[str, Tensor]]):
r"""Pre-reduce method of the prior model.
Args:
Expand All @@ -33,7 +34,7 @@ def pre_reduce(self, x, z, pos, batch, extra_args):
"""
return x

def post_reduce(self, y, z, pos, batch, extra_args):
def post_reduce(self, y, z, pos, batch, extra_args: Optional[Dict[str, Tensor]]):
r"""Post-reduce method of the prior model.
Args:
Expand Down
50 changes: 50 additions & 0 deletions torchmdnet/priors/coulomb.py
@@ -0,0 +1,50 @@
import torch
from torchmdnet.priors.base import BasePrior
from torchmdnet.models.utils import Distance
from torch_scatter import scatter
from typing import Optional, Dict

class Coulomb(BasePrior):
"""This class implements a Coulomb potential, scaled by erf(alpha*r) to reduce its
effect at short distances.
To use this prior, the Dataset must include a field called `partial_charges` with each sample,
containing the partial charge for each atom. It also must provide the following attributes.
distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters
energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol)
"""
def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, dataset=None):
super(Coulomb, self).__init__()
if distance_scale is None:
distance_scale = dataset.distance_scale
if energy_scale is None:
energy_scale = dataset.energy_scale
self.distance = Distance(0, torch.inf, max_num_neighbors=max_num_neighbors)
self.alpha = alpha
self.max_num_neighbors = max_num_neighbors
self.distance_scale = float(distance_scale)
self.energy_scale = float(energy_scale)

def get_init_args(self):
return {'alpha': self.alpha,
'max_num_neighbors': self.max_num_neighbors,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale}

def reset_parameters(self):
pass

def post_reduce(self, y, z, pos, batch, extra_args: Optional[Dict[str, torch.Tensor]]):
# Convert to nm and calculate distance.
x = 1e9*self.distance_scale*pos
alpha = self.alpha/(1e9*self.distance_scale)
edge_index, distance, _ = self.distance(x, batch)

# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
q = extra_args['partial_charges'][edge_index]
energy = torch.erf(alpha*distance)*q[0]*q[1]/distance
energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum")
energy = energy.reshape(y.shape)
return y + energy
17 changes: 13 additions & 4 deletions torchmdnet/priors/zbl.py
@@ -1,6 +1,8 @@
import torch
from torchmdnet.priors.base import BasePrior
from torchmdnet.models.utils import Distance, CosineCutoff
from torch_scatter import scatter
from typing import Optional, Dict

class ZBL(BasePrior):
"""This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion.
Expand Down Expand Up @@ -28,8 +30,8 @@ def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, dista
self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance)
self.cutoff_distance = cutoff_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = distance_scale
self.energy_scale = energy_scale
self.distance_scale = float(distance_scale)
self.energy_scale = float(energy_scale)

def get_init_args(self):
return {'cutoff_distance': self.cutoff_distance,
Expand All @@ -41,8 +43,10 @@ def get_init_args(self):
def reset_parameters(self):
pass

def post_reduce(self, y, z, pos, batch, extra_args):
def post_reduce(self, y, z, pos, batch, extra_args: Optional[Dict[str, torch.Tensor]]):
edge_index, distance, _ = self.distance(pos, batch)
if edge_index.shape[1] == 0:
return y
atomic_number = self.atomic_number[z[edge_index]]
# 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential.
a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23)
Expand All @@ -51,4 +55,9 @@ def post_reduce(self, y, z, pos, batch, extra_args):
f *= self.cutoff(distance)
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1)
energy = f*atomic_number[0]*atomic_number[1]/distance
energy = 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum")
if energy.shape[0] < y.shape[0]:
energy = torch.nn.functional.pad(energy, (0, y.shape[0]-energy.shape[0]))
energy = energy.reshape(y.shape)
return y + energy

0 comments on commit 237b4fe

Please sign in to comment.