Skip to content
Permalink
Browse files

rename cgcnn_conv -> cg_conv, cleanup code

  • Loading branch information...
rusty1s committed Oct 15, 2019
1 parent 8aae2e5 commit 096754258a0f7ee3766a93bba010b9f6ab20d0f0
@@ -61,7 +61,7 @@ In detail, the following methods are currently implemented:
* **[GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv)** from Kipf and Welling: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) (ICLR 2017)
* **[ChebConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.ChebConv)** from Defferrard *et al.*: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375) (NIPS 2016)
* **[NNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.NNConv)** from Gilmer *et al.*: [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) (ICML 2017)
* **[CGCNNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.CGCNNConv)** from Xie and Grossman: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301) (Physical Review Letters 120, 2018)
* **[CGConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.CGConv)** from Xie and Grossman: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301) (Physical Review Letters 120, 2018)
* **[ECConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.ECConv)** from Simonovsky and Komodakis: [Edge-Conditioned Convolution on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017)
* **[GATConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv)** from Veličković *et al.*: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) (ICLR 2018)
* **[SAGEConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.SAGEConv)** from Hamilton *et al.*: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) (NIPS 2017)
@@ -0,0 +1,13 @@
import torch
from torch_geometric.nn import CGConv


def test_cg_conv():
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
num_nodes = edge_index.max().item() + 1
x = torch.randn((num_nodes, 16))
pseudo = torch.rand((edge_index.size(1), 3))

conv = CGConv(16, 3)
assert conv.__repr__() == 'CGConv(16, 16, dim=3)'
assert conv(x, edge_index, pseudo).size() == (num_nodes, 16)

This file was deleted.

@@ -11,6 +11,6 @@ def test_spline_conv():
pseudo = torch.rand((edge_index.size(1), 3))

conv = SplineConv(in_channels, out_channels, dim=3, kernel_size=5)
assert conv.__repr__() == 'SplineConv(16, 32)'
assert conv.__repr__() == 'SplineConv(16, 32, dim=3)'
with torch_geometric.debug():
assert conv(x, edge_index, pseudo).size() == (num_nodes, out_channels)
@@ -18,7 +18,7 @@
from .gmm_conv import GMMConv
from .spline_conv import SplineConv
from .nn_conv import NNConv, ECConv
from .cgcnn_conv import CGCNNConv
from .cg_conv import CGConv
from .edge_conv import EdgeConv, DynamicEdgeConv
from .x_conv import XConv
from .ppf_conv import PPFConv
@@ -47,7 +47,7 @@
'SplineConv',
'NNConv',
'ECConv',
'CGCNNConv',
'CGConv',
'EdgeConv',
'DynamicEdgeConv',
'XConv',
@@ -0,0 +1,66 @@
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing


class CGConv(MessagePassing):
r"""The crystal graph convolutional operator from the
`"Crystal Graph Convolutional Neural Networks for an
Accurate and Interpretable Prediction of Material Properties"
<https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`_
paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \sigma \left(
\sigma \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot
g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)
where :math:`\mathbf{z}_{i,j} = \[ \mathbf{x}_i, \mathbf{x}_j,
\mathbf{e}_{i,j} \]` denotes the concatenation of central node features,
neighboring node features and edge features.
In addition, :math:`\sigma` and :math:`g` denote the sigmoid and softplus
functions, respectively.
Args:
channels (int): Size of each input sample.
dim (int): Edge feature dimensionality.
aggr (string, optional): The aggregation operator to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"mean"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, channels, dim, aggr='add', bias=True, **kwargs):
super(CGConv, self).__init__(aggr=aggr, **kwargs)
self.in_channels = channels
self.out_channels = channels
self.dim = dim

self.lin_f = Linear(2 * channels + dim, channels, bias=bias)
self.lin_s = Linear(2 * channels + dim, channels, bias=bias)

self.reset_parameters()

def reset_parameters(self):
self.lin_f.reset_parameters()
self.lin_s.reset_parameters()

def forward(self, x, edge_index, edge_attr):
""""""
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_i, x_j, edge_attr):
z = torch.cat([x_i, x_j, edge_attr], dim=-1)
return self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))

def update(self, aggr_out, x):
return aggr_out + x

def __repr__(self):
return '{}({}, {}, dim={})'.format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.dim)

This file was deleted.

@@ -53,7 +53,6 @@ class SplineConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, in_channels, out_channels, dim, kernel_size,
is_open_spline=True, degree=1, aggr='mean', root_weight=True,
bias=True, **kwargs):
@@ -145,5 +144,6 @@ def update(self, aggr_out, x):
return aggr_out

def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
return '{}({}, {}, dim={})'.format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.dim)

0 comments on commit 0967542

Please sign in to comment.
You can’t perform that action at this time.