In [10]:
import torch
import torch.nn as nn
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, BatchNorm, JumpingKnowledge
import torch.nn.functional as F
import helper
import data_preprocessing_training

In [11]:
# Create a Dataset class
class MolecularGraphTrain(Dataset):
    def __init__(self, cleaned_data, transform=None, pre_transform=None):
        super(MolecularGraphTrain, self).__init__(transform=transform, pre_transform=pre_transform)
        self.graphs = list(cleaned_data.values())
        self._indices = range(len(self.graphs))
    
    def __len__(self):
        return len(self._indices)
    
    def get(self, idx):
        graph_info = self.graphs[idx]
        return self.create_pyg_data(graph_info)
    
    def __getitem__(self, idx):
        data = self.get(self._indices[idx])
        data = data if self.transform is None else self.transform(data)
        return data
    
    def create_pyg_data(self, graph_info):
        # Extract nodes and edges from the graph information
        node_id_feature = graph_info["node_id_feature"]
        edge_features = graph_info["edge_features"]
        target_variable = graph_info["target_variable"]
    
        # Create the node feature matrix
        node_ids = sorted(node_id_feature.keys())
        node_features = []
        for node_id in node_ids:
            features = [
                node_id_feature[node_id]["atomic"],
                node_id_feature[node_id]["valence"],
                node_id_feature[node_id]["formal_charge"],
                node_id_feature[node_id]["aromatic"],
                node_id_feature[node_id]["hybridization"],
                node_id_feature[node_id]["radical_electrons"]
            ]
            node_features.append(features)
        x = torch.tensor(node_features, dtype=torch.float)
    
        # Create the edge list
        edge_index = []
        edge_attr = []
        for edge in edge_features:
            edge_index.append([edge["source"], edge["target"]])
            edge_attr.append([
                edge["type"],
                edge["stereo"],
                edge["aromatic"],
                edge["conjugated"]
            ])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
        # Create the target variable tensor
        target_list = [target_variable[node_id] for node_id in node_ids]
        y = torch.tensor([[t["mass"], t["charge"], t["sigma"], t["epsilon"]] for t in target_list], dtype=torch.float)
    
        # Return the graph as a Data object
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

In [28]:
# Create a Dataset class
class MolecularGraphTest(Dataset):
    def __init__(self, cleaned_data, transform=None, pre_transform=None):
        super(MolecularGraphTest, self).__init__(transform=transform, pre_transform=pre_transform)
        self.graphs = list(cleaned_data.values())
        self._indices = range(len(self.graphs))
    
    def __len__(self):
        return len(self._indices)
    
    def get(self, idx):
        graph_info = self.graphs[idx]
        return self.create_pyg_data(graph_info)
    
    def __getitem__(self, idx):
        data = self.get(self._indices[idx])
        data = data if self.transform is None else self.transform(data)
        return data
    
    def create_pyg_data(self, graph_info):
        # Extract nodes and edges from the graph information
        node_id_feature = graph_info["node_id_feature"]
        edge_features = graph_info["edge_features"]
    
        # Create the node feature matrix
        node_ids = sorted(node_id_feature.keys())
        node_features = []
        for node_id in node_ids:
            features = [
                node_id_feature[node_id]["atomic"],
                node_id_feature[node_id]["valence"],
                node_id_feature[node_id]["formal_charge"],
                node_id_feature[node_id]["aromatic"],
                node_id_feature[node_id]["hybridization"],
                node_id_feature[node_id]["radical_electrons"]
            ]
            node_features.append(features)
        x = torch.tensor(node_features, dtype=torch.float)
    
        # Create the edge list
        edge_index = []
        edge_attr = []
        for edge in edge_features:
            edge_index.append([edge["source"], edge["target"]])
            edge_attr.append([
                edge["type"],
                edge["stereo"],
                edge["aromatic"],
                edge["conjugated"]
            ])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
        # Return the graph as a Data object
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [29]:
class NodeEmbedding(nn.Module):
    def __init__(self, num_atomic, num_valence, num_formal_charge, num_hybridization, num_radical_electrons, embedding_dim):
        super(NodeEmbedding, self).__init__()
        self.atomic_embedding = nn.Embedding(num_atomic, embedding_dim)
        self.valence_embedding = nn.Embedding(num_valence, embedding_dim)
        self.formal_charge_embedding = nn.Embedding(num_formal_charge, embedding_dim)
        self.hybridization_embedding = nn.Embedding(num_hybridization, embedding_dim)
        self.radical_electrons_embedding = nn.Embedding(num_radical_electrons, embedding_dim)

    def forward(self, atomic, valence, formal_charge, aromatic, hybridization, radical_electrons):
        atomic_embed = self.atomic_embedding(atomic)
        valence_embed = self.valence_embedding(valence)
        formal_charge_embed = self.formal_charge_embedding(formal_charge)
        hybridization_embed = self.hybridization_embedding(hybridization)
        radical_electrons_embed = self.radical_electrons_embedding(radical_electrons)

        # Concatenate boolean features
        other_features = torch.stack([aromatic], dim=1).float()

        # Concatenate all features together
        return torch.cat([atomic_embed, valence_embed, formal_charge_embed, hybridization_embed, radical_electrons_embed, other_features], dim=1)

class EdgeEmbedding(nn.Module):
    def __init__(self, num_type, num_stereo, embedding_dim):
        super(EdgeEmbedding, self).__init__()
        self.type_embedding = nn.Embedding(num_type, embedding_dim)
        self.stereo_embedding = nn.Embedding(num_stereo, embedding_dim)

    def forward(self, type_, stereo, aromatic, conjugated):
        type_embed = self.type_embedding(type_)
        stereo_embed = self.stereo_embedding(stereo)

        # Concatenate boolean features directly
        other_features = torch.stack([aromatic, conjugated], dim=1).float()

        # Concatenate all features together
        return torch.cat([type_embed, stereo_embed, other_features], dim=1)

In [30]:
class ImprovedGNNWithEmbeddings(torch.nn.Module):

    def __init__(self, node_embedding_dim, edge_embedding_dim, hidden_dim, output_dim, num_layers = 6, num_atomic = 12, num_valence = 7, num_hybridization = 5, num_type = 4, num_stereo = 3 ,num_formal_charge = 3, num_radical_electrons = 1):
        super(ImprovedGNNWithEmbeddings, self).__init__()
        self.node_embedding = NodeEmbedding(num_atomic, num_valence, num_formal_charge, num_hybridization, num_radical_electrons, node_embedding_dim)
        self.edge_embedding = EdgeEmbedding(num_type, num_stereo, edge_embedding_dim)
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        # Define the first GINEConv layer, with the correct edge_dim specified
        self.convs.append(GINEConv(
            torch.nn.Sequential(
                torch.nn.Linear(node_input_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=edge_input_dim, train_eps = True
        ))
        self.norms.append(BatchNorm(hidden_dim))

        # Additional GINEConv layers, each with the correct edge_dim
        for _ in range(num_layers - 1):
            self.convs.append(GINEConv(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim)
                ),
                edge_dim=edge_input_dim, train_eps = True
            ))
            self.norms.append(BatchNorm(hidden_dim))

        # Jumping Knowledge mechanism
        self.jump = JumpingKnowledge(mode="cat")

        # Final fully connected layers
        self.fc1 = torch.nn.Linear(hidden_dim * num_layers, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        features = []

        # Pass through GINEConv layers and apply batch normalization
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index, edge_attr)
            x = F.relu(norm(x))
            features.append(x)

        # Apply Jumping Knowledge (JK) to concatenate all layers
        x = self.jump(features)

        # Directly pass through the linear layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [31]:
import data_preprocessing_testing

train_data = helper.load_data_from_file("data.json")
cleaned_data_train = data_preprocessing_training.extract_clean_data(train_data)


In [42]:
# Hyperparameters
node_input_dim = 6  # node feature dimension
edge_input_dim = 4  # edge feature dimension
hidden_dim = 256
output_dim = 4  # The number of outputs (mass, charge, sigma, epsilon)
num_epochs = 500
learning_rate = 0.0001

model = ImprovedGNNWithEmbeddings(node_embedding_dim = 32, edge_embedding_dim = 32, hidden_dim = hidden_dim, output_dim = 4)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()

# Move model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

ImprovedGNNWithEmbeddings(
  (node_embedding): NodeEmbedding(
    (atomic_embedding): Embedding(12, 32)
    (valence_embedding): Embedding(7, 32)
    (formal_charge_embedding): Embedding(3, 32)
    (hybridization_embedding): Embedding(5, 32)
    (radical_electrons_embedding): Embedding(1, 32)
  )
  (edge_embedding): EdgeEmbedding(
    (type_embedding): Embedding(4, 32)
    (stereo_embedding): Embedding(3, 32)
  )
  (convs): ModuleList(
    (0): GINEConv(nn=Sequential(
      (0): Linear(in_features=6, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    ))
    (1-5): 5 x GINEConv(nn=Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    ))
  )
  (norms): ModuleList(
    (0-5): 6 x BatchNorm(256)
  )
  (jump): JumpingKnowledge(cat)
  (fc1): Linear(in_features=1536, out_features=256, bias=True)
  (fc2): Linear(in_features=

In [43]:
from torch_geometric.loader import DataLoader as PyGDataLoader

# Set up the training and testing DataLoaders using PyTorch Geometric DataLoader
train_dataset = MolecularGraphTrain(cleaned_data_train)


# Use a smaller batch size for better memory management, particularly with graph data
train_loader = PyGDataLoader(train_dataset, batch_size=128, shuffle=True)


# Training loop
num_epochs = 300
model.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in train_loader:
        # Move data to the device
        batch = batch.to(device)

        # Forward pass
        output = model(batch)
        loss = criterion(output, batch.y)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

# Model evaluation on the test dataset
model.eval()


# Convert predictions to a suitable format if needed


Epoch [1/300], Loss: 17.6491
Epoch [2/300], Loss: 5.0653
Epoch [3/300], Loss: 1.0530
Epoch [4/300], Loss: 0.4029
Epoch [5/300], Loss: 0.2339
Epoch [6/300], Loss: 0.1642
Epoch [7/300], Loss: 0.1334
Epoch [8/300], Loss: 0.1229
Epoch [9/300], Loss: 0.0945
Epoch [10/300], Loss: 0.0752
Epoch [11/300], Loss: 0.0699
Epoch [12/300], Loss: 0.0709
Epoch [13/300], Loss: 0.0566
Epoch [14/300], Loss: 0.0525
Epoch [15/300], Loss: 0.0570
Epoch [16/300], Loss: 0.0384
Epoch [17/300], Loss: 0.0412
Epoch [18/300], Loss: 0.0412
Epoch [19/300], Loss: 0.0358
Epoch [20/300], Loss: 0.0365
Epoch [21/300], Loss: 0.0289
Epoch [22/300], Loss: 0.0307
Epoch [23/300], Loss: 0.0308
Epoch [24/300], Loss: 0.0272
Epoch [25/300], Loss: 0.0297
Epoch [26/300], Loss: 0.0291
Epoch [27/300], Loss: 0.0327
Epoch [28/300], Loss: 0.0300
Epoch [29/300], Loss: 0.0258
Epoch [30/300], Loss: 0.0221
Epoch [31/300], Loss: 0.0208
Epoch [32/300], Loss: 0.0256
Epoch [33/300], Loss: 0.0227
Epoch [34/300], Loss: 0.0183
Epoch [35/300], Loss: 

In [45]:
torch.save(model.state_dict(), "improved_gnn_model.pth")

In [117]:
test_data = helper.load_data_from_file("permutation_masked.json")
cleaned_data_test = data_preprocessing_testing.extract_clean_data(test_data)
test_dataset = MolecularGraphTest(cleaned_data_test)
test_loader = PyGDataLoader(test_dataset, batch_size=1, shuffle=False)

In [118]:
predictions = []

with torch.no_grad():
    for batch in test_loader:
        # Move data to the device
        batch = batch.to(device)

        # Forward pass (prediction)
        output = model(batch)

        # Append the predictions
        predictions.append(output.cpu().numpy())

In [119]:
predictions[0]

array([[ 1.58286982e+01, -2.81746477e-01,  2.74159253e-01,
         5.95377922e-01],
       [ 1.21445084e+01, -1.17456332e-01,  3.52593303e-01,
         3.21284503e-01],
       [ 1.19519997e+01, -2.52483077e-02,  3.40997577e-01,
         2.78204381e-01],
       [ 9.79994118e-01,  8.54815692e-02,  2.37182260e-01,
         1.16635114e-01],
       [ 1.21018267e+01, -1.53920308e-01,  3.57927710e-01,
         3.40657890e-01],
       [ 1.19478636e+01, -7.81530291e-02,  3.47847670e-01,
         2.97787905e-01],
       [ 9.18771029e-01,  4.36948031e-01,  4.28940915e-03,
        -1.28815174e-02],
       [ 9.70475018e-01,  5.03990315e-02,  2.48270303e-01,
         1.10584401e-01],
       [ 1.20845108e+01, -1.60441294e-01,  3.51223499e-01,
         2.81991839e-01],
       [ 9.57518160e-01,  1.08655617e-01,  2.45491982e-01,
         1.15989536e-01],
       [ 1.22063856e+01, -6.03287406e-02,  3.32707584e-01,
         3.14476371e-01],
       [ 9.70475018e-01,  5.03990315e-02,  2.48270303e-01,
      