<a href="https://colab.research.google.com/github/prithvirajanR/24h-ecg-analysis/blob/main/HyperbolicGNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperbolic Graph Neural Networks

This notebook is part of the lab semiar [Machine Learning and AI in the Life Sciences: Methods and Applications](https://www.fu-berlin.de/vv/de/lv/935203) at Freie Universität Berlin.

In this practical we will implement a hyperbolic graph convolutional network. As last week we will apply it to the problem of graph-level prediction of the properties of drug-like molecules.

## Setup environment

Let's install the required python packages. We will be working with the [`jax` library](https://docs.jax.dev/en/latest/index.html) and its ecosystem of deep learning tools. For hyperbolic space we use the [`Morphomatics` library](https://morphomatics.github.io/)

In [None]:
%%capture
!pip install ipdb
!pip install jax jraph flax optax torch-geometric
!pip install morphomatics

# QM9 Dataset

We will work with the QM9 dataset that is a benchmark database used in machine learning for molecular properties. It contains data for 133k small organic molecules with 19 regression targets. We can use the PyG libraray to load the molecules as graphs equpped with a number of pre-computed features:

**Atom features (`G.x`)** - $\mathbb{R}^{|V| \times 11}$
- 1st-5th features: Atom type (one-hot: H, C, N, O, F)
- 6th feature (also `data.z`): Atomic number (number of protons).
- 7th feature: Aromatic (binary)
- 8th-10th features: Electron orbital hybridization (one-hot: sp, sp2, sp3)
- 11th feature: Number of hydrogens

**Edge Index (`G.edge_index`)** - $\mathbb{R}^{2×|E|}$
- A tensor of dimensions 2 x `G.num_edges` that describe the edge connectivity of the graph

**Edge features (`G.edge_attr`)** - $\mathbb{R}^{|E|\times 4}$
- 1st-4th features: bond type (one-hot: single, double, triple, aromatic)

**Atom positions (`G.pos`)** - $\mathbb{R}^{|V|\times 3}$
- 3D coordinates of each atom.

**Target (`G.y`)** - $\mathbb{R}^{19}$

| Target | Property                      | Unit         | Description                                              |
|--------|-------------------------------|--------------|----------------------------------------------------------|
| 0      | $\mu$                         | D            | Dipole moment                                            |
| 1      | $\alpha$                      | $a_0^3$      | Isotropic polarizability                                |
| 2      | $\epsilon_{\textrm{HOMO}}$    | eV           | Highest occupied molecular orbital energy                |
| 3      | $\epsilon_{\textrm{LUMO}}$    | eV           | Lowest unoccupied molecular orbital energy               |
| 4      | $\Delta \epsilon$             | eV           | Gap between $\epsilon_{\textrm{HOMO}}$ and $\epsilon_{\textrm{LUMO}}$ |
| 5      | $\langle R^2 \rangle$         | $a_0^2$      | Electronic spatial extent                                |
| 6      | ZPVE                          | eV           | Zero point vibrational energy                            |
| 7      | $U_0$                         | eV           | Internal energy at 0K                                    |
| 8      | $U$                           | eV           | Internal energy at 298.15K                               |
| 9      | $H$                           | eV           | Enthalpy at 298.15K                                      |
| 10     | $G$                           | eV           | Free energy at 298.15K                                   |
| 11     | $c_{\textrm{v}}$              | cal/mol·K    | Heat capacity at 298.15K                                 |
| 12     | $U_0^{\textrm{ATOM}}$         | eV           | Atomization energy at 0K                                 |
| 13     | $U^{\textrm{ATOM}}$           | eV           | Atomization energy at 298.15K                            |
| 14     | $H^{\textrm{ATOM}}$           | eV           | Atomization enthalpy at 298.15K                          |
| 15     | $G^{\textrm{ATOM}}$           | eV           | Atomization free energy at 298.15K                       |
| 16     | $A$                           | GHz          | Rotational constant                                      |
| 17     | $B$                           | GHz          | Rotational constant                                      |
| 18     | $C$                           | GHz          | Rotational constant                                      |



In [None]:
from typing import NamedTuple

import torch
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops

import numpy as np
import jax
import jax.numpy as jnp

class SetTarget(NamedTuple):
    target: int = 0
    """
    Only keep the label for a specific target (there are 19 targets in QM9).
    """
    def __call__(self, data):
        data.y = data.y[:, self.target]
        return data

# Transforms which are applied during data loading:
transform = SetTarget(0)

# Define our dataset, using torch datasets
qm9_dataset = QM9('./qm9/', transform=transform)
print(f"Total number of samples: {len(qm9_dataset)}.")

# Split datasets (subset of full dataset for efficiency)
train_dataset = qm9_dataset[:1000]
val_dataset = qm9_dataset[1000:2000]
test_dataset = qm9_dataset[2000:3000]
print(f"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.")


Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting qm9/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


Total number of samples: 130831.
Created dataset splits with 1000 training, 1000 validation, 1000 test samples.


In [None]:
import torch_geometric
from typing import List, Generator
from torch.utils.data import DataLoader, default_collate

import jraph

BATCH_SIZE = 32
MAX_NUM_NODES = 29
MAX_NUM_EDGES = 56
TYPE = jnp.float32

def jraph_iterate(data: List[torch_geometric.data.Data]) -> Generator[jraph.GraphsTuple, None, None]:
    for G in data:
        yield jraph.GraphsTuple(
            n_node=jnp.asarray([G.num_nodes]),
            n_edge=jnp.asarray([G.num_edges]),
            nodes=jnp.asarray(G.x, dtype=TYPE),
            edges=jnp.ones(G.num_edges, dtype=TYPE),
            globals=jnp.asarray([G.y], dtype=TYPE),
            senders=jnp.asarray(G.edge_index[0]),
            receivers=jnp.asarray(G.edge_index[1]))

def jraph_collate(batch: List[torch_geometric.data.Data]) -> Generator[jraph.GraphsTuple, None, None]:
  """
  Collate function specifies how to combine a list of data samples into a batch.
  """
  return next(jraph.dynamically_batch(
        jraph_iterate(batch),
        n_node=batch_size * MAX_NUM_NODES + 1, # Plus one for the extra padding node.
        n_edge=batch_size * MAX_NUM_EDGES,
        n_graph=BATCH_SIZE + 1))

# Create pytorch data loader with custom collate function
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=jraph_collate, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=3BATCH_SIZE2, collate_fn=jraph_collate, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=jraph_collate, shuffle=False)

# Hyperbolic embedding

> <img src="https://cdn-icons-png.freepik.com/256/8089/8089604.png?semt=ais_hybrid" alt="drawing" width="100"/> <br>
> * implement a function that embedds a QM9 graph `G` into the Lorentz model of hyperbolic space (i.e. assign a hyperbolic target to each node based on its features `G.x`)
> * visualize the embedding for an example graph embedded into hyperbolic 2-space

**TIP:** Note that the definitions in [Liu et al. 2019] and the implementation in `Morphomatics` differ in the choice of symmetry axis of the hyperboloid, i.e. $(1, 0, ..., 0)$ versus $(0, ..., 0, 1)$, respectively.

In [None]:
from morphomatics.manifold import HyperbolicSpace
import jraph

def embedd(G: jraph.GraphsTuple, d: int) -> jraph.GraphsTuple:
  x = G.x
  # TODO: implement a function that embedds graph G into hyperbolic d-space
  return G._replace(nodes=x)

# Hyperbolic graph convolutional layer


> <img src="https://cdn-icons-png.freepik.com/256/8089/8089604.png?semt=ais_hybrid" alt="drawing" width="100"/> <br>
> * Implement the hyperbolic graph convolution layer introduced in [Liu et al. 2019]
> * Create and run a unit test that checks the equivariance w.r.t. node permutations of your module


> Liu, Q., Nickel, M., & Kiela, D.:</br>
> **[Hyperbolic graph neural networks.](https://arxiv.org/pdf/1910.12892)**  
> Advances in neural information processing systems, 32, 2019.</br>
> [![Preprint](https://img.shields.io/badge/arXiv-2007.05275-red)](http://arxiv.org/abs/1910.12892)

In [None]:
import flax.linen as nn

class HyperbolicGraphConvolution(nn.Module):
    psi: nn.Module
    phi: nn.Module

    @nn.compact
    def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
      x = G.nodes

      # TODO: Your code here

      return G._replace(nodes=x)

# Hyperbolic graph convolutional network


> <img src="https://cdn-icons-png.freepik.com/256/8089/8089604.png?semt=ais_hybrid" alt="drawing" width="100"/> <br>
> * Implement an __invariant__ graph neural network based on your hyperbolic, convolutional layer that regresses the graph-level target `G.y`


In [None]:
class HyperbolicGCN(nn.Module):
    num_layers: int
    hidden_size: int

    @nn.compact
    def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # TODO: implement a GCN for graph-level prediction
        graph_embedding = None

        # Regression head
        return nn.Dense(1)(graph_embedding)

# Predicting the electric dipole moment


> <img src="https://cdn-icons-png.freepik.com/256/8089/8089604.png?semt=ais_hybrid" alt="drawing" width="100"/> <br>
> * Train your model on (a subset of) the QM9 dataset
> * evaluate the performance of your model
