In [1]:
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from torch_geometric.nn import global_mean_pool
from torch.utils.data import random_split
import torch
from torch import Tensor, LongTensor
from torch_geometric.nn import radius_graph
import torch.nn as nn
import torch_geometric.transforms as T
from torch_geometric.nn import TransformerConv
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from typing import Tuple
from ase import Atoms

torch.manual_seed(42)

<torch._C.Generator at 0x10792a1d0>

In [5]:
dataset = QM9(root='./data')

y_values = dataset.y

# normalize y values along each column
y_values = (y_values - y_values.mean(dim=0)) / y_values.std(dim=0)
dataset._data.y = y_values

In [8]:
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.85, 0.05, 0.1], generator=generator)

In [4]:
torch.save(dataset, 'data/splits/dataset.pt')
torch.save(train_dataset, 'data/splits/train_dataset.pt')
torch.save(val_dataset, 'data/splits/val_dataset.pt')
torch.save(test_dataset, 'data/splits/test_dataset.pt')

In [9]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
class GaussianSmearing(torch.nn.Module):
    def __init__(
        self,
        start: float = 0.0,
        stop: float = 5.0,
        num_gaussians: int = 50,
    ):
        super().__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)

    def forward(self, dist: Tensor) -> Tensor:
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

In [13]:
class RadiusInteractionGraph(torch.nn.Module):
    r"""Creates edges based on atom positions :obj:`pos` to all points within
    the cutoff distance.

    Args:
        cutoff (float, optional): Cutoff distance for interatomic interactions.
            (default: :obj:`10.0`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            collect for each node within the :attr:`cutoff` distance with the
            default interaction graph method.
            (default: :obj:`32`)
    """
    def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32):
        super().__init__()
        self.cutoff = cutoff
        self.max_num_neighbors = max_num_neighbors

    def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]:
        r"""
        Args:
            pos (Tensor): Coordinates of each atom.
            batch (LongTensor, optional): Batch indices assigning each atom to
                a separate molecule.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        return edge_index, edge_weight

In [32]:
class AttentionBlock(nn.Module):
    def __init__(self, num_features, num_targets, heads, num_gaussians):
        super().__init__()
        # GATConv with edge features
        self.conv = GATConv(
            in_channels=num_features, 
            out_channels=num_targets,
            heads=heads,
            edge_dim=num_gaussians  # match your num_gaussians for edge attributes
        )
        self.ln = nn.LayerNorm(num_targets * heads)
        self.ff = nn.Linear(num_targets * heads, num_features)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        # GATConv only needs x, edge_index, and edge_attr
        out = self.conv(x, edge_index, edge_attr=edge_attr)
        out = self.ln(out)
        out = self.ff(out)
        return x + out  # residual connection


class Model(nn.Module):
    def __init__(
        self,
        hidden_channels: int = 128,
        num_features: int = 11,
        num_targets: int = 19,
        heads: int = 8,
        cutoff: float = 5.0,
        max_num_neighbors: int = 32,
        num_gaussians: int = 50,
        num_blocks: int = 3,
    ):
        super().__init__()
        self.embedding = nn.Embedding(100, hidden_channels, padding_idx=0)
        self.interaction_graph = RadiusInteractionGraph(cutoff, max_num_neighbors)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)
        self.blocks = nn.ModuleList(
            [
                AttentionBlock(hidden_channels, num_targets, heads, num_gaussians)
                for _ in range(num_blocks)
            ]
        )
        self.lin1 = nn.Linear(hidden_channels, hidden_channels // 2)
        self.act = nn.ReLU()
        self.lin2 = nn.Linear(hidden_channels // 2, num_targets)

    def forward(self, data):
        h = self.embedding(data.z)
        edge_index, edge_weight = self.interaction_graph(data.pos, data.batch)
        edge_attr = self.distance_expansion(edge_weight)

        for block in self.blocks:
            h = h + block(h, edge_index, edge_weight, edge_attr)

        out = self.lin1(h)
        out = self.act(out)
        out = self.lin2(out)

        return out


In [33]:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
criterion = torch.nn.MSELoss()

In [36]:
# train on one batch to test the model
def train():
    model.train()
    # get first batch only
    data = next(iter(train_loader))
    optimizer.zero_grad()
    # get node embeddings
    node_embeddings = model(data)
    # get graph embedding
    out = global_mean_pool(node_embeddings, data.batch)
    loss = criterion(out[:,7], data.y[:,7])
    loss.backward()
    optimizer.step()
    return loss.item()

In [38]:
for epoch in range(1, 100):
    loss = train()
    print(f'Epoch {epoch}, Loss: {loss:.4f}')

Epoch 1, Loss: 0.5363
Epoch 2, Loss: 0.3657
Epoch 3, Loss: 0.3409
Epoch 4, Loss: 0.3845
Epoch 5, Loss: 0.4619
Epoch 6, Loss: 0.5015
Epoch 7, Loss: 0.4844
Epoch 8, Loss: 0.4412
Epoch 9, Loss: 0.2549
Epoch 10, Loss: 0.3926
Epoch 11, Loss: 0.6556
Epoch 12, Loss: 0.3610
Epoch 13, Loss: 0.4893
Epoch 14, Loss: 0.2691
Epoch 15, Loss: 0.3302
Epoch 16, Loss: 0.4661
Epoch 17, Loss: 0.5319
Epoch 18, Loss: 0.4192
Epoch 19, Loss: 0.4921
Epoch 20, Loss: 0.5398
Epoch 21, Loss: 0.5813
Epoch 22, Loss: 0.2633
Epoch 23, Loss: 0.4397
Epoch 24, Loss: 0.3027
Epoch 25, Loss: 0.7000
Epoch 26, Loss: 0.3932
Epoch 27, Loss: 0.3514
Epoch 28, Loss: 0.4299
Epoch 29, Loss: 0.6667
Epoch 30, Loss: 0.3789
Epoch 31, Loss: 0.2189
Epoch 32, Loss: 0.6897
Epoch 33, Loss: 1.6795
Epoch 34, Loss: 0.8366
Epoch 35, Loss: 0.4885
Epoch 36, Loss: 0.3352
Epoch 37, Loss: 0.2558
Epoch 38, Loss: 0.7270
Epoch 39, Loss: 0.4491
Epoch 40, Loss: 0.2835
Epoch 41, Loss: 1.2102
Epoch 42, Loss: 0.4246
Epoch 43, Loss: 0.3800
Epoch 44, Loss: 0.36

In [117]:
def test():
    total_loss = 0  
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            out = model(data.x, data.edge_index)
            out = global_mean_pool(out, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss.item()
    return total_loss / len(test_loader)

torch.Size([64, 19])
