Skip to content

Commit

Permalink
Only warn once on double underscore type name checks (#9268)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 2, 2024
1 parent ab8f3fd commit ae185ba
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
9 changes: 8 additions & 1 deletion test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import warnings

import pytest
import torch
Expand Down Expand Up @@ -563,7 +564,13 @@ def test_hetero_data_invalid_names():
data = HeteroData()
with pytest.warns(UserWarning, match="single underscores"):
data['my test', 'a__b', 'my test'].edge_attr = torch.randn(10, 16)
assert data.edge_types == [('my test', 'a__b', 'my test')]
with warnings.catch_warnings(): # No warning should be raised afterwards:
warnings.simplefilter('error')
data['my test', 'a__c', 'my test'].edge_attr = torch.randn(10, 16)
assert data.edge_types == [
('my test', 'a__b', 'my test'),
('my test', 'a__c', 'my test'),
]


def test_hetero_data_update():
Expand Down
16 changes: 11 additions & 5 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]

_DISPLAYED_TYPE_NAME_WARNING: bool = False


class HeteroData(BaseData, FeatureStore, GraphStore):
r"""A data object describing a heterogeneous graph, holding multiple node
Expand Down Expand Up @@ -562,11 +564,15 @@ def collect(
return mapping

def _check_type_name(self, name: str):
if '__' in name:
warnings.warn(f"The type '{name}' contains double underscores "
f"('__') which may lead to unexpected behavior. "
f"To avoid any issues, ensure that your type names "
f"only contain single underscores.")
global _DISPLAYED_TYPE_NAME_WARNING
if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
_DISPLAYED_TYPE_NAME_WARNING = True
warnings.warn(f"There exist type names in the "
f"'{self.__class__.__name__}' object that contain "
f"double underscores '__' (e.g., '{name}'). This "
f"may lead to unexpected behavior. To avoid any "
f"issues, ensure that your type names only contain "
f"single underscores.")

def get_node_store(self, key: NodeType) -> NodeStorage:
r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
Expand Down

0 comments on commit ae185ba

Please sign in to comment.