Skip to content

Commit

Permalink
Update egnn_pytorch_geometric.py
Browse files Browse the repository at this point in the history
This is due to a naming convention change in torch_geometric v2.3.0, where __attribute__ changed to _attribute. This means attributes like __check_input__ and __user_args__ have been renamed to _check_input and _user_args respectively. More details can be found in this pull request: pyg-team/pytorch_geometric#6999
  • Loading branch information
souramoo authored Nov 17, 2023
1 parent 23e725f commit 3d3e033
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions egnn_pytorch/egnn_pytorch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
size = self.__check_input__(edge_index, size)
coll_dict = self.__collect__(self.__user_args__,
size = self._check_input(edge_index, size)
coll_dict = self._collect(self._user_args,
edge_index, size, kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
Expand Down Expand Up @@ -436,4 +436,4 @@ def forward(self, x, edge_index, batch, edge_attr,
return x

def __repr__(self):
return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))
return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))

0 comments on commit 3d3e033

Please sign in to comment.