In [1]:
import pandas as pd
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.NeighborSearch import NeighborSearch
from Bio.PDB.Atom import Atom
import numpy as np
import os

In [2]:
## sites located at the interface in test dataset (anno)
df_anno = pd.read_table(f"{PROJECT_DIR}/dataset/transform/test_metalnet.tsv")
df = pd.read_table(f"{PROJECT_DIR}/dataset/collect/analysis/sites_interface_multi_num.tsv")
inter_sites = set(zip(df['pdb'], df['metal_chain'], df['metal_pdb_seq_num']))
df_anno = df_anno[df_anno.apply(lambda row: (row['pdb'], row['metal_chain'], row['metal_pdb_seq_num']) in inter_sites, axis=1)]
del df

## predicted as true in test dataset
df_pred = pd.read_table("../../pred_pairs.tsv")
df_pred = df_pred[df_pred['filter_by_graph'] == 1]

## how many true residues on interface are successfully predicted

In [3]:
true_residues = set(zip(df_anno['seq_id'], df_anno['resi_domain_posi']))
pred_residues = set()
for _, row in df_pred.iterrows():
    pred_residues.add((row['seq_id'], row['resi_seq_posi_1']))
    pred_residues.add((row['seq_id'], row['resi_seq_posi_2']))
intersection = pred_residues & true_residues

In [7]:
print("Number of metal sites located at interface: ", len(set(zip(df_anno['pdb'], df_anno['metal_chain'], df_anno['metal_pdb_seq_num']))))
print(f"Residues located at interface: {len(intersection)}  / {len(true_residues)} (predicted / true)", )
intersection

Number of metal sites located at interface:  24
Residues located at interface: 15  / 67 (predicted / true)


{('3a6v_B', 52),
 ('3a6v_B', 56),
 ('5gox_A', 100),
 ('5gox_A', 103),
 ('6hbe_C', 79),
 ('6hbe_C', 266),
 ('6j27_C', 117),
 ('6j27_C', 120),
 ('6sev_A', 24),
 ('6sev_A', 51),
 ('6sev_A', 55),
 ('7ukh_D', 104),
 ('7ukh_D', 110),
 ('7ukh_D', 131),
 ('7ukh_D', 132)}

## how many predicted sites can be modeled on interface

In [8]:
## proteins with homomer (modeled, four species, from 'An atlas of protein homo-oligomerization across domains of life')
df = pd.read_table("~/database/pdb/homomer/homomer_four_species.csv")
sp_abbr = {'hs': 'human', 'sc': 'yeast', 'ec': 'ecoli', 'pf': 'pfuri'}
df['species'] = df['org'].map(lambda x: sp_abbr[x])
uniprot_to_species = dict(zip(df['code_sub'], df['species']))
del df

  df = pd.read_table("~/database/pdb/homomer/homomer_four_species.csv")


In [9]:
## pdb to uniprot
df = pd.read_table("~/database/uniprot/idmapping/pdb2uniprot.csv")
pdb_to_uniprot = dict(zip(zip(df['pdb'], df['chain']), df['uniprot']))
extra = {
    ('8hmq', 'B'): 'P14618',
    ('8ba5', 'A'): 'Q9H3E2',
}
pdb_to_uniprot.update(extra) # add human proteins in test_pred (no yeast, ecoli, pfuri proteins found)
del df

In [12]:
def calc_avg_plddt(
    id: str,
    pdb_file: str
):
    residues = PDBParser(QUIET=True).get_structure(id, pdb_file).get_residues()
    plddts = []
    for r in residues:
        ca: Atom = r['CA']
        plddt = ca.get_bfactor()
        plddts.append(plddt)
    return np.average(plddts)

path = "~/database/pdb/species"
def get_af2_pdb(
    uniprot: str,
    species: str,
):
    pdb_file_path = os.path.join(path, f"{species}_af2")
    pdb_files = os.popen(f"find {pdb_file_path} -name *{uniprot}*").readlines()
    pdb_files = [os.path.join(pdb_file_path, i.strip()) for i in pdb_files]
    return pdb_files

In [13]:
## predicted as true in test dataset, and has homomers
records = []
for _, row in df_pred.iterrows():
    seq_id = row['seq_id']
    pdb, chain = seq_id.split("_") # no split domains
    if (pdb, chain) in pdb_to_uniprot.keys():
        uniprot = pdb_to_uniprot[(pdb, chain)]
        if uniprot in uniprot_to_species.keys():
            sp = uniprot_to_species[uniprot]
            pdb_files = get_af2_pdb(uniprot, sp)
            assert len(pdb_files) == 1 # no fragments
            pdb_file = pdb_files[0]
            avg_plddt = calc_avg_plddt(row['seq_id'], pdb_file)
            records.append({
                **row,
                "uniprot": uniprot,
                "species": sp,
                "avg_plddt": avg_plddt,
            })
df_pred_with_homomers = pd.DataFrame(records)
del records

## labeled as true interface sites, and has homomers
records = []
for _, row in df_anno.iterrows():
    assert row['domain'] == " "
    pdb, chain = row['pdb'], row['metal_chain']
    if (pdb, chain) in pdb_to_uniprot.keys():
        uniprot = pdb_to_uniprot[(pdb, chain)]
        if uniprot in uniprot_to_species.keys():
            sp = uniprot_to_species[uniprot]
            records.append({
                **row,
                "uniprot": uniprot,
                "species": sp,
            })
df_anno_with_homomers = pd.DataFrame(records)
len(df_anno_with_homomers)

13

In [14]:
def get_homomer_protein(uniprot):
    path = "~/database/pdb/homomer/AF_dimer_models_full_length_relaxed" # NOTE: we consider dimer here
    file = os.popen(f"ls {path}/{uniprot}*").readline().strip()
    return file

def get_inter_residues(
    uniprot: str, 
    pdb_file: str,
    inter_threshold: float = 8.,
    neighbor_num_threshold: int = 2, 
):
    atoms = []
    residues = list(PDBParser(QUIET=True).get_structure(uniprot, pdb_file).get_residues())
    for r in residues:
        atoms.append(r['CA'])
    ns = NeighborSearch(atoms)
    
    inter_residues = []
    for r in residues:
        ca_atom = r['CA']
        ca_atom: Atom
        neighbors = ns.search(ca_atom.get_vector().get_array(), inter_threshold, "R")

        other_chain_neighbors = []
        for n in neighbors:
            if n.get_full_id()[2] != r.get_full_id()[2]:
                other_chain_neighbors.append(n)
        if len(other_chain_neighbors) >= neighbor_num_threshold:
            inter_residues.append(r)
    
    result = set()
    for r in inter_residues:
        seq_num = r.get_full_id()[3][1]
        posi = seq_num - 1
        result.add((uniprot, posi))
    return result

def get_pred_inter_residues(
    df_pred: pd.DataFrame,   
    inter_threshold: float = 8.,
    neighbor_num_threshold: int = 2,
    plddt_threshold: int = 70
):
    df = df_pred[df_pred['avg_plddt'] >= plddt_threshold]
    pred_residues = set()
    for _, row in df.iterrows():
        pred_residues.add((row['uniprot'], row['resi_seq_posi_1']))
        pred_residues.add((row['uniprot'], row['resi_seq_posi_2']))
    homomers = set(df['uniprot'])
    inter_residues = set()
    for uniprot in homomers:
        result = get_inter_residues(uniprot, get_homomer_protein(uniprot), inter_threshold, neighbor_num_threshold)
        inter_residues |= result
    pred_inter_residues = pred_residues & inter_residues
    return pred_inter_residues


def calc_metrics(
    true_residues: set,
    pred_residues: set,
):
    result = dict()
    intersection = true_residues & pred_residues
    recall = len(intersection) / len(true_residues) if len(true_residues) != 0 else 0
    precision = len(intersection) / len(pred_residues) if len(pred_residues) != 0 else 0
    f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0
    result['precision'] = precision
    result['recall'] = recall
    result['f1'] = f1
    return result

In [15]:
anno_inter_residues = set(zip(df_anno_with_homomers['uniprot'], df_anno_with_homomers['resi_domain_posi']))

inter_thresholds = [4, 6, 8, 10]
neighbor_num_thresholds = [1, 2, 3]
plddt_thresholds = [60, 70, 80, 90]

In [16]:
from itertools import product

records = []
for i, n, p in product(inter_thresholds, neighbor_num_thresholds, plddt_thresholds):
    pred_inter_residues = get_pred_inter_residues(df_pred_with_homomers, i, n, p)
    record = calc_metrics(anno_inter_residues, pred_inter_residues)
    record.update({
        "inter_threshold": i,
        "neighbor_num_threshold": n,
        "plddt_threshold": p
    })
    records.append(record)

In [17]:
df_result = pd.DataFrame(records)
df_result.iloc[df_result['f1'].argmax()]

precision                  0.500000
recall                     0.307692
f1                         0.380952
inter_threshold            8.000000
neighbor_num_threshold     1.000000
plddt_threshold           70.000000
Name: 25, dtype: float64

In [18]:
anno_inter_residues
get_pred_inter_residues(df_pred_with_homomers, 8, 1, 60)

{('P31415', 127),
 ('P31415', 134),
 ('P31415', 258),
 ('P31415', 260),
 ('P31415', 263),
 ('P31415', 329),
 ('P31415', 330),
 ('Q08499', 79),
 ('Q08499', 81),
 ('Q9NZV8', 104),
 ('Q9NZV8', 110),
 ('Q9NZV8', 131),
 ('Q9NZV8', 132)}

{('A2RUC4', 182),
 ('A2RUC4', 184),
 ('P04183', 185),
 ('P35914', 238),
 ('Q08499', 243),
 ('Q9NZV8', 104),
 ('Q9NZV8', 110),
 ('Q9NZV8', 131),
 ('Q9NZV8', 132)}