Skip to content

Commit

Permalink
move custom transform to examples
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 64c6de3 commit 3fbceb7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
17 changes: 16 additions & 1 deletion examples/colors_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,27 @@
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import HandleNodeAttention
from torch_geometric.data import DataLoader
from torch_geometric.nn import GINConv, TopKPooling
from torch_geometric.nn import global_add_pool as gsum
from torch_scatter import scatter_mean


class HandleNodeAttention(object):
def __call__(self, data):
if data.x.dim() == 1:
data.x = data.x.unsqueeze(-1)
data.node_attention = torch.softmax(data.x[:, 0], dim=0)
if data.x.shape[1] > 1:
data.x = data.x[:, 1:]
else:
# not supposed to use node attention as node features,
# because it is typically not available in the val/test set
data.x = None

return data


train_path = osp.join(
osp.dirname(osp.realpath(__file__)), '..', 'data', 'COLORS-3')
dataset = TUDataset(train_path, name='COLORS-3', use_node_attr=True,
Expand Down
18 changes: 17 additions & 1 deletion examples/triangles_sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import OneHotDegree, HandleNodeAttention
from torch_geometric.transforms import OneHotDegree
from torch_geometric.transforms import Compose
from torch_geometric.data import DataLoader
from torch_geometric.nn import GINConv, SAGPooling
from torch_geometric.nn import global_max_pool as gmp
from torch_scatter import scatter_mean


class HandleNodeAttention(object):
def __call__(self, data):
if data.x.dim() == 1:
data.x = data.x.unsqueeze(-1)
data.node_attention = torch.softmax(data.x[:, 0], dim=0)
if data.x.shape[1] > 1:
data.x = data.x[:, 1:]
else:
# not supposed to use node attention as node features,
# because it is typically not available in the val/test set
data.x = None

return data


transform = Compose([HandleNodeAttention(), OneHotDegree(max_degree=14)])

train_path = osp.join(
Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .two_hop import TwoHop
from .line_graph import LineGraph
from .generate_mesh_normals import GenerateMeshNormals
from .handle_node_attention import HandleNodeAttention

__all__ = [
'Compose',
Expand Down Expand Up @@ -63,5 +62,4 @@
'TwoHop',
'LineGraph',
'GenerateMeshNormals',
'HandleNodeAttention'
]

0 comments on commit 3fbceb7

Please sign in to comment.