## Annotated DimeNet [<a href="https://arxiv.org/abs/2003.03123">Paper</a>]

Molecular representation is a crucial task in computational chemistry and drug discovery. To understand the properties and interactions of molecules, it is essential to have the ability to accurately represent their complex structures. One approach that has shown promise in this area is the DimeNet architecture (https://arxiv.org/abs/2003.03123, <a href="https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py">Pytorch Source Code</a>). DimeNet utilizes directional message passing to efficiently and accurately represent molecules. In this blog, we will use PyTorch to explore each component of this network. Feel free to follow along in a Google Colab for hands-on experience.

### Introduction

Graph Neural Networks (GNNs) have become a highly sought-after architecture for modeling the quantum mechanical properties of molecules. In the past, GNNs primarily used 2D graph representations of molecules. However, recent advancements such as the DimeNet paper, which utilizes 3D graph representations, have greatly improved the performance of GNNs on these tasks. This paper proposes a new technique called directional message passing. This approach utilizes the positions of atoms in 3D and performs message passing using inter-atomic distances and angles between triplets of atoms, resulting in a more efficient and accurate representation of the molecules.

The contributions of this paper are as follows:
* A message passing scheme that utilizes directional information.
* The construction of directional embeddings using Spherical Bessel functions and Spherical Harmonics.

### Prerequisites
This blog assumes familiarity with message passing in graph neural networks and the usage of PyTorch. For people new to it, feel free to refer to these excellent distil pub articles,
- https://distill.pub/2021/gnn-intro/
- https://distill.pub/2021/understanding-gnns/
- [Scatter Function](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter) for implementing message passing

### Overview
The blog is structured as follows:

1. An explanation of how to construct directional embeddings using 2D spherical Fourier-Bessel Basis Functions.
2. A detailed description of the architecture of DimeNet and the implementation of each component.
3. An experiment to demonstrate the use of DimeNet architecture on QM9 dataset.

### Package Installation Commands

In [1]:
# # Ensure there is pyg, torch_geometric, torch, sympy, jupyter
# conda create -n annotated-dimenet python=3.8
# conda activate annotated-dimenet
# conda install pytorch torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
# pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
# pip install torch_geometric
# conda install sympy
# pip install jupyter

### Import the necessary libraries

In [2]:
import os
import os.path as osp
from math import pi as PI
from math import sqrt
from typing import Callable, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import Embedding, Linear
from torch_scatter import scatter
from torch_sparse import SparseTensor

from torch_geometric.data import Dataset, download_url
from torch_geometric.data.makedirs import makedirs
from torch_geometric.nn import radius_graph
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor

  from .autonotebook import tqdm as notebook_tqdm


## Directional Embeddings

<center>
<img src="imgs/message-passing.png" width="200"/> <br/>
</center>

The key idea in this paper is to include directional information during message passing. Consider an edge between atom `i` and `j`, the message from `j` to `i` is represented as $m_{ji}$. The message $m_{ji}$ is updated with the help of directional information by considering neighbours of `j` forming an atom triplet `ijk`. For example, we can use the angle between the direction `kj` (from atom `k` to atom `j`) and direction `ji` (from atom `j` to atom `i`) using $\alpha_{(kj, ji)} = \angle x_k x_j x_i$.  To update the directional embedding $m_{ji}$, we first consider messages $m_{kj}$ from all the neighbours (as seen in the figure below) to `j`. The message passing update for edge embeddings can be updated from the following three,
* incoming messages $m_{kj}$ from neighbours `k`
* the directional information $\alpha_{(kj, ji)}$ 
* the inter-atomic distance $d_{ij}$. 

$$
h_i = \sum_{k \in \mathcal{N}_i}m_{ki}
$$ 


### Representation of inter-atomic distances and angles
As we previously discussed, the message passing layer involves inter-atomic distances and angles. A logical question that arises is how these interatomic distances and angles are provided to the model. The paper employs the following two methods to transform the raw values:

* <i>Radial Basis Function (RBF)</i> $e_{RBF}^{{ji}}$ for interatomic distances $d_{ji}$.
* <i>Spherical Basis Function (SBF)</i> $a_{SBF}^{(kj, ji)}$, a joint angle and distance-based function that takes as input the angle $\alpha_{(kj, ji)}$ and the interatomic distance $d_{kj}$.

<i>Think of the above two functions as ways to transform the input (angles and distances) into a form or space where the model can more easily extract relevant information.

### Radial Basis Function
The radial basis function used here is an orthogonal basis. It takes into input the interatomic distance $d$ and a cutoff distance $c$. The functional form is as follows, 

$$
\tilde{\mathcal{e}}_{RBF}(d) = \sqrt{\frac{2}{c}}\frac{\sin{(\frac{n\pi}{c}}d)}{d}
$$

* The value of the integer $n$ ranges from $n \in [1, ... , N_{RBF}]$, where $N_{RBF}$ denotes the number of orthogonal components (bases).
* To simplify, each inter-atomic distance input $d$ will now be converted into a $N_{RBF}$ sized tensor using the radial basis function.
* Using this basis improves parameter efficiency by using 1/4th of the number of parameters used by a Gaussian RBF (Table 3 in the [Paper](https://arxiv.org/abs/2003.03123))
<!-- * NOT INCLUDING THIS The paper verifies that this basis function requires 1/4th the number of parameters ($N_{RBF}$) as compared to the gaussian radial basis -->

### Spherical Basis Function
This basis layer is a joint 2D basis for $d_{kj}$ and $\alpha_{(kj,ji)}$, a function that depends on both the interatomic distance and an angle. The functional form is as follows,

$$
\tilde{\mathcal{a}}_{SBF, ln}(d, \alpha) = \sqrt{\frac{2}{c^3 j^2_{l + 1}(z_{ln})}} j_l(\frac{z_{ln}}{c}d)Y_l^0(\alpha)
$$
Here, $l \in [0 .. N_{SHBF} - 1]$ and $n \in [1 ... N_{SRBF}]$. The $z_{ln}$ can be computed using bessel function implementations, while $Y_l^0(\alpha)$ can be computed using spherical harmonics function implementations.

<i>An easier way to think about this is that each pair of interatomic distance input $d_{ij}$ and angle $\alpha_{(kj,ji)}$ is converted into a $N_{SRBF} \times N_{SHBF}$ sized tensor using the spherical basis function.</i>

### Continuous Cutoff using an Envelope Function

As we use the 3D structure of the molecule in the model, every interatomic pair could potentially be considered for message passing. However, using all interatomic pairs could significantly increase the computational overhead. To mitigate this, we introduce a cutoff ($c$) on the interatomic distance to determine neighbors.

One issue that arises with a hard cutoff is that the above functions (radial basis and spherical basis) that depend on $d$ are no longer twice continuously differentiable. To alleviate this problem, the paper multiplies the $\tilde{\mathcal{a}}_{SBF}(d)$ and $\tilde{\mathcal{e}}_{RBF}(d)$ with an envelope function $u(d)$, which makes the function twice differentiable.

$$
\mathcal{e}_{RBF}(d) = u(d)\tilde{\mathcal{e}}_{RBF}(d)
$$
$$
\mathcal{a}_{SBF}(d) = u(d) \tilde{\mathcal{a}}_{SBF}(d)
$$

Here, $u(d)$ is the envelope function.

### Message Passing Equation
The message passing equation would be as follows,

$$
m_{ji}^{(l + 1)} = f_{\text{update}}(m_{ji}^{(l)}, \sum_{k in \mathcal{N}_j \backslash \{i\}}f_{\text{int}}(m_{kj}^{(l)}), e_{\text{RBF}}^{(ji)}, a_{\text{SBF}}^{(kj, ji)})
$$

## Implementation of Spherical Basis Function and Radial Basis Function

### Continuous Cutoff Envelope

The equation for the continuous envelope

$$
u(d) = 1 - \frac{(p + 1)(p + 2)}{2}d^p + p(p + 2)d^{p+1} - \frac{p(p + 1)}{2}d^{p + 2}
$$

The paper uses a default value of p = 6 or `exponent = 5.` in the code.


In [3]:
class Envelope(torch.nn.Module):
    def __init__(self, exponent: int):
        super().__init__()
        self.p = exponent + 1
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, x: Tensor) -> Tensor:
        p, a, b, c = self.p, self.a, self.b, self.c
        x_pow_p0 = x.pow(p - 1)
        x_pow_p1 = x_pow_p0 * x
        x_pow_p2 = x_pow_p1 * x
        
        return (
            1. / x + a * x_pow_p0 + b * x_pow_p1 +
            c * x_pow_p2) * (x < 1.0).to(x.dtype)

### Radial Basis Layer

If we recall the equation,
$$
\mathcal{e}_{RBF}(d) = u(d)\tilde{\mathcal{e}}_{RBF}(d)
$$

$$
\tilde{\mathcal{e}}_{RBF}(d) = \sqrt{\frac{2}{c}}\frac{\sin{(\frac{n\pi}{c}}d)}{d} 
$$

$$
u(d) : \text{Envelope Equation}
$$

$$
n \in [1, ... , N_{RBF}]
$$

where $N_{RBF}$ denotes the number of orthogonal bases. The paper implements the Radial Basis Function as a layer where the frequency information is learned via backpropagation. We initialize the parameter values initially as $n\pi/c$. The following mapping from the math variables to code variables is used,
- $N_{RBF} \to $ `num_radial` 
- $c \to $ `cutoff`

In [4]:
class RadialBasisLayer(torch.nn.Module):
    '''RadialBasisLayer'''
    def __init__(self, num_radial: int, cutoff: float = 5.0,
                 envelope_exponent: int = 5):
        super().__init__()
        # the c in the radial basis layer equation
        self.cutoff = cutoff
        # u(d) / envelope
        self.envelope = Envelope(envelope_exponent)
        
        # the different frequencies to be considered to generate orthogonal basis
        self.freq = torch.nn.Parameter(torch.Tensor(num_radial))
        
        # make sure we reset_parameters during __init__()
        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
        self.freq.requires_grad_()

    def forward(self, dist: Tensor) -> Tensor:
        # compute d = (d/c)
        dist = (dist.unsqueeze(-1) / self.cutoff)
        # compute u(d/c) * sin(\frac{n\pi}{c} \times d)
        return self.envelope(dist) * (self.freq * dist).sin()

### Spherical Basis Layer

If we recall the equation,

$$
\mathcal{a}_{RBF}(d) = u(d) \tilde{\mathcal{a}}_{RBF}(d) \\
$$
$$
\tilde{\mathcal{a}}_{SBF, ln}(d, \alpha) = \sqrt{\frac{2}{c^3 j^2_{l + 1}(z_{ln})}} j_l(\frac{z_{ln}}{c}d)Y_l^0(\alpha)
$$
where $l \in [0 .. N_{SHBF} - 1]$ and $n \in [1 ... N_{SRBF}]$

The following mapping from math variables to code variables will be used,
* $N_{SHBF} \to $ `num_spherical` 
* $N_{SRBF} \to $ `num_radial` 
* $c \to $ `cutoff`

The following two variables represent,
* $z_{ln}$: $n'$th root of the $l$-order Bessel Functions
* $Y_l^0(\alpha)$: Special Harmonics

These variables have been implemented in `torch_geometric.nn.models.dimenet_utils` as `bessel_basis` and `real_sph_harm`. The exact details of implementing these functions using sympy's <a href='https://docs.sympy.org/latest/index.html'>symbolic computation library</a> are beyond the scope of this blog. However, it's worth noting that,
- $z_{ln}$ has a total of $N_{SHBF} \times N_{SRBF}$ values as it varies with both $l$ and $n$ whereas $Y_l^0(\alpha)$ has $N_{SHBF}$ values as it varies with $l$
- The output of these functions are sympy expressions, which means they will contain symbols like `sin()`, `cos()`, and `x`
- To convert any symbolic expression involving (`x`, `sin` and `cos`) into a lambda function, we use `sym.lambdify([x], expression, modules)` [<a href='https://docs.sympy.org/latest/modules/utilities/lambdify.html#sympy.utilities.lambdify.lambdify'>Link</a>], the modules here maps the symbolic functions `sin()`and `cos()` to our torch functions

In [5]:
import sympy as sym

from torch_geometric.nn.models.dimenet_utils import (
    bessel_basis, # SRBF
    real_sph_harm, # SHBF
)

class SphericalBasisLayer(torch.nn.Module):
    def __init__(
        self, 
        num_spherical: int, 
        num_radial: int,
        cutoff: float = 5.0, 
        envelope_exponent: int = 5
    ):
        super().__init__()
        assert num_radial <= 64
        self.num_spherical = num_spherical
        self.num_radial = num_radial
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)
        
        # We can bessel basis and spherical Harmonic forms in sympy expressions
        # sympy expressions -> Equations containing `sin` `cos` `x` and `theta`
        # computing z_{ln} constant, based on l(num_spherical) and n(num_radial)
        bessel_forms = bessel_basis(num_spherical, num_radial)
        # computing Y_l^0(\alpha)
        sph_harm_forms = real_sph_harm(num_spherical)
        
        # let's fill spherical and radial functions for 
        # l \in [0, ..., N_SHBF - 1]
        # n \in [1, ..., N_SRBF]
        
        # spherical functions only dependent on l, there will be N_SHBF of them
        self.sph_funcs = []
        # bessel functions dependent on l and n, there will be N_SHBF x N_SRBF of them
        self.bessel_funcs = []
            
        # Using Sympy, we convert sympy expressions into lambda functions
        x, theta = sym.symbols('x theta')
        modules = {'sin': torch.sin, 'cos': torch.cos}
        
        # i goes from 0 to num_spherical x-x 1 (exactly the range of l)
        for i in range(num_spherical):
            if i == 0:
                sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
                self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
            else:
                sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
                self.sph_funcs.append(sph)
                
            for j in range(num_radial):
                bessel = sym.lambdify([x], bessel_forms[i][j], modules)
                self.bessel_funcs.append(bessel)

    def forward(self, dist: Tensor, angle: Tensor, idx_kj: Tensor) -> Tensor:
        '''Performs Forward Pass'''
        # computes d / c
        dist = dist / self.cutoff
        
        # n, k = self.num_spherical, self.num_radial
        # computes over all radial bessel (n x k) functions and stack over feature dimension (1)
        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
        # multiply with u(d) * bessel() 
        # Since u(d) will be of same shape as d, u(d).unsqueeze(-1) to allow dot product
        rbf = self.envelope(dist).unsqueeze(-1) * rbf
        
        # compute over all spherical (n) functions and stack over feature dimension (1)
        sbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)
        
        n, k = self.num_spherical, self.num_radial
        # multiply the two to get a_SBF = u(d) * radial_bessel() * spherical_harmonics()
        out = (rbf[idx_kj].view(-1, n, k) * sbf.view(-1, n, 1)).view(-1, n * k)
        
        return out

## Architecture of DimeNet

<center>
<img src='imgs/dimenet-arch.png' width='600'/>
</center>

The architecture consists of the following,
- RBF and SBF to transform our interatomic distances and angles
- Embedding Blocks to Initialize initial messages embeddings
- (Multiple) Interaction Blocks to perform Message Passing
- Output Block to convert embeddings to generate a prediction



<b> High-Level Overview of Network Architecture</b><br>
Let's take a look at the high-level structure of the architecture and informally write down the `forward()` function. The architecture is as follows:
<center>
<img src="imgs/model-only.png" width="200"/>
</center>

* Given the number of nodes `num_nodes` and edge_index `edge_index`, get all triplets of form `ijk` and edge indices of `ji` and `kj`
* From each triplet, we can get the distance of the `ij` pair and the angle `ijk` with the help of the coordinates,

<center>
<img src="imgs/angle-distance.png" width="600"/>
</center>

* Convert distances and angles into a tensor
* Transform the inter-atomic distance using the radial basis function `self.rbf(d)`
* Transform the inter-atomic distance and angle using the spherical basis function `self.sbf(d, angle)`
* The outputs are constructed from `1 Embedding Block` and `6 Interaction Blocks` in a sequential manner, where:
    * The first output prediction comes from the `embedding + output` block using initial input embeddings
    * All other output predictions come from the `interaction + output` block using embeddings from the previous layer
* All the outputs are summed up to provide the final scalar prediction per molecule




### Embedding Block: Getting Message Embedding

Let's take a closer look at the Embedding Block.

<center>
<img src='imgs/embedding-block.png' width='400'/>
</center>

Here's what happens in this block:
1. Four inputs are passed to this Embedding Block (as seen in the `forward()` function)
    - The atomic Numbers of the atoms in the batch (`x`)
    - The distance between atoms `i` and `j`, transformed using the Radial Basis Function (RBF)
    - The atomic number of `i`
    - The atomic number of `j`
2. The RBF output is passed through a `Linear(num_radial, hidden_channels)` layer to obtain distance embeddings.
3. Atomic numbers are transformed into learnable embeddings using `nn.Embedding(95, hidden_channels)`. Here, 95 is the maximum atomic number that can be expected.
3. The embeddings for atoms `i` and `j`, and the distance embeddings are concatenated along the `dim=1` dimension. This results in a shape `3 x hidden_channels`.
4. A `Linear(3 * hidden_channels, hidden_channels)` layer is applied to the concatenated representation, resulting in the final edge embedding between atoms `i` and `j`.
5. The final message, along with the RBF input, is passed to the `Output Block` to obtain a scalar prediction for each atom $t_i^{(1)}$.

In [6]:
class EmbeddingBlock(torch.nn.Module):
    '''Implementation of Embedding Block
    
    Parameters
    ----------
    num_radial: int
        Number of radial features (feature dimensions of rbf output)
    hidden_channels: int
        feature dimension of output of this embedding block
    
    '''
    def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act
        
        self.emb = Embedding(95, hidden_channels)
        self.lin_rbf = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
        x = self.emb(x)
        rbf = self.act(self.lin_rbf(rbf))
        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))

### Residual Layer

The Interaction Block makes use of residual layers, inspired by ResNets, to improve the flow of gradients. The architecture of a residual layer is as follows:

<center>
<img src='imgs/residual.png' width='200'/>
</center>

The implementation is as follows,
- 2 `nn.Linear` layers are used, which do not change the input feature dimension
- A non-linear activation function is applied between the two linear layers, and the initial input (`x`) is added to the output `act(lin_2(act(lin_1(x))))` to produce the final prediction.

In [8]:
class ResidualLayer(torch.nn.Module):
    def __init__(self, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        self.lin1.bias.data.fill_(0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x: Tensor) -> Tensor:
        return x + self.act(self.lin2(self.act(self.lin1(x))))

### Interaction Block: Message Passing

This block implements the main message passing equation,

$$
m_{ji}^{(l + 1)} = f_{\text{update}}(m_{ji}^{(l)}, \sum_{k in \mathcal{N}_j \backslash \{i\}}f_{\text{int}}(m_{kj}^{(l)}), e_{\text{RBF}}^{(ji)}, a_{\text{SBF}}^{(kj, ji)})
$$

<center>
    <img src='imgs/interaction-block.png' width='500'/>
</center>


Some points over the implementation of the interaction block,
1. The inputs to this block are,
    - Input embedding features `x` (from the embedding/interaction block before)
    - Radial Basis Function input (input feature dimension: `num_radial`)
    - Spherical Basis Function embeddings (input feature dimension: `num_spherical * num_radial`)
    - `idx_kj`: All indices from atom `k` to atom `j`
    - `idx_ji`: All indices from atom `j` to atom `i`
2. Apply `Linear(num_radial, hidden_channels, bias=False)` to $e_{\text{RBF}}$
3. $a_{\text{SBF}}$ is transformed into a $N_\text{bilinear}$-representation using `Linear(num_radial * num_spherical, num_bilinear, bias=False)`
3. To compute the message to be calculated,
    - Transform the messages $x_{ij}$ and $x_{kj}$ using `activation(linear(hidden_channels, hidden_channels))`
    - Dot product between `linear(rbf)` and above transformed $x_{kj}$
    - Bilinear product $a_{sbf}^T\cdot W \cdot x_{kj}$
    - Perform message passing using `torch_scatter.scatter` to get $\sum_k {m_{kj}}$
4. Post Message Passing, update our $m_{ji}^{(l)} = \text{Swish Activation}(\text{linear}(m_{ji}^{(l - 1)})) + \sum_k {m_{kj}}$ 
5. Run this through some `Residual Layer`
6. Pass the final updated message along with RBF input into the `Output Layer` to get the target scalar value for that interaction block $l$ and each atom $i$ as $t_i^{(l)}$

<b>Note</b>: [Swish](https://arxiv.org/abs/1710.05941v2) / [SiLU](https://arxiv.org/abs/1702.03118) works much better than activation functions used in previous works.

In [7]:
class InteractionBlock(torch.nn.Module):
    '''Interaction Block in DimeNet: Reponsible for Message Passing'''
    def __init__(self, hidden_channels: int, num_bilinear: int,
                 num_spherical: int, num_radial: int, num_before_skip: int,
                 num_after_skip: int, act: Callable):
        '''
        Initialize Interaction Module
        
        Args:
            hidden_channels (int): Hidden embedding size.
            num_bilinear (int): Size of the bilinear layer tensor.
            num_spherical (int): Number of spherical harmonics.
            num_radial (int): Number of radial basis functions.
            num_before_skip (int, optional): Number of residual layers in the
                interaction blocks before the skip connection. (default: :obj:`1`)
            num_after_skip (int, optional): Number of residual layers in the
                interaction blocks after the skip connection. (default: :obj:`2`)
            act (str or Callable, optional): The activation function.
                (default: :obj:`"swish"`)        
        '''
        super().__init__()
        self.act = act

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear,
                              bias=False)

        # Dense transformations of input messages.
        self.lin_kj = Linear(hidden_channels, hidden_channels)
        self.lin_ji = Linear(hidden_channels, hidden_channels)

        self.W = torch.nn.Parameter(
            torch.Tensor(hidden_channels, num_bilinear, hidden_channels))

        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])
        self.lin = Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        '''Initializing Parameters'''
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)
        self.W.data.normal_(mean=0, std=2 / self.W.size(0))
        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_after_skip:
            res_layer.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,
                idx_ji: Tensor) -> Tensor:
        # transform rbf and sbf input using their respective nn.Linear()
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)
        
        # Transform the messages into activation(linear(message))
        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        
        # bilinear product
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        
        # message passing
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))
        
        # update our message
        h = x_ji + x_kj
        
        # Apply residual layers
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h

### Output Block

<center>
<img src='imgs/output-block.png' width='100'/>
</center>

The output block is applied after the embedding block and each interaction block. This is responsible for calculating the final scalar target $t_i^{(l)}$ for each atom ($i$) for that interaction block ($l$) or Embedding Block. The following steps happen:

* Take in the RBF input and the input messages from the previous block as input
* Calculate the dot product between the Linear Weight transformed RBF input and input messages. Using the torch_scatter.scatter() function is an efficient way to do this for each atom $i$.
* Forward pass through standard feedforward neural network layers

<i>It is important to note that the dot product between the transformed RBF output and input messages guarantees a twice differentiable output because of the envelope function.</i>

In [9]:
class OutputBlock(torch.nn.Module):
    def __init__(self, num_radial: int, hidden_channels: int,
                 out_channels: int, num_layers: int, act: Callable):
        super().__init__()
        self.act = act
        
        # Linear layer to convert rbf input
        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        
        # linear layers to convert output through num_layers
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        
        # final linear layer to convert it into `out_channels` dim output
        self.lin = Linear(hidden_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        self.lin.weight.data.fill_(0)

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,
                num_nodes: Optional[int] = None) -> Tensor:
        x = self.lin_rbf(rbf) * x

        # scatter-add to add messages from all m_{ji} to atom $i$
        x = scatter(x, i, dim=0, dim_size=num_nodes)
        
        # pass the x through multiple linear layers
        for lin in self.lins:
            x = self.act(lin(x))
        
        # pass it through the final output layer
        return self.lin(x)

### Final Prediction

The final prediction is calculated by summing up the outputs from the output blocks of all the interaction layers and the embedding layer.


$$
t = \sum_i \sum_l t_{i}^{(l)}
$$

### DimeNet: Putting it All Together


A quick summary of the forward pass steps,
* Given the 3D structure of the molecule, calculate the inter-atomic distance and angle for each atom triplet.
* Use the radial basis function `self.rbf(d)` and spherical basis function `self.sbf(d, angle)` to transform these distances and angles.
* Utilize the output from the embedding and interaction blocks, in a sequential manner, to construct the final prediction.
* Sum the outputs to get a final scalar prediction per molecule.

In [10]:
class DimeNet(torch.nn.Module):
    r"""The directional message passing neural network (DimeNet) from the
    `"Directional Message Passing for Molecular Graphs"
    <https://arxiv.org/abs/2003.03123>`_ paper.
    DimeNet transforms messages based on the angle between them in a
    rotation-equivariant fashion.

    Args:
        hidden_channels (int): Hidden embedding size.
        out_channels (int): Size of each output sample.
        num_blocks (int): Number of building blocks.
        num_bilinear (int): Size of the bilinear layer tensor.
        num_spherical (int): Number of spherical harmonics.
        num_radial (int): Number of radial basis functions.
        cutoff (float, optional): Cutoff distance for interatomic
            interactions. (default: :obj:`5.0`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            collect for each node within the :attr:`cutoff` distance.
            (default: :obj:`32`)
        envelope_exponent (int, optional): Shape of the smooth cutoff.
            (default: :obj:`5`)
        num_before_skip (int, optional): Number of residual layers in the
            interaction blocks before the skip connection. (default: :obj:`1`)
        num_after_skip (int, optional): Number of residual layers in the
            interaction blocks after the skip connection. (default: :obj:`2`)
        num_output_layers (int, optional): Number of linear layers for the
            output blocks. (default: :obj:`3`)
        act (str or Callable, optional): The activation function.
            (default: :obj:`"swish"`)
    """

    def __init__(
        self,
        hidden_channels: int,
        out_channels: int,
        num_blocks: int,
        num_bilinear: int,
        num_spherical: int,
        num_radial,
        cutoff: float = 5.0,
        max_num_neighbors: int = 32,
        envelope_exponent: int = 5,
        num_before_skip: int = 1,
        num_after_skip: int = 2,
        num_output_layers: int = 3,
        act: Union[str, Callable] = 'swish',
    ):
        super().__init__()

        if num_spherical < 2:
            raise ValueError("num_spherical should be greater than 1")
        
        # get activation function for the `act` string passing in `__init__`
        act = activation_resolver(act)

        # cutoff value $c$ below which we consider atoms to be neighbours
        self.cutoff = cutoff
        # If neighbours exceed this we consider top 
        # max_num_neighbours based on inter-atomic distance
        self.max_num_neighbors = max_num_neighbors
        
        # number of interaction blocks/message passing blocks
        self.num_blocks = num_blocks

        # Our radial basis layer for inter-atomic distance
        self.rbf = RadialBasisLayer(num_radial, cutoff, envelope_exponent)

        # Spherical Basis Layer for 2D joint representation
        # using distance and angle
        self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
                                       envelope_exponent)
        
        # embedding block
        self.emb = EmbeddingBlock(num_radial, hidden_channels, act)
        
        # embedding blocks
        self.output_blocks = torch.nn.ModuleList([
            OutputBlock(num_radial, hidden_channels, out_channels,
                        num_output_layers, act) for _ in range(num_blocks + 1)
        ])

        self.interaction_blocks = torch.nn.ModuleList([
            InteractionBlock(hidden_channels, num_bilinear, num_spherical,
                             num_radial, num_before_skip, num_after_skip, act)
            for _ in range(num_blocks)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        self.rbf.reset_parameters()
        self.emb.reset_parameters()
        for out in self.output_blocks:
            out.reset_parameters()
        for interaction in self.interaction_blocks:
            interaction.reset_parameters()

    def triplets(
        self,
        edge_index: Tensor,
        num_nodes: int,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        '''Get kji triplets for directional message passing'''
        row, col = edge_index  # j->i
        
        value = torch.arange(row.size(0), device=row.device)
        
        # get sparse adjacency matrix
        adj_t = SparseTensor(row=col, col=row, value=value,
                             sparse_sizes=(num_nodes, num_nodes))
        adj_t_row = adj_t[row]
        num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

        # Node indices (k->j->i) for triplets.
        idx_i = col.repeat_interleave(num_triplets)
        idx_j = row.repeat_interleave(num_triplets)
        idx_k = adj_t_row.storage.col()
        mask = idx_i != idx_k  # Remove i == k triplets.
        idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

        # Edge indices (k-j, j->i) for triplets.
        idx_kj = adj_t_row.storage.value()[mask]
        idx_ji = adj_t_row.storage.row()[mask]

        return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji            
            
    def forward(
        self,
        z: Tensor,
        pos: Tensor,
        batch: OptTensor = None,
    ) -> Tensor:
        """"""
        # construct edges based on the cutoff decided
        # max neighbours should still be under self.max_num_neighbours (Hyperparameter)
        edge_index = radius_graph(
            pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors
        )
        
        # get list of triplets
        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            edge_index, num_nodes=z.size(0))

        # Calculate L2 distances. 
        dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

        # Calculate angles
        # first compute the direction j -> i and k -> j
        pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_j]
        # dot product (|x||y|cos\theta)
        a = (pos_ji * pos_ki).sum(dim=-1)
        # cross product (|x||y|sin\theta)
        b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
        # computes tan inverse of b / a or 
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block and it's corresponding output
        x = self.emb(z, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Message Passing Interaction blocks
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P = P + output_block(x, rbf, i, num_nodes=pos.size(0))

        return P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)

## Run on QM9 Dataset

A quick overview of the training process with the QM9 dataset.

1. Initialize the model with hyperparameters based on the paper, and load the dataset.
    - Note that we are using `out_channels = 1` since we are training on a single target.
2. Set the dataset.data.y to the chosen target. We use `target = 0` for the property $\mu$.
3. Train the model using the usual Pytorch method, and track the mean absolute error for the train and test datasets in each epoch for 10 epochs.

In [11]:
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# initialize dataset
dataset = QM9('.')
# initialize model
model = DimeNet(
    hidden_channels=128,
    out_channels=1,
    num_blocks=6,
    num_bilinear=8,
    num_spherical=7,
    num_radial=6,
    cutoff=5.0,
    envelope_exponent=5,
    num_before_skip=1,
    num_after_skip=2,
    num_output_layers=3,
)

# we use the 0th target, for others refer to the original paper
target = 0
dataset.data.y = dataset.data.y[:, target]

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
# Use the same random seed as the official DimeNet` implementation.
random_state = np.random.RandomState(seed=42)
perm = torch.from_numpy(random_state.permutation(np.arange(130831)))    
train_idx = perm[:110000]
val_idx = perm[110000:120000]
test_idx = perm[120000:]
train_dataset, val_dataset, test_dataset = (dataset[train_idx], dataset[val_idx], dataset[test_idx])

In [13]:
model = model.to(device)
train_loader = DataLoader(train_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
loss_fn = torch.nn.L1Loss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
epochs = 2

for epoch in range(epochs):
    # store mae in each epoch
    epoch_losses = []
    epoch_maes = []
    for _, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        y_true = data.y.unsqueeze(-1)
        
        # loss 
        y_pred = model(data.z, data.pos, data.batch)
        loss = loss_fn(y_pred, y_true)
        
        # optimization step
        optimizer.zero_grad()
        loss.backward()
        # to prevent exploding gradients, added gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=10, norm_type=2.0
        )
        optimizer.step()
        
        # update loss for batch
        epoch_losses.append(loss.detach().cpu().item())
        
        # compute mae
        epoch_maes.append(
            (y_true.squeeze() - y_pred.squeeze()).mean().abs().item()
        )
    
    # test on test dataloader
    test_epoch_maes = []
    for _, data in enumerate(tqdm(test_loader)):
        data = data.to(device)
        y_true = data.y        
        # run without grad
        with torch.no_grad():
            y_test_pred = model(data.z, data.pos, data.batch)
        
        # compute mae
        test_epoch_maes.append(
            (y_true.squeeze() - y_test_pred.squeeze()).mean().abs().item()
        )
    
    # compute
    print (
        f'For training epoch {epoch}: '
        f'mean train loss is {np.mean(epoch_losses)}, '
        f'mean train mae is {np.mean(epoch_maes)}, '
        f'mean test mae is {np.mean(test_epoch_maes)}'
    )

100%|██████████| 1719/1719 [06:34<00:00,  4.35it/s]
100%|██████████| 170/170 [00:13<00:00, 12.38it/s]


For training epoch 0: mean train loss is 5.1584482583249285, mean train mae is 4.963539379876142, mean test mae is 0.15622362069347326


100%|██████████| 1719/1719 [06:21<00:00,  4.51it/s]
100%|██████████| 170/170 [00:12<00:00, 13.30it/s]

For training epoch 1: mean train loss is 0.25099485188531906, mean train mae is 0.14398329698647536, mean test mae is 0.08410120517672862





## References

This blog is inspired from the following sources,
* <a href="https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py">Pytorch Geometric Model</a>
* The original <a href="https://arxiv.org/abs/2003.03123">Directional Message Passing for Molecular Graphs</a> paper


## Acknowledgement

I would like to extend my sincerest gratitude to [Johannes Gasteigger](https://twitter.com/gasteigerjo) for reviewing everything patiently and providing timely feedback. Thank you for your time and support.