<a href="https://colab.research.google.com/github/xy2119/SO3_Invariant_ProteinMPNN/blob/main/notebooks/find_pdb_domain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import requests
import json
import pandas as pd
from typing import List
import time
import glob

In [None]:
# Read data
df = pd.read_csv('/content/pdb_2021aug02/list.csv')
df['PROTEIN'] = df['CHAINID'].str[:4].str.upper()

In [None]:
df

Unnamed: 0,CHAINID,DEPOSITION,RESOLUTION,HASH,CLUSTER,SEQUENCE,PROTEIN
0,5naf_A,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF
1,5naf_B,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF
2,5naf_C,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF
3,5naf_D,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF
4,5nag_A,2017-02-27,1.680,82310,6750,MTATDNARQVTIIGAGLAGTLVARLLARNGWQVNLFERRPDPRIET...,5NAG
...,...,...,...,...,...,...,...
555780,4a11_B,2011-09-13,3.310,67664,25947,MLGFLSARQTGLEDPLRLRRAESTRRVLGLELNKDRDVERIHGGGI...,4A11
555781,4a12_A,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12
555782,4a12_B,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12
555783,4a12_C,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12


In [None]:
# Read data
df = pd.read_csv('/content/pdb_2021aug02_sample/list.csv')
df['PROTEIN'] = df['CHAINID'].str[:4].str.upper()

# Read test_clusters.txt and valid_clusters as lists of integers
with open("/content/pdb_2021aug02_sample/test_clusters.txt", "r") as f:
    test_clusters = [int(line.strip()) for line in f.readlines()]

with open("/content/pdb_2021aug02_sample/valid_clusters.txt", "r") as f:
    valid_clusters = [int(line.strip()) for line in f.readlines()]

# If number in df["CLUSTER"] is in test_clusters, then set df["TEST"] = True
df["TEST"] = df["CLUSTER"].isin(test_clusters)
df["VALID"] = df["CLUSTER"].isin(valid_clusters)
df["TRAIN"] = ~df["CLUSTER"].isin(test_clusters) & ~df["CLUSTER"].isin(valid_clusters)

train_protein = df[df['TRAIN'] == True]['PROTEIN'].unique()
valid_protein = df[df['VALID'] == True]['PROTEIN'].unique()
test_protein= df[df['TEST'] == True]["PROTEIN"].unique()

print("No. of unique protein in train/valid/test")
len(train_protein),len(valid_protein),len(test_protein)

(148004, 8071, 7470)

In [None]:
print("No. of protein:",len(df["PROTEIN"]))

555785

In [None]:
def get_RCSB_data_GRAPHQl(instance_ids: List[str]):
    """
    This function takes a list of PDB IDs (instance_ids) and returns the domain
    information from the RCSB DB using the GRAPHQL API  

    Args:
        instance_ids (List[str]): List of PDB IDs (e.g. ["5NAJ.A", "5NAJ.B"])
    
    A generic query string sent via HTTP looks like:
    https://data.rcsb.org/graphql?query={entry(entry_id:"4HHB"){exptl{method}}}
    """
    # Create instance IDs list
    # Convert the list to a comma-separated string
    instance_ids_str = ",".join([f'"{instance_id}"' for instance_id in instance_ids])
    # Define the API endpoint and GraphQL query
    api_url = "https://data.rcsb.org/graphql"
    query=f"""
      query {{
        polymer_entity_instances(instance_ids: [{instance_ids_str}]) {{
          rcsb_id
          rcsb_polymer_instance_annotation {{
            type
            annotation_lineage {{
              name
              depth
            }}
          }}
        }}
      }}
    """
    # Make the API request and parse the response JSON
    response = requests.post(api_url, json={'query': query})
    
    if response.status_code != 200:
        raise Exception(f"API request failed with status code {response.status_code}")
        
    response_json = response.json()
    # Extract the domain data from the response JSON
    return response_json["data"]["polymer_entity_instances"]


def extract_domain_information(
        query_ids: List, 
        response_data: List[dict],
        type: List[str] = ["CATH","ECOD","SCOP","SCOP2"]
    ):
    """
    This function takes a list of PDB IDs (query_ids) and the domain information
    returned by the RCSB DB (response_data) and returns a dictionary of the
    domain information (as classified by the type) for each PDB ID
    """
    # Create a dictionary to store the domain information & a list of IDs with no data
    domain_info = {}
    no_data_ids = []
    # Loop over the response data
    for data in response_data:
        # Check if the query ID matches the RCSB ID
        if (query_id := data["rcsb_id"]) in query_ids:
            # Loop over the domain annotations
            domain_data = {}
            try:
                for i in data["rcsb_polymer_instance_annotation"]:
                    pass
            except:
                no_data_ids.append(query_id)
                continue # skip the rest of this loop

            for annotation in data["rcsb_polymer_instance_annotation"]:
                # Check if the annotation is a domain
                if (a_type := annotation["type"]) in type:
                    # Loop over the annotation lineage
                    for lineage in annotation["annotation_lineage"]:
                        # e.g. {"CATH_1": Alpha Beta}
                        domain_data[f"{a_type}_{lineage['depth']}"] = lineage["name"]
                    
            domain_info[query_id] = domain_data
            # delete query_id from query_ids list
            query_ids.remove(query_id) 
        
    no_data_ids += query_ids

    return domain_info, no_data_ids


In [None]:
df['CHAINID'] = df['CHAINID'].str.replace('_', '.').str.upper()

In [None]:
# Convert batch by batch and store the results in a list
results = []
missing_ids=[]
for i in range(0, len(df), 500):
    batch_df = df.iloc[i:i+500]
    batch_ids = batch_df["CHAINID"].to_list()
    batch_data = get_RCSB_data_GRAPHQl(batch_ids)
    batch_formatted_data, batch_missing_ids = extract_domain_information(batch_ids, batch_data)
    results.append(batch_formatted_data)
    missing_ids.append(batch_missing_ids)
# Concatenate all the results into a single DataFrame
formatted_data = pd.concat([pd.DataFrame.from_dict(d, orient='index') for d in results])
 
# Print the number of missing IDs
print(f"Number of missing IDs: {len(missing_ids)}")

Number of missing IDs: 1112


In [None]:
# join the two tables on the protein ID and the index of table2
merged_table = df.join(formatted_data[["CATH_1","CATH_2","CATH_3"]], on='CHAINID')
merged_table

Unnamed: 0,CHAINID,DEPOSITION,RESOLUTION,HASH,CLUSTER,SEQUENCE,PROTEIN,TEST,VALID,TRAIN,CATH_1,CATH_2,CATH_3
0,5NAF.A,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF,False,False,True,Mainly Beta,7 Propeller,Methylamine Dehydrogenase
1,5NAF.B,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF,False,False,True,Mainly Beta,7 Propeller,Methylamine Dehydrogenase
2,5NAF.C,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF,False,False,True,Mainly Beta,7 Propeller,Methylamine Dehydrogenase
3,5NAF.D,2017-02-27,2.493,57113,12123,MGSSHHHHHHSSGLEVLFQGPEENGAHTIANNHTDMMEVDGDVEIP...,5NAF,False,False,True,Mainly Beta,7 Propeller,Methylamine Dehydrogenase
4,5NAG.A,2017-02-27,1.680,82310,6750,MTATDNARQVTIIGAGLAGTLVARLLARNGWQVNLFERRPDPRIET...,5NAG,True,False,False,Alpha Beta,3-Layer(bba) Sandwich,FAD/NAD(P)-binding domain
...,...,...,...,...,...,...,...,...,...,...,...,...,...
555780,4A11.B,2011-09-13,3.310,67664,25947,MLGFLSARQTGLEDPLRLRRAESTRRVLGLELNKDRDVERIHGGGI...,4A11,False,False,True,Mainly Beta,7 Propeller,Methylamine Dehydrogenase
555781,4A12.A,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12,False,False,True,Alpha Beta,Roll,Thiol Ester Dehydrase
555782,4A12.B,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12,False,False,True,Alpha Beta,Roll,Thiol Ester Dehydrase
555783,4A12.C,2011-09-13,3.150,111624,5197,XRGETLKLKKDKRREAIRQQIDSNPFITDHELSDLFQVSIQTIRLD...,4A12,False,False,True,Alpha Beta,Roll,Thiol Ester Dehydrase


In [None]:
merged_table.to_csv("pdb_domain.csv")

In [None]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import os

for subset in ['TRAIN', 'VALID', 'TEST']:
    # filter the merged table by the current subset and group it by "PROTEIN"
    subset_table = merged_table[merged_table[subset] == True]
    protein_groups = subset_table.groupby('PROTEIN').agg({
        'CHAINID': 'count',
        'CATH_1': 'first',
        'CATH_2': 'first',
        'CATH_3': 'first',
        'SEQUENCE': lambda x: ''.join(x)
    }).reset_index()

    if not os.path.exists(f'./img/{subset.lower()}/'):
        os.makedirs(f'./img/{subset.lower()}/')

    # extract the sequence length for each protein in the new table
    seq_lengths = protein_groups['SEQUENCE'].str.len()

    # create a new histogram figure for the sequence lengths
    fig = go.Figure(data=[go.Histogram(x=seq_lengths)])

    # customize the appearance of the figure
    fig.update_layout(
        xaxis_title='Sequence Length (Number of Residues)',
        yaxis_title='Count',
        title=f'Histogram of Sequence Lengths in {subset.lower().capitalize()} Subset (Grouped by Protein)'
    )
     
    # show the plot
    fig.show()
    # save the plot as an image
    fig.write_html(f'./img/{subset.lower()}/sequence_length_histogram.html')


    for level in ['CATH_1','CATH_2', 'CATH_3']:
        # group the current subset by the current level and count the number of occurrences of each protein family
        grouped_data = protein_groups.groupby(level, sort=False)['CHAINID'].count().reset_index(name='count').sort_values('count', ascending=False)

        # create a new histogram figure using the sorted and grouped data
        fig = px.bar(grouped_data, x=level, y='count')

        # customize the appearance of the figure
        fig.update_layout(
            xaxis_title=f'Protein Family ({level})',
            yaxis_title='Count',
            title=f'Histogram of Protein Family Distribution in {subset.lower().capitalize()} Set ({level}, Descending Order)'
        )
        
        fig.update_yaxes(range=[0, grouped_data['count'].max()])

        # show the plot
        fig.show()

        # save the plot as an image
        fig.write_html(f'./img/{subset.lower()}/{level}_histogram.html')
