diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e48e8867e23..bb8d1cae4ed3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added - Added triplet sampling in `LinkNeighborLoader` ([#6004](https://github.com/pyg-team/pytorch_geometric/pull/6004)) +- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036)) +- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934)) +- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007)) +- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834)) - Added `LRGBDataset` to include 5 datasets from the [Long Range Graph Benchmark](https://openreview.net/pdf?id=in7XC5RcjEn) ([#5935](https://github.com/pyg-team/pytorch_geometric/pull/5935)) - Added a warning for invalid node and edge type names in `HeteroData` ([#5990](https://github.com/pyg-team/pytorch_geometric/pull/5990)) - Added PyTorch 1.13 support ([#5975](https://github.com/pyg-team/pytorch_geometric/pull/5975)) @@ -15,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939)) - Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938)) - Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919)) -- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944)) +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003)) - Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903)) - Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886)) - Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888)) diff --git a/examples/captum_explainability.py b/examples/captum_explainability.py index 85620d0ab3ad..af421722fa2c 100644 --- a/examples/captum_explainability.py +++ b/examples/captum_explainability.py @@ -7,7 +7,7 @@ import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid -from torch_geometric.nn import Explainer, GCNConv, to_captum +from torch_geometric.nn import Explainer, GCNConv, to_captum_model dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') @@ -49,7 +49,7 @@ def forward(self, x, edge_index): # Captum assumes that for all given input tensors, dimension 0 is # equal to the number of samples. Therefore, we use unsqueeze(0). -captum_model = to_captum(model, mask_type='edge', output_idx=output_idx) +captum_model = to_captum_model(model, mask_type='edge', output_idx=output_idx) edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device) ig = IntegratedGradients(captum_model) @@ -69,7 +69,7 @@ def forward(self, x, edge_index): # Node explainability # =================== -captum_model = to_captum(model, mask_type='node', output_idx=output_idx) +captum_model = to_captum_model(model, mask_type='node', output_idx=output_idx) ig = IntegratedGradients(captum_model) ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target, @@ -88,8 +88,8 @@ def forward(self, x, edge_index): # Node and edge explainability # ============================ -captum_model = to_captum(model, mask_type='node_and_edge', - output_idx=output_idx) +captum_model = to_captum_model(model, mask_type='node_and_edge', + output_idx=output_idx) ig = IntegratedGradients(captum_model) ig_attr_node, ig_attr_edge = ig.attribute( diff --git a/examples/hetero/bipartite_sage.py b/examples/hetero/bipartite_sage.py new file mode 100644 index 000000000000..6d5af770af68 --- /dev/null +++ b/examples/hetero/bipartite_sage.py @@ -0,0 +1,163 @@ +import os.path as osp + +import torch +import torch.nn.functional as F +from torch.nn import Embedding, Linear + +import torch_geometric.transforms as T +from torch_geometric.datasets import MovieLens +from torch_geometric.nn import SAGEConv +from torch_geometric.nn.conv.gcn_conv import gcn_norm + +path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') +dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') +data = dataset[0] +data['user'].x = torch.arange(data['user'].num_nodes) +data['user', 'movie'].edge_label = data['user', 'movie'].edge_label.float() + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +data = data.to(device) + +# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing: +data = T.ToUndirected()(data) +del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label. + +# Perform a link-level split into training, validation, and test edges: +train_data, val_data, test_data = T.RandomLinkSplit( + num_val=0.1, + num_test=0.1, + neg_sampling_ratio=0.0, + edge_types=[('user', 'rates', 'movie')], + rev_edge_types=[('movie', 'rev_rates', 'user')], +)(data) + +# Generate the co-occurence matrix of movies<>movies: +metapath = [('movie', 'rev_rates', 'user'), ('user', 'rates', 'movie')] +train_data = T.AddMetaPaths(metapaths=[metapath])(train_data) + +# Apply normalization to filter the metapath: +_, edge_weight = gcn_norm( + train_data['movie', 'movie'].edge_index, + num_nodes=train_data['movie'].num_nodes, + add_self_loops=False, +) +edge_index = train_data['movie', 'movie'].edge_index[:, edge_weight > 0.002] + +train_data['movie', 'metapath_0', 'movie'].edge_index = edge_index +val_data['movie', 'metapath_0', 'movie'].edge_index = edge_index +test_data['movie', 'metapath_0', 'movie'].edge_index = edge_index + + +class MovieGNNEncoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels): + super().__init__() + + self.conv1 = SAGEConv(-1, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, hidden_channels) + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index).relu() + return self.lin(x) + + +class UserGNNEncoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels): + super().__init__() + self.conv1 = SAGEConv((-1, -1), hidden_channels) + self.conv2 = SAGEConv((-1, -1), hidden_channels) + self.conv3 = SAGEConv((-1, -1), hidden_channels) + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x_dict, edge_index_dict): + movie_x = self.conv1( + x_dict['movie'], + edge_index_dict[('movie', 'metapath_0', 'movie')], + ).relu() + + user_x = self.conv2( + (x_dict['movie'], x_dict['user']), + edge_index_dict[('movie', 'rev_rates', 'user')], + ).relu() + + user_x = self.conv3( + (movie_x, user_x), + edge_index_dict[('movie', 'rev_rates', 'user')], + ).relu() + + return self.lin(user_x) + + +class EdgeDecoder(torch.nn.Module): + def __init__(self, hidden_channels): + super().__init__() + self.lin1 = Linear(2 * hidden_channels, hidden_channels) + self.lin2 = Linear(hidden_channels, 1) + + def forward(self, z_src, z_dst, edge_label_index): + row, col = edge_label_index + z = torch.cat([z_src[row], z_dst[col]], dim=-1) + + z = self.lin1(z).relu() + z = self.lin2(z) + return z.view(-1) + + +class Model(torch.nn.Module): + def __init__(self, num_users, hidden_channels, out_channels): + super().__init__() + self.user_emb = Embedding(num_users, hidden_channels) + self.user_encoder = UserGNNEncoder(hidden_channels, out_channels) + self.movie_encoder = MovieGNNEncoder(hidden_channels, out_channels) + self.decoder = EdgeDecoder(hidden_channels) + + def forward(self, x_dict, edge_index_dict, edge_label_index): + z_dict = {} + x_dict['user'] = self.user_emb(x_dict['user']) + z_dict['user'] = self.user_encoder(x_dict, edge_index_dict) + z_dict['movie'] = self.movie_encoder( + x_dict['movie'], + edge_index_dict[('movie', 'metapath_0', 'movie')], + ) + return self.decoder(z_dict['user'], z_dict['movie'], edge_label_index) + + +model = Model(data['user'].num_nodes, hidden_channels=64, out_channels=64) +model = model.to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.0003) + + +def train(): + model.train() + optimizer.zero_grad() + out = model( + train_data.x_dict, + train_data.edge_index_dict, + train_data['user', 'movie'].edge_label_index, + ) + loss = F.mse_loss(out, train_data['user', 'movie'].edge_label) + loss.backward() + optimizer.step() + return float(loss) + + +@torch.no_grad() +def test(data): + model.eval() + out = model( + data.x_dict, + data.edge_index_dict, + data['user', 'movie'].edge_label_index, + ).clamp(min=0, max=5) + rmse = F.mse_loss(out, data['user', 'movie'].edge_label).sqrt() + return float(rmse) + + +for epoch in range(1, 701): + loss = train() + train_rmse = test(train_data) + val_rmse = test(val_data) + test_rmse = test(test_data) + print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, ' + f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}') diff --git a/examples/ogbn_proteins_deepgcn.py b/examples/ogbn_proteins_deepgcn.py index ef902ea595c8..9acbc67162c9 100644 --- a/examples/ogbn_proteins_deepgcn.py +++ b/examples/ogbn_proteins_deepgcn.py @@ -5,7 +5,7 @@ from torch_scatter import scatter from tqdm import tqdm -from torch_geometric.loader import RandomNodeSampler +from torch_geometric.loader import RandomNodeLoader from torch_geometric.nn import DeepGCNLayer, GENConv dataset = PygNodePropPredDataset('ogbn-proteins', root='../data') @@ -24,9 +24,9 @@ mask[splitted_idx[split]] = True data[f'{split}_mask'] = mask -train_loader = RandomNodeSampler(data, num_parts=40, shuffle=True, - num_workers=5) -test_loader = RandomNodeSampler(data, num_parts=5, num_workers=5) +train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True, + num_workers=5) +test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5) class DeeperGCN(torch.nn.Module): diff --git a/examples/rev_gnn.py b/examples/rev_gnn.py index 79f1911b4088..383d0b6b665e 100644 --- a/examples/rev_gnn.py +++ b/examples/rev_gnn.py @@ -13,7 +13,7 @@ from tqdm import tqdm import torch_geometric.transforms as T -from torch_geometric.loader import RandomNodeSampler +from torch_geometric.loader import RandomNodeLoader from torch_geometric.nn import GroupAddRev, SAGEConv from torch_geometric.utils import index_to_mask @@ -91,11 +91,11 @@ def forward(self, x, edge_index): for split in ['train', 'valid', 'test']: data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0]) -train_loader = RandomNodeSampler(data, num_parts=10, shuffle=True, - num_workers=5) +train_loader = RandomNodeLoader(data, num_parts=10, shuffle=True, + num_workers=5) # Increase the num_parts of the test loader if you cannot fit # the full batch graph into your GPU: -test_loader = RandomNodeSampler(data, num_parts=1, num_workers=5) +test_loader = RandomNodeLoader(data, num_parts=1, num_workers=5) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = RevGNN( diff --git a/test/loader/test_random_node_loader.py b/test/loader/test_random_node_loader.py new file mode 100644 index 000000000000..556ddc7de10f --- /dev/null +++ b/test/loader/test_random_node_loader.py @@ -0,0 +1,52 @@ +import torch + +from torch_geometric.data import Data, HeteroData +from torch_geometric.loader import RandomNodeLoader + + +def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): + row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) + col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) + return torch.stack([row, col], dim=0) + + +def test_random_node_loader(): + data = Data() + data.x = torch.randn(100, 128) + data.node_id = torch.arange(100) + data.edge_index = get_edge_index(100, 100, 500) + data.edge_attr = torch.randn(500, 32) + + loader = RandomNodeLoader(data, num_parts=4, shuffle=True) + assert len(loader) == 4 + + for batch in loader: + assert len(batch) == 4 + assert batch.node_id.min() >= 0 + assert batch.node_id.max() < 100 + assert batch.edge_index.size(1) == batch.edge_attr.size(0) + assert torch.allclose(batch.x, data.x[batch.node_id]) + batch.validate() + + +def test_heterogeneous_random_node_loader(): + data = HeteroData() + data['paper'].x = torch.randn(100, 128) + data['paper'].node_id = torch.arange(100) + data['author'].x = torch.randn(200, 128) + data['author'].node_id = torch.arange(200) + data['paper', 'author'].edge_index = get_edge_index(100, 200, 500) + data['paper', 'author'].edge_attr = torch.randn(500, 32) + data['author', 'paper'].edge_index = get_edge_index(200, 100, 400) + data['author', 'paper'].edge_attr = torch.randn(400, 32) + data['paper', 'paper'].edge_index = get_edge_index(100, 100, 600) + data['paper', 'paper'].edge_attr = torch.randn(600, 32) + + loader = RandomNodeLoader(data, num_parts=4, shuffle=True) + assert len(loader) == 4 + + for batch in loader: + assert len(batch) == 4 + assert batch.node_types == data.node_types + assert batch.edge_types == data.edge_types + batch.validate() diff --git a/test/nn/aggr/test_fused.py b/test/nn/aggr/test_fused.py new file mode 100644 index 000000000000..62e66cc71493 --- /dev/null +++ b/test/nn/aggr/test_fused.py @@ -0,0 +1,67 @@ +import pytest +import torch + +from torch_geometric.nn.aggr.fused import FusedAggregation +from torch_geometric.nn.resolver import aggregation_resolver + + +@pytest.mark.parametrize('aggrs', [ + ['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'], + ['sum', 'min', 'max', 'mul', 'var', 'std'], + ['min', 'max', 'mul', 'var', 'std'], + ['mean', 'min', 'max', 'mul', 'var', 'std'], + ['sum', 'min', 'max', 'mul', 'std'], + ['mean', 'min', 'max', 'mul', 'std'], + ['min', 'max', 'mul', 'std'], +]) +def test_fused_aggregation(aggrs): + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + + x = torch.randn(6, 1) + y = x.clone() + index = torch.tensor([0, 0, 1, 1, 1, 3]) + + x.requires_grad_(True) + y.requires_grad_(True) + + aggr = FusedAggregation(aggrs) + assert str(aggr) == 'FusedAggregation()' + out = aggr(x, index) + + expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1) + assert torch.allclose(out, expected) + + out.mean().backward() + assert x.grad is not None + expected.mean().backward() + assert y.grad is not None + assert torch.allclose(x.grad, y.grad) + + +if __name__ == '__main__': + import time + + x = torch.randn(50000, 64, device='cuda') + index = torch.randint(1000, (x.size(0), ), device='cuda') + + aggrs = ['sum', 'mean', 'max', 'std'] + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + fused_aggr = FusedAggregation(aggrs) + + num_warmups, num_steps = (500, 1000) + + for i in range(num_warmups + num_steps): + if i == num_warmups: + torch.cuda.synchronize() + t = time.perf_counter() + torch.cat([aggr(x, index, dim_size=1000) for aggr in aggrs], dim=-1) + torch.cuda.synchronize() + print(f'Vanilla implementation: {time.perf_counter() - t:.4f} seconds') + + for i in range(num_warmups + num_steps): + if i == num_warmups: + torch.cuda.synchronize() + t = time.perf_counter() + fused_aggr(x, index, dim_size=1000) + torch.cuda.synchronize() + print(f'Fused implementation: {time.perf_counter() - t:.4f} seconds') diff --git a/test/nn/models/test_hetero_explainer.py b/test/nn/models/test_hetero_explainer.py index 3bc639d438d3..f50e0ea73aa7 100644 --- a/test/nn/models/test_hetero_explainer.py +++ b/test/nn/models/test_hetero_explainer.py @@ -1,8 +1,38 @@ +import pytest import torch from torch_geometric.data import HeteroData -from torch_geometric.nn import GCNConv, HeteroConv, SAGEConv, to_hetero -from torch_geometric.nn.models.explainer import clear_masks, set_hetero_masks +from torch_geometric.nn import ( + GCNConv, + HeteroConv, + SAGEConv, + captum_output_to_dicts, + to_captum_input, + to_captum_model, + to_hetero, +) +from torch_geometric.nn.models.explainer import ( + CaptumHeteroModel, + clear_masks, + set_hetero_masks, +) +from torch_geometric.testing import withPackage +from torch_geometric.typing import Metadata + +mask_types = ['edge', 'node_and_edge', 'node'] +methods = [ + 'Saliency', + 'InputXGradient', + 'Deconvolution', + 'FeatureAblation', + 'ShapleyValueSampling', + 'IntegratedGradients', + 'GradientShap', + 'Occlusion', + 'GuidedBackprop', + 'KernelShap', + 'Lime', +] def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): @@ -11,6 +41,16 @@ def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): return torch.stack([row, col], dim=0) +def get_hetero_data(): + data = HeteroData() + data['paper'].x = torch.randn(8, 16) + data['author'].x = torch.randn(10, 8) + data['paper', 'paper'].edge_index = get_edge_index(8, 8, 10) + data['author', 'paper'].edge_index = get_edge_index(10, 8, 10) + data['paper', 'author'].edge_index = get_edge_index(8, 10, 10) + return data + + class HeteroModel(torch.nn.Module): def __init__(self): super().__init__() @@ -45,14 +85,20 @@ def forward(self, x, edge_index): return self.conv2(x, edge_index) -def test_set_clear_mask(): - data = HeteroData() - data['paper'].x = torch.randn(50, 16) - data['author'].x = torch.randn(30, 8) - data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200) - data['author', 'paper'].edge_index = get_edge_index(30, 50, 100) - data['paper', 'author'].edge_index = get_edge_index(50, 30, 100) +class HeteroSAGE(torch.nn.Module): + def __init__(self, metadata: Metadata): + super().__init__() + self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False) + def forward(self, x_dict, edge_index_dict, + additonal_arg=None) -> torch.Tensor: + # Make sure additonal args gets passed. + assert additonal_arg is not None + return self.graph_sage(x_dict, edge_index_dict)['paper'] + + +def test_set_clear_mask(): + data = get_hetero_data() edge_mask_dict = { ('paper', 'to', 'paper'): torch.ones(200), ('author', 'to', 'paper'): torch.ones(100), @@ -89,3 +135,76 @@ def test_set_clear_mask(): str_edge_type = '__'.join(edge_type) assert model.conv1[str_edge_type]._edge_mask is None assert not model.conv1[str_edge_type].explain + + +@withPackage('captum') +@pytest.mark.parametrize('mask_type', mask_types) +@pytest.mark.parametrize('method', methods) +def test_captum_attribution_methods_hetero(mask_type, method): + from captum import attr # noqa + data = get_hetero_data() + metadata = data.metadata() + model = HeteroSAGE(metadata) + captum_model = to_captum_model(model, mask_type, 0, metadata) + explainer = getattr(attr, method)(captum_model) + assert isinstance(captum_model, CaptumHeteroModel) + + args = ['additional_arg1'] + input, additional_forward_args = to_captum_input(data.x_dict, + data.edge_index_dict, + mask_type, *args) + if mask_type == 'node': + sliding_window_shapes = ((3, 3), (3, 3)) + elif mask_type == 'edge': + sliding_window_shapes = ((5, ), (5, ), (5, )) + else: + sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, )) + + if method == 'IntegratedGradients': + attributions, delta = explainer.attribute( + input, target=0, internal_batch_size=1, + additional_forward_args=additional_forward_args, + return_convergence_delta=True) + elif method == 'GradientShap': + attributions, delta = explainer.attribute( + input, target=0, return_convergence_delta=True, baselines=input, + n_samples=1, additional_forward_args=additional_forward_args) + elif method == 'DeepLiftShap' or method == 'DeepLift': + attributions, delta = explainer.attribute( + input, target=0, return_convergence_delta=True, baselines=input, + additional_forward_args=additional_forward_args) + elif method == 'Occlusion': + attributions = explainer.attribute( + input, target=0, sliding_window_shapes=sliding_window_shapes, + additional_forward_args=additional_forward_args) + else: + attributions = explainer.attribute( + input, target=0, additional_forward_args=additional_forward_args) + + if mask_type == 'node': + assert len(attributions) == len(metadata[0]) + x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type, + metadata) + for node_type in metadata[0]: + num_nodes = data[node_type].num_nodes + num_node_feats = data[node_type].x.shape[1] + assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) + elif mask_type == 'edge': + assert len(attributions) == len(metadata[1]) + _, edge_attr_dict = captum_output_to_dicts(attributions, mask_type, + metadata) + for edge_type in metadata[1]: + num_edges = data[edge_type].edge_index.shape[1] + assert edge_attr_dict[edge_type].shape == (num_edges, ) + else: + assert len(attributions) == len(metadata[0]) + len(metadata[1]) + x_attr_dict, edge_attr_dict = captum_output_to_dicts( + attributions, mask_type, metadata) + for edge_type in metadata[1]: + num_edges = data[edge_type].edge_index.shape[1] + assert edge_attr_dict[edge_type].shape == (num_edges, ) + + for node_type in metadata[0]: + num_nodes = data[node_type].num_nodes + num_node_feats = data[node_type].x.shape[1] + assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) diff --git a/test/nn/models/test_to_captum.py b/test/nn/models/test_to_captum.py index 55a1ad1ab4ff..08c9c46b3044 100644 --- a/test/nn/models/test_to_captum.py +++ b/test/nn/models/test_to_captum.py @@ -1,8 +1,10 @@ import pytest import torch -from torch_geometric.nn import GAT, GCN, Explainer, SAGEConv, to_captum +from torch_geometric.data import Data, HeteroData +from torch_geometric.nn import GAT, GCN, Explainer, SAGEConv from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.models import to_captum_input, to_captum_model from torch_geometric.testing import withPackage x = torch.randn(8, 3, requires_grad=True) @@ -31,7 +33,8 @@ @pytest.mark.parametrize('model', [GCN, GAT]) @pytest.mark.parametrize('output_idx', [None, 1]) def test_to_captum(model, mask_type, output_idx): - captum_model = to_captum(model, mask_type=mask_type, output_idx=output_idx) + captum_model = to_captum_model(model, mask_type=mask_type, + output_idx=output_idx) pre_out = model(x, edge_index) if mask_type == 'node': mask = x * 0.0 @@ -61,22 +64,16 @@ def test_to_captum(model, mask_type, output_idx): def test_captum_attribution_methods(mask_type, method): from captum import attr # noqa - captum_model = to_captum(GCN, mask_type, 0) - input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float, - requires_grad=True) + captum_model = to_captum_model(GCN, mask_type, 0) explainer = getattr(attr, method)(captum_model) - + data = Data(x, edge_index) + input, additional_forward_args = to_captum_input(data.x, data.edge_index, + mask_type) if mask_type == 'node': - input = x.clone().unsqueeze(0) - additional_forward_args = (edge_index, ) sliding_window_shapes = (3, 3) elif mask_type == 'edge': - input = input_mask - additional_forward_args = (x, edge_index) sliding_window_shapes = (5, ) elif mask_type == 'node_and_edge': - input = (x.clone().unsqueeze(0), input_mask) - additional_forward_args = (edge_index, ) sliding_window_shapes = ((3, 3), (5, )) if method == 'IntegratedGradients': @@ -100,9 +97,9 @@ def test_captum_attribution_methods(mask_type, method): attributions = explainer.attribute( input, target=0, additional_forward_args=additional_forward_args) if mask_type == 'node': - assert attributions.shape == (1, 8, 3) + assert attributions[0].shape == (1, 8, 3) elif mask_type == 'edge': - assert attributions.shape == (1, 14) + assert attributions[0].shape == (1, 14) else: assert attributions[0].shape == (1, 8, 3) assert attributions[1].shape == (1, 14) @@ -144,3 +141,77 @@ def explain_message(self, inputs, x_i, x_j): assert torch.allclose(conv.x_i, x[edge_index[1]]) assert torch.allclose(conv.x_j, x[edge_index[0]]) + + +@withPackage('captum') +@pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge']) +def test_to_captum_input(mask_type): + num_nodes = x.shape[0] + num_node_feats = x.shape[1] + num_edges = edge_index.shape[1] + + # Check for Data: + data = Data(x, edge_index) + args = 'test_args' + inputs, additional_forward_args = to_captum_input(data.x, data.edge_index, + mask_type, args) + if mask_type == 'node': + assert len(inputs) == 1 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert len(additional_forward_args) == 2 + assert torch.allclose(additional_forward_args[0], edge_index) + elif mask_type == 'edge': + assert len(inputs) == 1 + assert inputs[0].shape == (1, num_edges) + assert inputs[0].sum() == num_edges + assert len(additional_forward_args) == 3 + assert torch.allclose(additional_forward_args[0], x) + assert torch.allclose(additional_forward_args[1], edge_index) + else: + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_edges) + assert inputs[1].sum() == num_edges + assert len(additional_forward_args) == 2 + assert torch.allclose(additional_forward_args[0], edge_index) + + # Check for HeteroData: + data = HeteroData() + x2 = torch.rand(8, 3) + data['paper'].x = x + data['author'].x = x2 + data['paper', 'to', 'author'].edge_index = edge_index + data['author', 'to', 'paper'].edge_index = edge_index.flip([0]) + inputs, additional_forward_args = to_captum_input(data.x_dict, + data.edge_index_dict, + mask_type, args) + if mask_type == 'node': + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_nodes, num_node_feats) + assert len(additional_forward_args) == 2 + for key in data.edge_types: + torch.allclose(additional_forward_args[0][key], + data[key].edge_index) + elif mask_type == 'edge': + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_edges) + assert inputs[1].shape == (1, num_edges) + assert inputs[1].sum() == inputs[0].sum() == num_edges + assert len(additional_forward_args) == 3 + for key in data.node_types: + torch.allclose(additional_forward_args[0][key], data[key].x) + for key in data.edge_types: + torch.allclose(additional_forward_args[1][key], + data[key].edge_index) + else: + assert len(inputs) == 4 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_nodes, num_node_feats) + assert inputs[2].shape == (1, num_edges) + assert inputs[3].shape == (1, num_edges) + assert inputs[3].sum() == inputs[2].sum() == num_edges + assert len(additional_forward_args) == 2 + for key in data.edge_types: + torch.allclose(additional_forward_args[0][key], + data[key].edge_index) diff --git a/test/utils/test_sparse.py b/test/utils/test_sparse.py index 170b481f73c4..f47079d6184e 100644 --- a/test/utils/test_sparse.py +++ b/test/utils/test_sparse.py @@ -1,7 +1,13 @@ import torch +from torch_sparse import SparseTensor from torch_geometric.testing import is_full_test -from torch_geometric.utils import dense_to_sparse +from torch_geometric.utils import ( + dense_to_sparse, + is_sparse, + is_torch_sparse_tensor, + to_torch_coo_tensor, +) def test_dense_to_sparse(): @@ -35,3 +41,51 @@ def test_dense_to_sparse(): edge_index, edge_attr = jit(adj) assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]] assert edge_attr.tolist() == [3, 1, 2, 1, 2] + + +def test_is_torch_sparse_tensor(): + x = torch.randn(5, 5) + + assert not is_torch_sparse_tensor(x) + assert not is_torch_sparse_tensor(SparseTensor.from_dense(x)) + assert is_torch_sparse_tensor(x.to_sparse()) + + +def test_is_sparse(): + x = torch.randn(5, 5) + + assert not is_sparse(x) + assert is_sparse(SparseTensor.from_dense(x)) + assert is_sparse(x.to_sparse()) + + +def test_to_torch_coo_tensor(): + edge_index = torch.tensor([ + [0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2], + ]) + edge_attr = torch.randn(edge_index.size(1), 8) + + adj = to_torch_coo_tensor(edge_index) + assert adj.size() == (4, 4) + assert adj.layout == torch.sparse_coo + assert torch.allclose(adj.indices(), edge_index) + + adj = to_torch_coo_tensor(edge_index, size=6) + assert adj.size() == (6, 6) + assert adj.layout == torch.sparse_coo + assert torch.allclose(adj.indices(), edge_index) + + adj = to_torch_coo_tensor(edge_index, edge_attr) + assert adj.size() == (4, 4, 8) + assert adj.layout == torch.sparse_coo + assert torch.allclose(adj.indices(), edge_index) + assert torch.allclose(adj.values(), edge_attr) + + if is_full_test(): + jit = torch.jit.script(to_torch_coo_tensor) + adj = jit(edge_index, edge_attr) + assert adj.size() == (4, 4, 8) + assert adj.layout == torch.sparse_coo + assert torch.allclose(adj.indices(), edge_index) + assert torch.allclose(adj.values(), edge_attr) diff --git a/test/utils/test_torch_sparse_tensor.py b/test/utils/test_torch_sparse_tensor.py deleted file mode 100644 index 2bb862db673c..000000000000 --- a/test/utils/test_torch_sparse_tensor.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch_sparse import SparseTensor - -from torch_geometric.utils import is_torch_sparse_tensor - - -def test_is_torch_sparse_tensor(): - x = torch.randn(5, 5) - - assert not is_torch_sparse_tensor(x) - assert not is_torch_sparse_tensor(SparseTensor.from_dense(x)) - assert is_torch_sparse_tensor(x.to_sparse()) diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index d3d68baa18ab..12044cc0019e 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -42,7 +42,7 @@ from torch_geometric.loader import GraphSAINTEdgeSampler # noqa from torch_geometric.loader import GraphSAINTRandomWalkSampler # noqa from torch_geometric.loader import ShaDowKHopSampler # noqa -from torch_geometric.loader import RandomNodeSampler # noqa +from torch_geometric.loader import RandomNodeLoader # noqa from torch_geometric.loader import DataLoader # noqa from torch_geometric.loader import DataListLoader # noqa from torch_geometric.loader import DenseDataLoader # noqa @@ -66,8 +66,8 @@ 'data.GraphSAINTRandomWalkSampler')(GraphSAINTRandomWalkSampler) ShaDowKHopSampler = deprecated("use 'loader.ShaDowKHopSampler' instead", 'data.ShaDowKHopSampler')(ShaDowKHopSampler) -RandomNodeSampler = deprecated("use 'loader.RandomNodeSampler' instead", - 'data.RandomNodeSampler')(RandomNodeSampler) +RandomNodeSampler = deprecated("use 'loader.RandomNodeLoader' instead", + 'data.RandomNodeSampler')(RandomNodeLoader) DataLoader = deprecated("use 'loader.DataLoader' instead", 'data.DataLoader')(DataLoader) DataListLoader = deprecated("use 'loader.DataListLoader' instead", diff --git a/torch_geometric/graphgym/loader.py b/torch_geometric/graphgym/loader.py index 177da7f34b17..28285cb872a3 100644 --- a/torch_geometric/graphgym/loader.py +++ b/torch_geometric/graphgym/loader.py @@ -26,7 +26,7 @@ GraphSAINTNodeSampler, GraphSAINTRandomWalkSampler, NeighborSampler, - RandomNodeSampler, + RandomNodeLoader, ) from torch_geometric.utils import ( index_to_mask, @@ -256,11 +256,11 @@ def get_loader(dataset, sampler, batch_size, shuffle=True): batch_size=batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True) elif sampler == "random_node": - loader_train = RandomNodeSampler(dataset[0], - num_parts=cfg.train.train_parts, - shuffle=shuffle, - num_workers=cfg.num_workers, - pin_memory=True) + loader_train = RandomNodeLoader(dataset[0], + num_parts=cfg.train.train_parts, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True) elif sampler == "saint_rw": loader_train = \ GraphSAINTRandomWalkSampler(dataset[0], diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 982e46edd635..92ff5f3ce96d 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -1,3 +1,5 @@ +from torch_geometric.deprecation import deprecated + from .dataloader import DataLoader from .neighbor_loader import NeighborLoader from .link_neighbor_loader import LinkNeighborLoader @@ -6,7 +8,7 @@ from .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler) from .shadow import ShaDowKHopSampler -from .random_node_sampler import RandomNodeSampler +from .random_node_loader import RandomNodeLoader from .data_list_loader import DataListLoader from .dense_data_loader import DenseDataLoader from .temporal_dataloader import TemporalDataLoader @@ -30,7 +32,7 @@ 'GraphSAINTEdgeSampler', 'GraphSAINTRandomWalkSampler', 'ShaDowKHopSampler', - 'RandomNodeSampler', + 'RandomNodeLoader', 'DataListLoader', 'DenseDataLoader', 'TemporalDataLoader', @@ -38,3 +40,8 @@ 'ImbalancedSampler', 'DynamicBatchSampler', ] + +RandomNodeSampler = deprecated( + details="use 'loader.RandomNodeLoader' instead", + func_name='loader.RandomNodeSampler', +)(RandomNodeLoader) diff --git a/torch_geometric/loader/random_node_loader.py b/torch_geometric/loader/random_node_loader.py new file mode 100644 index 000000000000..3c80036ee358 --- /dev/null +++ b/torch_geometric/loader/random_node_loader.py @@ -0,0 +1,68 @@ +import math +from typing import Union + +import torch +from torch import Tensor + +from torch_geometric.data import Data, HeteroData +from torch_geometric.data.hetero_data import to_homogeneous_edge_index + + +class RandomNodeLoader(torch.utils.data.DataLoader): + r"""A data loader that randomly samples nodes within a graph and returns + their induced subgraph. + + .. note:: + + For an example of using + :class:`~torch_geometric.loader.RandomNodeLoader`, see + `examples/ogbn_proteins_deepgcn.py + `_. + + Args: + data (torch_geometric.data.Data or torch_geometric.data.HeteroData): + The :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` graph object. + num_parts (int): The number of partitions. + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. + """ + def __init__( + self, + data: Union[Data, HeteroData], + num_parts: int, + **kwargs, + ): + self.data = data + self.num_parts = num_parts + + if isinstance(data, HeteroData): + edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data) + self.node_dict, self.edge_dict = node_dict, edge_dict + else: + edge_index = data.edge_index + + self.edge_index = edge_index + self.num_nodes = data.num_nodes + + super().__init__( + range(self.num_nodes), + batch_size=math.ceil(self.num_nodes / num_parts), + collate_fn=self.collate_fn, + **kwargs, + ) + + def collate_fn(self, index): + if not isinstance(index, Tensor): + index = torch.tensor(index) + + if isinstance(self.data, Data): + return self.data.subgraph(index) + + elif isinstance(self.data, HeteroData): + node_dict = { + key: index[(index >= start) & (index < end)] - start + for key, (start, end) in self.node_dict.items() + } + return self.data.subgraph(node_dict) diff --git a/torch_geometric/loader/random_node_sampler.py b/torch_geometric/loader/random_node_sampler.py deleted file mode 100644 index 264a19c6fa7e..000000000000 --- a/torch_geometric/loader/random_node_sampler.py +++ /dev/null @@ -1,91 +0,0 @@ -import copy - -import torch -from torch import Tensor -from torch_sparse import SparseTensor - - -class RandomIndexSampler(torch.utils.data.Sampler): - def __init__(self, num_nodes: int, num_parts: int, shuffle: bool = False): - self.N = num_nodes - self.num_parts = num_parts - self.shuffle = shuffle - self.n_ids = self.get_node_indices() - - def get_node_indices(self): - n_id = torch.randint(self.num_parts, (self.N, ), dtype=torch.long) - n_ids = [(n_id == i).nonzero(as_tuple=False).view(-1) - for i in range(self.num_parts)] - return n_ids - - def __iter__(self): - if self.shuffle: - self.n_ids = self.get_node_indices() - return iter(self.n_ids) - - def __len__(self): - return self.num_parts - - -class RandomNodeSampler(torch.utils.data.DataLoader): - r"""A data loader that randomly samples nodes within a graph and returns - their induced subgraph. - - .. note:: - - For an example of using - :class:`~torch_geometric.loader.RandomNodeSampler`, see - `examples/ogbn_proteins_deepgcn.py - `_. - - Args: - data (torch_geometric.data.Data): The graph data object. - num_parts (int): The number of partitions. - shuffle (bool, optional): If set to :obj:`True`, the data is reshuffled - at every epoch (default: :obj:`False`). - **kwargs (optional): Additional arguments of - :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. - """ - def __init__(self, data, num_parts: int, shuffle: bool = False, **kwargs): - assert data.edge_index is not None - - self.N = N = data.num_nodes - self.E = data.num_edges - - self.adj = SparseTensor( - row=data.edge_index[0], col=data.edge_index[1], - value=torch.arange(self.E, device=data.edge_index.device), - sparse_sizes=(N, N)) - - self.data = copy.copy(data) - self.data.edge_index = None - - super().__init__( - self, batch_size=1, - sampler=RandomIndexSampler(self.N, num_parts, shuffle), - collate_fn=self.__collate__, **kwargs) - - def __getitem__(self, idx): - return idx - - def __collate__(self, node_idx): - node_idx = node_idx[0] - - data = self.data.__class__() - data.num_nodes = node_idx.size(0) - adj, _ = self.adj.saint_subgraph(node_idx) - row, col, edge_idx = adj.coo() - data.edge_index = torch.stack([row, col], dim=0) - - for key, item in self.data: - if key in ['num_nodes']: - continue - if isinstance(item, Tensor) and item.size(0) == self.N: - data[key] = item[node_idx] - elif isinstance(item, Tensor) and item.size(0) == self.E: - data[key] = item[edge_idx] - else: - data[key] = item - - return data diff --git a/torch_geometric/nn/aggr/basic.py b/torch_geometric/nn/aggr/basic.py index 1cfc08ea4255..af1cae157769 100644 --- a/torch_geometric/nn/aggr/basic.py +++ b/torch_geometric/nn/aggr/basic.py @@ -109,7 +109,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: var = self.var_aggr(x, index, ptr, dim_size, dim) - return torch.sqrt(var.relu() + 1e-5) + return (var.relu() + 1e-5).sqrt() class SoftmaxAggregation(Aggregation): diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py new file mode 100644 index 000000000000..3778037247b5 --- /dev/null +++ b/torch_geometric/nn/aggr/fused.py @@ -0,0 +1,293 @@ +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torch_geometric.nn import ( + Aggregation, + MaxAggregation, + MeanAggregation, + MinAggregation, + MulAggregation, + StdAggregation, + SumAggregation, + VarAggregation, +) +from torch_geometric.nn.resolver import aggregation_resolver + + +class FusedAggregation(Aggregation): + r"""Helper class to fuse computation of multiple aggregations together. + Used internally in :class:`~torch_geometric.nn.aggr.MultiAggregation` to + speed-up computation. + Currently, the following optimizations are performed: + + * :class:`MeanAggregation` will share the output with + :class:`SumAggregation` in case it is present as well. + + * :class:`VarAggregation` will share the output with either + :class:`MeanAggregation` or :class:`SumAggregation` in case one of them + is present as well. + + * :class:`StdAggregation` will share the output with either + :class:`VarAggregation`, :class:`MeanAggregation` or + :class:`SumAggregation` in case one of them is present as well. + + In addition, temporary values such as the count per group index or the + mask for invalid rows are shared as well. + + Benchmarking results on PyTorch 1.12 (summed over 1000 runs): + + +------------------------------+---------+---------+ + | Aggregators | Vanilla | Fusion | + +==============================+=========+=========+ + | :obj:`[sum, mean]` | 0.4019s | 0.1666s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, min, max]` | 0.7841s | 0.4223s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, var]` | 0.9711s | 0.3614s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, var, std]` | 1.5994s | 0.3722s | + +------------------------------+---------+---------+ + + Args: + aggrs (list): The list of aggregation schemes to use. + """ + # We can fuse all aggregations together that rely on `scatter` directives. + FUSABLE_AGGRS = { + SumAggregation, + MeanAggregation, + MinAggregation, + MaxAggregation, + MulAggregation, + VarAggregation, + StdAggregation, + } + + # All aggregations that rely on computing the degree of indices. + DEGREE_BASED_AGGRS = { + MeanAggregation, + VarAggregation, + StdAggregation, + } + + # All aggregations that require manual masking for invalid rows: + MASK_REQUIRED_AGGRS = { + MinAggregation, + MaxAggregation, + MulAggregation, + } + + # Map aggregations to `reduce` options in `scatter` directives. + REDUCE = { + SumAggregation: 'sum', + MeanAggregation: 'sum', + MinAggregation: 'amin', + MaxAggregation: 'amax', + MulAggregation: 'prod', + VarAggregation: 'pow_sum', + StdAggregation: 'pow_sum', + } + + def __init__(self, aggrs: List[Union[Aggregation, str]]): + super().__init__() + + if not isinstance(aggrs, (list, tuple)): + raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " + f"be a list or tuple (got '{type(aggrs)}').") + + if len(aggrs) == 0: + raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " + f"not be empty.") + + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + self.aggr_cls = [aggr.__class__ for aggr in aggrs] + self.aggr_index = {cls: i for i, cls in enumerate(self.aggr_cls)} + + for cls in self.aggr_cls: + if cls not in self.FUSABLE_AGGRS: + raise ValueError(f"Received aggregation '{cls.__name__}' in " + f"'{self.__class__.__name__}' which is not " + f"fusable") + + # Check whether we need to compute degree information: + self.need_degree = False + for cls in self.aggr_cls: + if cls in self.DEGREE_BASED_AGGRS: + self.need_degree = True + + # Check whether we need to compute mask information: + self.requires_mask = False + for cls in self.aggr_cls: + if cls in self.MASK_REQUIRED_AGGRS: + self.requires_mask = True + + # Determine which reduction to use for each aggregator: + # An entry of `None` means that this operator re-uses intermediate + # outputs from other aggregators. + self.reduce_ops: List[Optional[str]] = [] + # Determine which `(Aggregator, index)` to use as intermediate output: + self.lookup_ops: List[Optional[Tuple[Any, int]]] = [] + + for cls in self.aggr_cls: + if cls == MeanAggregation: + # Directly use output of `SumAggregation`: + if SumAggregation in self.aggr_index: + self.reduce_ops.append(None) + self.lookup_ops.append( + (SumAggregation, self.aggr_index[SumAggregation])) + else: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append(None) + + elif cls == VarAggregation: + if MeanAggregation in self.aggr_index: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append( + (MeanAggregation, self.aggr_index[MeanAggregation])) + elif SumAggregation in self.aggr_index: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append( + (SumAggregation, self.aggr_index[SumAggregation])) + else: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append(None) + + elif cls == StdAggregation: + # Directly use output of `VarAggregation`: + if VarAggregation in self.aggr_index: + self.reduce_ops.append(None) + self.lookup_ops.append( + (VarAggregation, self.aggr_index[VarAggregation])) + elif MeanAggregation in self.aggr_index: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append( + (MeanAggregation, self.aggr_index[MeanAggregation])) + elif SumAggregation in self.aggr_index: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append( + (SumAggregation, self.aggr_index[SumAggregation])) + else: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append(None) + + else: + self.reduce_ops.append(self.REDUCE[cls]) + self.lookup_ops.append(None) + + def forward(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + # Assert two-dimensional input for now to simplify computation: + # TODO refactor this to support any dimension. + self.assert_index_present(index) + self.assert_two_dimensional_input(x, dim) + + if self.need_degree: + count = x.new_zeros(dim_size) + count.scatter_add_(0, index, x.new_ones(x.size(0))) + if self.requires_mask: + mask = count == 0 + count = count.clamp_(min=1).view(-1, 1) + + elif self.requires_mask: # Mask to set non-existing indicses to zero: + mask = x.new_ones(dim_size, dtype=torch.bool) + mask[index] = False + + num_feats = x.size(-1) + index = index.view(-1, 1).expand(-1, num_feats) + + ####################################################################### + + outs: List[Optional[Tensor]] = [] + + # Iterate over all reduction ops to compute first results: + for i, reduce in enumerate(self.reduce_ops): + if reduce is None: + outs.append(None) + continue + + src = x * x if reduce == 'pow_sum' else x + reduce = 'sum' if reduce == 'pow_sum' else reduce + + fill_value = 0.0 + if reduce == 'amin': + fill_value = float('inf') + elif reduce == 'amax': + fill_value = float('-inf') + elif reduce == 'prod': + fill_value = 1.0 + + # `include_self=True` + manual masking leads to faster runtime: + out = x.new_full((dim_size, num_feats), fill_value) + out.scatter_reduce_(0, index, src, reduce, include_self=True) + if fill_value != 0.0: + out = out.masked_fill(mask.view(-1, 1), 0.0) + outs.append(out) + + ####################################################################### + + # Compute `MeanAggregation` first to be able to re-use it: + i = self.aggr_index.get(MeanAggregation) + if i is not None: + if self.lookup_ops[i] is None: + sum_ = outs[i] + else: + tmp_aggr, j = self.lookup_ops[i] + assert tmp_aggr == SumAggregation + sum_ = outs[j] + + outs[i] = sum_ / count + + # Compute `VarAggregation` second to be able to re-use it: + i = self.aggr_index.get(VarAggregation) + if i is not None: + if self.lookup_ops[i] is None: + mean = x.new_zeros(dim_size, num_feats) + mean.scatter_reduce_(0, index, x, 'sum', include_self=True) + mean = mean / count + else: + tmp_aggr, j = self.lookup_ops[i] + if tmp_aggr == SumAggregation: + mean = outs[j] / count + elif tmp_aggr == MeanAggregation: + mean = outs[j] + else: + raise NotImplementedError + + pow_sum = outs[i] + outs[i] = (pow_sum / count) - (mean * mean) + + # Compute `StdAggregation` last: + i = self.aggr_index.get(StdAggregation) + if i is not None: + var = None + if self.lookup_ops[i] is None: + pow_sum = outs[i] + mean = x.new_zeros(dim_size, num_feats) + mean.scatter_reduce_(0, index, x, 'sum', include_self=True) + mean = mean / count + else: + tmp_aggr, j = self.lookup_ops[i] + if tmp_aggr == VarAggregation: + var = outs[j] + elif tmp_aggr == SumAggregation: + pow_sum = outs[i] + mean = outs[j] / count + elif tmp_aggr == MeanAggregation: + pow_sum = outs[i] + mean = outs[j] + else: + raise NotImplementedError + + if var is None: + var = (pow_sum / count) - (mean * mean) + + outs[i] = (var.relu() + 1e-5).sqrt() + + ####################################################################### + + out = torch.cat(outs, dim=-1) + + return out diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 944cc22c6c15..acb76d18028a 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -26,7 +26,7 @@ from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.typing import Adj, Size -from torch_geometric.utils import is_torch_sparse_tensor +from torch_geometric.utils import is_sparse, is_torch_sparse_tensor from .utils.helpers import expand_left from .utils.inspector import Inspector, func_body_repr, func_header_repr @@ -183,14 +183,15 @@ def __init__( def __check_input__(self, edge_index, size): the_size: List[Optional[int]] = [None, None] - if is_torch_sparse_tensor(edge_index): + if is_sparse(edge_index): if self.flow == 'target_to_source': raise ValueError( ('Flow direction "target_to_source" is invalid for ' - 'message propagation via `torch.sparse.Tensor`. If ' - 'you really want to make use of a reverse message ' - 'passing flow, pass in the transposed sparse tensor to ' - 'the message passing module, e.g., `adj_t.t()`.')) + 'message propagation via `torch_sparse.SparseTensor` ' + 'or `torch.sparse.Tensor`. If you really want to make ' + 'use of a reverse message passing flow, pass in the ' + 'transposed sparse tensor to the message passing module, ' + 'e.g., `adj_t.t()`.')) the_size[0] = edge_index.size(1) the_size[1] = edge_index.size(0) return the_size @@ -212,18 +213,6 @@ def __check_input__(self, edge_index, size): the_size[1] = size[1] return the_size - elif isinstance(edge_index, SparseTensor): - if self.flow == 'target_to_source': - raise ValueError( - ('Flow direction "target_to_source" is invalid for ' - 'message propagation via `torch_sparse.SparseTensor`. If ' - 'you really want to make use of a reverse message ' - 'passing flow, pass in the transposed sparse tensor to ' - 'the message passing module, e.g., `adj_t.t()`.')) - the_size[0] = edge_index.sparse_size(1) - the_size[1] = edge_index.sparse_size(0) - return the_size - raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or ' @@ -403,9 +392,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): size = self.__check_input__(edge_index, size) # Run "fused" message and aggregation (if applicable). - if ((isinstance(edge_index, SparseTensor) - or is_torch_sparse_tensor(edge_index)) and self.fuse - and not self.explain): + if is_sparse(edge_index) and self.fuse and not self.explain: coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index a2cbc840b0e5..b1a24014e297 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -9,7 +9,8 @@ from .graph_unet import GraphUNet from .schnet import SchNet from .dimenet import DimeNet, DimeNetPlusPlus -from .explainer import Explainer, to_captum +from .explainer import (Explainer, to_captum, to_captum_model, to_captum_input, + captum_output_to_dicts) from .gnn_explainer import GNNExplainer from .metapath2vec import MetaPath2Vec from .deepgcn import DeepGCNLayer @@ -47,6 +48,9 @@ 'DimeNetPlusPlus', 'Explainer', 'to_captum', + 'to_captum_model', + 'to_captum_input', + 'captum_output_to_dicts', 'GNNExplainer', 'MetaPath2Vec', 'DeepGCNLayer', diff --git a/torch_geometric/nn/models/explainer.py b/torch_geometric/nn/models/explainer.py index 9925ed52f6d6..e01ec25609e7 100644 --- a/torch_geometric/nn/models/explainer.py +++ b/torch_geometric/nn/models/explainer.py @@ -1,13 +1,14 @@ from inspect import signature from math import sqrt -from typing import Dict, Optional +from typing import Dict, Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data +from torch_geometric.deprecation import deprecated from torch_geometric.nn import MessagePassing -from torch_geometric.typing import EdgeType +from torch_geometric.typing import EdgeType, Metadata, NodeType from torch_geometric.utils import get_num_hops, k_hop_subgraph, to_networkx @@ -114,31 +115,290 @@ def forward(self, mask, *args): return x -def to_captum(model: torch.nn.Module, mask_type: str = "edge", - output_idx: Optional[int] = None) -> torch.nn.Module: +# TODO(jinu) Is there any point of inheriting from `CaptumModel` +class CaptumHeteroModel(CaptumModel): + def __init__(self, model: torch.nn.Module, mask_type: str, output_id: int, + metadata: Metadata): + super().__init__(model, mask_type, output_id) + self.node_types = metadata[0] + self.edge_types = metadata[1] + self.num_node_types = len(self.node_types) + self.num_edge_types = len(self.edge_types) + + def _captum_data_to_hetero_data( + self, *args + ) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Tensor], Optional[Dict[ + EdgeType, Tensor]]]: + """Converts tuple of tensors to `x_dict`, `edge_index_dict` and + `edge_mask_dict`.""" + + if self.mask_type == 'node': + node_tensors = args[:self.num_node_types] + node_tensors = [mask.squeeze(0) for mask in node_tensors] + x_dict = dict(zip(self.node_types, node_tensors)) + edge_index_dict = args[self.num_node_types] + elif self.mask_type == 'edge': + edge_mask_tensors = args[:self.num_edge_types] + x_dict = args[self.num_edge_types] + edge_index_dict = args[self.num_edge_types + 1] + else: + node_tensors = args[:self.num_node_types] + node_tensors = [mask.squeeze(0) for mask in node_tensors] + x_dict = dict(zip(self.node_types, node_tensors)) + edge_mask_tensors = args[self.num_node_types:self.num_node_types + + self.num_edge_types] + edge_index_dict = args[self.num_node_types + self.num_edge_types] + + if 'edge' in self.mask_type: + edge_mask_tensors = [mask.squeeze(0) for mask in edge_mask_tensors] + edge_mask_dict = dict(zip(self.edge_types, edge_mask_tensors)) + else: + edge_mask_dict = None + return x_dict, edge_index_dict, edge_mask_dict + + def forward(self, *args): + # Validate args: + if self.mask_type == "node": + assert len(args) >= self.num_node_types + 1 + len_remaining_args = len(args) - (self.num_node_types + 1) + elif self.mask_type == "edge": + assert len(args) >= self.num_edge_types + 2 + len_remaining_args = len(args) - (self.num_edge_types + 2) + else: + assert len(args) >= self.num_node_types + self.num_edge_types + 1 + len_remaining_args = len(args) - (self.num_node_types + + self.num_edge_types + 1) + + # Get main args: + (x_dict, edge_index_dict, + edge_mask_dict) = self._captum_data_to_hetero_data(*args) + + if 'edge' in self.mask_type: + set_hetero_masks(self.model, edge_mask_dict, edge_index_dict) + + if len_remaining_args > 0: + # If there are args other than `x_dict` and `edge_index_dict` + x = self.model(x_dict, edge_index_dict, + *args[-len_remaining_args:]) + else: + x = self.model(x_dict, edge_index_dict) + + if 'edge' in self.mask_type: + clear_masks(self.model) + + if self.output_idx is not None: + x = x[self.output_idx].unsqueeze(0) + return x + + +def _to_edge_mask(edge_index: Tensor) -> Tensor: + num_edges = edge_index.shape[1] + return torch.ones(num_edges, requires_grad=True, device=edge_index.device) + + +def _raise_on_invalid_mask_type(mask_type: str): + if mask_type not in ['node', 'edge', 'node_and_edge']: + raise ValueError(f"Invalid mask type (got {mask_type})") + + +def to_captum_input(x: Union[Tensor, Dict[EdgeType, Tensor]], + edge_index: Union[Tensor, Dict[EdgeType, + Tensor]], mask_type: str, + *args) -> Tuple[Tuple[Tensor], Tuple[Tensor]]: + r"""Given :obj:`x`, :obj:`edge_index` and :obj:`mask_type`, converts it + to a format to use in `Captum.ai `_ attribution + methods. Returns :obj:`inputs` and :obj:`additional_forward_args` + required for `Captum`'s :obj:`attribute` functions. + See :obj:`torch_geometric.nn.to_captum_model` for example usage. + + Args: + + x (Tensor or Dict[NodeType, Tensor]): The node features. For + heterogenous graphs this is a dictionary holding node featues + for each node type. + edge_index(Tensor or Dict[EdgeType, Tensor]): The edge indicies. For + heterogenous graphs this is a dictionary holding edge index + for each edge type. + mask_type (str): Denotes the type of mask to be created with + a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, + and :obj:`"node_and_edge"`: + *args: Additional forward arguments of the model being explained + which will be added to :obj:`additonal_forward_args`. + For :class:`Data` this is arguments other than :obj:`x` and + :obj:`edge_index`. For :class:`HeteroData` this is arguments other + than :obj:`x_dict` and :obj:`edge_index_dict`. + """ + _raise_on_invalid_mask_type(mask_type) + + additional_forward_args = [] + if isinstance(x, Tensor) and isinstance(edge_index, Tensor): + if mask_type == "node": + inputs = [x.unsqueeze(0)] + elif mask_type == "edge": + inputs = [_to_edge_mask(edge_index).unsqueeze(0)] + additional_forward_args.append(x) + else: + inputs = [x.unsqueeze(0), _to_edge_mask(edge_index).unsqueeze(0)] + additional_forward_args.append(edge_index) + + elif isinstance(x, Dict) and isinstance(edge_index, Dict): + node_types = x.keys() + edge_types = edge_index.keys() + inputs = [] + if mask_type == "node": + for key in node_types: + inputs.append(x[key].unsqueeze(0)) + elif mask_type == "edge": + for key in edge_types: + inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0)) + additional_forward_args.append(x) + else: + for key in node_types: + inputs.append(x[key].unsqueeze(0)) + for key in edge_types: + inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0)) + additional_forward_args.append(edge_index) + + else: + raise ValueError( + "'x' and 'edge_index' need to be either" + f"'Dict' or 'Tensor' got({type(x)}, {type(edge_index)})") + additional_forward_args.extend(args) + return tuple(inputs), tuple(additional_forward_args) + + +def captum_output_to_dicts( + captum_attrs: Tuple[Tensor], mask_type: str, metadata: Metadata +) -> Tuple[Optional[Dict[NodeType, Tensor]], Optional[Dict[EdgeType, Tensor]]]: + r"""Convert the output of `Captum.ai `_ attribution + methods which is a tuple of attributions to two dictonaries with node and + edge attribution tensors. This function is used while explaining + :obj:`HeteroData` objects. See :obj:`torch_geometric.nn.to_captum_model` + for example usage. + + Args: + captum_attrs (tuple[tensor]): The output of attribution methods. + mask_type (str): Denotes the type of mask to be created with + a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, + and :obj:`"node_and_edge"`: + + 1. :obj:`"edge"`: :obj:`captum_attrs` contains only edge + attributions. The returned tuple has no node attributions and a + edge attribution dictionary with key `EdgeType` and value + edge mask tensor of shape :obj:`[num_edges]`. + + 2. :obj:`"node"`: :obj:`captum_attrs` contains only node + attributions. The returned tuple has node attribution dictonary + with key `NodeType` and value node mask tensor of shape + :obj:`[num_nodes, num_features]` and no edge attribution. + + 3. :obj:`"node_and_edge"`: :obj:`captum_attrs` contains only node + attributions. The returned tuple contains node attribution + dictionary followed by edge attribution dictionary. + + metadata (Metadata): The metadata of the heterogeneous graph. + """ + _raise_on_invalid_mask_type(mask_type) + node_types = metadata[0] + edge_types = metadata[1] + x_attr_dict, edge_attr_dict = None, None + captum_attrs = [captum_attr.squeeze(0) for captum_attr in captum_attrs] + if mask_type == "node": + assert len(node_types) == len(captum_attrs) + x_attr_dict = dict(zip(node_types, captum_attrs)) + elif mask_type == "edge": + assert len(edge_types) == len(captum_attrs) + edge_attr_dict = dict(zip(edge_types, captum_attrs)) + elif mask_type == "node_and_edge": + assert len(edge_types) + len(node_types) == len(captum_attrs) + x_attr_dict = dict(zip(node_types, captum_attrs[:len(node_types)])) + edge_attr_dict = dict(zip(edge_types, captum_attrs[len(node_types):])) + return x_attr_dict, edge_attr_dict + + +@deprecated(details='Use `torch_geometric.nn.to_captum_model` instead') +def to_captum( + model: torch.nn.Module, mask_type: str = "edge", + output_idx: Optional[int] = None, metadata: Optional[Metadata] = None +) -> Union[CaptumModel, CaptumHeteroModel]: + r""" + Alias for :obj:`to_captum_model`. + + .. warning:: + + :obj:`~torch_geometric.nn.to_captum` is deprecated and will + be removed in a future release. + Use :obj:`torch_geometric.nn.to_captum_model` instead. + + """ + return to_captum_model(model, mask_type, output_idx, metadata) + + +def to_captum_model( + model: torch.nn.Module, mask_type: str = "edge", + output_idx: Optional[int] = None, metadata: Optional[Metadata] = None +) -> Union[CaptumModel, CaptumHeteroModel]: r"""Converts a model to a model that can be used for `Captum.ai `_ attribution methods. + Sample code for homogenous graphs: + .. code-block:: python from captum.attr import IntegratedGradients - from torch_geometric.nn import GCN, to_captum + from torch_geometric.data import Data + from torch_geometric.nn import GCN + from torch_geometric.nn import to_captum_model, to_captum_input + + data = Data(x=(...), edge_index(...)) model = GCN(...) ... # Train the model. # Explain predictions for node `10`: + mask_type="edge" output_idx = 10 + captum_model = to_captum_model(model, mask_type, output_idx) + inputs, additional_forward_args = to_captum_input(data.x, + data.edge_index,mask_type) + + ig = IntegratedGradients(captum_model) + ig_attr = ig.attribute(inputs = inputs, + target=int(y[output_idx]), + additional_forward_args=additional_forward_args, + internal_batch_size=1) + - captum_model = to_captum(model, mask_type="edge", - output_idx=output_idx) - edge_mask = torch.ones(num_edges, requires_grad=True, device=device) + Sample code for heterogenous graphs: + + .. code-block:: python + + from captum.attr import IntegratedGradients + + from torch_geometric.data import HeteroData + from torch_geometric.nn import HeteroConv + from torch_geometric.nn import (captum_output_to_dicts, + to_captum_model, to_captum_input) + + data = HeteroData(...) + model = HeteroConv(...) + ... # Train the model. + + # Explain predictions for node `10`: + mask_type="edge" + metadata = data.metadata + output_idx = 10 + captum_model = to_captum_model(model, mask_type, output_idx, metadata) + inputs, additional_forward_args = to_captum_input(data.x_dict, + data.edge_index_dict, mask_type) ig = IntegratedGradients(captum_model) - ig_attr = ig.attribute(edge_mask.unsqueeze(0), + ig_attr = ig.attribute(inputs=inputs, target=int(y[output_idx]), - additional_forward_args=(x, edge_index), + additional_forward_args=additional_forward_args, internal_batch_size=1) + edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata) + .. note:: For an example of using a Captum attribution method within PyG, see @@ -150,31 +410,19 @@ def to_captum(model: torch.nn.Module, mask_type: str = "edge", model (torch.nn.Module): The model to be explained. mask_type (str, optional): Denotes the type of mask to be created with a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`, - and :obj:`"node_and_edge"`: - - 1. :obj:`"edge"`: The inputs to the forward function should be an - edge mask tensor of shape :obj:`[1, num_edges]`, a regular - :obj:`x` matrix and a regular :obj:`edge_index` matrix. - - 2. :obj:`"node"`: The inputs to the forward function should be a - node feature tensor of shape :obj:`[1, num_nodes, num_features]` - and a regular :obj:`edge_index` matrix. - - 3. :obj:`"node_and_edge"`: The inputs to the forward function - should be a node feature tensor of shape - :obj:`[1, num_nodes, num_features]`, an edge mask tensor of - shape :obj:`[1, num_edges]` and a regular :obj:`edge_index` - matrix. - - For all mask types, additional arguments can be passed to the - forward function as long as the first arguments are set as - described. (default: :obj:`"edge"`) + and :obj:`"node_and_edge"`. (default: :obj:`"edge"`) output_idx (int, optional): Index of the output element (node or link index) to be explained. With :obj:`output_idx` set, the forward function will return the output of the model for the element at the index specified. (default: :obj:`None`) + metadata (Metadata, optional): The metadata of the heterogeneous graph. + Only required if explaning over a `HeteroData` object. + (default: :obj: `None`) """ - return CaptumModel(model, mask_type, output_idx) + if metadata is None: + return CaptumModel(model, mask_type, output_idx) + else: + return CaptumHeteroModel(model, mask_type, output_idx, metadata) class Explainer(torch.nn.Module): diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 8af72feedc5c..1ac6191893ff 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -18,7 +18,8 @@ from .mask import index_to_mask, mask_to_index from .to_dense_batch import to_dense_batch from .to_dense_adj import to_dense_adj -from .sparse import dense_to_sparse +from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor, + to_torch_coo_tensor) from .unbatch import unbatch, unbatch_edge_index from .normalized_cut import normalized_cut from .grid import grid @@ -36,7 +37,6 @@ structured_negative_sampling_feasible) from .train_test_split_edges import train_test_split_edges from .scatter import scatter -from .torch_sparse_tensor import is_torch_sparse_tensor from .spmm import spmm __all__ = [ @@ -98,6 +98,8 @@ 'train_test_split_edges', 'scatter', 'is_torch_sparse_tensor', + 'is_sparse', + 'to_torch_coo_tensor', 'spmm', ] diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py index cc3032ea75bd..ffcf504b4f8b 100644 --- a/torch_geometric/utils/sparse.py +++ b/torch_geometric/utils/sparse.py @@ -1,7 +1,8 @@ -from typing import Tuple +from typing import Any, Optional, Tuple, Union import torch from torch import Tensor +from torch_sparse import SparseTensor def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]: @@ -46,3 +47,69 @@ def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]: row = batch + edge_index[1] col = batch + edge_index[2] return torch.stack([row, col], dim=0), edge_attr + + +def is_torch_sparse_tensor(src: Any) -> bool: + """Returns :obj:`True` if the input :obj:`src` is a + :class:`torch.sparse.Tensor` (in any sparse layout). + + Args: + src (Any): The input object to be checked. + """ + return isinstance(src, Tensor) and src.is_sparse + + +def is_sparse(src: Any) -> bool: + """Returns :obj:`True` if the input :obj:`src` is of type + :class:`torch.sparse.Tensor` (in any sparse layout) or of type + :class:`torch_sparse.SparseTensor`. + + Args: + src (Any): The input object to be checked. + """ + return is_torch_sparse_tensor(src) or isinstance(src, SparseTensor) + + +def to_torch_coo_tensor( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + size: Optional[Union[int, Tuple[int, int]]] = None, +) -> Tensor: + """Converts a sparse adjacency matrix defined by edge indices and edge + attributes to a :class:`torch.sparse.Tensor`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): The edge attributes. + (default: :obj:`None`) + size (int or (int, int), optional): The size of the sparse matrix. + If given as an integer, will create a quadratic sparse matrix. + If set to :obj:`None`, will infer a quadratic sparse matrix based + on :obj:`edge_index.max() + 1`. (default: :obj:`None`) + + :rtype: :class:`torch.sparse.FloatTensor` + + Example: + + >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2]]) + >>> to_torch_coo_tensor(edge_index) + tensor(indices=tensor([[0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2]]), + values=tensor([1., 1., 1., 1., 1., 1.]), + size=(4, 4), nnz=6, layout=torch.sparse_coo) + + """ + if size is None: + size = int(edge_index.max()) + 1 + if not isinstance(size, (tuple, list)): + size = (size, size) + + if edge_attr is None: + edge_attr = torch.ones(edge_index.size(1), device=edge_index.device) + + size = tuple(size) + edge_attr.size()[1:] + out = torch.sparse_coo_tensor(edge_index, edge_attr, size, + device=edge_index.device) + out = out.coalesce() + return out diff --git a/torch_geometric/utils/spmm.py b/torch_geometric/utils/spmm.py index 8c66f1dd4bec..98d8b74cb0bb 100644 --- a/torch_geometric/utils/spmm.py +++ b/torch_geometric/utils/spmm.py @@ -4,7 +4,7 @@ from torch import Tensor from torch_sparse import SparseTensor, matmul -from .torch_sparse_tensor import is_torch_sparse_tensor +from .sparse import is_torch_sparse_tensor @torch.jit._overload diff --git a/torch_geometric/utils/torch_sparse_tensor.py b/torch_geometric/utils/torch_sparse_tensor.py deleted file mode 100644 index efe054689d55..000000000000 --- a/torch_geometric/utils/torch_sparse_tensor.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any - -from torch import Tensor - - -def is_torch_sparse_tensor(src: Any) -> bool: - """Returns :obj:`True` if the input :obj:`x` is a PyTorch - :obj:`SparseTensor` (in any sparse format). - - Args: - src (Any): The input object to be checked. - """ - return isinstance(src, Tensor) and src.is_sparse