## Know your libraries (KYL)

In [None]:
from dataset import QM9Dataset,LogSDataset,LogPDataset
import numpy as np
import pandas as pd
from utils import *
import torch
from tqdm import tqdm
from model import Autoencoder # Simply importing the autoencoder model module from the model.py file
from model import GNN3D

import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import MolsToGridImage

import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Split your dataset (SYD)

In [None]:
# Load dataset from CSV
dataset = pd.read_csv("../data/qm9_dipole/qm9.csv")

# Split dataset into train and test
train_dataset, test_dataset = split_dataset(dataset, 0.8)

# Write train_dataset and test_dataset to CSV files
train_dataset.to_csv("../data/qm9_dipole/train.csv", index=False)
test_dataset.to_csv("../data/qm9_dipole/test.csv", index=False)

print("Train and test datasets saved successfully.")

## Process your data (PYD)

In [None]:
train_samples = QM9Dataset("../data/qm9_dipole/train")
print(train_samples)
print("===================================")
test_samples = QM9Dataset("../data/qm9_dipole/test")
print(test_samples)

## Know your featues (KYF)

In [None]:
# Printing out the dimensions of all of these features with a description of what each feature is
print(f"Atomic Features: {(train_samples[0])[0].shape} - This represents the atomic features of the molecule")
print(f"Bond Features: {(train_samples[0])[1].shape} - This represents the bond features of the molecule")
print(f"Angle Features: {(train_samples[0])[2].shape} - This represents the angle features of the molecule")
print(f"Dihedral Features: {(train_samples[0])[3].shape} - This represents the dihedral features of the molecule")
print(f"Global Molecular Features: {(train_samples[0])[4].shape} - This represents the global molecular features of the molecule")
print(f"Bond Indices: {(train_samples[0])[5].shape} - This represents the bond indices of the molecule")
print(f"Angle Indices: {(train_samples[0])[6].shape} - This represents the angle indices of the molecule")
print(f"Dihedral Indices: {(train_samples[0])[7].shape} - This represents the dihedral indices of the molecule")
print(f"Target: {(train_samples[0])[8].shape} - This represents the target of the molecule")

## Know your modules (KYM)

In [None]:
def train_autoencoder_model(samples, atom_autoencoder, bond_autoencoder, n_epochs=10, printstep=10, save_dir="./models/", dataset_name="logs"):
    """Train Autoencoder: Atom and bond"""
    # Set the model in training mode
    atom_autoencoder.train()
    bond_autoencoder.train()

    # Create the directory if it does not exist
    os.makedirs(save_dir, exist_ok=True)

    # Lists to store losses
    avg_atom_rmse_losses = []
    avg_bond_rmse_losses = []

    for epoch_i in range(n_epochs):
        avg_atom_rmse_loss = 0
        avg_bond_rmse_loss = 0
        total_samples = 0
        
        for i, molecule in enumerate(samples):
            atom_features = molecule[0].to(device)
            bond_features = molecule[1].to(device)
            
            # Forward pass
            atom_features_reconstructed = atom_autoencoder(atom_features)
            bond_features_reconstructed = bond_autoencoder(bond_features)
            
            # Calculating loss
            atom_loss = mse_loss_fn(atom_features_reconstructed, atom_features)
            bond_loss = mse_loss_fn(bond_features_reconstructed, bond_features)
            
            # Backward pass and optimization step
            atom_optimizer.zero_grad()
            bond_optimizer.zero_grad()
            atom_loss.backward()
            bond_loss.backward()
            atom_optimizer.step()
            bond_optimizer.step()
            
            # Calculating average loss
            avg_atom_rmse_loss = (avg_atom_rmse_loss * total_samples + torch.sqrt(atom_loss).item()) / (total_samples + 1)
            avg_bond_rmse_loss = (avg_bond_rmse_loss * total_samples + torch.sqrt(bond_loss).item()) / (total_samples + 1)           
            total_samples += 1

            
            if (i % printstep == 0):
                print(f"Epoch: {epoch_i + 1:>3}/{n_epochs} | Samples: {i + 1:>6}/{len(samples)} | Atom Avg. RMSE Loss: {avg_atom_rmse_loss:.4f} | Bond Avg. RMSE Loss: {avg_bond_rmse_loss:.4f}")
        
        # Append average losses for the epoch to lists
        avg_atom_rmse_losses.append(avg_atom_rmse_loss)
        avg_bond_rmse_losses.append(avg_bond_rmse_loss)
                
    # Save the accumulated losses to a file
    losses_file = os.path.join(save_dir, f"autoencoder_losses_{dataset_name}.txt")
    with open(losses_file, 'w') as f:
        f.write(f"Epoch\tAvg Atom RMSE Loss\tAvg Bond RMSE Loss\n")
        for epoch_i in range(n_epochs):
            f.write(f"{epoch_i+1}\t{avg_atom_rmse_losses[epoch_i]}\t{avg_bond_rmse_losses[epoch_i]}\n")

    # Save model state
    torch.save(atom_autoencoder.state_dict(), f"{save_dir}atom_autoencoder_{dataset_name}.pth")
    torch.save(bond_autoencoder.state_dict(), f"{save_dir}bond_autoencoder_{dataset_name}.pth")
    
    return avg_atom_rmse_losses, avg_bond_rmse_losses
    



def process_molecule(molecule, atom_autoencoder, bond_autoencoder):
    """Separate input features and target"""
    # Ensure molecule has at least 8 elements
    if len(molecule) < 8:
        return None, None
    
    target = molecule[8].to(device)  # Assuming target is the 9th element (index 8)
    # Assuming first 8 elements are molecule graph input features:
    # atomic_features, bond_features, angle_features, dihedral_features, 
    # global_molecular_features, bond_indices, angle_indices,  dihedral_indices    
    molecule_data = [mol_elem.to(device) for mol_elem in molecule[:8]]
    
    # Putting latent atomic and bond features through encoders
    molecule_data[0] = atom_autoencoder.encode(molecule_data[0])
    molecule_data[1] = bond_autoencoder.encode(molecule_data[1])
    
    return molecule_data, target




def train_gnn3d_model(gnn3d, samples, atom_autoencoder, bond_autoencoder, n_epochs=10, printstep=10, save_dir="./models/", dataset_name="logs"):
    """Train GNN3D"""
    # Set the model in training mode
    gnn3d.train()

    # Create the directory if it does not exist
    os.makedirs(save_dir, exist_ok=True)
   
    # Lists to store losses
    avg_rmse_losses = []
    avg_mse_losses = []

    for epoch_i in range(n_epochs):
        avg_rmse_loss = 0
        avg_mse_loss = 0
        total_samples = 0
        
        for i in tqdm(range(len(samples)), desc='Processing samples'):
            molecule_data, target = process_molecule(samples[i], atom_autoencoder, bond_autoencoder)
            if molecule_data is None:
                continue
            
            # Forward pass
            prediction = gnn3d(molecule_data)
            
            # Calculating loss
            loss = mse_loss_fn(prediction, target)
            
            # Backward pass
            gnn_optimizer.zero_grad()
            loss.backward()
            gnn_optimizer.step()
            
            # Calculating average loss
            avg_rmse_loss = (avg_rmse_loss * total_samples + torch.sqrt(loss).item()) / (total_samples + 1)
            avg_mse_loss = (avg_mse_loss * total_samples + loss.item()) / (total_samples + 1)
            total_samples += 1
            
            if (i % printstep == 0):
                print(f"Epoch: {epoch_i:>3} | Samples: {i:>6}/{len(samples)} | Avg. RMSE Loss: {avg_rmse_loss:.4f} | Avg. MSE Loss: {avg_mse_loss:.4f}")

                # Append average losses for the epoch to lists
        avg_rmse_losses.append(avg_rmse_loss)
        avg_mse_losses.append(avg_mse_loss)
                
    # Save the accumulated losses to a file
    losses_file = os.path.join(save_dir, f"gnn3d_losses_{dataset_name}.txt")
    with open(losses_file, 'w') as f:
        f.write(f"Epoch\tAvg RMSE Loss\tAvg RMSE Loss\n")
        for epoch_i in range(n_epochs):
            f.write(f"{epoch_i+1}\t{avg_rmse_losses[epoch_i]}\t{avg_mse_losses[epoch_i]}\n")

    # Save model state
    torch.save(gnn3d.state_dict(), f"{save_dir}gnn3d_{dataset_name}.pth")
    
    return avg_rmse_losses, avg_mse_losses




def evaluate_gnn3d_model(gnn3d, samples, atom_autoencoder, bond_autoencoder):
    """Evaluate GNN3D"""
    # Set the model in evaluation mode
    gnn3d.eval()

    avg_rmse_loss = 0
    avg_mse_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for i, molecule in enumerate(samples):
            molecule_data, target = process_molecule(molecule, atom_autoencoder, bond_autoencoder)
            if molecule_data is None:
                continue
            
            # Forward pass
            prediction = gnn3d(molecule_data)
            
            # Calculating loss
            loss = mse_loss_fn(prediction, target)
            
            # Accumulate losses
            avg_rmse_loss += (torch.sqrt(loss).item())
            avg_mse_loss += loss.item()
            total_samples += 1
    
    # Calculate average losses
    avg_rmse_loss /= total_samples
    avg_mse_loss /= total_samples
    
    print(f"Evaluation Results:")
    print(f"Avg. RMSE Loss: {avg_rmse_loss:.4f}")
    print(f"Avg. MSE Loss: {avg_mse_loss:.4f}")
    
    return avg_rmse_loss, avg_mse_loss

## Training

### Autoencoder training

In [None]:
atom_autoencoder = Autoencoder(154, 10).to(device)
bond_autoencoder = Autoencoder(10, 3).to(device)
mse_loss_fn = torch.nn.MSELoss()
atom_optimizer = torch.optim.Adam(atom_autoencoder.parameters())
bond_optimizer = torch.optim.Adam(bond_autoencoder.parameters())

atom_losses, bond_losses = train_autoencoder_model(train_samples, atom_autoencoder, bond_autoencoder, n_epochs=5, printstep=500, save_dir="./models/qm9_dipole/", dataset_name="qm9")

In [None]:
# Plotting the losses (optional)
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(atom_losses)+1), atom_losses, marker='o', linestyle='-', color='b', label='Avg. Atom RMSE Loss')
#plt.plot(range(1, len(bond_losses)+1), bond_losses, marker='o', linestyle='-', color='r', label='Avg. Bond RMSE Loss')
plt.title(f"Atom Autoencoder Training Losses for qm9_dipole")
plt.xlabel("Epoch")
plt.ylabel("RMSE Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"./models/qm9_dipole/atom_autoencoder_losses_qm9_dipole.png")
plt.show()

# Plotting the losses (optional)
plt.figure(figsize=(10, 6))
#plt.plot(range(1, len(atom_losses)+1), atom_losses, marker='o', linestyle='-', color='b', label='Avg. Atom RMSE Loss')
plt.plot(range(1, len(bond_losses)+1), bond_losses, marker='o', linestyle='-', color='r', label='Avg. Bond RMSE Loss')
plt.title(f"Bond Autoencoder Training Losses for qm9_dipole")
plt.xlabel("Epoch")
plt.ylabel("RMSE Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"./models/qm9_dipole/bond_autoencoder_losses_logp.png")
plt.show()

### GNN3D training

In [None]:
atom_autoencoder = Autoencoder(154, 10).to(device)
bond_autoencoder = Autoencoder(10, 3).to(device)
mse_loss_fn = torch.nn.MSELoss()
gnn3d = GNN3D(atomic_vector_size=10, bond_vector_size=3, number_of_molecular_features=200, number_of_targets=1).to(device)
gnn_optimizer = torch.optim.Adam(gnn3d.parameters())
rmse_losses, mse_losses = train_gnn3d_model(gnn3d, train_samples, atom_autoencoder, bond_autoencoder, n_epochs=5, printstep=500, save_dir="./models/qm9_dipole/", dataset_name="qm9")

In [None]:
# Assuming rmse_losses is the list of RMSE losses obtained from training
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(rmse_losses)+1), rmse_losses, marker='o', linestyle='-', color='b', label='Avg. RMSE Loss')
plt.title(f"GNN3D Training Losses for qm9_dipole")
plt.xlabel("Epoch")
plt.ylabel("RMSE Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"./models/qm9_dipole/gnn3d_rmse_losses_qm9_dipole.png")
plt.show()

## Testing

In [None]:
evaluate_gnn3d_model(
    gnn3d, 
    test_samples, 
    atom_autoencoder, 
    bond_autoencoder
)

## Know your predictions (KYP)

In [None]:
def molecule_from_smiles(smiles):
    """Convert SMILES to rdkit.Mol with 2D coordinates"""
    molecule = Chem.MolFromSmiles(smiles, sanitize=False)
    flag = Chem.SanitizeMol(molecule, catchErrors=True)
    if flag != Chem.SanitizeFlags.SANITIZE_NONE:
        Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)

    Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
    return molecule

def smi2conf_etkdg(smiles):
    '''Convert SMILES to rdkit.Mol with 3D coordinates
       Use Experimental Torsion angles Knowledge-based Distance Geometry ETKDGv3
       method - version 3 (small rings)
       params.numThreads = 0 uses max no. of threads allowed on a conputer
    '''
    params = Chem.rdDistGeom.srETKDGv3()
    params.randomSeed = 12412
    params.clearConfs = True
    params.numThreads = 0
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.EmbedMultipleConfs(mol, numConfs=500, params=params)
        
        return mol
    else:
        return None



# Set the model in evaluation mode
gnn3d.eval()

# List to store predictions
predictions = []

# Iterate over test samples
for molecule in test_samples:
    # Assuming process_molecule function converts each molecule list to tensor
    molecule_data, _ = process_molecule(molecule, atom_autoencoder, bond_autoencoder)
    
    if molecule_data is None:
        continue  # Skip if molecule_data is None
    
    # Move molecule_data to device if necessary
    molecule_data = [elem.to(device) for elem in molecule_data]
    
    # Forward pass
    with torch.no_grad():
        prediction = gnn3d(molecule_data)
    
    # Extract prediction, assuming single scalar output prediction
    prediction_value = prediction.item()
    
    # Store prediction
    predictions.append(prediction_value)



# Load test_samples from CSV
dataset = pd.read_csv("../data/qm9/test.csv")
molecules = [molecule_from_smiles(dataset.smiles.values[i]) for i in range(len(test_samples))]
y_true = [dataset.dipole.values[i] for i in range(len(test_samples))]
y_pred = [predictions[i] for i in range(len(test_samples))]

legends = [f"y_true/y_pred = {y_true[i]:.2f}/{y_pred[i]:.2f}" for i in range(len(y_true))]
MolsToGridImage(molecules, molsPerRow=5, maxMols=len(test_samples),legends=legends)

## Future task: What is so unique about dihedral angles?