# AF3 rescore with gnina

## Setup

In [None]:
#| default_exp gnina_rescore

In [None]:
#| export
import pandas as pd
import re, os, subprocess, py3Dmol
from Bio.PDB import MMCIFParser, PDBIO, Select
from rdkit import Chem
from rdkit.Chem import AllChem
from pathlib import Path
from fastcore.all import L
from tqdm.contrib.concurrent import process_map
from functools import partial

## Split the AF3 output .cif into protein.pdb and ligand.sdf

In [None]:
#| export
class ChainSelect(Select):
    "Select chain to save"
    def __init__(self, chain_ids):
        self.chain_ids = chain_ids
    def accept_chain(self, chain):
        return chain.get_id() in self.chain_ids

In [None]:
#| export
def rename_residues(structure, chain_id, new_resname='LIG'):
    "Rename residue name from LIG_L to LIG as LIG_L exceeds lengths and leads to error in RDKit"
    for model in structure:
        for chain in model:
            if chain.id == chain_id:
                for residue in chain:
                    residue.resname = new_resname

In [None]:
#| export
def split_cif(cif_path, chainA_pdb_path, chainL_pdb_path):
    "Split AF3 output CIF to protein and ligand PDBs"
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure('complex', cif_path)
    rename_residues(structure, chain_id='L', new_resname='LIG')
    io = PDBIO()
    io.set_structure(structure)
    io.save(str(chainA_pdb_path), ChainSelect('A'))  # receptor
    io.save(str(chainL_pdb_path), ChainSelect('L'))  # ligand

In [None]:
#| export
def pdb2sdf(pdb_path, sdf_path):
    "Convert ligand pdb to sdf file"
    mol = Chem.MolFromPDBFile(pdb_path, sanitize=True, removeHs=False)
    if mol:
        writer = Chem.SDWriter(sdf_path)
        writer.write(mol)
        writer.close()
        return None
    else:
        print('Conversion failed for:', pdb_path)
        return pdb_path

In [None]:
#| export
def prepare_rec_lig(cif_path, chainA_pdb_path, chainL_sdf_path):
    "Split AF3 cif to protein.pdb (chainA) and ligand.sdf (chainL) "
    
    tmp_name = Path(cif_path).stem
    tmp_path = f'{tmp_name}_lig.pdb'
    split_cif(cif_path, chainA_pdb_path, tmp_path)
    failed = pdb2sdf(tmp_path, chainL_sdf_path)
    try:
        os.remove(tmp_path)
    except OSError:
        pass
    return failed

In [None]:
prepare_rec_lig('gnina_test/cif/test.cif','gnina_test/chain_A.pdb','gnina_test/chain_L.sdf')

## gnina score

According to [gnina doc](https://github.com/gnina/gnina?tab=readme-ov-file):

```bash
gnina -r chain_A.pdb -l chain_L.sdf --minimize -o minimized.sdf.gz
```

In [None]:
#| export
def gnina_rescore_local(protein_pdb, # receptor file
                  ligand_sdf, # ligand file
                  ):
    
    command = ['./gnina', 
               '-r', protein_pdb, 
               '-l', ligand_sdf, 
               '--minimize']

    result = subprocess.run(command, capture_output=True, text=True)
    return result.stdout

In [None]:
# %%time
# out = gnina_rescore_local('gnina_test/chain_A.pdb','gnina_test/chain_L.sdf')
# out

CPU times: user 2.42 ms, sys: 417 µs, total: 2.83 ms
Wall time: 4.43 s




In [None]:
#| export
def gnina_rescore_docker(protein_pdb, ligand_sdf):
    """
    Run GNINA rescoring using Docker. Supports receptor and ligand in different folders.
    """
    protein_pdb = Path(protein_pdb).resolve()
    ligand_sdf = Path(ligand_sdf).resolve()

    # Mount points inside the Docker container
    rec_mount = '/recdata'
    lig_mount = '/ligdata'

    command = [
        'docker', 'run', '--rm',
        '-v', f'{protein_pdb.parent}:{rec_mount}', # mount path separately
        '-v', f'{ligand_sdf.parent}:{lig_mount}',
        'gnina/gnina',
        'gnina',
        '-r', f'{rec_mount}/{protein_pdb.name}',
        '-l', f'{lig_mount}/{ligand_sdf.name}',
        '--minimize',
    ]

    result = subprocess.run(command, capture_output=True, text=True)
    return result.stdout

In [None]:
# %%time
# out = gnina_rescore_docker('gnina_test/chain_A.pdb','gnina_test/chain_L.sdf')
# out

In [None]:
#| export
def extract_gnina_rescore(txt):
    "Extract GNINA output text to dictionary."
    
    pattern = re.search(
        r"Affinity:\s+(?P<binding_energy>[-.\d]+)\s+(?P<uncertainty>[-.\d]+).*?"
        r"RMSD:\s+(?P<RMSD>[-.\d]+).*?"
        r"CNNscore:\s+(?P<CNNscore>[-.\d]+).*?"
        r"CNNaffinity:\s+(?P<CNNaffinity>[-.\d]+).*?"
        r"CNNvariance:\s+(?P<CNNvariance>[-.\d]+)",
        txt,
        re.DOTALL)

    if not pattern:
        print("Failed to match GNINA output format.")
    
    return {k: float(v) for k, v in pattern.groupdict().items()} # convert values to float

In [None]:
out = "              _             \n             (_)            \n   __ _ _ __  _ _ __   __ _ \n  / _` | '_ \\| | '_ \\ / _` |\n | (_| | | | | | | | | (_| |\n  \\__, |_| |_|_|_| |_|\\__,_|\n   __/ |                    \n  |___/                     \n\ngnina  master:e9cb230+   Built Feb 11 2023.\ngnina is based on smina and AutoDock Vina.\nPlease cite appropriately.\n\nWARNING: No GPU detected. CNN scoring will be slow.\nRecommend running with single model (--cnn crossdock_default2018)\nor without cnn scoring (--cnn_scoring=none).\n\nCommandline: ./gnina -r chain_A.pdb -l chain_L.sdf --minimize\nAffinity: -10.96345  -1.51405 (kcal/mol)\nRMSD: 1.15404\nCNNscore: 0.49978 \nCNNaffinity: 7.32008\nCNNvariance: 0.18500\n"

In [None]:
extract_gnina_rescore(out)

{'binding_energy': -10.96345,
 'uncertainty': -1.51405,
 'RMSD': 1.15404,
 'CNNscore': 0.50439,
 'CNNaffinity': 7.30706,
 'CNNvariance': 0.17173}

In [None]:
#| export
def get_gnina_rescore(cif_path,is_local=False):
    "Split the CIF into receptor and ligand folders, then extract the GNINA rescored affinity score"
    cif_path = Path(cif_path).expanduser()
    parent,stem = cif_path.parent,cif_path.stem

    rec_dir,lig_dir = Path(str(parent) + '_receptor'),Path(str(parent) + '_ligand')
    
    rec_path,lig_path = rec_dir/f'{stem}.pdb',lig_dir/f'{stem}.sdf'
    
    rec_dir.mkdir(exist_ok=True)
    lig_dir.mkdir(exist_ok=True)
    
    prepare_rec_lig(cif_path,rec_path,lig_path)
    if is_local:
        gnina_output = gnina_rescore_local(rec_path,lig_path)
    else:
        gnina_output = gnina_rescore_docker(rec_path,lig_path)
    return extract_gnina_rescore(gnina_output)

In [None]:
# get_gnina_rescore('gnina_test/cif/test.cif',is_local=True)

Non-parallel for multiple .cif files:

In [None]:
# cifs = L(Path('gnina_test/cif').expanduser().glob("*.cif")) # just take cif file

# out = {p.stem: get_gnina_rescore(p) for p in tqdm(cifs)}

# out_df = pd.DataFrame(out).T

In [None]:
#| export
def get_gnina_rescore_folder(cif_folder,is_local=False):
    "Parallel processing to get gnina rescore given folder path"
    cifs = L(Path(cif_folder).expanduser().glob("*.cif")) # just take cif file
    
    func = partial(get_gnina_rescore,is_local=is_local)
    results = process_map(func, cifs, max_workers=4)

    # use path.stem as df index
    results_dict = dict(zip([p.stem for p in cifs], results))
    return pd.DataFrame(results_dict).T.reset_index(names='ID')

In [None]:
# %%time
# get_gnina_rescore_folder('gnina_test/cif',is_local=True)

  0%|          | 0/4 [00:00<?, ?it/s]

CPU times: user 30.3 ms, sys: 28.8 ms, total: 59.1 ms
Wall time: 8.38 s


Unnamed: 0,ID,binding_energy,uncertainty,RMSD,CNNscore,CNNaffinity,CNNvariance
0,test,-10.96345,-1.51405,1.15404,0.50439,7.30706,0.17173
1,test2,-14.18709,-1.21779,0.51255,0.89946,8.61992,0.02119
2,test3,-10.35332,-1.34231,0.53057,0.71652,7.61942,0.25805
3,test4,-14.19527,-1.27426,0.58653,0.92182,8.65907,0.02081


## End

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()