In [1]:
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from torch_geometric.nn import global_mean_pool
from torch.utils.data import random_split
import torch
from torch import Tensor, LongTensor
from torch_geometric.nn import radius_graph
import torch.nn as nn
import torch_geometric.transforms as T
from torch_geometric.nn import TransformerConv
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from typing import Tuple
from ase import Atoms

torch.manual_seed(42)

<torch._C.Generator at 0x11b206370>

In [41]:
from cbp.model import CBPSimpleMLP
model = CBPSimpleMLP(num_features=3696, num_hidden=1000, num_targets=1)

In [42]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
criterion = torch.nn.MSELoss()

In [43]:
data = next(iter(train_loader))
def train():
    model.train()
    # get first batch only
    optimizer.zero_grad()
    # get node embeddings
    node_embeddings = model(data)
    # get graph embedding
    out = global_mean_pool(node_embeddings, data.batch)
    loss = criterion(out[:,7], data.y[:,7])
    loss.backward()
    optimizer.step()
    return loss.item()

In [44]:
for epoch in range(1, 500):
    loss = train()
    scheduler.step()
    if epoch % 50 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')


NotImplementedError: Module [MLP] is missing the required "forward" function

In [4]:
dataset = QM9(root='./data')

y_values = dataset.y

# normalize y values along each column
y_values = (y_values - y_values.mean(dim=0)) / y_values.std(dim=0)
dataset._data.y = y_values

In [83]:
from dscribe.descriptors import SOAP
soap_desc = SOAP(species=["C", "H", "O", "N", "F"], r_cut=5, n_max=8, l_max=6, average="inner")

In [30]:
from ase import Atoms
atoms = Atoms(numbers=dataset[0].z.tolist(), positions=dataset[0].pos.tolist())
desc = soap_desc.create(atoms)
desc

array([0.01623526, 0.06291466, 0.15724008, ..., 0.        , 0.        ,
       0.        ])

In [55]:
from torch_geometric.data import Data

setattr(dataset[0], 'soap', desc)

In [56]:
dataset[0]

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])

In [8]:
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.85, 0.05, 0.1], generator=generator)

In [4]:
torch.save(dataset, 'data/splits/dataset.pt')
torch.save(train_dataset, 'data/splits/train_dataset.pt')
torch.save(val_dataset, 'data/splits/val_dataset.pt')
torch.save(test_dataset, 'data/splits/test_dataset.pt')

In [46]:
# read in dataste
train_dataset = torch.load('data/splits/train_dataset.pt', weights_only=False)
val_dataset = torch.load('data/splits/val_dataset.pt', weights_only=False)
test_dataset = torch.load('data/splits/test_dataset.pt', weights_only=False)
# create dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [86]:
class AddFeatureTransform:
    def __init__(self):
        pass
        
    def __call__(self, data):
        # Calculate your new feature here
        data.soap = self.calc_soap(data)
        return data
    
    def calc_soap(self, data):
        atoms = Atoms(numbers=data.z.tolist(), positions=data.pos.tolist())
        desc = soap_desc.create(atoms)
        return desc

# Usage
transform = AddFeatureTransform()
dataset = QM9(root="data", force_reload=False, pre_transform=transform)
dataset[0]



Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5], soap=[5740])

In [89]:
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.85, 0.05, 0.1], generator=generator)

In [91]:
# torch.save(dataset, 'data/splits/dataset.pt')
# torch.save(train_dataset, 'data/splits/train_dataset.pt')
# torch.save(val_dataset, 'data/splits/val_dataset.pt')
torch.save(test_dataset, 'data/splits/test_dataset.pt')

: 

In [9]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
class GaussianSmearing(torch.nn.Module):
    def __init__(
        self,
        start: float = 0.0,
        stop: float = 5.0,
        num_gaussians: int = 50,
    ):
        super().__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)

    def forward(self, dist: Tensor) -> Tensor:
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

In [13]:
class RadiusInteractionGraph(torch.nn.Module):
    r"""Creates edges based on atom positions :obj:`pos` to all points within
    the cutoff distance.

    Args:
        cutoff (float, optional): Cutoff distance for interatomic interactions.
            (default: :obj:`10.0`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            collect for each node within the :attr:`cutoff` distance with the
            default interaction graph method.
            (default: :obj:`32`)
    """
    def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32):
        super().__init__()
        self.cutoff = cutoff
        self.max_num_neighbors = max_num_neighbors

    def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]:
        r"""
        Args:
            pos (Tensor): Coordinates of each atom.
            batch (LongTensor, optional): Batch indices assigning each atom to
                a separate molecule.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        return edge_index, edge_weight

In [32]:
class AttentionBlock(nn.Module):
    def __init__(self, num_features, num_targets, heads, num_gaussians):
        super().__init__()
        # GATConv with edge features
        self.conv = GATConv(
            in_channels=num_features, 
            out_channels=num_targets,
            heads=heads,
            edge_dim=num_gaussians  # match your num_gaussians for edge attributes
        )
        self.ln = nn.LayerNorm(num_targets * heads)
        self.ff = nn.Linear(num_targets * heads, num_features)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        # GATConv only needs x, edge_index, and edge_attr
        out = self.conv(x, edge_index, edge_attr=edge_attr)
        out = self.ln(out)
        out = self.ff(out)
        return x + out  # residual connection


class Model(nn.Module):
    def __init__(
        self,
        hidden_channels: int = 128,
        num_features: int = 11,
        num_targets: int = 19,
        heads: int = 8,
        cutoff: float = 5.0,
        max_num_neighbors: int = 32,
        num_gaussians: int = 50,
        num_blocks: int = 3,
    ):
        super().__init__()
        self.embedding = nn.Embedding(100, hidden_channels, padding_idx=0)
        self.interaction_graph = RadiusInteractionGraph(cutoff, max_num_neighbors)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)
        self.blocks = nn.ModuleList(
            [
                AttentionBlock(hidden_channels, num_targets, heads, num_gaussians)
                for _ in range(num_blocks)
            ]
        )
        self.lin1 = nn.Linear(hidden_channels, hidden_channels // 2)
        self.act = nn.ReLU()
        self.lin2 = nn.Linear(hidden_channels // 2, num_targets)

    def forward(self, data):
        h = self.embedding(data.z)
        edge_index, edge_weight = self.interaction_graph(data.pos, data.batch)
        edge_attr = self.distance_expansion(edge_weight)

        for block in self.blocks:
            h = h + block(h, edge_index, edge_weight, edge_attr)

        out = self.lin1(h)
        out = self.act(out)
        out = self.lin2(out)

        return out


In [96]:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
criterion = torch.nn.MSELoss()

In [97]:
# train on one batch to test the model
data = next(iter(train_loader))
def train():
    model.train()
    # get first batch only
    optimizer.zero_grad()
    # get node embeddings
    node_embeddings = model(data)
    # get graph embedding
    out = global_mean_pool(node_embeddings, data.batch)
    loss = criterion(out[:,7], data.y[:,7])
    loss.backward()
    optimizer.step()
    return loss.item()

In [98]:
for epoch in range(1, 500):
    loss = train()
    scheduler.step()
    if epoch % 50 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')


Epoch 50, Loss: 0.2424, LR: 0.001000
Epoch 100, Loss: 0.1470, LR: 0.001000
Epoch 150, Loss: 0.0455, LR: 0.001000
Epoch 200, Loss: 0.0246, LR: 0.001000
Epoch 250, Loss: 0.0086, LR: 0.001000
Epoch 300, Loss: 0.0089, LR: 0.001000
Epoch 350, Loss: 0.0013, LR: 0.001000
Epoch 400, Loss: 0.0009, LR: 0.001000
Epoch 450, Loss: 0.0003, LR: 0.001000


In [99]:
with torch.no_grad():
    out = model(data)
    out = global_mean_pool(out, data.batch)
    print(out[:,7])
    print(data.y[:,7])


tensor([-0.1652, -1.2174, -0.3294,  0.7120, -1.0163, -3.9658, -0.2178, -1.2608,
         0.2047, -0.1852,  1.5539, -0.2822, -2.9801,  0.6996, -0.7121,  0.2769,
        -0.6954,  0.7406, -0.1978,  0.5902, -0.2861,  0.7648,  1.5744, -0.6649,
        -0.2589,  2.4867, -1.0844, -0.3288,  0.1834, -0.4734, -1.1804, -0.1702,
         0.3361,  0.3196,  4.0020,  1.1726, -0.7010, -1.2437,  1.1501, -0.2748,
        -0.6873,  0.2087, -0.3241, -1.6299, -1.6005,  0.6289,  0.1907,  0.2356,
        -0.9825,  1.3913,  0.6752, -1.6797,  1.5454,  1.5256, -1.1322, -0.1567,
         0.6679, -0.6931,  0.6295, -1.2356, -1.5883, -0.1738, -1.0776, -1.5624])
tensor([-0.1814, -1.2356, -0.3352,  0.6782, -1.0160, -3.9753, -0.2403, -1.2358,
         0.1932, -0.2100,  1.5271, -0.3054, -2.9874,  0.6894, -0.7376,  0.2757,
        -0.7062,  0.7238, -0.2098,  0.5653, -0.3051,  0.7429,  1.5504, -0.6802,
        -0.2748,  2.4801, -1.1115, -0.3668,  0.1627, -0.4872, -1.2069, -0.1791,
         0.3177,  0.3059,  3.9877,  1.1

In [72]:
with torch.no_grad():
    out = model(next(iter(test_loader)))
    print(out[:,7].shape)

torch.Size([1115])


In [73]:
data = next(iter(test_loader))
print(data.y[:,7].shape)


torch.Size([64])


In [117]:
def test():
    total_loss = 0  
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            out = model(data.x, data.edge_index)
            out = global_mean_pool(out, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss.item()
    return total_loss / len(test_loader)

torch.Size([64, 19])
