Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added modifed version of PMLP with edge attributes #8898

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion torch_geometric/nn/models/pmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import BatchNorm1d

from torch_geometric.nn import SimpleConv
from torch_geometric.nn import MessagePassing, SimpleConv
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops


class PMLP(torch.nn.Module):
Expand Down Expand Up @@ -99,3 +101,178 @@
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')


class EdgeConv(MessagePassing):
"""A message passing layer that uses edge features for convolution.

This layer extends PyTorch Geometric's `MessagePassing` class to perform
convolutions by incorporating edge attributes into the message passing process.
It aggregates messages using the mean aggregation function.

Args:
in_channels (int): Size of each input sample's features.
out_channels (int): Size of each output sample's features.
bias (bool, optional): If set to `False`, the layer will not learn an additive bias.
Default is `True`.

Methods:
forward(x, edge_index, edge_attr): Performs the forward pass of the layer.
message(x_j, edge_attr): Constructs messages to node i for each edge (j, i) in `edge_index`.
"""
def __init__(self, in_channels, out_channels, bias=True):
super(EdgeConv, self).__init__(aggr='mean') # Mean aggregation.
self.lin = Linear(in_channels + in_channels, out_channels, bias=bias)

Check warning on line 125 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L124-L125

Added lines #L124 - L125 were not covered by tests

def forward(self, x, edge_index, edge_attr):
"""Forward pass of the EdgeConv layer.

Args:
x (Tensor): Node feature matrix with shape [num_nodes, in_channels].
edge_index (LongTensor): Graph connectivity in COO format with shape [2, num_edges].
edge_attr (Tensor): Edge feature matrix with shape [num_edges, in_channels].

Returns:
Tensor: Updated node features with shape [num_nodes, out_channels].
"""
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

Check warning on line 139 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L138-L139

Added lines #L138 - L139 were not covered by tests

def message(self, x_j, edge_attr):
"""Constructs messages for each node based on edge attributes.

Args:
x_j (Tensor): Incoming features for each edge with shape [num_edges, in_channels].
edge_attr (Tensor): Edge features for each edge with shape [num_edges, in_channels].

Returns:
Tensor: Messages for each node with shape [num_edges, out_channels].
"""
tmp = torch.cat([x_j, edge_attr], dim=1)
return self.lin(tmp)

Check warning on line 152 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L151-L152

Added lines #L151 - L152 were not covered by tests

def __repr__(self) -> str:
"""Creates a string representation of this EdgeConv instance, showing the
key configurations of the layer.

Returns:
str: A string representation including the class name and key layer
configurations such as input and output channels.
"""
return (f'{self.__class__.__name__}('

Check warning on line 162 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L162

Added line #L162 was not covered by tests
f'in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'bias={self.lin.bias is not None})')


class PMLP_with_EdgeAttr(torch.nn.Module):
"""Propagational MLP with Edge Attributes (PMLP_with_EdgeAttr) model.

This model extends an MLP to incorporate edge attributes in its message passing mechanism,
using EdgeConv layers for convolution operations.

Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden layer sample.
out_channels (int): Size of each output sample.
num_layers (int): Number of layers in the MLP.
edge_attr_dim (int): Dimensionality of edge attributes.
dropout (float, optional): Dropout probability. Default is 0.0.
norm (bool, optional): If `True`, applies batch normalization. Default is `True`.
bias (bool, optional): If `True`, layers will learn an additive bias. Default is `True`.

Methods:
forward(x, edge_index, edge_attr): Performs the forward pass of the model.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
edge_attr_dim: int,
dropout: float = 0.0,
norm: bool = True,
bias: bool = True,
):
super(PMLP_with_EdgeAttr, self).__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.dropout = dropout
self.bias = bias

Check warning on line 204 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L198-L204

Added lines #L198 - L204 were not covered by tests

self.lins = torch.nn.ModuleList()
self.lins.append(Linear(in_channels, hidden_channels, bias))
for _ in range(num_layers - 2):
self.lins.append(Linear(hidden_channels, hidden_channels, bias))
self.lins.append(Linear(hidden_channels, out_channels, bias))

Check warning on line 210 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L206-L210

Added lines #L206 - L210 were not covered by tests

self.norm = None
if norm:
self.norm = BatchNorm1d(hidden_channels, affine=True,

Check warning on line 214 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L212-L214

Added lines #L212 - L214 were not covered by tests
track_running_stats=True)

self.conv = EdgeConv(hidden_channels + edge_attr_dim, hidden_channels,

Check warning on line 217 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L217

Added line #L217 was not covered by tests
bias=bias)
self.reset_parameters()

Check warning on line 219 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L219

Added line #L219 was not covered by tests

def reset_parameters(self):
"""Initializes or resets the model parameters."""
for lin in self.lins:
torch.nn.init.xavier_uniform_(

Check warning on line 224 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L223-L224

Added lines #L223 - L224 were not covered by tests
lin.weight, gain=torch.nn.init.calculate_gain('relu'))
if self.bias:
torch.nn.init.zeros_(lin.bias)
self.conv.reset_parameters()

Check warning on line 228 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L226-L228

Added lines #L226 - L228 were not covered by tests

def forward(self, x, edge_index, edge_attr):
"""Forward pass of the PMLP_with_EdgeAttr model.

Args:
x (Tensor): Node feature matrix with shape [num_nodes, in_channels].
edge_index (LongTensor): Graph connectivity in COO format with shape [2, num_edges].
edge_attr (Tensor, optional): Edge feature matrix with shape [num_edges, edge_attr_dim].

Returns:
Tensor: Updated node features with shape [num_nodes, out_channels].
"""
if not self.training and edge_index is None:
raise ValueError(

Check warning on line 242 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L241-L242

Added lines #L241 - L242 were not covered by tests
f"'edge_index' needs to be present during inference in '{self.__class__.__name__}'"
)

for i in range(self.num_layers):
x = x @ self.lins[i].weight.t()
if not self.training:
x = self.conv(x, edge_index, edge_attr)
if self.bias:
x = x + self.lins[i].bias
if i != self.num_layers - 1:
if self.norm is not None:
x = self.norm(x)
x = F.tanh(x)
x = F.dropout(x, p=self.dropout, training=self.training)

Check warning on line 256 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L246-L256

Added lines #L246 - L256 were not covered by tests

return x

Check warning on line 258 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L258

Added line #L258 was not covered by tests

def __repr__(self) -> str:
"""Creates a string representation of this PMLP_with_EdgeAttr instance,
showing the key configurations of the model.

Returns:
str: A string representation including the class name and key model
configurations such as input channels, hidden channels, output
channels, and the number of layers.
"""
return (

Check warning on line 269 in torch_geometric/nn/models/pmlp.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/pmlp.py#L269

Added line #L269 was not covered by tests
f'{self.__class__.__name__}('
f'in_channels={self.in_channels}, '
f'hidden_channels={self.hidden_channels}, '
f'out_channels={self.out_channels}, '
f'num_layers={self.num_layers}, '
f'edge_attr_dim={self.conv.in_channels - self.hidden_channels}, '
f'dropout={self.dropout}, '
f'norm={self.norm is not None}, '
f'bias={self.bias})')
Loading