# AF3 rescore with gnina

## Setup

In [None]:
#| default_exp gnina_af3

In [None]:
#| export
import pandas as pd
import re, os, subprocess, py3Dmol
from Bio.PDB import MMCIFParser, PDBIO, Select
from rdkit import Chem
from rdkit.Chem import AllChem

## Split the AF3 output .cif into protein.pdb and ligand.sdf

In [None]:
#| export
class ChainSelect(Select):
    "Select chain to save"
    def __init__(self, chain_ids):
        self.chain_ids = chain_ids
    def accept_chain(self, chain):
        return chain.get_id() in self.chain_ids

In [None]:
#| export
def rename_residues(structure, chain_id, new_resname='LIG'):
    "Rename residue name from LIG_L to LIG as LIG_L exceeds lengths and leads to error in RDKit"
    for model in structure:
        for chain in model:
            if chain.id == chain_id:
                for residue in chain:
                    residue.resname = new_resname

In [None]:
#| export
def split_cif(cif_path, chainA_pdb_path, chainL_pdb_path):
    "Split AF3 output CIF to protein and ligand PDBs"
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure('complex', cif_path)
    rename_residues(structure, chain_id='L', new_resname='LIG')
    io = PDBIO()
    io.set_structure(structure)
    io.save(chainA_pdb_path, ChainSelect('A'))  # receptor
    io.save(chainL_pdb_path, ChainSelect('L'))  # ligand

In [None]:
#| export
def pdb2sdf(pdb_path, sdf_path):
    "Convert ligand pdb to sdf file"
    mol = Chem.MolFromPDBFile(pdb_path, sanitize=True, removeHs=False)
    if mol:
        writer = Chem.SDWriter(sdf_path)
        writer.write(mol)
        writer.close()
        return None
    else:
        print('Conversion failed for:', pdb_path)
        return pdb_path

In [None]:
#| export
def prepare_rec_lig(cif_path, chainA_pdb_path, chainL_sdf_path):
    "Split AF3 cif to protein.pdb (chainA) and ligand.sdf (chainL) "
    tmp = 'tmp_lig.pdb'
    split_cif(cif_path, chainA_pdb_path, tmp)
    failed = pdb2sdf(tmp, chainL_sdf_path)
    try:
        os.remove(tmp)
    except OSError:
        pass
    return failed

In [None]:
prepare_rec_lig('test.cif','chain_A.pdb','chain_L.sdf')

## gnina score

According to [gnina doc](https://github.com/gnina/gnina?tab=readme-ov-file):

```bash
gnina -r chain_A.pdb -l chain_L.sdf --minimize -o minimized.sdf.gz
```

In [None]:
#| export
def gnina_rescore(protein_pdb, # receptor file
                  ligand_sdf, # ligand file
                  ):
    
    command = ['./gnina', 
               '-r', protein_pdb, 
               '-l', ligand_sdf, 
               '--minimize']

    result = subprocess.run(command, capture_output=True, text=True)
    return result.stdout

In [None]:
# out = gnina_rescore('chain_A.pdb','chain_L.sdf')

In [None]:
# out



In [None]:
#| export
def extract_gnina_output(txt):
    "Extract GNINA output text to dictionary."
    
    pattern = re.search(
        r"Affinity:\s+(?P<binding_energy>[-.\d]+)\s+(?P<uncertainty>[-.\d]+).*?"
        r"RMSD:\s+(?P<RMSD>[-.\d]+).*?"
        r"CNNscore:\s+(?P<CNNscore>[-.\d]+).*?"
        r"CNNaffinity:\s+(?P<CNNaffinity>[-.\d]+).*?"
        r"CNNvariance:\s+(?P<CNNvariance>[-.\d]+)",
        txt,
        re.DOTALL)

    if not pattern:
        print("Failed to match GNINA output format.")
    
    return {k: float(v) for k, v in pattern.groupdict().items()} # convert values to float

In [None]:
out = "              _             \n             (_)            \n   __ _ _ __  _ _ __   __ _ \n  / _` | '_ \\| | '_ \\ / _` |\n | (_| | | | | | | | | (_| |\n  \\__, |_| |_|_|_| |_|\\__,_|\n   __/ |                    \n  |___/                     \n\ngnina  master:e9cb230+   Built Feb 11 2023.\ngnina is based on smina and AutoDock Vina.\nPlease cite appropriately.\n\nWARNING: No GPU detected. CNN scoring will be slow.\nRecommend running with single model (--cnn crossdock_default2018)\nor without cnn scoring (--cnn_scoring=none).\n\nCommandline: ./gnina -r chain_A.pdb -l chain_L.sdf --minimize\nAffinity: -10.96345  -1.51405 (kcal/mol)\nRMSD: 1.15404\nCNNscore: 0.49978 \nCNNaffinity: 7.32008\nCNNvariance: 0.18500\n"

In [None]:
extract_gnina_output(out)

{'binding_energy': -10.96345,
 'uncertainty': -1.51405,
 'RMSD': 1.15404,
 'CNNscore': 0.49978,
 'CNNaffinity': 7.32008,
 'CNNvariance': 0.185}

In [None]:
#| export
def get_gnina_rescore(protein_pdb,ligand_sdf):
    out = gnina_rescore('chain_A.pdb','chain_L.sdf')
    return extract_gnina_output(out)

In [None]:
# get_gnina_rescore('chain_A.pdb','chain_L.sdf')

{'binding_energy': -10.96345,
 'uncertainty': -1.51405,
 'RMSD': 1.15404,
 'CNNscore': 0.49978,
 'CNNaffinity': 7.32008,
 'CNNvariance': 0.185}

## End

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()