Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop usage of __dunder__ names #6999

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Drop internal usage of `__dunder__` names ([#6999](https://github.com/pyg-team/pytorch_geometric/issues/6999))
- Changed the interface of `sort_edge_index`, `coalesce` and `to_undirected` to only return single `edge_index` information in case the `edge_attr` argument is not specified ([#6875](https://github.com/pyg-team/pytorch_geometric/issues/6875), [#6879](https://github.com/pyg-team/pytorch_geometric/issues/6879), [#6893](https://github.com/pyg-team/pytorch_geometric/issues/6893))
- Fixed a bug in `to_hetero` when using an uninitialized submodule without implementing `reset_parameters` ([#6863](https://github.com/pyg-team/pytorch_geometric/issues/6790))
- Fixed a bug in `get_mesh_laplacian` ([#6790](https://github.com/pyg-team/pytorch_geometric/issues/6790))
Expand Down
14 changes: 7 additions & 7 deletions torch_geometric/contrib/explain/graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,15 @@ def freeze_model(self, module):
for param in module.parameters():
param.requires_grad = False

def __set_flags__(self, model):
def _set_flags(self, model):
for module in model.modules():
if isinstance(module, MessagePassing):
module.explain_message = explain_message.__get__(
module, MessagePassing)
module.explain = True

def __inject_messages__(self, model: torch.nn.Module, message_scale,
message_replacement, set=False):
def _inject_messages(self, model: torch.nn.Module, message_scale,
message_replacement, set=False):
i = 0
for module in model.modules():
if isinstance(module, MessagePassing):
Expand All @@ -345,7 +345,7 @@ def train_explainer(self, model: torch.nn.Module, x: Tensor,
"'integer' or set to 'None' instead.")

self.freeze_model(model)
self.__set_flags__(model)
self._set_flags(model)

input_dims, output_dims = [], []
for module in model.modules():
Expand Down Expand Up @@ -406,7 +406,7 @@ def train_explainer(self, model: torch.nn.Module, x: Tensor,
gates.append(sampling_weights)
total_penalty += penalty

self.__inject_messages__(model, gates, self.baselines)
self._inject_messages(model, gates, self.baselines)

self.lambda_op = torch.tensor(self.init_lambda,
requires_grad=True)
Expand All @@ -425,7 +425,7 @@ def train_explainer(self, model: torch.nn.Module, x: Tensor,
if index is not None:
y_hat, y = y_hat[index], y[index]

self.__inject_messages__(model, gates, self.baselines, True)
self._inject_messages(model, gates, self.baselines, True)

loss = self._loss(y_hat, y, total_penalty)

Expand Down Expand Up @@ -456,7 +456,7 @@ def explain(self, model: torch.nn.Module, *,
"'integer' or set to 'None' instead.")

self.freeze_model(model)
self.__set_flags__(model)
self._set_flags(model)

with torch.no_grad():
latest_source_embeddings, latest_messages = [], []
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/loader/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def __init__(self, data, num_parts: int, recursive: bool = False,
if log: # pragma: no cover
print('Done!', file=sys.stderr)

self.data = self.__permute_data__(data, perm, adj)
self.data = self._permute_data(data, perm, adj)
self.partptr = partptr
self.perm = perm

def __permute_data__(self, data, node_idx, adj):
def _permute_data(self, data, node_idx, adj):
out = copy.copy(data)
for key, value in data.items():
if data.is_node_attr(key):
Expand Down Expand Up @@ -138,10 +138,10 @@ class ClusterLoader(torch.utils.data.DataLoader):
def __init__(self, cluster_data, **kwargs):
self.cluster_data = cluster_data

super().__init__(range(len(cluster_data)), collate_fn=self.__collate__,
super().__init__(range(len(cluster_data)), collate_fn=self._collate,
**kwargs)

def __collate__(self, batch):
def _collate(self, batch):
if not isinstance(batch, torch.Tensor):
batch = torch.tensor(batch)

Expand Down
26 changes: 13 additions & 13 deletions torch_geometric/loader/graph_saint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, data, batch_size: int, num_steps: int = 1,
assert not data.edge_index.is_cuda

self.num_steps = num_steps
self.__batch_size__ = batch_size
self._batch_size = batch_size
self.sample_coverage = sample_coverage
self.save_dir = save_dir
self.log = log
Expand All @@ -71,34 +71,34 @@ def __init__(self, data, batch_size: int, num_steps: int = 1,

self.data = data

super().__init__(self, batch_size=1, collate_fn=self.__collate__,
super().__init__(self, batch_size=1, collate_fn=self._collate,
**kwargs)

if self.sample_coverage > 0:
path = osp.join(save_dir or '', self.__filename__)
path = osp.join(save_dir or '', self._filename)
if save_dir is not None and osp.exists(path): # pragma: no cover
self.node_norm, self.edge_norm = torch.load(path)
else:
self.node_norm, self.edge_norm = self.__compute_norm__()
self.node_norm, self.edge_norm = self._compute_norm()
if save_dir is not None: # pragma: no cover
torch.save((self.node_norm, self.edge_norm), path)

@property
def __filename__(self):
def _filename(self):
return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

def __len__(self):
return self.num_steps

def __sample_nodes__(self, batch_size):
def _sample_nodes(self, batch_size):
raise NotImplementedError

def __getitem__(self, idx):
node_idx = self.__sample_nodes__(self.__batch_size__).unique()
node_idx = self._sample_nodes(self._batch_size).unique()
adj, _ = self.adj.saint_subgraph(node_idx)
return node_idx, adj

def __collate__(self, data_list):
def _collate(self, data_list):
assert len(data_list) == 1
node_idx, adj = data_list[0]

Expand All @@ -123,7 +123,7 @@ def __collate__(self, data_list):

return data

def __compute_norm__(self):
def _compute_norm(self):
node_count = torch.zeros(self.N, dtype=torch.float)
edge_count = torch.zeros(self.E, dtype=torch.float)

Expand Down Expand Up @@ -166,7 +166,7 @@ class GraphSAINTNodeSampler(GraphSAINTSampler):
r"""The GraphSAINT node sampler class (see
:class:`~torch_geometric.loader.GraphSAINTSampler`).
"""
def __sample_nodes__(self, batch_size):
def _sample_nodes(self, batch_size):
edge_sample = torch.randint(0, self.E, (batch_size, self.batch_size),
dtype=torch.long)

Expand All @@ -177,7 +177,7 @@ class GraphSAINTEdgeSampler(GraphSAINTSampler):
r"""The GraphSAINT edge sampler class (see
:class:`~torch_geometric.loader.GraphSAINTSampler`).
"""
def __sample_nodes__(self, batch_size):
def _sample_nodes(self, batch_size):
row, col, _ = self.adj.coo()

deg_in = 1. / self.adj.storage.colcount()
Expand Down Expand Up @@ -210,11 +210,11 @@ def __init__(self, data, batch_size: int, walk_length: int,
save_dir, log, **kwargs)

@property
def __filename__(self):
def _filename(self):
return (f'{self.__class__.__name__.lower()}_{self.walk_length}_'
f'{self.sample_coverage}.pt')

def __sample_nodes__(self, batch_size):
def _sample_nodes(self, batch_size):
start = torch.randint(0, self.N, (batch_size, ), dtype=torch.long)
node_idx = self.adj.random_walk(start.flatten(), self.walk_length)
return node_idx.view(-1)
46 changes: 23 additions & 23 deletions torch_geometric/nn/conv/message_passing.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,28 @@ class EdgeUpdater_Collect_{{uid}}(NamedTuple):
class {{cls_name}}({{parent_cls_name}}):

@torch.jit._overload_method
def __check_input__(self, edge_index, size):
def _check_input(self, edge_index, size):
# type: (Tensor, Size) -> List[Optional[int]]
pass

@torch.jit._overload_method
def __check_input__(self, edge_index, size):
def _check_input(self, edge_index, size):
# type: (SparseTensor, Size) -> List[Optional[int]]
pass

{{check_input}}

@torch.jit._overload_method
def __lift__(self, src, edge_index, dim):
def _lift(self, src, edge_index, dim):
# type: (Tensor, Tensor, int) -> Tensor
pass

@torch.jit._overload_method
def __lift__(self, src, edge_index, dim):
def _lift(self, src, edge_index, dim):
# type: (Tensor, SparseTensor, int) -> Tensor
pass

def __lift__(self, src, edge_index, dim):
def _lift(self, src, edge_index, dim):
if isinstance(edge_index, Tensor):
index = edge_index[dim]
return src.index_select(self.node_dim, index)
Expand All @@ -76,16 +76,16 @@ class {{cls_name}}({{parent_cls_name}}):
'argument `edge_index`.'))

@torch.jit._overload_method
def __collect__(self, edge_def, size, kwargs):
def _collect(self, edge_def, size, kwargs):
# type: (Tensor, List[Optional[int]], Propagate_{{uid}}) -> Collect_{{uid}}
pass

@torch.jit._overload_method
def __collect__(self, edge_def, size, kwargs):
def _collect(self, edge_def, size, kwargs):
# type: (SparseTensor, List[Optional[int]], Propagate_{{uid}}) -> Collect_{{uid}}
pass

def __collect__(self, edge_def, size, kwargs):
def _collect(self, edge_def, size, kwargs):
init = torch.tensor(0.)
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in user_args %}
Expand All @@ -99,19 +99,19 @@ class {{cls_name}}({{parent_cls_name}}):
{%- if arg[-2:] == '_j' %}
tmp = data[1]
if isinstance(tmp, Tensor):
self.__set_size__(size, 1, tmp)
self._set_size(size, 1, tmp)
{{arg}} = data[0]
{%- else %}
tmp = data[0]
if isinstance(tmp, Tensor):
self.__set_size__(size, 0, tmp)
self._set_size(size, 0, tmp)
{{arg}} = data[1]
{%- endif %}
else:
{{arg}} = data
if isinstance({{arg}}, Tensor):
self.__set_size__(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
{{arg}} = self.__lift__({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
self._set_size(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
{{arg}} = self._lift({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

Expand Down Expand Up @@ -155,16 +155,16 @@ class {{cls_name}}({{parent_cls_name}}):

{% if edge_updater_types|length > 0 %}
@torch.jit._overload_method
def __collect_edge__(self, edge_def, size, kwargs):
def _collect_edge(self, edge_def, size, kwargs):
# type: (Tensor, List[Optional[int]], EdgeUpdater_{{uid}}) -> EdgeUpdater_Collect_{{uid}}
pass

@torch.jit._overload_method
def __collect_edge__(self, edge_def, size, kwargs):
def _collect_edge(self, edge_def, size, kwargs):
# type: (SparseTensor, List[Optional[int]], EdgeUpdater_{{uid}}) -> EdgeUpdater_Collect_{{uid}}
pass

def __collect_edge__(self, edge_def, size, kwargs):
def _collect_edge(self, edge_def, size, kwargs):
init = torch.tensor(0.)
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in edge_user_args %}
Expand All @@ -178,19 +178,19 @@ class {{cls_name}}({{parent_cls_name}}):
{%- if arg[-2:] == '_j' %}
tmp = data[1]
if isinstance(tmp, Tensor):
self.__set_size__(size, 1, tmp)
self._set_size(size, 1, tmp)
{{arg}} = data[0]
{%- else %}
tmp = data[0]
if isinstance(tmp, Tensor):
self.__set_size__(size, 0, tmp)
self._set_size(size, 0, tmp)
{{arg}} = data[1]
{%- endif %}
else:
{{arg}} = data
if isinstance({{arg}}, Tensor):
self.__set_size__(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
{{arg}} = self.__lift__({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
self._set_size(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
{{arg}} = self._lift({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

Expand Down Expand Up @@ -246,7 +246,7 @@ class {{cls_name}}({{parent_cls_name}}):
pass

def propagate(self, edge_index, {{ prop_types.keys()|join(', ') }}, size=None):
the_size = self.__check_input__(edge_index, size)
the_size = self._check_input(edge_index, size)
in_kwargs = Propagate_{{uid}}({% for k in prop_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

{% if fuse %}
Expand All @@ -255,7 +255,7 @@ class {{cls_name}}({{parent_cls_name}}):
return self.update(out{% for k in update_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
{% endif %}

kwargs = self.__collect__(edge_index, the_size, in_kwargs)
kwargs = self._collect(edge_index, the_size, in_kwargs)
out = self.message({% for k in msg_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
out = self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %})
return self.update(out{% for k in update_args %}, {{k}}=kwargs.{{k}}{% endfor %})
Expand All @@ -272,9 +272,9 @@ class {{cls_name}}({{parent_cls_name}}):
pass

def edge_updater(self, edge_index{% for k in edge_updater_types.keys() %}, {{k}} {% endfor %}):
the_size = self.__check_input__(edge_index, size=None)
the_size = self._check_input(edge_index, size=None)
in_kwargs = EdgeUpdater_{{uid}}({% for k in edge_updater_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})
kwargs = self.__collect_edge__(edge_index, the_size, in_kwargs)
kwargs = self._collect_edge(edge_index, the_size, in_kwargs)
return self.edge_update({% for k in edge_update_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
{% else %}
def edge_updater(self):
Expand Down
Loading