In [1]:
import os, sys, torch, prody
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

self_dir = os.getcwd()
root_dir = os.path.normpath(self_dir + '/..' * 2)
package_dir = os.path.join(root_dir, 'src')
sys.path.append(package_dir)

# from ml_modules.data.datasets import Dataset
from ml_modules.data.enm import TNM_Computer
from ml_modules.data.retrievers import AlphaFold_Retriever



data_dir = '../../data/processed/v7a'

tnm_computer = TNM_Computer()
pdb_retriever = AlphaFold_Retriever()

coupling_types = ['codir', 'coord', 'deform']
edge_types = ['contact'] + coupling_types

vmd_resolution = 1024
v_cmap = plt.get_cmap('plasma')

setup = 'contact_12-codir_1CONT-coord_1CONT-deform_1CONT'

n_entries_to_process = 150


  import pkg_resources


In [2]:

stats_dir = 'stats'
os.makedirs(stats_dir, exist_ok=True)

plots_dir = 'plots'
os.makedirs(plots_dir, exist_ok=True)


et_color = {
    'contact': (0., 0.6039215686274509, 0.8705882352941177),
    'codir': (0., 0.803921568627451, 0.4235294117647059),
    'coord': (0.6862745098039216, 0.34509803921568627, 0.7294117647058823),
    'deform': (1., 0.7764705882352941, 0.11764705882352941),
}

labels = {
    et: f'{et} coupling' if et != 'contact' else 'distance (Ã…)'
    for et in edge_types
}

chain_color = {
    '0A': [0,0,255,255],
    '0B': [255,0,0,255],
    '0C': [0,255,255,255],
}

residue_color = {
    'ASP': [230,10,10,255],
    'GLU': [230,10,10,255],
    'CYS': [230,230,0,255],
    'MET': [230,230,0,255],
    'LYS': [20,90,255,255],
    'ARG': [20,90,255,255],
    'SER': [250,150,0,255],
    'THR': [250,150,0,255],
    'PHE': [50,50,170,255],
    'TYR': [50,50,170,255],
    'ASN': [0,220,220,255],
    'GLN': [0,220,220,255],
    'GLY': [235,235,235,255],
    'LEU': [15,130,15,255],
    'VAL': [15,130,15,255],
    'ILE': [15,130,15,255],
    'ALA': [200,200,200,255],
    'TRP': [180,90,180,255],
    'HIS': [130,130,210,255],
    # 'HSE': [130,130,210,255],
    'PRO': [220,150,130,255],
}

res3_to_res1 = {
    'CYS': 'C',
    'ASP': 'D',
    'SER': 'S',
    'GLN': 'Q',
    'LYS': 'K',
    'ILE': 'I',
    'PRO': 'P',
    'THR': 'T',
    'PHE': 'F',
    'ASN': 'N',
    'GLY': 'G',
    'HIS': 'H',
    #  'HSE': 'H',
    'LEU': 'L',
    'ARG': 'R',
    'TRP': 'W',
    'ALA': 'A',
    'VAL': 'V',
    'GLU': 'E',
    'TYR': 'Y',
    'MET': 'M',
}


In [3]:
src_dir = '../20250526-1 true and baseline vs dynamics performance/stats'

src_list = [
    f'{src_dir}/accessions - largest improvement (baseline 0 vs dynamics 4).csv',
    f'{src_dir}/accessions - largest improvement (baseline 1 vs dynamics 4).csv',
    f'{src_dir}/accessions - largest improvement (baseline 2 vs dynamics 4).csv',
]

accessions_to_process = np.unique(np.concatenate([
    np.loadtxt(
        os.path.abspath(src_file),
        usecols=0,
        delimiter=',',
        dtype=np.str_
    )[:n_entries_to_process] for src_file in src_list
], axis=0))

print('Number of accessions:', len(accessions_to_process))
print(accessions_to_process)


Number of accessions: 326
['A0A077S9N1' 'A0A0R4J009' 'A1Z7A8' 'A2BE93' 'A3FPI6' 'A5PMP1' 'A6H5X4'
 'A7MCK9' 'B2GTW6' 'D3Z645' 'F1RET2' 'F8WIT2' 'G5EFJ5' 'G5EFV5' 'I2HA94'
 'I7FXD9' 'O02373' 'O04420' 'O05264' 'O08908' 'O08914' 'O14508' 'O16406'
 'O17754' 'O22832' 'O23553' 'O31697' 'O31775' 'O32199' 'O34687' 'O34723'
 'O61980' 'O62431' 'O65693' 'O75529' 'O76075' 'O80934' 'O95551' 'P00358'
 'P00903' 'P00950' 'P01123' 'P02394' 'P02929' 'P02930' 'P03813' 'P03899'
 'P07395' 'P07487' 'P08997' 'P09029' 'P09372' 'P09734' 'P0A6J8' 'P0A794'
 'P0A7J3' 'P0A8U0' 'P0A8W0' 'P0A9W6' 'P0AAD6' 'P0AC59' 'P0ACN7' 'P0AEZ3'
 'P0AFE4' 'P0AFX4' 'P0AFX9' 'P0AG40' 'P0AGB6' 'P0CF20' 'P0CZ23' 'P10127'
 'P10674' 'P13056' 'P14802' 'P15565' 'P15624' 'P16371' 'P16989' 'P20042'
 'P21365' 'P23396' 'P23715' 'P25294' 'P27111' 'P28188' 'P28248' 'P28305'
 'P28707' 'P29746' 'P30177' 'P30750' 'P32795' 'P32939' 'P33204' 'P34460'
 'P35178' 'P35659' 'P36771' 'P37051' 'P38646' 'P39522' 'P39814' 'P40005'
 'P40161' 'P40423' 'P42226

In [4]:

pdb_src_dir = '/mnt/hdd/yenlin/data/AlphaFoldDB/pdb'
graph_src_dir = f'/mnt/hdd/yenlin/data/v7a/graphs/{setup}'

pdb_files_missing = []
graph_files_missing = []

for acc_idx, accession in enumerate(accessions_to_process):

    pdb_file = os.path.abspath(pdb_retriever.path_to_file(accession))
    if not os.path.exists(pdb_file):
        pdb_files_missing.append(os.path.basename(pdb_file))

    graph_file = os.path.abspath(f'{data_dir}/graphs/{setup}/{accession}-AFv4.pt')
    if not os.path.exists(graph_file):
        graph_files_missing.append(os.path.basename(graph_file))

print(f'PDB files to retrieve from dp180:{pdb_src_dir}')
print(pdb_files_missing)
print()
print(f'Graph files to retrieve from dp180:{graph_src_dir}')
print(graph_files_missing)


PDB files to retrieve from dp180:/mnt/hdd/yenlin/data/AlphaFoldDB/pdb
[]

Graph files to retrieve from dp180:/mnt/hdd/yenlin/data/v7a/graphs/contact_12-codir_1CONT-coord_1CONT-deform_1CONT
[]


In [5]:

stats_save_dir = f'{stats_dir}/{setup}'
os.makedirs(stats_save_dir, exist_ok=True)

pbar = tqdm(accessions_to_process)
for acc_idx, accession in enumerate(pbar):
    pbar.set_description(accession)

    save_file = f'{stats_save_dir}/metrics - {accession}.csv'
    if os.path.exists(save_file):
        continue

    pdb_file = os.path.abspath(pdb_retriever.path_to_file(accession))
    assert os.path.exists(pdb_file)

    graph_file = os.path.abspath(f'{data_dir}/graphs/{setup}/{accession}-AFv4.pt')
    data = torch.load(graph_file)

    atoms = prody.parsePDB(pdb_file, subset='ca')
    resnames1 = [
        res3_to_res1[resname] for resname in atoms.getResnames()
    ]

    n_nodes = data.num_nodes
    node_list = np.arange(n_nodes, dtype=np.int_)

    graphs = {}
    header = ['node_id', 'resnames']
    all_values = [node_list, resnames1]
    for et in edge_types:

        graphs[et] = nx.Graph()
        graphs[et].add_nodes_from(node_list)

        edge_index = data['residue', et, 'residue'].edge_index
        # remove reverse edges for undirected graphs
        edge_index = edge_index[:, edge_index[0] < edge_index[1]]

        graphs[et].add_edges_from(edge_index.T.tolist())

        assert graphs[et].number_of_nodes() == n_nodes
        assert graphs[et].number_of_edges() == edge_index.shape[1]

        ################################################################
        # COMPUTE METRICS
        ################################################################

        # DEGREE
        degree = np.array(graphs[et].degree())[:,1]
        header.append(f'{et}_degree')
        all_values.append(degree.astype(np.str_))

        # BETWEENNESS
        betweenness = np.array(list(
            nx.betweenness_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_betweenness')
        all_values.append([f'{v:.4f}' for v in betweenness])

        # CLOSENESS
        closeness = np.array(list(
            nx.closeness_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_closeness')
        all_values.append([f'{v:.4f}' for v in closeness])

        # LAPLACIAN
        laplacian = np.array(list(
            nx.laplacian_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_laplacian')
        all_values.append([f'{v:.4f}' for v in laplacian])

        # SUBGRAPH
        subgraph = np.array(list(
            nx.subgraph_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_subgraph')
        all_values.append([f'{v:.4f}' for v in subgraph])

        # HARMONIC
        harmonic = np.array(list(
            nx.harmonic_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_harmonic')
        all_values.append([f'{v:.4f}' for v in harmonic])

        # PERCOLATION
        percolation = np.array(list(
            nx.percolation_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_percolation')
        all_values.append([f'{v:.4f}' for v in percolation])

        # LOAD
        load = np.array(list(
            nx.load_centrality(graphs[et]).items()
        ))[:,1]
        header.append(f'{et}_load')
        all_values.append([f'{v:.4f}' for v in load])

    ### SAVE ALL VALUES
    np.savetxt(
        save_file,
        np.array(all_values).T,
        delimiter=',',
        header=','.join(header),
        fmt='%s'
    )


  0%|          | 0/326 [00:00<?, ?it/s]

  data = torch.load(graph_file)
  data = torch.load(graph_file)
