In [39]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MD17
from torch_geometric.nn import GATConv, global_add_pool
from torch_geometric.transforms import KNNGraph

In [40]:
# Compute edge indices
knn_graph = KNNGraph(k=6) 

In [54]:
# Load the MD17 dataset
#https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/md17.html#MD17
path = "/home/sire/phd/srz228573/benchmarking_datasets/baseline_model_dataset/gat"
dataset = MD17(root=path, name = 'revised benzene', transform=knn_graph)

In [51]:
# # Prepare the data
# for data in dataset:
#     data.z = data.z.unsqueeze(-1).float()  # Use atomic numbers as node features


In [55]:
dataset[0].z

tensor([6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1])

In [9]:
# Use 80% of the data for training, and 20% for testing
train_dataset = dataset[:int(0.8 * len(dataset))]
test_dataset = dataset[int(0.8 * len(dataset)):]

In [30]:
train_dataset[0]

tensor([-145434.1250])

In [56]:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



In [57]:
class GAT(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GAT, self).__init__()
        self.conv1 = GATConv(num_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.fc1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_add_pool(x, batch)  # Pooling to obtain graph-level representation
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GAT(dataset.num_features, hidden_channels=64, num_classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


In [59]:
def train():
    model.train()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.z, data.edge_index, data.batch)
        loss = F.mse_loss(output, data.energy)
        loss.backward()
        optimizer.step()


In [60]:
def test(loader):
    model.eval()
    loss = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            output = model(data.z, data.edge_index, data.batch)
        loss += F.mse_loss(output, data.energy).item()
    return loss / len(loader)

In [61]:
for epoch in range(2):
    train()
    train_loss = test(train_loader)
    test_loss = test(test_loader)
    print(f'Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

AssertionError: Static graphs not supported in 'GATConv'