In [7]:
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.transforms import RadiusGraph, NormalizeScale
from typing import Any
from torch_geometric.datasets import MD17
import torch

In [8]:
class NormalizeEnergy(T.BaseTransform):
    def __call__(self, data: Any) -> Any:
        data.energy = torch.div(data.energy, -406757.5938)
        # -406757.5938 is the minimum energy (i.e., the maximum absolute value of energy) in the 3 datasets.
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [9]:
class NormalizeForce(T.BaseTransform):
    def __call__(self, data: Any) -> Any:
        data.force = torch.div(data.force, -406757.5938)
        # -406757.5938 is the minimum energy (i.e., the maximum absolute value of energy) in the 3 datasets.
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [10]:
pre_transform = T.Compose([RadiusGraph(1.8100), NormalizeScale(), NormalizeForce(), NormalizeEnergy()])
# 1.8100 is the maximum distance between two connected nodes in QM9

In [11]:
benzene_dataset = MD17(root='benzene/', name='benzene', pre_transform=pre_transform)
uracil_dataset = MD17(root='uracil/', name='uracil', pre_transform=pre_transform)
aspirin_dataset = MD17(root='aspirin/', name='aspirin', pre_transform=pre_transform)

Downloading http://quantum-machine.org/gdml/data/npz/md17_benzene2017.npz
Processing...
Done!
Downloading http://quantum-machine.org/gdml/data/npz/md17_uracil.npz
Processing...
Done!
Downloading http://quantum-machine.org/gdml/data/npz/md17_aspirin.npz
Processing...
Done!
