In [14]:
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from biopandas.pdb import PandasPdb
from prody import parsePDBHeader
from tqdm import tqdm
from utils import read_pdb_to_dataframe


In [15]:
amino_acid_dict = {
    "ALA": "A",
    "CYS": "C",
    "ASP": "D",
    "GLU": "E",
    "PHE": "F",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LYS": "K",
    "LEU": "L",
    "MET": "M",
    "ASN": "N",
    "PRO": "P",
    "GLN": "Q",
    "ARG": "R",
    "SER": "S",
    "THR": "T",
    "VAL": "V",
    "TRP": "W",
    "TYR": "Y",
}


In [67]:
def get_binding_residues(
    df_antibody: pd.DataFrame, df_antigen: pd.DataFrame
) -> Dict[str, Dict[int, str]]:
    """Return dictionary with imgt positions mapping to aa name, and list of distances to antigen.

    Args:
        df_antibody (pd.DataFrame): Dataframe of antibody chain with columns taken from pdb file.
        df_antigen (pd.DataFrame): Dataframe of antigen chain with columns taken from pdb file.

    Returns:
        Dict[str, Dict[int, str]]: Dictionary with imgt positions mapping to aa name, and list \
            of distances to antigen
    """

    antibody_coords = df_antibody[["x", "y", "z"]].astype(float).to_numpy()
    antigen_coords = df_antigen[["x", "y", "z"]].astype(float).to_numpy()

    # Compute pairwise distances

    values = df_antibody[["AA", "Res_Num"]].values
    distances = np.linalg.norm(antibody_coords[:, np.newaxis] - antigen_coords, axis=2)
    amino_dict = {"positions": {}, "distances": {}}

    for row_ab, (AA, Res_Num) in enumerate(values):
        # print(AA,Res_Num)
        if Res_Num=="nan":
            raise ValueError(Res_Num)
        amino_dict["positions"][Res_Num] = amino_acid_dict[AA]
        if Res_Num not in amino_dict["distances"]:
            amino_dict["distances"][Res_Num] = [np.min(distances[row_ab, :])]

        else:
            amino_dict["distances"][Res_Num].append(np.min(distances[row_ab, :]))
    return amino_dict


In [68]:
def format_pdb(pdb_file):
    '''
    Process pdb file into pandas df

    Original author: Alissa Hummer

    :param pdb_file: file path of .pdb file to convert
    :returns: df with atomic level info
    '''

    pd.options.mode.chained_assignment = None
    pdb_whole = pd.read_csv(pdb_file,header=None,delimiter='\t')
    pdb_whole.columns = ['pdb']
    pdb = pdb_whole[pdb_whole['pdb'].str.startswith('ATOM')]
    pdb['Atom_Num'] = pdb['pdb'].str[6:11].copy()
    pdb['Atom_Name'] = pdb['pdb'].str[11:16].copy()
    pdb['AA'] = pdb['pdb'].str[17:20].copy()
    pdb['Chain'] = pdb['pdb'].str[20:22].copy()
    pdb['Res_Num'] = pdb['pdb'].str[22:27].copy().str.strip()
    pdb['x'] = pdb['pdb'].str[27:38].copy()
    pdb['y'] = pdb['pdb'].str[38:46].copy()
    pdb['z'] = pdb['pdb'].str[46:54].copy()#
    pdb['Atom_type'] = pdb['pdb'].str[77].copy()
    pdb.drop('pdb',axis=1,inplace=True)
    pdb.replace({' ':''}, regex=True, inplace=True)
    pdb.reset_index(inplace=True)
    pdb.drop('index',axis=1,inplace=True)

    # remove H atoms from our data (interested in heavy atoms only)
    pdb = pdb[pdb['Atom_type']!='H']

    return pdb


In [104]:
df = format_pdb("/home/gathenes/all_structures/imgt/1bvk.pdb")


In [105]:
print(df.columns)


Index(['Atom_Num', 'Atom_Name', 'AA', 'Chain', 'Res_Num', 'x', 'y', 'z',
       'Atom_type'],
      dtype='object')


In [107]:
df_chain_heavy = df.query("Chain == 'A'")
df_chain_light = df.query("Chain == 'B'")
df_chain_antigen = df.query("Chain.isin(['C'])")


In [108]:
print(df_chain_antigen)


     Atom_Num Atom_Name   AA Chain Res_Num       x       y       z Atom_type
0           1         N  LYS     C       1  28.475  91.225  65.120         N
1           2        CA  LYS     C       1  29.068  90.365  66.170         C
2           3         C  LYS     C       1  28.160  89.218  66.450         C
3           4         O  LYS     C       1  27.013  89.191  66.023         O
4           5        CB  LYS     C       1  29.274  91.178  67.454         C
...       ...       ...  ...   ...     ...     ...     ...     ...       ...
996       997        CB  LEU     C     129  12.071  87.273  77.569         C
997       998        CG  LEU     C     129  12.758  86.136  76.780         C
998       999       CD1  LEU     C     129  11.740  85.152  76.219         C
999      1000       CD2  LEU     C     129  13.589  86.731  75.647         C
1000     1001       OXT  LEU     C     129   9.892  85.376  79.471         O

[1001 rows x 9 columns]


In [None]:
abr_dict_heavy = get_binding_residues(df_chain_heavy, df_chain_antigen)


In [None]:
print(abr_dict_heavy)


{'positions': {'1': 'E', '2': 'V', '3': 'Q', '4': 'L', '5': 'V', '6': 'E', '7': 'S', '8': 'G', '9': 'G', '11': 'G', '12': 'L', '13': 'V', '14': 'Q', '15': 'P', '16': 'G', '17': 'G', '18': 'S', '19': 'L', '20': 'R', '21': 'L', '22': 'S', '23': 'C', '24': 'A', '25': 'A', '26': 'S', '27': 'G', '28': 'F', '29': 'D', '30': 'I', '35': 'Y', '36': 'D', '37': 'D', '38': 'D', '39': 'I', '40': 'H', '41': 'W', '42': 'V', '43': 'R', '44': 'Q', '45': 'A', '46': 'P', '47': 'G', '48': 'K', '49': 'G', '50': 'L', '51': 'E', '52': 'W', '53': 'V', '54': 'A', '55': 'Y', '56': 'I', '57': 'A', '58': 'P', '59': 'S', '62': 'Y', '63': 'G', '64': 'Y', '65': 'T', '66': 'D', '67': 'Y', '68': 'A', '69': 'D', '70': 'S', '71': 'V', '72': 'K', '74': 'G', '75': 'R', '76': 'F', '77': 'T', '78': 'I', '79': 'S', '80': 'A', '81': 'D', '82': 'T', '83': 'S', '84': 'K', '85': 'N', '86': 'T', '87': 'A', '88': 'Y', '89': 'L', '90': 'Q', '91': 'M', '92': 'N', '93': 'S', '94': 'L', '95': 'R', '96': 'A', '97': 'E', '98': 'D', '99'

In [None]:
from utils import get_labels


In [None]:
labels_heavy_4_5, sequence_heavy, numbers_heavy = get_labels(abr_dict_heavy["positions"], abr_dict_heavy["distances"], alpha=4.5)


In [None]:
print(labels_heavy_4_5, sequence_heavy, numbers_heavy)
print(len(labels_heavy_4_5), len(sequence_heavy), len(numbers_heavy))


[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ['E', 'V', 'Q', 'L', 'V', 'E', 'S', 'G', 'G', 'G', 'L', 'V', 'Q', 'P', 'G', 'G', 'S', 'L', 'R', 'L', 'S', 'C', 'A', 'A', 'S', 'G', 'F', 'D', 'I', 'Y', 'D', 'D', 'D', 'I', 'H', 'W', 'V', 'R', 'Q', 'A', 'P', 'G', 'K', 'G', 'L', 'E', 'W', 'V', 'A', 'Y', 'I', 'A', 'P', 'S', 'Y', 'G', 'Y', 'T', 'D', 'Y', 'A', 'D', 'S', 'V', 'K', 'G', 'R', 'F', 'T', 'I

In [123]:
def rec_dd():
    return defaultdict(rec_dd)
def build_dictionary(pdbs_and_chain:pd.DataFrame)->Dict:
    """Transform dataframe with pdb codes and heavy and light chain names into Dictionary with \
        indices mapping to heavy and light lists of matching imgt numbers, sequences and labels.

    Args:
        pdbs_and_chain (pd.DataFrame): Dataframe with pdb codes and heavy and light chain names.

    Returns:
        Dict: Dictionary with indices mapping to heavy and light lists of matching imgt numbers, \
            sequences and labels.
    """
    dataset_dict = rec_dd()
    for index in tqdm(range(len(pdbs_and_chain))):
        pdb_code = pdbs_and_chain.iloc[index]["pdb"]
        if pdb_code=='2ltq':
            continue
        H_id = pdbs_and_chain.iloc[index]["Hchain"]
        L_id = pdbs_and_chain.iloc[index]["Lchain"]
        antigen_id = pdbs_and_chain.iloc[index]["antigen_chain"]
        df = format_pdb(f"/home/gathenes/all_structures/imgt/{pdb_code}.pdb")

        df_chain_heavy = df.query("Chain == @H_id")
        df_chain_light = df.query("Chain == @L_id")
        antigen_ids=antigen_id.split(";")
        df_chain_antigen = df.query("Chain.isin(@antigen_ids)")
        abr_dict_heavy = get_binding_residues(df_chain_heavy, df_chain_antigen)
        labels_heavy_4_5, sequence_heavy, numbers_heavy = get_labels(abr_dict_heavy["positions"], abr_dict_heavy["distances"], alpha=4.5)
        abr_dict_light = get_binding_residues(df_chain_light, df_chain_antigen)
        labels_light_4_5, sequence_light, numbers_light = get_labels(abr_dict_light["positions"], abr_dict_light["distances"], alpha=4.5)

        dataset_dict[index]["pdb_code"]=pdb_code
        inverse_number_heavy = {each : i for i,each in enumerate(numbers_heavy)}
        inverse_number_light = {each : i for i,each in enumerate(numbers_light)}
        left, right=1, 128
        while str(left) not in inverse_number_heavy :
            left+=1
        while str(right) not in inverse_number_heavy:
            right-=1
        heavy_left, heavy_right = inverse_number_heavy[str(left)], inverse_number_heavy[str(right)]
        left, right=1, 128
        while str(left) not in inverse_number_light :
            left+=1
        while str(right) not in inverse_number_light:
            right-=1
        light_left, light_right = inverse_number_light[str(left)], inverse_number_light[str(right)]

        labels_heavy_4_5, sequence_heavy, numbers_heavy = labels_heavy_4_5[heavy_left:heavy_right+1], sequence_heavy[heavy_left:heavy_right+1], numbers_heavy[heavy_left:heavy_right+1]
        labels_light_4_5, sequence_light, numbers_light = labels_light_4_5[light_left:light_right+1], sequence_light[light_left:light_right+1], numbers_light[light_left:light_right+1]

        dataset_dict[index]["H_id numbers"]=numbers_heavy
        dataset_dict[index]["L_id numbers"]=numbers_light

        dataset_dict[index]["H_id sequence"] = "".join(sequence_heavy)
        dataset_dict[index]["H_id labels 4.5"] = labels_heavy_4_5
        dataset_dict[index]["L_id sequence"] = "".join(sequence_light)
        dataset_dict[index]["L_id labels 4.5"] = labels_light_4_5

        for alpha in [3,3.5,4,5,5.5,6,6.5,7,7.5]:
            labels_heavy, _, _ = get_labels(abr_dict_heavy["positions"], abr_dict_heavy["distances"], alpha=alpha)
            labels_light, _, _ = get_labels(abr_dict_light["positions"], abr_dict_light["distances"], alpha=alpha)
            labels_heavy = labels_heavy[heavy_left:heavy_right+1]
            labels_light = labels_light[light_left:light_right+1]
            dataset_dict[index][f"H_id labels {alpha}"] = labels_heavy
            dataset_dict[index][f"L_id labels {alpha}"] = labels_light
    return dataset_dict


In [124]:
test=pd.read_csv("/home/gathenes/paragraph_benchmark/expanded_dataset/test_set.csv")


In [125]:
dictionary = build_dictionary(test)


  0%|          | 0/218 [00:00<?, ?it/s]

100%|██████████| 218/218 [01:54<00:00,  1.91it/s]


In [132]:
for each in dictionary.keys():
    print(each,dictionary[each]["pdb_code"])


0 1bvk
1 1egj
2 1fe8
3 1fj1
4 1fpt
5 1h0d
6 1hh6
7 1kb5
8 1kcs
9 1nfd
10 1nmc
11 1pkq
12 1qkz
13 1tet
14 1tzh
15 1v7m
16 1wej
17 1ztx
18 2dd8
19 2h1p
20 2jel
21 2nyy
22 2qsc
23 2r29
24 2zch
25 3fn0
26 3hmx
27 3hr5
28 3i50
29 3ifp
30 3l95
31 3o6l
32 3r1g
33 3s37
34 3sge
35 3t2n
36 3uc0
37 3uji
38 4bz1
39 4cni
40 4edw
41 4etq
42 4f37
43 4gms
44 4hha
45 4i2x
46 4i3r
47 4i77
48 4j6r
49 4jb9
50 4jkp
51 4jzj
52 4m8q
53 4n0y
54 4np4
55 4nrx
56 4om0
57 4onf
58 4qci
59 4qyo
60 4rgn
61 4u6v
62 4uu9
63 4wht
64 4xxd
65 4ydj
66 4yk4
67 4yue
68 5anm
69 5b71
70 5bk1
71 5bk2
72 5c0n
73 5csz
74 5d96
75 5do2
76 5e2w
77 5e94
78 5ggs
79 5hdq
80 5ijk
81 5iq9
82 5jq6
83 5k9q
84 5kvf
85 5l0q
86 5ldn
87 5mhr
88 5mv3
89 5myk
90 5myx
91 5n7w
92 5nph
93 5o4g
94 5t5f
95 5t6p
96 5te6
97 5tud
98 5usi
99 5v6l
100 5vag
101 5vic
102 5vpl
103 5vta
104 5vzy
105 5w08
106 5w0d
107 5w3p
108 5wna
109 5xcr
110 5xwd
111 5y2l
112 6a0z
113 6apb
114 6axk
115 6b08
116 6b0s
117 6ba5
118 6bit
119 6bqb
120 6c9u
121 6cdm
122 6cdp
123

In [133]:
print(each,dictionary[14])


217 defaultdict(<function rec_dd at 0x7f8010df2820>, {'pdb_code': '1tzh', 'H_id numbers': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '111A', '112A', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128'], 'L_id numbers': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', 