# Part 2: GNNs for Molecular Applications in Geometric Deep Learning



## Resume:
The second part of the course focuses on the application of GNNs to molecular data. We will explore different geometric deep learning approaches specifically designed for handling molecules, from small organic compounds to large biomolecules.


## Plan

### 1. Introduction to Molecular Geometric Data
- **Definition of Molecular Geometric Data:** Understand how molecular structures are represented as geometric data.
- **Challenges in Molecular Data Processing:** Discuss challenges like irregularity, invariance to rotation/translation, and varying molecular size.

### 2. Invariant Networks
- **Concept of Invariance:** Learn the importance of invariant properties for molecular data.
- **Key Architectures:** Overview of models like SchNet and Invariant Point Attention.
- **Applications:** Use cases in molecular property prediction and drug discovery.

### 3. Cartesian Networks
- **Cartesian Coordinates in Molecular GNNs:** How GNNs use Cartesian data to model atomic interactions.
- **Key Architectures:** Examples like GVP-GNN and E(n)-GNN.
- **Applications:** Predicting molecular properties and dynamics.

### 4. Spherical Networks
- **Introduction to Spherical Molecular Data:** Understand how data on spherical domains is used in molecular modeling.
- **Spherical GNNs:** Introduction to architectures like Tensor Fields Networks.
- **Applications:** Modeling molecular conformations and protein structures.




#### References
- Based on: ["Survey of Geometric GNNs for 3D Atomic Systems"](https://arxiv.org/pdf/2312.07511)

![Geometric GNNs for 3D atomic system](https://miro.medium.com/v2/resize:fit:1400/1*AYsGjZhbdr701OndCvnfng.png)



## 1. Introduction to Molecular Geometric Data

### Definition of Molecular Geometric Data
Molecular structures are typically represented as point clouds where each point corresponds to an atom in 3D space. The geometric properties, such as atomic positions and bond lengths, are crucial for understanding molecular behavior and properties.

### Challenges in Molecular Data Processing
Processing molecular data involves several challenges:
- **Irregularity:** Molecules can vary greatly in size and shape.
- **Invariance to Rotation/Translation:** Molecular properties should not change under rotation or translation.
- **Varying Molecular Size:** Models must handle small organic molecules and large proteins alike.

### Applications in Molecular Data
Molecular data is used in several key applications:
- **Molecular Property Prediction:** Estimating properties like binding affinity, reactivity, or toxicity.
- **Drug Discovery:** Identifying potential drug candidates by modeling molecule interactions.
- **Protein Structure Prediction:** Determining 3D protein structures from amino acid sequences.

### Molecules as Graphs
Molecules can be represented as graphs where atoms are nodes and bonds are edges. This graph representation allows GNNs to learn complex molecular relationships and predict their properties accurately.


## 2. Invariant Networks

### Concept of Invariance
Invariant networks in geometric deep learning ensure that molecular representations remain unchanged under transformations like rotations and translations. This invariance is crucial for molecular data, where the orientation or position of the molecule should not affect the prediction.

### Key Architectures: SchNet and Invariant Point Attention

#### SchNet
**SchNet** is a neural network designed for predicting molecular properties using 3D geometries of molecules. It employs continuous-filter convolutional layers that operate on atomistic point clouds, making it invariant to rotations and translations.

![SchNet Architecture](https://d3i71xaburhd42.cloudfront.net/5bf31dc4bd54b623008c13f8bc8954dc7c9a2d80/4-Figure2-1.png)

**Paper Link:** [SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions](https://arxiv.org/abs/1706.08566)

#### Invariant Point Attention
**Invariant Point Attention (IPA)** is a technique used in models like AlphaFold for processing geometric data, particularly for molecular modeling. IPA ensures the attention mechanism is invariant to 3D transformations, making it effective for tasks like predicting protein folding.

![Invariant Point Attention](https://github.com/lucidrains/invariant-point-attention/blob/main/ipa.png?raw=true)

**Paper Link:** [AlphaFold: Deep Learning-Based Protein Structure Prediction](https://www.nature.com/articles/s41586-021-03819-2)

### Applications
- **Molecular Property Prediction:** Estimating molecular properties such as binding affinity and reactivity.
- **Drug Discovery:** Identifying new drug candidates by modeling molecular interactions.

### Implementing an Invariant Network with `MessagePassing`

Below is a simple example of an invariant network using the `MessagePassing` class from PyTorch Geometric:


In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear
import torch.nn.functional as F

class InvariantMPNN(MessagePassing):
    def __init__(self, in_channels, out_channels, num_rbf=16):
        super(InvariantMPNN, self).__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels)
        self.dist_lin = Linear(num_rbf, out_channels)  # Linear transformation for RBF-transformed distances
        self.rbf_centers = torch.linspace(0, 5, num_rbf)  # Radial basis function centers
        self.rbf_gamma = torch.tensor(1.0)  # Gamma parameter for RBF

    def forward(self, x, pos, edge_index):
        # x: Node features
        # pos: Node coordinates
        # edge_index: Edge indices

        # Calculate distances between connected nodes
        row, col = edge_index
        edge_vectors = pos[row] - pos[col]
        distances = torch.norm(edge_vectors, p=2, dim=-1).unsqueeze(-1)
        
        # Compute RBF of distances
        rbf = torch.exp(-self.rbf_gamma * (distances - self.rbf_centers) ** 2)

        # Propagate messages
        return self.propagate(edge_index, x=x, rbf=rbf)

    def message(self, x_j, rbf):
        # x_j: Source node features
        # rbf: RBF-transformed distance features

        edge_features = self.dist_lin(rbf)  # Transform RBF features
        return self.lin(x_j) + edge_features  # Combine node and edge features

    def update(self, aggr_out):
        # aggr_out: Aggregated messages
        return F.relu(aggr_out)  # Apply ReLU non-linearity

# Example usage
node_features = torch.randn(4, 3)  # 4 nodes, 3 features per node
node_coords = torch.randn(4, 3)  # 4 nodes, 3D coordinates
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)  # Edges in COO format

model = InvariantMPNN(in_channels=3, out_channels=2)
output = model(node_features, node_coords, edge_index)

print("Output Node Features:", output)

#### Explanation of the Code:

* **Distance Calculation:** Computes Euclidean distances between connected nodes.
* **Radial Basis Function (RBF):** Transforms distances into edge features using RBF.
* **Message Passing:** Uses transformed edge features along with node features for message passing.
* **Update Function:** Applies ReLU activation to aggregated messages to update node features.

This implementation makes the MPNN invariant to rotations and translations by using distance-based features, making it well-suited for molecular applications such as property prediction and drug discovery.

## Equivariant Networks
![Invariance vs Equivariance](https://i.sstatic.net/BdU1v.png)

## 3. Cartesian Networks

### Cartesian Coordinates in Molecular GNNs

Cartesian networks use Cartesian coordinates to model atomic interactions in molecules. Each atom's position in 3D space is represented by its Cartesian coordinates  $(x, y, z)$. These coordinates are used to compute distances and angles between atoms, which are fundamental for understanding molecular properties and dynamics.


he main idea behind Cartesian networks is to use weighted sums of vectors that are equivariant to rotations of the input data (e.g., vectors between atoms or hidden vector fields), where the weights are learned from the data. 

### Key Architectures: GVP-GNN and E(n)-GNN

#### 1. **Geometric Vector Perceptron (GVP-GNN)**
The **GVP-GNN** is designed for molecular systems, integrating geometric information by using both scalar and vector features. It leverages vector operations to maintain rotational and translational equivariance, allowing it to predict molecular properties that depend on atomic positions.

![GVP-GNN Architecture](https://raphael.tc.com/publication/gvp/featured_hu9b7018f8c7956abbafe4971f4d4d6c72_752084_720x0_resize_lanczos_2.png)

- **Paper Link:** [Learning Protein Structure with a Differentiable Simulator](https://arxiv.org/abs/2009.01411)

#### 2. **E(n)-Equivariant Graph Neural Networks (E(n)-GNN)**
**E(n)-GNN** extends GNNs to be equivariant under Euclidean transformations (translations, rotations, and reflections). This network computes features that transform consistently under Euclidean transformations by using equivariant operations that process geometric data in a way consistent with physical laws.

- **Key Idea:** E(n)-GNN maintains equivariance by incorporating operations that respect the symmetries of Euclidean space.

![E(n)-GNN Architecture](https://ehoogeboom.github.io/publication/egnn/featured_hua4419112e0b0f9c21e721be460820b18_120982_680x500_fill_q90_lanczos_center_2.png)

- **Paper Link:** [E(n) Equivariant Graph Neural Networks](https://arxiv.org/abs/2102.09844)

### Applications

- **Predicting Molecular Properties:** GVP-GNN and E(n)-GNN are used for tasks such as predicting binding affinities, chemical reactivity, and electronic properties.
- **Modeling Molecular Dynamics:** These architectures help simulate molecular motions and interactions over time, providing insights into complex molecular behaviors like folding and binding. 

### Mathematical Formulation of E(n)-GNN

The E(n)-GNN updates the node features $h_i$ and coordinates $x_i$ for each node $i$ as follows:

1. **Node Update:**
$$
h_i' = h_i + \sum_{j \in \mathcal{N}(i)} f_{\text{node}}\left(h_i, h_j, ||x_i - x_j||\right)
$$
where:
- $h_i$ is the feature of node $i$.
- $x_i$ is the coordinate of node $i$.
- $\mathcal{N}(i)$ is the set of neighbors of node $i$.
- $f_{\text{node}}$ is a learnable function.

2. **Coordinate Update:**
$$
x_i' = x_i + \frac{1}{| \mathcal{N}(i) |} \sum_{j \in \mathcal{N}(i)} (x_j - x_i) \cdot g_{\text{coord}}(h_i, h_j, ||x_i - x_j||)
$$
where:
- $g_{\text{coord}}$ is a learnable function.

### Code Snippet: Implementing a Simple Cartesian Network

Here is a simple implementation of a Cartesian-based GNN layer using PyTorch Geometric:


In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import scatter

class CartesianGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CartesianGNNLayer, self).__init__(aggr='add')  # Aggregation function: 'add'
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * in_channels + 1, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        self.coord_mlp = nn.Sequential(
            nn.Linear(out_channels, 1),
            nn.Tanh()  # To limit coordinate updates
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU()
        )
        
    def forward(self, x, pos, edge_index):
        num_nodes = x.size(0)
        # Start message passing
        out = self.propagate(edge_index, x=x, pos=pos, size=(num_nodes, num_nodes))
        x_out, coord_updates = out
        # Update positions
        pos = pos + coord_updates
        return x_out, pos
    
    def message(self, x_i, x_j, pos_i, pos_j):
        # Relative positional differences and distances
        diff = pos_i - pos_j  # [num_edges, 3]
        dist = torch.norm(diff, dim=-1, keepdim=True)  # [num_edges, 1]
        
        # Edge features: concatenate node features and distance
        edge_input = torch.cat([x_i, x_j, dist], dim=-1)  # [num_edges, 2 * in_channels + 1]
        e_ij = self.edge_mlp(edge_input)  # [num_edges, out_channels]
        
        # Compute coordinate updates
        coord_update = self.coord_mlp(e_ij) * diff  # [num_edges, 3]
        
        return e_ij, coord_update
    
    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        e_ij, coord_update = inputs
        num_nodes = dim_size  # Total number of nodes
        # Aggregate edge features
        aggr_e = scatter(e_ij, index, dim=0, dim_size=num_nodes, reduce='add')
        # Aggregate coordinate updates
        aggr_coord = scatter(coord_update, index, dim=0, dim_size=num_nodes, reduce='mean')
        return aggr_e, aggr_coord
    
    def update(self, aggr_out, x):
        aggr_e, aggr_coord = aggr_out
        # Update node features
        node_input = torch.cat([x, aggr_e], dim=-1)  # Concatenate along feature dimension
        x_out = self.node_mlp(node_input)
        return x_out, aggr_coord

# Example usage
x = torch.randn(10, 3)  # 10 nodes, 3 features per node
pos = torch.randn(10, 3)  # 10 nodes, 3D positions
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)  # Edge index

layer = CartesianGNNLayer(in_channels=3, out_channels=2)
out = layer(x, pos, edge_index)
print(out)

## 4. Spherical Networks

### Higher-Degree Representations in Molecular GNNs

In Cartesian networks, we operate primarily on **1-degree features** (vectors), which are sufficient to describe single elements in space. However, molecular structures often exhibit complex geometric relationships that cannot be fully captured by vectors alone. **Spherical networks** address this limitation by utilizing representations capable of expressing **higher-degree features**. These higher-degree features can capture intricate spatial patterns and symmetries essential for accurately modeling molecular interactions.

### Separating Radial and Angular Components

To achieve precise representations of molecular data, spherical networks consider functions defined in 3D space with separate **radial** and **angular** components. The **radial part** requires high precision to accurately represent distances between atoms, while the **angular part** captures the orientations and directions of atomic interactions. By treating these components separately, spherical networks can model complex molecular interactions more effectively.

### Spherical Harmonics as Fourier Basis on a Sphere

**Spherical harmonics** serve as the Fourier basis functions on a sphere and are fundamental in representing the angular component of functions in 3D space. They allow us to decompose complex angular dependencies into a series of functions with varying degrees $ l $, each corresponding to different levels of detail and symmetry properties.

- **Degrees $ l $**: Represent different levels of angular frequency, with higher degrees capturing more complex angular variations.
- **Orders $ m $**: Range from $-l$ to $+l$ and represent different orientations for a given degree.

### Irreducible Representations (Irreps) of Different Degrees

1. **0-Degree Irreps (Scalars)**: Invariant under rotation, representing features that do not change with orientation.
2. **1-Degree Irreps (Vectors)**: Transform linearly under rotations, representing directional features.
3. **Higher-Degree Irreps**: Degrees $ l = 2, 3, \ldots $ correspond to more complex features (e.g., tensors) that capture higher-order geometric information and transform predictably under rotations.

### Message Passing in Spherical Networks

In spherical GNNs, the message-passing mechanism involves projecting features onto spherical harmonics and handling higher-degree interactions.

#### 1. Feature Functions

Each node $ j $ has a feature function $ f_j(\vec{r}) $ that can be expanded in terms of spherical harmonics of degree $ l_1 $:

$$
f_j(\vec{r}) = \sum_{l_1, m_1} f_j^{l_1 m_1} Y_{l_1}^{m_1}(\hat{r})
$$

- $ f_j^{l_1 m_1} $: Coefficients for the spherical harmonic components.
- $ Y_{l_1}^{m_1}(\hat{r}) $: Spherical harmonics evaluated at direction $ \hat{r} $.

#### 2. Delta Function Expansion

The delta function $ \delta(\vec{r}_{ij}) $, representing the positional difference between nodes $ i $ and $ j $, is expanded in spherical harmonics of degree $ l_2 $:

$$
\delta(\vec{r}_{ij}) = \delta(r_{ij}) \sum_{l_2, m_2} Y_{l_2}^{m_2}(\hat{r}_{ij})
$$

- $ \delta(r_{ij}) $: Radial part ensuring precision in distances.
- $ \hat{r}_{ij} $: Unit vector in the direction from $ i $ to $ j $.

#### 3. Tensor Product and Projection

The product of the feature function $ f_j(\vec{r}) $ and the delta function $ \delta(\vec{r}_{ij}) $ is projected onto degree $ l_3 $ using the **Clebsch-Gordan coefficients**, resulting in message terms:

$$
[m_{ij}]^{l_3 m_3}_{l_1 l_2} = \sum_{m_1, m_2} \langle l_1 m_1, l_2 m_2 | l_3 m_3 \rangle f_j^{l_1 m_1} Y_{l_2}^{m_2}(\hat{r}_{ij})
$$

- $ \langle l_1 m_1, l_2 m_2 | l_3 m_3 \rangle $: Clebsch-Gordan coefficients for coupling degrees $ l_1 $ and $ l_2 $ to $ l_3 $.
- This operation ensures the resulting features transform correctly under rotations.

#### 4. Message Computation

The message from node $ j $ to node $ i $ is computed by summing over all combinations of degrees $ l_1, l_2, l_3 $:

$$
m_{ij} = \sum_{l_1, l_2, l_3} w^{l_3}_{l_1 l_2} [m_{ij}]^{l_3}_{l_1 l_2}
$$

- $ w^{l_3}_{l_1 l_2} $: Learnable weights obtained from data, capturing interaction strengths.
- The summation over degrees allows the network to capture interactions at multiple levels of complexity.

#### 5. Feature Update

Each node $ i $ updates its features by aggregating the messages from its neighbors:

$$
f_i' = \sum_{j \in \mathcal{N}(i)} m_{ij}
$$

- $ \mathcal{N}(i) $: Set of neighboring nodes of $ i $.
- This aggregation maintains the equivariance property of the network.

### Tensor Product and Clebsch-Gordan Coefficients

The **tensor product** $ \otimes $ combines two irreducible representations of degrees $ l_1 $ and $ l_2 $ to produce representations of degree $ l_3 $. The Clebsch-Gordan coefficients govern this coupling, ensuring that the resulting features transform correctly under rotations, thus preserving **equivariance**.

### Example: Tensor Field Networks (TFNs)

**Tensor Field Networks (TFNs)** exemplify spherical networks that maintain equivariance under rotations and translations by operating on higher-degree irreps.

- **Feature Transformation**: Atomic features are expanded into spherical harmonics, capturing complex angular dependencies.
- **Equivariant Convolution**: TFNs perform convolutions using tensor products and Clebsch-Gordan coefficients, combining features in a rotation-equivariant manner.
- **Feature Aggregation**: Messages are aggregated over neighbors, with higher-degree irreps enabling the network to model intricate spatial relationships.

#### Mathematical Formulation in TFNs

1. **Input Features**: $ f_j^{l m} $ for each node $ j $.
2. **Message Passing**:

   $$
   m_{ij}^{l_3 m_3} = \sum_{l_1, l_2} \sum_{m_1, m_2} w^{l_3}_{l_1 l_2} \langle l_1 m_1, l_2 m_2 | l_3 m_3 \rangle f_j^{l_1 m_1} Y_{l_2}^{m_2}(\hat{r}_{ij}) \delta(r_{ij})
   $$

3. **Feature Update**:

   $$
   f_i'^{l_3 m_3} = \sum_{j \in \mathcal{N}(i)} m_{ij}^{l_3 m_3}
   $$

### Applications

- **Modeling Molecular Conformations**: By capturing higher-order geometric features, spherical networks like TFNs can predict molecular shapes and dynamics with high accuracy.
- **Protein Structure Prediction**: The ability to model complex interactions and orientations makes spherical networks suitable for predicting 3D protein structures.

### Visualization of Spherical Harmonics

![Spherical Harmonics Visualization](https://upload.wikimedia.org/wikipedia/commons/7/74/Real_Spherical_Harmonics_Figure_Table_Complex_Radial_Magnitude.gif)

- The images represent spherical harmonics of various degrees and orders, visualizing how higher-degree functions capture more complex angular patterns.

### Key Takeaways

- **Higher-Degree Features**: Spherical networks extend beyond vectors to include higher-degree features, enabling richer representations of molecular structures.
- **Separation of Radial and Angular Components**: Handling radial and angular components separately allows for precise modeling of distances and orientations.
- **Equivariance through Spherical Harmonics**: Using spherical harmonics and tensor products ensures that network operations are equivariant under rotations, crucial for modeling physical systems.

### Additional Resources

- **Tensor Field Networks Paper**: [Tensor Field Networks: Rotation- and Translation-Equivariant Neural Networks for 3D Point Clouds](https://arxiv.org/abs/1802.08219)
- **Clebsch-Gordan Coefficients Tutorial**: [Understanding Clebsch-Gordan Coefficients](https://quantummechanics.ucsd.edu/ph130a/130_notes/node328.html)


![Tensor Field FLow](https://github.com/RobDHess/Steerable-E3-GNN/raw/main/assets/forward_pass_faster_larger.gif)
![TFN illustration](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41467-022-29939-5/MediaObjects/41467_2022_29939_Fig1_HTML.png?as=webp)


#### Implementation of a Simple Spherical GNN Layer
In this section, we will implement a basic spherical GNN layer. This layer computes radial basis functions (RBF) of distances between nodes and spherical harmonics of the angular part of the vectors between nodes. The input degree is 0, and the output degree is 2.

We will use predefined functions to compute spherical harmonics and tensor products, and we will load Clebsch-Gordan coefficients to handle the operations correctly.

Step-by-Step Implementation
1. Define spherical harmonics

In [None]:
import numpy as np

In [None]:
def associated_legendre_polynomials(L, x):
    """
    Compute the associated Legendre polynomials.

    Parameters:
    L (int): The maximum degree of the polynomials.
    x (torch.Tensor): The input tensor.

    Returns:
    torch.Tensor: A tensor containing the associated Legendre polynomials.
    """
    P = [torch.ones_like(x) for _ in range((L+1)*L//2)]
    
    # Compute the polynomials for l in range(1, L)
    for l in range(1, L):
        P[(l+3)*l//2] = - np.sqrt((2*l-1)/(2*l)) * torch.sqrt(1-x**2) * P[(l+2)*(l-1)//2]
    
    # Compute the polynomials for m in range(L-1)
    for m in range(L-1):
        P[(m+2)*(m+1)//2+m] = x * np.sqrt(2*m+1) * P[(m+1)*m//2+m]
        for l in range(m+2, L):
            P[(l+1)*l//2+m] = ((2*l-1)*x*P[l*(l-1)//2 + m]/np.sqrt((l**2-m**2)) - P[(l-1)*(l-2)//2+m]*np.sqrt(((l-1)**2-m**2)/(l**2-m**2)))
    return torch.stack(P, dim=0)

def spherical_harmonics(L, THETA, PHI, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
    """
    Compute the spherical harmonics.

    Parameters:
    L (int): The maximum degree of the harmonics.
    THETA (torch.Tensor): The theta angles.
    PHI (torch.Tensor): The phi angles.
    device (torch.device): The device to use for computations (default is CUDA if available).

    Returns:
    list: A list of tensors containing the spherical harmonics for each degree l.
    """
    P = associated_legendre_polynomials(L, torch.cos(PHI))
    M2 =  [torch.zeros_like(THETA) for _ in range(2*(L-1)+1)]
    output =  [[torch.zeros_like(THETA, device = device) for _ in range(2*l+1)] for l in range(L)]
    
    # Compute cosine and sine components for each m
    for m in range(L):
        if m > 0:
            M2[L-1+m] = torch.cos(m*THETA)
            M2[L-1-m] = torch.sin(m*THETA)
        else:
            M2[L-1]  = torch.ones_like(THETA)
    
    # Compute the spherical harmonics for each l and m
    for l in range(L):
        for m in range(l+1):
            if m > 0:
                output[l][l+m] = np.sqrt(2)*P[(l+1)*l//2+m]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1+m]
                output[l][l-m] = np.sqrt(2)*P[(l+1)*l//2+m]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1-m]
            else:
                output[l][l  ] = P[(l+1)*l//2]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1]
    
    return torch.concat([torch.stack(output_i, dim = 0).to(device) for output_i in output])

2. Implement the tensor product. 




In [None]:
def tensor_product(f_j, Y_r, cg, W, rbf, in_degree, r_degree, out_degree):
    # Tensor product using Clebsch-Gordan coefficients
    in_degree_to_order = torch.tensor([int(np.floor(np.sqrt(i + 1)))-1 for i in range((in_degree + 1) ** 2)], dtype=torch.long)
    r_degree_to_order = torch.tensor([int(np.floor(np.sqrt(i + 1)))-1 for i in range((r_degree + 1) ** 2)], dtype=torch.long)
    out_degree_to_order = torch.tensor([int(np.floor(np.sqrt(i + 1)))-1 for i in range((out_degree + 1) ** 2)], dtype=torch.long)
    # print(W.shape, in_degree_to_order.shape, r_degree_to_order.shape, out_degree_to_order.shape)
    W_spanned = ((W[in_degree_to_order])[:, r_degree_to_order])[:, :, out_degree_to_order]
    # print(cg.shape,  in_degree, r_degree, out_degree)
    # print( f_j.shape)
    # print(Y_r.shape)
    # print((cg[:(in_degree + 1) ** 2, :(r_degree + 1) ** 2, :(out_degree + 1) ** 2, ]).shape)
    # print(W_spanned.shape)
    # print(rbf.shape)
    out = torch.einsum('exa, ye, xyz, xyzabr, er->ezb', f_j, Y_r, cg[:(in_degree + 1) ** 2, :(r_degree + 1) ** 2, :(out_degree + 1) ** 2, ], W_spanned, rbf)
    return out


3. Define a simple Spherical GNN Layer

In [None]:
import torch.nn as nn
from torch.nn.functional import softplus

# Define the Spherical GNN Layer
class SphericalGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, num_rbf=16, in_degree=0, r_degree=2, out_degree=2):
        super(SphericalGNNLayer, self).__init__(aggr='add')
        self.num_rbf = num_rbf
        self.in_degree = in_degree
        self.r_degree = r_degree
        self.out_degree = out_degree
        self.rbf_centers = torch.linspace(0, 5, num_rbf)  # Radial basis function centers
        self.rbf_gamma = torch.tensor(1.0)  # Gamma parameter for RBF

        
        self.W = nn.Parameter(torch.randn( self.in_degree+1, self.r_degree+1, self.out_degree+1, in_channels, out_channels, self.num_rbf))

        # Load Clebsch-Gordan coefficients
        self.cg = torch.load('CG_tensor_2.pt')

    def forward(self, x, pos, edge_index):
        # Compute pairwise distances
        row, col = edge_index
        diff = pos[row] - pos[col]
        dist = diff.norm(dim=-1)

        # Compute RBF and spherical harmonics
        rbf = torch.exp(-self.rbf_gamma[None] * (dist[:, None] - self.rbf_centers[None]) ** 2)
        sh = self.spherical_harmonics(diff)

        # Perform message passing
        out = self.propagate(edge_index, x=x, rbf=rbf, sh=sh)
        return out

    def message(self, x_j, rbf, sh):
        # Tensor product of input features with spherical 
        x_j  = torch.reshape(x_j, (x_j.shape[0], (self.in_degree+1)**2, -1))
        tp = tensor_product(x_j, sh, self.cg, self.W, rbf, self.in_degree, self.r_degree, self.out_degree)
        return tp.reshape(tp.shape[0], -1)

    

    def spherical_harmonics(self, vectors):
        # Compute spherical harmonics of vectors
        theta = torch.atan2(vectors[:, 1], vectors[:, 0])
        phi = torch.acos(vectors[:, 2] / vectors.norm(dim=-1))
        sh = spherical_harmonics(self.r_degree + 1, theta, phi)
        return sh
    

# Example usage
node_features = torch.randn(4, 3)  # 4 nodes, 3 features per node
node_coords = torch.randn(4, 3)  # 4 nodes, 3D coordinates
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)  # Edges in COO format

model = SphericalGNNLayer(in_channels=3, out_channels=2)
output = model(node_features, node_coords, edge_index)

print("Output Node Features:", output)

#### Explanation of the Key Components

1. **Radial Basis Functions (RBF):** A set of functions applied to the distances between nodes. The output is transformed via a linear layer and activation function to capture the radial dependencies.

2. **Spherical Harmonics (SH):** Computed from the angular parts of vectors between nodes. SH helps encode the angular information of atomic positions.

3. **Tensor Product:** Combines the input features with the spherical harmonics using Clebsch-Gordan coefficients to produce higher-degree features.



## Exercise: Training a GNN on the QM9 Dataset using Custom Layers

In this final exercise, you will train a GNN on the `QM9` dataset using one of the custom layers you have implemented: `InvariantMPNN`, `CartesianGNNLayer`, or `SphericalGNNLayer`.

### Step 1: Load the QM9 Dataset

In [None]:
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader


batch_size = 32

# Load the QM9 dataset
dataset = QM9(root='data/QM9')



# choose the target


'''
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | Target | Property                         | Description                                                                       | Unit                                        |
    +========+==================================+===================================================================================+=============================================+
    | 0      | :math:`\mu`                      | Dipole moment                                                                     | :math:`\textrm{D}`                          |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 1      | :math:`\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 2      | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 3      | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 4      | :math:`\Delta \epsilon`          | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 5      | :math:`\langle R^2 \rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 6      | :math:`\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 7      | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 8      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 9      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 10     | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 11     | :math:`c_{\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 12     | :math:`U_0^{\textrm{ATOM}}`      | Atomization energy at 0K                                                          | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 13     | :math:`U^{\textrm{ATOM}}`        | Atomization energy at 298.15K                                                     | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 14     | :math:`H^{\textrm{ATOM}}`        | Atomization enthalpy at 298.15K                                                   | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 15     | :math:`G^{\textrm{ATOM}}`        | Atomization free energy at 298.15K                                                | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 16     | :math:`A`                        | Rotational constant                                                               | :math:`\textrm{GHz}`                        |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 17     | :math:`B`                        | Rotational constant                                                               | :math:`\textrm{GHz}`                        |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | 18     | :math:`C`                        | Rotational constant                                                               | :math:`\textrm{GHz}`                        |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+

    
'''

target = 0


mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, target].item(), std[:, target].item()


# Splitting dataset...
train_dataset = dataset[:110000]
val_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]


# DataLoader settings...
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Step 2: Choose and Define the Model
Select one of the custom layers (`InvariantMPNN`, `CartesianGNNLayer`, or `SphericalGNNLayer`) and define the model. Here’s an example using the `InvariantMPNN` layer:

In [None]:
import torch
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import global_mean_pool

class CustomGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, layer_type='invariant'):
        super(CustomGNN, self).__init__()
        if layer_type == 'invariant':
            self.conv1 = InvariantMPNN(in_channels, hidden_channels)
        elif layer_type == 'cartesian':
            self.conv1 = CartesianGNNLayer(in_channels, hidden_channels)
        elif layer_type == 'spherical':
            self.conv1 = SphericalGNNLayer(in_channels, hidden_channels)
        
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, pos, edge_index, batch):
        # Message Passing Layer
        x = self.conv1(x, pos, edge_index)
        x = torch.relu(x)

        # Global Pooling
        x = global_mean_pool(x, batch)
        x = torch.relu(self.lin1(x))
        return self.lin2(x)

# Initialize the model
model = CustomGNN(in_channels=11, hidden_channels=64, out_channels=1, layer_type='invariant')


### Step 3: Define Training and Evaluation Functions

In [None]:
import torch.nn.functional as F
from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=0.001)



def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to('cuda')
        optimizer.zero_grad()
        out = model(data.x, data.pos, data.edge_index, data.batch)
        loss = F.mse_loss(out, data.y[:, target])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    error = 0
    for data in loader:
        data = data.to('cuda')
        with torch.no_grad():
            out = model(data.x, data.pos, data.edge_index, data.batch)
            error += (out - data.y[:, target]).abs().sum().item()
    return error / len(loader.dataset)


### Step 4: Train and Evaluate the Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

target = 0  # Select the property index to predict

for epoch in range(1, 101):
    loss = train()
    test_error = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test MAE: {test_error:.4f}')