In [22]:
import pandas as pd
from rdkit import Chem
from rich import print

# Load first 100 rows of energy CSV
energy_df = pd.read_csv('gdb9/gdb9.sdf.csv', nrows=100)

# Load first 100 molecules from SDF
supplier = Chem.SDMolSupplier('gdb9/gdb9.sdf', removeHs=False)

# We'll collect matched molecules here
matched_mols = []

for i, mol in enumerate(supplier):
    if mol is None:
        continue
    if i >= 100:
        break

    mol_id = mol.GetProp('_Name') if mol.HasProp('_Name') else None
    if mol_id is None:
        continue

    # Find matching energy row
    energy_row = energy_df[energy_df['mol_id'] == mol_id]

    if not energy_row.empty:
        u0 = energy_row.iloc[0]['u0']
        mol.SetProp('U0', str(u0))
        matched_mols.append(mol)

# Print results to check
for mol in matched_mols:
    print(mol.GetProp('_Name'), mol.GetProp('U0'))


gdb_1 -40.47893
gdb_2 -56.525887
gdb_3 -76.404702
gdb_4 -77.308427
gdb_5 -93.411888
gdb_6 -114.483613
gdb_7 -79.764152
gdb_8 -115.679136
gdb_9 -116.609549
gdb_10 -132.71815
gdb_11 -153.787612
gdb_12 -169.860788
gdb_13 -119.052475
gdb_14 -154.972731
gdb_15 -154.960361
gdb_16 -117.824798
gdb_17 -153.742562
gdb_18 -193.08834
gdb_19 -209.159302
gdb_20 -225.221461
gdb_21 -158.342346
gdb_22 -194.267232
gdb_23 -153.459846
gdb_24 -169.557758
gdb_25 -185.648533
gdb_26 -190.624631
gdb_27 -206.721858
gdb_28 -227.798785
gdb_29 -155.908941
gdb_30 -155.897345
gdb_31 -172.006141
gdb_32 -188.042067
gdb_33 -191.810916
gdb_34 -207.916786
gdb_35 -193.075202
gdb_36 -209.144909
gdb_37 -229.013797
gdb_38 -228.992613
gdb_39 -158.340943
gdb_40 -194.261089
gdb_41 -194.254127
gdb_42 -230.183076
gdb_43 -157.116735
gdb_44 -193.039603
gdb_45 -173.147782
gdb_46 -193.034988
gdb_47 -157.115484
gdb_48 -193.034094
gdb_49 -248.375248
gdb_50 -210.101789
gdb_51 -226.160842
gdb_52 -229.969129
gdb_53 -246.02915
gdb_54 -197.

In [37]:
import pandas as pd
import torch
from rdkit import Chem
from torch_geometric.data import Data, Batch
import copy

import sys
import os

N = 10000

# Add CGCF folder to sys.path, so imports like 'models.cnf_edge...' work
cgcf_path = os.path.abspath('CGCF')
if cgcf_path not in sys.path:
    sys.path.insert(0, cgcf_path)


# Your model loading code (adjust paths if needed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load('CGCF/models/ckpt_drugs.pt', map_location=device)
args = checkpoint['args']

from models.cnf_edge.spectral_norm import add_spectral_norm
from models.edgecnf import EdgeCNF
from utils.misc import seed_all
from utils.transforms import get_standard_transforms

seed_all(args.seed)
tf = get_standard_transforms(order=args.aux_edge_order)

model = EdgeCNF(args).to(device)
if args.spectral_norm:
    add_spectral_norm(model)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Helper functions from your code

def mol_to_data_obj(mol):
    mol = copy.deepcopy(mol)

    atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
    pos = mol.GetConformer().GetPositions()

    edge_index = []
    edge_feats = []

    bond_type_to_int = {
        Chem.rdchem.BondType.SINGLE: 0,
        Chem.rdchem.BondType.DOUBLE: 1,
        Chem.rdchem.BondType.TRIPLE: 2,
        Chem.rdchem.BondType.AROMATIC: 3
    }

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_feat = bond_type_to_int.get(bond.GetBondType(), -1)

        edge_index.append((i, j))
        edge_feats.append(bond_feat)
        edge_index.append((j, i))
        edge_feats.append(bond_feat)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_feats = torch.tensor(edge_feats, dtype=torch.long)

    data = Data(
        x=torch.tensor(atom_feats, dtype=torch.long).view(-1, 1),
        pos=torch.tensor(pos, dtype=torch.float),
        edge_index=edge_index,
        edge_type=edge_feats
    )
    data.node_type = data.x.squeeze(-1)

    if tf:
        data = tf(data)
    return data

def compute_edge_lengths(data):
    src, dst = data.edge_index
    pos_src = data.pos[src]
    pos_dst = data.pos[dst]
    return torch.norm(pos_src - pos_dst, dim=1)

# Load first 100 molecules from sdf
sdf_path = 'gdb9/gdb9.sdf'
supplier = Chem.SDMolSupplier(sdf_path, removeHs=False)
mols = [mol for mol in supplier if mol is not None][:N]

# Load first 100 rows from CSV with energies
energy_df = pd.read_csv('gdb9/gdb9.sdf.csv', nrows=N)

# Create dict for quick lookup: mol_id -> u0 energy
energy_map = dict(zip(energy_df['mol_id'], energy_df['u0']))

logps = []
energies = []

# Process each molecule: compute logp, print name; logp; energy
for mol in mols:
    mol_id = mol.GetProp('_Name') if mol.HasProp('_Name') else None
    if mol_id is None:
        print("Skipping molecule without name")
        continue

    u0 = energy_map.get(mol_id, None)
    if u0 is None:
        print(f"No energy found for {mol_id}")
        continue

    try:
        data = mol_to_data_obj(mol).to(device)
        batch = Batch.from_data_list([data])
        edge_lengths = compute_edge_lengths(data).view(-1, 1)
        logp = model.get_log_prob(batch, edge_lengths)
        logp_val = logp.mean().item()
        logps.append(logp_val)
        energies.append(u0)
    except Exception as e:
        print(f"Error computing logp for {mol_id}: {e}")
        logp_val = float('nan')

    print(f"{mol_id}; {logp_val:.6f}; {u0:.6f}")


  checkpoint = torch.load('CGCF/models/ckpt_drugs.pt', map_location=device)


gdb_1; -0.573236; -40.478930
gdb_2; 1.033338; -56.525887
gdb_3; 0.460272; -76.404702
gdb_4; -12.158129; -77.308427
gdb_5; -8.971259; -93.411888
gdb_6; -3.286950; -114.483613
gdb_7; -8.491429; -79.764152
gdb_8; -4.364229; -115.679136
gdb_9; -11.172298; -116.609549
gdb_10; -8.836445; -132.718150
gdb_11; -6.170240; -153.787612
gdb_12; -4.649626; -169.860788
gdb_13; -11.854399; -119.052475
gdb_14; -9.483953; -154.972731
gdb_15; -6.043585; -154.960361
gdb_16; -13.240016; -117.824798
gdb_17; -11.290878; -153.742562
gdb_18; -7.334419; -193.088340
gdb_19; -6.890088; -209.159302
gdb_20; -4.929487; -225.221461
gdb_21; -13.824153; -158.342346
gdb_22; -11.438672; -194.267232
gdb_23; -22.173809; -153.459846
gdb_24; -21.724550; -169.557758
gdb_25; -3.375345; -185.648533
gdb_26; -9.254565; -190.624631
gdb_27; -5.871166; -206.721858
gdb_28; -3.693443; -227.798785
gdb_29; -10.356908; -155.908941
gdb_30; -12.030312; -155.897345
gdb_31; -10.662642; -172.006141
gdb_32; -9.768029; -188.042067
gdb_33; -11.3

In [38]:
print(len(logps))
print(len(energies))

9924
9924


In [39]:
from scipy.stats import spearmanr
from scipy.stats import pearsonr

corr, p_value = spearmanr(logps, energies)
p, c = pearsonr(logps, energies)
print(f"Spearman correlation: {corr:.4f}, p-value: {p_value:.4e}")
print(f"Pearson correlation: {p:.4f}, p-value: {c:.4e}")

Spearman correlation: -0.2183, p-value: 2.3059e-107
Pearson correlation: -0.2236, p-value: 1.1099e-112


In [41]:
from lifelines.utils import concordance_index
negative_energies = [-e for e in energies]
ci = concordance_index(negative_energies, logps)
print(f"Concordance index: {ci:.4f}")

Concordance index: 0.5700


In [42]:
from sklearn.feature_selection import mutual_info_regression
import numpy as np

# Convert to numpy arrays and reshape to 2D for sklearn
X = np.array(logps).reshape(-1, 1)
y = np.array(energies)

mi = mutual_info_regression(X, y, random_state=0)
print(f"Mutual Information: {mi[0]:.4f}")


Mutual Information: 0.6341


In [43]:
import numpy as np

rmse = np.sqrt(np.mean((np.array(logps) - np.array(energies))**2))
print(f"RMSE: {rmse:.4f}")


RMSE: 345.0896
