In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import torch
from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
from torch_geometric.data import Dataset, Data
import torch.nn.functional as F 
import os.path as osp
torch.__version__

'2.3.0'

In [2]:
# Check if MPS (Apple GPU support) is available and use it; otherwise, use CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [3]:
import pandas as pd
import psutil

# Path to your Parquet file
# file_path = '/Users/leo/Desktop/Dataset/polyOne_Dataset/raw/polyOne_aa.parquet'
# file_path = '/Users/lpc_0066/Desktop/Dataset/polyOne_Dataset/raw/polyOne_aa.parquet'
# file_path = "/Users/lpc_0066/Desktop/Dataset/ZINC_250k/250k_rndm_zinc_drugs_clean_3.csv"
file_path = "/Users/leo/Desktop/Dataset/aqueous_solubility/raw/curated-solubility-dataset.csv"

# Step 1: Estimate Dataset Size
# Load the dataset
# full_df = pd.read_parquet(file_path, engine='pyarrow')
full_df = pd.read_csv(file_path)

# Take a sample of the dataset to estimate memory usage per row
sample_df = full_df.sample(n=1000)

# Estimate total memory usage
estimated_memory_per_row = sample_df.memory_usage(deep=True).sum() / len(sample_df)
total_rows = len(full_df)
estimated_total_memory = estimated_memory_per_row * total_rows

print(f"Estimated memory usage: {estimated_total_memory / (1024**3):.2f} GB")

# Step 2: Check Available Memory
# Get total and available memory
memory_info = psutil.virtual_memory()
total_memory = memory_info.total
available_memory = memory_info.available

print(f"Total memory: {total_memory / (1024**3):.2f} GB")
print(f"Available memory: {available_memory / (1024**3):.2f} GB")

# Step 3: Compare Dataset Size to Available Memory
if estimated_total_memory < available_memory:
    print("The dataset should fit into the available memory.")
else:
    print("The dataset is too large to fit into the available memory.")

Estimated memory usage: 0.01 GB
Total memory: 16.00 GB
Available memory: 5.84 GB
The dataset should fit into the available memory.


Build Dataset

In [4]:
import torch
from torch_geometric.data import Dataset
from typing import Callable, Any, Optional

# root = "/Users/leo/Desktop/Dataset/polyOne_Dataset"
# root = "/Users/lpc_0066/Desktop/Dataset/polyOne_Dataset/"
# root = "/Users/lpc_0066/Desktop/Dataset/ZINC_250k"
root = "/Users/leo/Desktop/Dataset/aqueous_solubility"

class GraphDataset(Dataset):
    """
    Only support molecule not polymer.
    """
    def __init__(self, 
                 root, 
                 filename, 
                 target_name, 
                 transform: Callable[..., Any] | None = None, 
                 pre_transform: Callable[..., Any] | None = None, 
                 pre_filter: Callable[..., Any] | None = None) -> None:
        
        self.filename = filename
        self.target_name = target_name
        super().__init__(root, transform, pre_transform, pre_filter)
 
    @property
    def raw_file_names(self):
        return self.filename
    
    @property
    def processed_file_names(self): # self.data should be defined here or error occurs cuz super() calls processed_file_names property.
        # self.data = pd.read_parquet(self.raw_paths[0])
        self.data = pd.read_csv(self.raw_paths[0])
        return [f"data_{i}.pt" for i in range(len(self.data))]
    
    def download(self):
        # Used when the file is downloaded from url
        pass

    def len(self):
        return self.data.shape[0]
    
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data.to(device)
    
    def process(self):
        for index, row in self.data.iterrows():
            mol = Chem.MolFromSmiles(row.smiles)
            node_feats = self.get_node_features(mol)
            edge_feats = self.get_edge_features(mol)
            edge_index = self.get_adjacency_info(mol)
            label = self.get_labels(row[self.target_name])
            structure_id = [[row.smiles]]
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        structure_id=structure_id)
            torch.save(data, osp.join(self.processed_dir, f'data_{index}.pt'))
            
    def get_node_features(self, mol):
        all_node_features = []
        for atom in mol.GetAtoms():
            node_features = F.one_hot(torch.tensor(atom.GetAtomicNum()-1), num_classes=118) 
            node_features = node_features.tolist()

            node_features.append(atom.GetDegree())
            node_features.append(atom.GetFormalCharge())
            node_features.append(atom.GetHybridization())
            node_features.append(atom.GetIsAromatic())
            node_features.append(atom.GetTotalNumHs())
            node_features.append(atom.GetNumRadicalElectrons())
            node_features.append(atom.IsInRing())
            node_features.append(atom.GetChiralTag())
            node_features.append(atom.GetMass())

            all_node_features.append(node_features)
        return torch.tensor(all_node_features, dtype=torch.float32)
            
    def get_edge_features(self, mol):
        all_edge_features = []
        for bond in mol.GetBonds():
            edge_features = []
            edge_features.append(torch.tensor(bond.GetBondTypeAsDouble()))
            edge_features.append(bond.IsInRing())
            edge_features.append(bond.GetStereo())
            edge_features.append(bond.GetIsConjugated())

            all_edge_features.append(edge_features)
        return torch.tensor(all_edge_features, dtype=torch.float32)
    
    def get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j,i]]

        edge_indices_tensor = torch.tensor(edge_indices, dtype=torch.int64)
        return edge_indices_tensor.T
    
    def get_labels(self, label):
        tensor_label = torch.tensor([[label]])
        return tensor_label

In [5]:
dataset = GraphDataset(root=root, filename="curated-solubility-dataset.csv", target_name="Solubility")

In [6]:
dataset[1345]

Data(x=[23, 127], edge_index=[2, 50], edge_attr=[25, 4], y=[1, 1], structure_id=[1])

In [7]:
from torch_geometric.loader import DataLoader
dataset = dataset.shuffle()
#学習データとテストデータに分割する。
from sklearn.model_selection import train_test_split
dataset_train, dataset_test = train_test_split(dataset, test_size=0.2)

#バッチサイズ
batch_size = 64
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

GCN model construction

In [10]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool

class GCN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 64, output_dim: int = 1):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, data) -> torch.Tensor:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index = edge_index.to(torch.int64)
        
        # GCN layers with ReLU activation
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.conv3(x, edge_index)

        # Global mean pooling
        x = global_mean_pool(x, batch)

        # Fully connected layers with dropout and ReLU
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x

In [20]:
from torch import optim
from tqdm import tqdm as tqdm

device_train = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device_eval = torch.device('cpu')

input_dim = dataset.num_node_features
gcn = GCN(input_dim)

optimizer = optim.AdamW(gcn.parameters(), lr=0.001)
criterion = nn.L1Loss()

# Lists to record losses
epoch_loss_train = []
epoch_loss_test = []

# Define training loop
def train(model, device, loader, optimizer, criterion):
    model.train()  # Set model to training mode
    total_loss = 0
    for data in tqdm(loader, desc="Training", leave=False):
        data = data.to(device)  # Move data to the training device
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Define evaluation loop
def evaluate(model, device, loader, criterion):
    model.to(device)  # Move model to evaluation device
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    with torch.inference_mode():
        for data in tqdm(loader, desc="Evaluating", leave=False):
            data = data.to(device)  # Move data to the evaluation device
            output = model(data)
            loss = criterion(output, data.y)
            total_loss += loss.item()
    return total_loss / len(loader)

for epoch in range(100):
    # Training phase
    gcn.to(device_train)  # Ensure model is on the training device
    train_loss = train(gcn, device_train, loader_train, optimizer, criterion)
    epoch_loss_train.append(train_loss)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}")

    # Evaluation phase
    gcn.to(device_eval)  # Ensure model is on the evaluat
    eval_loss = evaluate(gcn, device_eval, loader_test, criterion)
    epoch_loss_test.append(eval_loss)
    print(f"Epoch {epoch+1}, Eval Loss: {eval_loss}")

Training:   0%|          | 0/125 [00:00<?, ?it/s]

                                                           

Epoch 1, Train Loss: 1.9313265705108642


                                                           

Epoch 1, Eval Loss: 1.8131747245788574


                                                           

Epoch 2, Train Loss: 1.7251832084655763


                                                           

Epoch 2, Eval Loss: 1.6634995862841606


                                                           

Epoch 3, Train Loss: 1.6311508340835572


                                                           

Epoch 3, Eval Loss: 1.5504405684769154


                                                           

Epoch 4, Train Loss: 1.5347203903198243


                                                           

Epoch 4, Eval Loss: 1.50291109085083


                                                           

Epoch 5, Train Loss: 1.4593891878128051


                                                           

Epoch 5, Eval Loss: 1.3959333561360836


                                                           

Epoch 6, Train Loss: 1.3619255032539368


                                                           

Epoch 6, Eval Loss: 1.3377492055296898


                                                           

Epoch 7, Train Loss: 1.3138109979629518


                                                           

Epoch 7, Eval Loss: 1.3465438298881054


                                                           

KeyboardInterrupt: 

Learning Curve

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(epoch_loss_train)), epoch_loss_train, label="Train")
plt.plot(range(len(epoch_loss_test)), epoch_loss_test, label="Test")
plt.legend()

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

Test performance

In [None]:
gcn.eval()  # ネットワークを推論モードに切り替える
predictions = []
labels = []

with torch.inference_mode():
    for data in loader_test:
        label = data.y
        output = gcn(data)  # ネットワークにテストデータを入力して予測結果を取得
        labels.append(label)
        predictions.append(output)

predictions = torch.cat(predictions, dim=0)  # 予測結果を結合して1つのテンソルにする
labels = torch.cat(labels, dim=0)
plt.scatter(labels.float(), predictions.float())
plt.xlabel('Measured')
plt.ylabel('Predicted')
plt.show()