Skip to content

PyG 2.4.0: Model compilation, on-disk datasets, hierarchical sampling

Compare
Choose a tag to compare
@akihironitta akihironitta released this 12 Oct 08:28
· 471 commits to master since this release
f97f3e1

We are excited to announce the release of PyG 2.4 πŸŽ‰πŸŽ‰πŸŽ‰

PyG 2.4 is the culmination of work from 62 contributors who have worked on features and bug-fixes for a total of over 500 commits since torch-geometric==2.3.1.

Highlights

PyTorch 2.1 and torch.compile(dynamic=True) support

The long wait has an end! With the release of PyTorch 2.1, PyG 2.4 now brings full support for torch.compile to graphs of varying size via the dynamic=True option, which is especially useful for use-cases that involve the usage of DataLoader or NeighborLoader. Examples and tutorials have been updated to reflect this support accordingly (#8134), and models and layers in torch_geometric.nn have been tested to produce zero graph breaks:

import torch_geometric

model = torch_geometric.compile(model, dynamic=True)

When enabling the dynamic=True option, PyTorch will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. As such, you should only ever not specify dynamic=True when graph sizes are guaranteed to never change. Note that dynamic=True requires PyTorch >= 2.1.0 to be installed.

PyG 2.4 is fully compatible with PyTorch 2.1, and supports the following combinations:

PyTorch 2.1 cpu cu118 cu121
Linux βœ… βœ… βœ…
macOS βœ…
Windows βœ… βœ… βœ…

You can still install PyG 2.4 on older PyTorch releases up to PyTorch 1.11 in case you are not eager to update your PyTorch version.

OnDiskDataset Interface

We added the OnDiskDataset base class for creating large graph datasets (e.g., molecular databases with billions of graphs), which do not easily fit into CPU memory at once (#8028, #8044, #8046, #8051, #8052, #8054, #8057, #8058, #8066, #8088, #8092, #8106). OnDiskDataset leverages our newly introduced Database backend (sqlite3 by default) for on-disk storage and access of graphs, supports DataLoader out-of-the-box, and is optimized for maximum performance.

OnDiskDataset utilizes a user-specified schema to store data as efficient as possible (instead of Python pickling). The schema can take int, float str, object or a dictionary with dtype and size keys (for specifying tensor data) as input, and can be nested as a dictionary. For example,

dataset = OnDiskDataset(root, schema={
    'x': dict(dtype=torch.float, size=(-1, 16)),
    'edge_index': dict(dtype=torch.long, size=(2, -1)),
    'y': float,
})

creates a database with three columns, where x and edge_index are stored as binary data, and y is stored as a float.

Afterwards, you can append data to the OnDiskDataset and retrieve data from it via dataset.append()/dataset.extend(), and dataset.get()/dataset.multi_get(), respectively. We added a fully working example on how to set up your own OnDiskDataset here (#8102). You can also convert in-memory dataset instances to an OnDiskDataset instance by running InMemoryDataset.to_on_disk_dataset() (#8116).

Neighbor Sampling Improvements

Hierarchical Sampling

One drawback of NeighborLoader is that it computes a representations for all sampled nodes at all depths of the network. However, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation. NeighborLoader will be marginally slower since we are computing node embeddings for nodes we no longer need. This is a trade-off we have made to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine.

With PyG 2.4, we introduced the option to eliminate this overhead and speed-up training and inference in mini-batch GNNs further, which we call "Hierarchical Neighborhood Sampling" (see here for the full tutorial) (#6661, #7089, #7244, #7425, #7594, #7942). Its main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer, and works seamlessly across several models, both in the homogeneous and heterogeneous graph setting. To support this trimming and implement it effectively, the NeighborLoader implementation in PyG and in pyg-lib additionally return the number of nodes and edges sampled in each hop, which are then used on a per-layer basis to trim the adjacency matrix and the various feature matrices to only maintain the required amount (see the trim_to_layer method):

class GNN(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int):
        super().__init__()

        self.convs = ModuleList([SAGEConv(in_channels, 64)])
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels)

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        num_sampled_nodes_per_hop: List[int],
        num_sampled_edges_per_hop: List[int],
    ) -> Tensor:

        for i, conv in enumerate(self.convs):
            # Trim edge and node information to the current layer `i`.
            x, edge_index, _ = trim_to_layer(
                i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
                x, edge_index)

            x = conv(x, edge_index).relu()

        return self.lin(x)

Corresponding examples can be found here and here.

Biased Sampling

Additionally, we added support for weighted/biased sampling in NeighborLoader/LinkNeighborLoader scenarios. For this, simply specify your edge_weight attribute during NeighborLoader initialization, and PyG will pick up these weights to perform weighted/biased sampling (#8038):

data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)

loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],
    weight_attr='edge_weight',
)

batch = next(iter(loader))

New models, datasets, examples & tutorials

As part of our algorithm and documentation sprints (#7892), we have added:

Join our Slack here if you're interested in joining community sprints in the future!

Breaking Changes

  • Data.keys() is now a method instead of a property (#7629):
    <=2.3 2.4
    data = Data(x=x, edge_index=edge_index)
    print(data.keys)
    # ['x', 'edge_index']
    data = Data(x=x, edge_index=edge_index)
    print(data.keys())
    # ['x', 'edge_index']
  • Dropped Python 3.7 support (#7939)
  • RemovedΒ FastHGTConv in favor ofΒ HGTConvΒ (#7117)
  • Removed the layer_type argument from GraphMaskExplainer (#7445)
  • Renamed dest argument to dst in utils.geodesic_distance (#7708)

Deprecations

Features

Data and HeteroData improvements

Data-loading improvements

Better support for sparse tensors

  • AddedΒ SparseTensorΒ support toΒ WLConvContinuous,Β GeneralConv,Β PDNConvΒ andΒ ARMAConvΒ (#8013)
  • ChangeΒ torch_sparse.SparseTensorΒ logic to utilizeΒ torch.sparse_csrΒ instead (#7041)
  • Added support forΒ torch.sparse.TensorΒ inΒ DataLoaderΒ (#7252)
  • Added support forΒ torch.jit.scriptΒ withinΒ MessagePassingΒ layers withoutΒ torch_sparseΒ being installed (#7061, #7062)
  • Added unbatching logic forΒ torch.sparse.Tensor (#7037)
  • Added support forΒ Data.num_edgesΒ for nativeΒ torch.sparse.TensorΒ adjacency matrices (#7104)
  • Accelerated sparse tensor conversion routines (#7042, #7043)
  • Added a sparseΒ cross_entropyΒ implementation (#7447, #7466)

Integration with 3rd-party libraries

  • AddedΒ FlopsCountΒ support viaΒ fvcoreΒ (#7693)
  • AddedΒ to_dglΒ andΒ from_dglΒ conversion functions (#7053)

torch_geometric.transforms

Bugfixes

Changes

Full Changelog

Full Changelog: 2.3.0...2.4.0

New Contributors