In [17]:
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np

from utils import atom_type_to_onehot, pairwise_edges

# NOTE: This pipeline is set up on CPU

In [15]:
class AtomDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, cutoff=2.0):
        """
        root: path containing the 'atoms.csv'
        cutoff: distance threshold for constructing edges
        """
        self.cutoff = cutoff
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['atoms.csv']
    
    @property
    def processed_file_names(self):
        return ['data.pt']
    
    def process(self):
        df = pd.read_csv(self.raw_paths[0])

        # 1) Collect node features
        #    - we'll store [one-hot(atom_type), x, y, z]
        positions = df[['X', 'Y', 'Z']].values

        # Construct feature matrix
        node_features = []
        for _, row in df.iterrows():
            atom_label = row['ATOM']  # e.g. 'N1'
            one_hot = atom_type_to_onehot(atom_label[0])
            # append coordinates
            coords = [row['X'], row['Y'], row['Z']]
            node_features.append(one_hot + coords)

        x = torch.tensor(node_features, dtype=torch.float)  # shape (num_nodes, 7)

        # 2) Construct edge_index using distance cutoff
        edge_index = pairwise_edges(positions, cutoff=self.cutoff)

        # 3) Magnetic moment as labels (node-level)
        y = torch.tensor(df['MAGNETIC_MOMENT'].values, dtype=torch.float).view(-1, 1)

        data = Data(x=x, edge_index=edge_index, y=y)

        data_list = [data]
        self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices), self.processed_paths[0])

        
    def _load_processed_data(self):
        """
        Safely load processed data with backward compatibility.
        """
        # Retain compatibility with current behavior and avoid FutureWarnings
        try:
            return torch.load(self.processed_paths[0])
        except TypeError:  # If a new PyTorch version requires explicit weights_only
            return torch.load(self.processed_paths[0], weights_only=False)

In [16]:
dataset = AtomDataset(root='./', cutoff=2.0)
data = dataset[0]
print(data)

Data(x=[66, 7], edge_index=[2, 156], y=[66, 1])


  self.data, self.slices = torch.load(self.processed_paths[0])


In [None]:
class GNNModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # x: (num_nodes, in_channels)
        # edge_index: (2, E)
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        x = self.conv2(x, edge_index)
        x = torch.relu(x)

        x = self.conv3(x, edge_index)
        # x: (num_nodes, out_channels=1)
        
        return x