<a href="https://colab.research.google.com/github/yuvipaloozie/QM9-GNN/blob/main/QM9_GNN_2D_Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Installations
!pip install torch torch-geometric
!pip install rdkit
!pip install qm9pack
!pip install py3Dmol
!pip install pandas numpy matplotlib seaborn tqdm

# Core Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Cheminformatics & DL
import qm9pack
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import py3Dmol

In [None]:
# Basic EDA

print("Loading data from qm9pack... ")
df = qm9pack.get_data('qm9')
print(f"Data shape: {df.shape}\n")
print("Data Head:")
display(df.head())

print("\nData Info:")
df.info()

TARGET_PROPERTY = 'HOMO_LUMO_gap_au'

original_size = len(df)
df_clean = df.dropna(subset=['SMILES', TARGET_PROPERTY])
cleaned_size = len(df_clean)

print(f"Original number of molecules: {original_size}")
print(f"Cleaned number of molecules (after dropping NaNs): {cleaned_size}")

for key in df_clean.keys():
    qm9pack.helper(key)

In [None]:
# Visualizations
plt.figure(figsize=(10, 6))
sns.histplot(df_clean[TARGET_PROPERTY], kde=True, bins=100)
plt.title(f'Distribution of {TARGET_PROPERTY}', fontsize=16)
plt.xlabel('HOMO-LUMO Gap (au)', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.show()

# Get summary statistics for the target
print(df_clean[TARGET_PROPERTY].describe())

# Select a subset of key numerical properties
numerical_cols = [
    'N_atoms', 'Dipole_debye', 'Polarizability_bohr3',
    'HOMO_au', 'LUMO_au', 'HOMO_LUMO_gap_au',
    'InternalEnergy_0K_au', 'Heatcapacity_Cv_cal_mol_K'
]

# Calculate the correlation matrix
corr_matrix = df_clean[numerical_cols].corr()

# Plot the heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='rocket')
plt.title('Correlation Heatmap of Key Properties', fontsize=16)
plt.show()

In [None]:
# Visualize molecules

random_row = df_clean.sample(n=1)
smiles_string = random_row['SMILES'].iloc[0]
molecule_index = random_row['Index'].iloc[0]
gap_value = random_row[TARGET_PROPERTY].iloc[0]

print(f"--- Displaying Random Molecule (Index: {molecule_index}) ---")
print(f"SMILES: {smiles_string}")
print(f"HOMO-LUMO Gap: {gap_value:.4f} au\n")

# 2D Visualization (RDKit)
print("2D Structure:")
mol_2d = Chem.MolFromSmiles(smiles_string)
display(Draw.MolToImage(mol_2d, size=(300, 300)))

# 3D Visualization (py3Dmol)
print("\n3D Interactive Structure:")
mol_3d = Chem.MolFromSmiles(smiles_string)
mol_3d = Chem.AddHs(mol_3d)
AllChem.EmbedMolecule(mol_3d, AllChem.ETKDG())
AllChem.MMFFOptimizeMolecule(mol_3d)
mblock = Chem.MolToMolBlock(mol_3d)

view = py3Dmol.view(width=500, height=400)
view.addModel(mblock, 'mol')
view.setStyle({'stick':{}, 'sphere': {'scale':0.3}})
view.zoomTo()
view.show()

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_smiles
from tqdm.notebook import tqdm
import numpy as np


print("Converting SMILES to graphs for GCN model...")
data_list = []
for index, row in tqdm(df_clean.iterrows(), total=df_clean.shape[0]):
    try:
        data = from_smiles(row['SMILES'])
        data.y = torch.tensor([[row[TARGET_PROPERTY]]], dtype=torch.float)
        data_list.append(data)
    except Exception as e:
        pass

print(f"Successfully converted {len(data_list)} molecules.")

# Create Train/Validation/Test Splits and DataLoaders
torch.manual_seed(42)
data_list = sorted(data_list, key=lambda x: torch.rand(1)) # Shuffle
train_size = int(0.8 * len(data_list))
val_size = int(0.1 * len(data_list))
test_size = len(data_list) - train_size - val_size

train_data = data_list[:train_size]
val_data = data_list[train_size:train_size + val_size]
test_data = data_list[train_size + val_size:]

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

# Use a standard batch size for this scalar regression task
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

class BaselineGCN(torch.nn.Module):
    def __init__(self, feature_size, model_dim):
        super().__init__()
        torch.manual_seed(42)
        # GCNConv is the classic, simpler graph convolution
        self.conv1 = GCNConv(feature_size, model_dim)
        self.conv2 = GCNConv(model_dim, model_dim)
        self.conv3 = GCNConv(model_dim, model_dim)

        # Readout layers
        self.linear1 = Linear(model_dim, 128)
        self.linear2 = Linear(128, 1) # Output is 1 value

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch


        x = x.float()
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        x = F.elu(self.conv3(x, edge_index))


        x = global_mean_pool(x, batch)

        x = F.elu(self.linear1(x))
        x = self.linear2(x) # Final prediction
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"--- Using device: {device} ---")

model_dim = 64
feature_size = data_list[0].num_features

gcn_model = BaselineGCN(feature_size, model_dim).to(device)
optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.001)
# We use L1Loss (Mean Absolute Error) as it's robust to outliers
loss_fn = torch.nn.L1Loss()

def train_gcn(model, loader):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training", leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

def eval_gcn(model, loader):
    model.eval()
    total_mae = 0
    with torch.no_grad():
        for data in tqdm(loader, desc="Evaluating", leave=False):
            data = data.to(device)
            out = model(data)
            mae = loss_fn(out, data.y)
            total_mae += mae.item() * data.num_graphs
    return total_mae / len(loader.dataset)

import matplotlib.pyplot as plt

train_loss_history = []
val_loss_history = []


print("--- Starting GCN Model Training (Tracking Loss) ---")
num_epochs = 25
for epoch in range(1, num_epochs + 1):
    train_loss = train_gcn(gcn_model, train_loader)
    val_mae = eval_gcn(gcn_model, val_loader)

    train_loss_history.append(train_loss)
    val_loss_history.append(val_mae)

    print(f'Epoch: {epoch:02d}, Train Loss (MAE): {train_loss:.4f}, Val MAE: {val_mae:.4f}')

print("\n--- GCN Model Training Complete ---")

plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_loss_history, label='Training Loss (MAE)')
plt.plot(range(1, num_epochs + 1), val_loss_history, label='Validation Loss (MAE)')
plt.title('GCN: Training vs. Validation Loss', fontsize=16)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss (Mean Absolute Error)', fontsize=12)
plt.legend()
plt.grid(True)
plt.show()
test_mae_gcn = eval_gcn(gcn_model, test_loader)
print(f'Final Test MAE for GCN Model: {test_mae_gcn:.4f} au')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score

def get_predictions(model, loader):
    """Utility function to get predictions and labels from the test set."""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in tqdm(loader, desc="Getting Predictions", leave=False):
            data = data.to(device)
            out = model(data)

            all_preds.append(out.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()
    return all_preds, all_labels

print("Gathering predictions from the GCN model...")
y_pred_gcn, y_true = get_predictions(gcn_model, test_loader)

print("Predictions and labels gathered. Ready for plotting.")

r2_gcn = r2_score(y_true, y_pred_gcn)
print(f'GCN Model R-squared (R²): {r2_gcn:.4f}')

plt.figure(figsize=(10, 10))

sns.scatterplot(x=y_true, y=y_pred_gcn, alpha=0.3, s=20, label='Predictions')
min_val = min(y_true.min(), y_pred_gcn.min())
max_val = max(y_true.max(), y_pred_gcn.max())
plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', lw=2, label='Parity Line')

plt.text(min_val + (max_val - min_val) * 0.05,  # 5% from the left
         max_val - (max_val - min_val) * 0.05,  # 5% from the top
         f'$R^2 = {r2_gcn:.4f}$',
         fontsize=14,
         va='top', # Vertical alignment
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.5)) # Add a semi-transparent box

plt.title('GCN: Predicted vs. Actual (with $R^2$ Score)', fontsize=16)
plt.xlabel('Actual HOMO-LUMO Gap (au)', fontsize=12)
plt.ylabel('Predicted HOMO-LUMO Gap (au)', fontsize=12)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend()
plt.show()

residuals_gcn = y_true - y_pred_gcn

plt.figure(figsize=(12, 7))


sns.scatterplot(x=y_pred_gcn, y=residuals_gcn, alpha=0.3, s=20)

plt.axhline(y=0, color='red', linestyle='--', lw=2)

plt.title('GCN: Residual Plot', fontsize=16)
plt.xlabel('Predicted HOMO-LUMO Gap (au)', fontsize=12)
plt.ylabel('Residual (Actual - Predicted)', fontsize=12)
plt.show()

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GATv2Conv, global_mean_pool
from tqdm.notebook import tqdm
import numpy as np


class BaselineGAT(torch.nn.Module):
    def __init__(self, feature_size, model_dim):
        super().__init__()
        torch.manual_seed(42)
        # GATv2 with 4 attention heads
        self.conv1 = GATv2Conv(feature_size, model_dim, heads=4)
        self.conv2 = GATv2Conv(model_dim * 4, model_dim, heads=4)
        self.conv3 = GATv2Conv(model_dim * 4, model_dim, heads=4)

        # Readout layers
        self.linear1 = Linear(model_dim * 4, 128)
        self.linear2 = Linear(128, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = x.float()
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        x = F.elu(self.conv3(x, edge_index))

        x = global_mean_pool(x, batch)

        x = F.elu(self.linear1(x))
        x = self.linear2(x)
        return x


model_dim = 64
feature_size = data_list[0].num_features

gat_model = BaselineGAT(feature_size, model_dim).to(device)
optimizer_gat = torch.optim.Adam(gat_model.parameters(), lr=0.001)
loss_fn_gat = torch.nn.L1Loss()

def train_gat(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training", leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

def eval_gat(model, loader, loss_fn):
    model.eval()
    total_mae = 0
    with torch.no_grad():
        for data in tqdm(loader, desc="Evaluating", leave=False):
            data = data.to(device)
            out = model(data)
            mae = loss_fn(out, data.y)
            total_mae += mae.item() * data.num_graphs
    return total_mae / len(loader.dataset)


print("--- Starting GATv2 Model Training (Advanced Baseline) ---")
num_epochs = 25

# Lists to store loss history for plotting
gat_train_loss_history = []
gat_val_loss_history = []

for epoch in range(1, num_epochs + 1):
    train_loss = train_gat(gat_model, train_loader, optimizer_gat, loss_fn_gat)
    val_mae = eval_gat(gat_model, val_loader, loss_fn_gat)

    # Save history
    gat_train_loss_history.append(train_loss)
    gat_val_loss_history.append(val_mae)

    print(f'Epoch: {epoch:02d}, Train Loss (MAE): {train_loss:.4f}, Val MAE: {val_mae:.4f}')

print("\n--- GATv2 Model Training Complete ---")
test_mae_gat = eval_gat(gat_model, test_loader, loss_fn_gat)
print(f'Final Test MAE for GATv2 Model: {test_mae_gat:.4f} au')
try:
    print(f'Final Test MAE for GCN Model:  {test_mae_gcn:.4f} au')
except NameError:
    print("Run GCN model cell to see GCN MAE.")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), gat_train_loss_history, label='Training Loss (MAE)')
plt.plot(range(1, num_epochs + 1), gat_val_loss_history, label='Validation Loss (MAE)')
plt.title('GATv2: Training vs. Validation Loss', fontsize=16)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss (Mean Absolute Error)', fontsize=12)
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score

def get_predictions(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in tqdm(loader, desc="Getting Predictions", leave=False):
            data = data.to(device)
            out = model(data)
            all_preds.append(out.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()
    return all_preds, all_labels


print("Gathering predictions from the GATv2 model...")
y_pred_gat, y_true_gat = get_predictions(gat_model, test_loader)


r2_gat = r2_score(y_true_gat, y_pred_gat)
print(f'GATv2 Model R-squared (R²): {r2_gat:.4f}')


plt.figure(figsize=(10, 10))
sns.scatterplot(x=y_true_gat, y=y_pred_gat, alpha=0.3, s=20, label='GAT Predictions')

min_val = min(y_true_gat.min(), y_pred_gat.min())
max_val = max(y_true_gat.max(), y_pred_gat.max())
plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', lw=2, label='Parity Line')

plt.text(min_val + (max_val - min_val) * 0.05,
         max_val - (max_val - min_val) * 0.05,
         f'$R^2 = {r2_gat:.4f}$',
         fontsize=14, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

plt.title('GATv2: Predicted vs. Actual (with $R^2$ Score)', fontsize=16)
plt.xlabel('Actual HOMO-LUMO Gap (au)', fontsize=12)
plt.ylabel('Predicted HOMO-LUMO Gap (au)', fontsize=12)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend()
plt.show()


residuals_gat = y_true_gat - y_pred_gat
plt.figure(figsize=(12, 7))
sns.scatterplot(x=y_pred_gat, y=residuals_gat, alpha=0.3, s=20)
plt.axhline(y=0, color='red', linestyle='--', lw=2)
plt.title('GATv2: Residual Plot', fontsize=16)
plt.xlabel('Predicted HOMO-LUMO Gap (au)', fontsize=12)
plt.ylabel('Residual (Actual - Predicted)', fontsize=12)
plt.show()

In [None]:
# Re-create data_list, this time adding the .smiles attribute
print("Re-processing data to include SMILES strings for analysis...")

data_list_with_smiles = []
for index, row in tqdm(df_clean.iterrows(), total=df_clean.shape[0]):
    try:
        data = from_smiles(row['SMILES'])
        data.y = torch.tensor([[row[TARGET_PROPERTY]]], dtype=torch.float)
        data.smiles = row['SMILES']
        data_list_with_smiles.append(data)
    except Exception as e:
        pass

# We use the same seed to ensure the shuffle is identical to before
torch.manual_seed(42)
data_list_shuffled = sorted(data_list_with_smiles, key=lambda x: torch.rand(1))

train_size = int(0.8 * len(data_list_shuffled))
val_size = int(0.1 * len(data_list_shuffled))
test_size = len(data_list_shuffled) - train_size - val_size

analysis_test_data = data_list_shuffled[train_size + val_size:]

analysis_test_loader = DataLoader(analysis_test_data, batch_size=64, shuffle=False)

print(f"Created new test loader with {len(analysis_test_data)} molecules.")

In [None]:
import pandas as pd
import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem


def get_predictions_with_details(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    all_smiles = []

    with torch.no_grad():
        for data in tqdm(loader, desc="Getting Predictions", leave=False):
            data = data.to(device)
            out = model(data)

            all_preds.append(out.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())
            # We can access the .smiles attribute from the batch
            all_smiles.extend(data.smiles)

    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()
    return all_preds, all_labels, all_smiles


y_pred, y_true, smiles = get_predictions_with_details(gat_model, analysis_test_loader)


error_df = pd.DataFrame({
    'SMILES': smiles,
    'Actual_Gap': y_true,
    'Predicted_Gap': y_pred
})
error_df['Absolute_Error'] = (error_df['Actual_Gap'] - error_df['Predicted_Gap']).abs()


error_df_sorted = error_df.sort_values(by='Absolute_Error', ascending=False)

print("--- GATv2 Model Failure Analysis ---")
display(error_df_sorted.head(10))

print("--- Visualizing Top 10 Worst Predictions ---")
for index, row in error_df_sorted.head(10).iterrows():
    print(f"--- Molecule Rank #{index + 1} ---")
    print(f"SMILES: {row['SMILES']}")
    print(f"Actual Gap: {row['Actual_Gap']:.4f} au")
    print(f"Predicted Gap: {row['Predicted_Gap']:.4f} au")
    print(f"Absolute Error: {row['Absolute_Error']:.4f} au")

    # Generate 3D model
    mol = Chem.MolFromSmiles(row['SMILES'])
    mol = Chem.AddHs(mol)


    embed_code = AllChem.EmbedMolecule(mol, AllChem.ETKDG())

    if embed_code == 0:
        try:

            AllChem.MMFFOptimizeMolecule(mol)

            mblock = Chem.MolToMolBlock(mol)

            view = py3Dmol.view(width=500, height=400)
            view.addModel(mblock, 'mol')
            view.setStyle({'stick':{}, 'sphere': {'scale':0.3}})
            view.zoomTo()
            view.show()

        except ValueError as e:
            print(f"  -> WARNING: MMFFOptimize failed: {e}")
            print("  -> Skipping 3D visualization for this molecule.")
    else:
        # If embed_code was -1 (failure)
        print(f"  -> WARNING: RDKit's EmbedMolecule failed to generate 3D conformer.")
        print("  -> Skipping 3D visualization for this molecule.")

    print("-" * 20) # Separator