In [1]:
import os
os.chdir('../')

In [2]:
import torch
import datamol
import numpy as np 
from rdkit import Chem
import biotite.structure.io.pdb as pdb
from model_args import get_ar_semla_model_args

import semlaflow.scriptutil as util
from semlaflow.buildutil import build_dm
from semlaflow.buildutil import build_model
from semlaflow.util.molrepr import GeometricMol
from semlaflow.util.pocket import ProteinPocket, PocketComplex, PocketComplexBatch

Failed to find the pandas get_adjustment() function to patch
Failed to patch pandas - PandasTools will have limited functionality


### Creating a dataloader for a custom set of protein ligand complex 

In [3]:
# Get example some mols
data = datamol.data.chembl_drugs()
smiles = data["smiles"].iloc[:].tolist()
mols = [Chem.MolFromSmiles(s) for s in smiles]



In [4]:
# Get some protein pockets in Biotite format
pdb_file = pdb.PDBFile.read('inference_api/data/example_holo.pdb')
# We need atomarray here not atom array stack - so we take the first model
pocket_atoms = pdb.get_structure(pdb_file, include_bonds=True)[0]

In [None]:
# Create the protein ligand complex batch
geo_complex_list = []

num_to_generate = 20
for i in range(num_to_generate):
    geo_mol = GeometricMol.from_rdkit(mols[i])
    
    geo_pocket = ProteinPocket.from_pocket_atoms(
        pocket_atoms, infer_res_bonds=True)
    
    geo_complex = PocketComplex(
        holo=geo_pocket, ligand=geo_mol)
    
    geo_complex_list.append(geo_complex)
    
# Create the batch
geo_complex_batch = PocketComplexBatch(geo_complex_list)

[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling AddHs()
[00:18:32] Molecule does not have explicit Hs. Consider calling 

In [7]:
args = get_ar_semla_model_args()
args.use_complex_metrics = True

### Build the model and load saved weights

In [8]:
use_existing_dm = True

vocab = util.build_vocab()
if use_existing_dm:
    dm = build_dm(args, vocab)
else:
    dm = build_dm(args, vocab, geo_complex_batch)
model = build_model(args, dm, vocab)

checkpoint = torch.load('semlaflow/saved/models/plinder-ar-conf/last.ckpt', map_location=torch.device('cuda'))

# HACK: A quick hack to remove later to make sure model is compatiable 
new_checkpoint = {}
for key in checkpoint['state_dict'].keys():
    if key.startswith('pocket_encoder'):
        new_key = key.replace('pocket_encoder.', 'pocket_encoder.encoder.')
        new_checkpoint[new_key] = checkpoint['state_dict'][key]
    else:
        new_checkpoint[key] = checkpoint['state_dict'][key]

model.load_state_dict(new_checkpoint)
model = model.cuda()

Using type ARGeometricComplexInterpolant for training

items per bucket [45, 39, 51, 53, 76, 104, 92, 86, 114, 84, 63, 111, 79, 3]
bucket batch sizes [88, 72, 64, 56, 48, 48, 40, 40, 40, 32, 32, 32, 24, 8]
batches per bucket [1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 2, 4, 4, 1]
Total training steps 6400
Using model class ComplexSemlaGenerator
Using CFM class ARComplexMolecularCFM


### Generate the molecules

In [9]:
def interval_predict(batch):
    # prior is just the initial fragment with gaussian position
    # holo_mols is the protein pocket (for model input)
    # holos is also the protein pocket (for evaluation)
    # gen_times is the time at which the fragment is generated (should be all 0 for prior)
    prior, data_mols, _, _, holo_mols, holo_pocks, _, _, gen_times = batch
    
    # Move all the data to GPU    
    curr = {k: v.cuda() for k, v in prior.items()}
    holo_mols = {k: v.cuda() for k, v in holo_mols.items()}
    gen_times = gen_times.cuda()
    
    # Compute the start and end times for each interpolation interval
    # After each interval, we ask RxnFlow to predict the next fragment to add
    start_times = np.array([i * args.t_per_ar_action for i in range(args.max_num_cuts + 1)])
    end_times = [t for t in start_times[1:]] + [1.0]

    # Start time [0, 0.25, 0.5, 0.75]
    # End time [0.25, 0.5, 0.75, 1.0]


    with torch.no_grad():
        for start_time, end_time in zip(start_times, end_times):
            num_steps = int((end_time - start_time) // (1.0 / args.integration_steps)) + 1
            
        
            curr, predicted, times = model._step_interval(
                curr, gen_times, num_steps, start_time, end_time, holo=holo_mols, holo_pocks=holo_pocks)
            
            # TODO: Here is where we use the partial prediction of the flow matching model
            # to condition the RxnFlow model policy. We use the predicted["coords"] to condition. 
            # Once RxnFlow outputs the new fragment to add - we should do the following
            # 1. Update the "curr" batch with the newly added fragment
            # - initalize the position of new fragment as Gaussian centered at 0
            # - update the atomics, mask, charges, and most importantly the bonds
            # 2. Update the gen_times, so that we assign current time to the new fragment as the gen times.
            
    predicted["coords"] = predicted["coords"] * model.coord_scale
    
    gen_mols = model._generate_mols(predicted)
    data_mols = model._generate_mols(data_mols, rescale=True)
    return gen_mols, data_mols, holo_pocks

In [10]:
from tqdm import tqdm

mols_list = []
data_mols_list = []
holos_list = []
for batch in tqdm(dm.val_dataloader()):
    gen_mols, data_mols, holos = interval_predict(batch)

    mols_list.extend(gen_mols)
    data_mols_list.extend(data_mols)
    holos_list.extend(holos)
    break

# molecules = [mol for mol_list in mols_list for mol in mol_list]
# holos_list = [holo for holo_list in holos_list for holo in holo_list]

print("Generation complete.")


items per bucket [4, 4, 10, 3, 5, 1, 2, 3, 1, 3, 4, 6, 4, 0]
bucket batch sizes [88, 72, 64, 56, 48, 48, 40, 40, 40, 32, 32, 32, 24, 8]
batches per bucket [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]


  0%|          | 0/13 [00:00<?, ?it/s]

[00:19:14] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4 5 7
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 1 3 4
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 14 15 16 19 20 22 23
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 12 13 14 15 16 17 18 28 29
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 14 15 16 19 20 22 23
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 22 23 25 28 29 31 32
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 14 15 16 19 20 22 23
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 1 3 4
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 27 30 33
[00:19:14] Can't kekulize mol.  Unkekulized atoms: 22 23 24 27 28 30 31
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:63.)
  n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
  0%|          | 0/13 [00:16<?, ?it/s]

Generation complete.





### Evaluate the metrics

In [11]:
print("Initialising metrics...")
metrics, stability_metrics, complex_metrics, conf_metrics = util.init_metrics()
print("Metrics complete.")

Initialising metrics...
No training data provided. Skipping novelty metric.
Metrics complete.


In [12]:
util.disable_lib_stdout()

print("Calculating generative metrics...")
results = util.calc_metrics_(
    mols_list, 
    metrics,
    complex_metrics=complex_metrics,
    holo_pocks=holos_list,
    conf_metrics=conf_metrics,
    data_mols=data_mols_list,
)
util.print_results(results)
print("Generation script complete!")

Calculating generative metrics...




  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


Metric                Result
------------------------------
connected-validity    1.00000
energy                57.63604
energy-per-atom       4.42124
energy-validity       1.00000
opt-energy-validity   1.00000
opt-rmsd              0.42911
strain                97.99454
strain-per-atom       7.49077
uniqueness            0.75000
validity              1.00000
clash                 30.00000
hydrophobic           3.00000
vdw                   7.00000
hbacceptor            2.25000
hbdonor               0.75000
conformer-centroid-rmsd0.80683
conformer-no-align-rmsd3.39102
conformer-rmsd        0.68081

Generation script complete!


### Visually inspect indivdual complex predictions

In [18]:
from semlaflow.util.visualize import complex_to_3dview

idx = 0
view = complex_to_3dview(
    mols_list[idx],
    holos_list[idx],
    data_mols_list[idx],
)
view.zoomTo()
view.show()

In [20]:
idx= 0
util.calc_metrics_(
    [mols_list[idx]], 
    metrics,
    complex_metrics=complex_metrics,
    holo_pocks=[holos_list[idx]],
    conf_metrics=conf_metrics,
    data_mols=[data_mols_list[idx]],
)



  0%|          | 0/1 [00:00<?, ?it/s]

{'connected-validity': tensor(1.),
 'energy': tensor(132.8751),
 'energy-per-atom': tensor(9.4911),
 'energy-validity': tensor(1.),
 'opt-energy-validity': tensor(1.),
 'opt-rmsd': tensor(0.9442),
 'strain': tensor(116.6753),
 'strain-per-atom': tensor(8.3339),
 'uniqueness': tensor(1.),
 'validity': tensor(1.),
 'clash': tensor(44.),
 'hydrophobic': tensor(7.),
 'vdw': tensor(7.),
 'hbacceptor': tensor(0.),
 'hbdonor': tensor(1.),
 'conformer-centroid-rmsd': tensor(1.2794),
 'conformer-no-align-rmsd': tensor(4.5411),
 'conformer-rmsd': tensor(0.8846)}