In [1]:
import os
os.chdir('../')

In [2]:
import torch

import datamol 
from rdkit import Chem

import biotite.structure.io.pdb as pdb
import semlaflow.scriptutil as util
from model_args import get_ar_semla_model_args

### Creating a dataloader for a custom set of protein ligand complex 

In [3]:
# Get some mols
data = datamol.data.chembl_drugs()
smiles = data["smiles"].iloc[:].tolist()
mols = [Chem.MolFromSmiles(s) for s in smiles]



In [4]:
# Get some protein pockets in Biotite format
# dm.val_dataset[0].holo.write_pdb('inference_api/data/example_holo.pdb')
pdb_file = pdb.PDBFile.read('inference_api/data/example_holo.pdb')
# We need atomarray here not atom array stack - so we take the first model
pocket_atoms = pdb.get_structure(pdb_file, include_bonds=True)[0]

In [5]:
# Create the protein ligand complex batch
from semlaflow.util.molrepr import GeometricMol
from semlaflow.data.datasets import PocketComplexDataset
from semlaflow.util.pocket import ProteinPocket, PocketComplex, PocketComplexBatch

geo_complex_list = []

for i in range(20):
    geo_mol = GeometricMol.from_rdkit(mols[i])
    geo_pocket = ProteinPocket.from_pocket_atoms(
        pocket_atoms, infer_res_bonds=True)
    geo_complex = PocketComplex(
        holo=geo_pocket, ligand=geo_mol)
    geo_complex_list.append(geo_complex)
    
# Create the batch
geo_complex_batch = PocketComplexBatch(geo_complex_list)

[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling AddHs()
[21:07:12] Molecule does not have explicit Hs. Consider calling 

### Build the model and load saved weights

In [6]:
from semlaflow.buildutil import build_dm
from semlaflow.buildutil import build_model

args = get_ar_semla_model_args()
vocab = util.build_vocab()
dm = build_dm(args, vocab, geo_complex_batch)
model = build_model(args, dm, vocab)

checkpoint = torch.load('semlaflow/saved/models/plinder-finetune-conf/epoch=311.ckpt', map_location=torch.device('cuda'))
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()

Using type ARGeometricComplexInterpolant for training

items per bucket [0, 1, 2, 3, 6, 5, 2, 1, 0, 0, 0, 0, 0, 0]
bucket batch sizes [88, 72, 64, 56, 48, 48, 40, 40, 40, 32, 32, 32, 24, 8]
batches per bucket [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
Total training steps 1400
Using model class ComplexSemlaGenerator
Using CFM class ARComplexMolecularCFM


### Generate the molecules

In [42]:
print("Initialising metrics...")
metrics, stability_metrics, complex_metrics, conf_metrics = util.init_metrics()
print("Metrics complete.")

Initialising metrics...
No training data provided. Skipping novelty metric.
Metrics complete.


In [43]:
def predict(batch):
    prior, data, interpolated, _, holo_mols, holos, times, _, gen_times = batch
    batch = {k: v.cuda() for k, v in prior.items()}
    holo_mols = {k: v.cuda() for k, v in holo_mols.items()}
    gen_times = gen_times.cuda()
    output = model._generate(batch, gen_times, args.integration_steps, 
        args.ode_sampling_strategy, holo=holo_mols)

    gen_mols = model._generate_mols(output)
    data_mols = model._generate_mols(data, rescale=True)
    
    return gen_mols, data_mols, holos

In [40]:
from tqdm import tqdm
from time import time

print("Running generation...")
mols_list = []
holos_list = []
data_mols_list = []
for batch in tqdm(dm.val_dataloader()):
    mols, data_mols, holos = predict(batch)
    
    mols_list.append(mols)
    holos_list.append(holos)
    data_mols_list.append(data_mols)

molecules = [mol for mol_list in mols_list for mol in mol_list]
holos_list = [holo for holo_list in holos_list for holo in holo_list]
data_mols_list = [data_mols for data_mols_list in data_mols_list for data_mols in data_mols_list]
print("Generation complete.")

Running generation...

items per bucket [0, 1, 2, 3, 6, 5, 2, 1, 0, 0, 0, 0, 0, 0]
bucket batch sizes [88, 72, 64, 56, 48, 48, 40, 40, 40, 32, 32, 32, 24, 8]
batches per bucket [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]


  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [01:13<00:00, 10.52s/it]

Generation complete.





How to prepare the batch 

In [None]:
util.disable_lib_stdout()

print("Calculating generative metrics...")
results = util.calc_metrics_(
    molecules, 
    metrics,
    complex_metrics=complex_metrics,
    holos=holos_list,
    conf_metrics=conf_metrics,
    data_mols=data_mols_list,
)
util.print_results(results)
print("Generation script complete!")

Calculating generative metrics...




  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


Metric                Result
------------------------------
connected-validity    1.00000
energy                198.96985
energy-per-atom       9.98923
energy-validity       1.00000
opt-energy-validity   1.00000
opt-rmsd              0.83566
strain                159.30522
strain-per-atom       8.11815
uniqueness            1.00000
validity              1.00000
clash                 65.30000
hydrophobic           5.70000
vdw                   9.00000
hbacceptor            0.55000
hbdonor               0.10000
conformer-centroid-rmsd2.70221
conformer-no-align-rmsd5.88650
conformer-rmsd        1.30960

Generation script complete!


In [49]:
from semlaflow.util.visualize import py3Dmol_visualize

idx = 2
view = py3Dmol_visualize(
    molecules[idx],
    holos_list[idx],
)
view.zoomTo()
view.show()

In [51]:
idx= 2
util.calc_metrics_([molecules[idx]], metrics,complex_metrics=complex_metrics,
    holos=[holos_list[idx]],
    conf_metrics=conf_metrics,
    data_mols=[data_mols_list[idx]],
)



  0%|          | 0/1 [00:00<?, ?it/s]

{'connected-validity': tensor(1.),
 'energy': tensor(144.8999),
 'energy-per-atom': tensor(14.4900),
 'energy-validity': tensor(1.),
 'opt-energy-validity': tensor(1.),
 'opt-rmsd': tensor(0.5880),
 'strain': tensor(127.4572),
 'strain-per-atom': tensor(12.7457),
 'uniqueness': tensor(1.),
 'validity': tensor(1.),
 'clash': tensor(21.),
 'hydrophobic': tensor(5.),
 'vdw': tensor(5.),
 'hbacceptor': tensor(0.),
 'hbdonor': tensor(0.),
 'conformer-centroid-rmsd': tensor(1.7941),
 'conformer-no-align-rmsd': tensor(3.1904),
 'conformer-rmsd': tensor(1.0654)}