diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 89b1cfe4c..956f683bd 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -55,34 +55,31 @@ def tensor_norm(tensor): class TensorNet(nn.Module): - r"""TensorNet's architecture. + r"""TensorNet's architecture, from TensorNet: Cartesian Tensor Representations + for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis. Args: hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_layers (int, optional): The number of interaction layers. - (default: :obj:`6`) + (default: :obj:`2`) num_rbf (int, optional): The number of radial basis functions :math:`\mu`. - (default: :obj:`50`) + (default: :obj:`32`) rbf_type (string, optional): The type of radial basis function to use. (default: :obj:`"expnorm"`) trainable_rbf (bool, optional): Whether to train RBF parameters with - backpropagation. (default: :obj:`True`) + backpropagation. (default: :obj:`False`) activation (string, optional): The type of activation function to use. (default: :obj:`"silu"`) cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. (default: :obj:`0.0`) cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. - (default: :obj:`5.0`) + (default: :obj:`4.5`) max_z (int, optional): Maximum atomic number. Used for initializing embeddings. - (default: :obj:`100`) + (default: :obj:`128`) max_num_neighbors (int, optional): Maximum number of neighbors to return for a given node/atom when constructing the molecular graph during forward passes. - This attribute is passed to the torch_cluster radius_graph routine keyword - max_num_neighbors, which normally defaults to 32. Users should set this to - higher values if they are using higher upper distance cutoffs and expect more - than 32 neighbors per node/atom. - (default: :obj:`32`) + (default: :obj:`64`) equivariance_invariance_group (string, optional): Group under whose action on input positions internal tensor features will be equivariant and scalar predictions will be invariant. O(3) or SO(3).