# Inspect Datasets and Save as Smol Objects

In [None]:
import sys
sys.path.append("..")

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
from rdkit import Chem, RDLogger
from torchmetrics import MetricCollection
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_3d = True

In [None]:
import semlaflow.util.rdkit as smolRD
import semlaflow.util.functional as smolF
import semlaflow.util.metrics as Metrics
from semlaflow.util.tokeniser import Vocabulary
from semlaflow.util.molrepr import GeometricMol, GeometricMolBatch

In [None]:
QM9_PATH = "../../../data/qm9"
RAW_DIR ="raw"
SPLIT_DIR = "raw_split"
SAVE_DIR = "smol"
SDF_FILE = "gdb9.sdf"
METADATA_FILE = "gdb9.sdf.csv"
SKIP_FILE = "uncharacterized.txt"

In [None]:
# Copied from MiDi code, so should create the same splits (they didn't make them available)
def split_qm9(metadata_df):
    n_samples = len(metadata_df)
    n_train = 100000
    n_test = int(0.1 * n_samples)
    n_val = n_samples - (n_train + n_test)

    # Shuffle dataset with df.sample, then split
    train, val, test = np.split(metadata_df.sample(frac=1, random_state=42), [n_train, n_val + n_train])
    return train, val, test

In [None]:
# Will skip mol indices which appear in the skip file
def rdkit_mols_from_df(split_path, sdf_path, skip_path):
    target_df = pd.read_csv(split_path, index_col=0)
    target_df.drop(columns=['mol_id'], inplace=True)

    with open(skip_path, 'r') as f:
        skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]

    suppl = Chem.SDMolSupplier(str(sdf_path), removeHs=False, sanitize=False)

    mols = []
    all_smiles = []

    errors = 0
    skipped = 0

    for i, mol in enumerate(tqdm(suppl)):
        if i not in target_df.index:
            continue

        if i in skip:
            skipped += 1
            continue

        try:
            Chem.SanitizeMol(mol)
            smiles = Chem.MolToSmiles(mol, isomericSmiles=False)
        except:
            smiles = None

        if smiles is None:
            errors += 1
        else:
            all_smiles.append(smiles)
            mols.append(mol)

    print(f"Skipped {skipped} mols which where in skip file.")
    print(f"Encountered {errors} molecules which failed sanitisation.")
    print(f"Completed loading of dataset with {len(mols)} molecules.")

    return mols

In [None]:
def build_vocab():
    # Need to make sure PAD has index 0
    special_tokens = ["<PAD>", "<MASK>"]
    core_atoms = ["H", "C", "N", "O", "F", "P", "S", "Cl"]
    other_atoms = ["Br", "B", "Al", "Si", "As", "I", "Hg", "Bi"]
    tokens = special_tokens + core_atoms + other_atoms
    return Vocabulary(tokens)

In [None]:
def matching_smiles(rdkit_mol, smol_mol, vocab):
    rdkit_mol2 = smol_mol.to_rdkit(vocab)
    smi1 = smolRD.smiles_from_mol(rdkit_mol, canonical=True)
    smi2 = smolRD.smiles_from_mol(rdkit_mol2, canonical=True)
    return smi1 == smi2

## QM9

### Split QM9 and load into separate CSVs

I have copied the code from the MiDi paper and used the same random seed, so hopefully this will generate the same splits as they used. But they haven't provided their splits so we can't say for sure without these.

This code just splits the csv file, which contains metadata and properties for each molecule. The full molecular coordinates are stored in a single sdf file.

In [None]:
qm9_path = Path(QM9_PATH)
dataset = pd.read_csv(qm9_path / RAW_DIR / METADATA_FILE)
train, val, test = split_qm9(dataset)

train_csv_path = qm9_path / SPLIT_DIR / "train.csv"
val_csv_path = qm9_path / SPLIT_DIR / "val.csv"
test_csv_path = qm9_path / SPLIT_DIR / "test.csv"

# train.to_csv(train_csv_path)
# val.to_csv(val_csv_path)
# test.to_csv(test_csv_path)

### Create Smol Datasets from RDKit Mols from SDF Files

In [None]:
RDLogger.DisableLog('rdApp.*')

In [None]:
vocab = build_vocab()

In [None]:
sdf_path = qm9_path / RAW_DIR / SDF_FILE
skip_path = qm9_path / RAW_DIR / SKIP_FILE

print("Processing train data...")
train_mols = rdkit_mols_from_df(train_csv_path, sdf_path, skip_path)

print("Processing val data...")
val_mols = rdkit_mols_from_df(val_csv_path, sdf_path, skip_path)

print("Processing test data...")
test_mols = rdkit_mols_from_df(test_csv_path, sdf_path, skip_path)

In [None]:
# Create Smol batches for ease of use later on
train_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in train_mols])
val_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in val_mols])
test_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in test_mols])

In [None]:
# Check it looks right
print("Dataset sizes:")
print(len(train_batch))
print(len(val_batch))
print(len(test_batch))

example_mol = train_batch[567]
print()
print("Example mol:")
print(example_mol.coords)
print(example_mol.atomics)
print(example_mol.bonds)
print(example_mol.charges)

In [None]:
example_mol.to_rdkit(vocab)

In [None]:
for atom in example_mol.to_rdkit(vocab).GetAtoms():
    print(f"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}")

In [None]:
train_path = qm9_path / SAVE_DIR / "train.smol"
val_path = qm9_path / SAVE_DIR / "val.smol"
test_path = qm9_path / SAVE_DIR / "test.smol"

train_bytes = train_batch.to_bytes()
val_bytes = val_batch.to_bytes()
test_bytes = test_batch.to_bytes()

train_path.write_bytes(train_bytes)
val_path.write_bytes(val_bytes)
test_path.write_bytes(test_bytes)

In [None]:
train_matching = [matching_smiles(mol1, mol2, vocab) for mol1, mol2 in zip(train_mols, train_batch.to_list())]
print("Proportion matching", sum(train_matching) / len(train_matching))

In [None]:
print(len(train_mols))
print(len(train_batch))

In [None]:
unmatched_idxs = [idx for idx, matching in enumerate(train_matching) if not matching]

In [None]:
idx = 100
unmatched_idx = unmatched_idxs[idx]
print(smolRD.smiles_from_mol(train_mols[unmatched_idx]))
print(smolRD.smiles_from_mol(train_batch[unmatched_idx].to_rdkit(vocab)))

In [None]:
train_valid = [smolRD.mol_is_valid(mol.to_rdkit(vocab)) for mol in train_batch]
print("Propertion valid", sum(train_valid) / len(train_valid))

## Analyse QM9 Dataset

In [None]:
train_coords = train_batch.coords
train_mask = train_batch.mask

_, std_dev = smolF.standardise_coords(train_coords, train_mask)
print("Coord std dev on train data", std_dev)

In [None]:
avg_n_atoms = sum(train_batch.seq_length) / len(train_batch.seq_length)
max_n_atoms = max(train_batch.seq_length)
min_n_atoms = min(train_batch.seq_length)
print("avg", avg_n_atoms)
print("max", max_n_atoms)
print("min", min_n_atoms)

In [None]:
plt.hist(train_batch.seq_length, bins=26)
plt.show()

### Firstly, try loading the saved data

In [None]:
SAVE_DIR = "smol"

In [None]:
qm9_path = Path(QM9_PATH)
train_path = qm9_path / SAVE_DIR / "train.smol"
val_path = qm9_path / SAVE_DIR / "val.smol"
test_path = qm9_path / SAVE_DIR / "test.smol"

In [None]:
train_bytes = train_path.read_bytes()
val_bytes = val_path.read_bytes()
test_bytes = test_path.read_bytes()

train_batch = GeometricMolBatch.from_bytes(train_bytes)
val_batch = GeometricMolBatch.from_bytes(val_bytes)
test_batch = GeometricMolBatch.from_bytes(test_bytes)

In [None]:
vocab = build_vocab()

In [None]:
sample_mols = train_batch.to_list()

In [None]:
sample_mols[567].to_rdkit(vocab)

In [None]:
for atom in sample_mols[567].to_rdkit(vocab).GetAtoms():
    print(f"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}")

In [None]:
gen_metrics = {
    "validity": Metrics.Validity(),
    "fc-validity": Metrics.Validity(connected=True),
    "uniqueness": Metrics.Uniqueness(),
    "energy-validity": Metrics.EnergyValidity(),
    "opt-energy-validity": Metrics.EnergyValidity(optimise=True),
    "energy": Metrics.AverageEnergy(),
    "energy-per-atom": Metrics.AverageEnergy(per_atom=True),
    "strain": Metrics.AverageStrainEnergy(),
    "strain-per-atom": Metrics.AverageStrainEnergy(per_atom=True),
    "opt-rmsd": Metrics.AverageOptRmsd()
}
gen_metrics = MetricCollection(gen_metrics, compute_groups=False)

In [None]:
# Compute benchmark metrics on loaded train dataset samples
rdkit_sample_mols = [mol.to_rdkit(vocab, sanitize=True) for mol in sample_mols]
gen_metrics.reset()
gen_metrics.update(rdkit_sample_mols)
results = gen_metrics.compute()

In [None]:
for metric, result in results.items():
    print(f"{metric} -- {result.item():.3f}")

In [None]:
# Compute benchmark metrics on original train dataset samples
gen_metrics.reset()
gen_metrics.update(train_mols)
results = gen_metrics.compute()

In [None]:
for metric, result in results.items():
    print(f"{metric} -- {result.item():.3f}")

In [None]:
for idx, mol in enumerate(sample_mols[82008:82010]):
    print(idx)
    mol.to_rdkit(vocab)

In [None]:
idx = 82000 + 8
sample_mols[idx].to_rdkit(vocab)

In [None]:
original_mol = Chem.Mol(train_mols[idx])
Chem.SanitizeMol(original_mol)
original_mol

### Recreate this issue with functions in the notebook

In [None]:
def mol_from_atoms(coords, tokens, bonds):
    try:
        atomics = [smolRD.PT.atomic_from_symbol(token) for token in tokens]
    except:
        return None

    # Add atom types
    mol = Chem.EditableMol(Chem.Mol())
    for atomic in atomics:
        mol.AddAtom(Chem.Atom(atomic))

    # Add 3D coords
    conf = Chem.Conformer(coords.shape[0])
    for idx, coord in enumerate(coords.tolist()):
        conf.SetAtomPosition(idx, coord)

    mol = mol.GetMol()
    mol.AddConformer(conf)

    # Add bonds if they have been provided
    mol = Chem.EditableMol(mol)
    for bond in bonds.astype(np.int32).tolist():
        start, end, b_type = bond

        if b_type not in smolRD.IDX_BOND_MAP:
            return None

        # Don't add self connections
        if start != end:
            b_type = smolRD.IDX_BOND_MAP[b_type]
            mol.AddBond(start, end, b_type)

    mol = mol.GetMol()
    for atom in mol.GetAtoms():
        atom.UpdatePropertyCache(strict=False)

    # try:
    #     Chem.SanitizeMol(mol)
    # except:
    #     return None

    return mol

In [None]:
def to_rdkit(mol, vocab):
    if len(mol.atomics.size()) == 2:
        vocab_indices = torch.argmax(mol.atomics, dim=1).tolist()
        tokens = vocab.tokens_from_indices(vocab_indices)

    else:
        atomics = mol.atomics.tolist()
        tokens = [smolRD.PT.symbol_from_atomic(a) for a in atomics]

    coords = mol.coords.numpy()
    bonds = mol.bonds.numpy()

    rdkit_mol = mol_from_atoms(coords, tokens, bonds)
    return rdkit_mol

In [None]:
idx = 82000 + 8
problem_mol = sample_mols[idx]
rdkit_mol = to_rdkit(problem_mol, vocab)

In [None]:
for atom in rdkit_mol.GetAtoms():
    print(f"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}")

print()
for atom in original_mol.GetAtoms():
    print(f"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}")