In [1]:
import pandas as pd
import numpy as np
import os
from tqdm.auto import tqdm

In [2]:
import pickle as pkl

def read_pkl(file_path):
    with open(file_path,'rb') as fr:
        return pkl.load(fr)

def save_pkl(file_path, val):
    fw = open(file_path, 'wb')
    pkl.dump(val, fw)
    fw.close()

In [3]:
tags = ['mf', 'cc', 'bp']
types = ['train', 'valid', 'test']

## 1. Firstly, we need generate pid_list_file

In [6]:
for tag in tags:
    for tp in types:
        pid_list = set()
        with open(f"./data_dpfunc/{tag}_{tp}_pid_list.txt", 'r') as f:
            lines = f.readlines()
            for line in lines:
                content = line.strip('\n').strip()
                pid_list.add(content)
        pid_list = list(pid_list)
        save_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl", pid_list)

## 2. pid_go_file

In [8]:
import os

In [10]:
for tag in tags:
    for tp in types:
        cmd = f"cp ./data_dpfunc/{tag}_{tp}_go.txt ./processed_file/{tag}_{tp}_go.txt"
        os.system(cmd)

## 3. pid_pdb_file

In [21]:
from tqdm.auto import tqdm
import gzip

In [None]:
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1

def extract_sequence_and_ca_coords(pdb_file, chain_id=None):
    parser = PDBParser(QUIET=True)
    if pdb_file.endswith('.gz'):
        with gzip.open(pdb_file, 'rt') as gz_file:
            temp_file = pdb_file.replace('.gz', '_temp')
            try:
                with open(temp_file, 'w') as temp:
                    temp.write(gz_file.read())
                
                structure = parser.get_structure('protein', temp_file)
            finally:
                if os.path.exists(temp_file):
                    os.remove(temp_file)
    else:
        structure = parser.get_structure('protein', pdb_file)
    
    results = {}
    
    for model in structure:
        for chain in model:
            if chain_id is None or chain.id == chain_id:
                sequence = ""
                ca_coords = []
                
                for residue in chain:
                    if residue.id[0] == ' ':
                        try:
                            aa = seq1(residue.resname)
                            sequence += aa
                            
                            if 'CA' in residue:
                                ca_atom = residue['CA']
                                coord = ca_atom.get_coord()
                                ca_coords.append((float(coord[0]), float(coord[1]), float(coord[2])))
                            else:
                                print(f"Warning: No CA atom found in residue {residue.resname}{residue.id[1]} of chain {chain.id}")
                                ca_coords.append(None)
                                
                        except KeyError:
                            print(f'Non natural residue in {pdb_file}')
                
                results[chain.id] = {
                    'sequence': sequence,
                    'ca_coords': ca_coords
                }
    
    return results

def extract_single_chain(pdb_file, chain_id='A'):
    results = extract_sequence_and_ca_coords(pdb_file, chain_id)
    
    if chain_id in results:
        return results[chain_id]['sequence'], results[chain_id]['ca_coords']
    else:
        print(f"Chain {chain_id} not found in PDB file")
        return "", []

In [6]:
pid_list = set()
for tag in tags:
    for tp in types:
        tp_pid_list = read_pkl(f'./processed_file/{tag}_{tp}_used_pid_list.pkl')
        pid_list = pid_list|set(tp_pid_list)
pid_list = list(pid_list)
len(pid_list)

59966

In [9]:
train_id_map = read_pkl('./data_dpfunc/train_id_map.pkl')
valid_id_map = read_pkl('./data_dpfunc/valid_id_map.pkl')
test_id_map = read_pkl('./data_dpfunc/test_id_map.pkl')

In [8]:
assert len(train_id_map)+len(valid_id_map)+len(test_id_map) == len(set(train_id_map.keys())|set(valid_id_map.keys())|set(test_id_map.keys()))

In [10]:
all_id_map = {}
for k,v in train_id_map.items():
    all_id_map[k] = v
for k,v in valid_id_map.items():
    all_id_map[k] = v
for k,v in test_id_map.items():
    all_id_map[k] = v

In [11]:
len(all_id_map)

60254

In [17]:
tp_map = pd.read_table('./idmapping_2025_06_13.tsv')
tp_map.head()

Unnamed: 0,From,Entry,Reviewed,Entry Name,Protein names
0,OTU4_ARATH,Q8LBZ4,reviewed,OTU4_ARATH,OVARIAN TUMOR DOMAIN-containing deubiquitinati...
1,CIB2_DROME,Q9W2Q5,reviewed,CIB2_DROME,Calcium and integrin-binding family member 2
2,DEKP3_ARATH,Q9SUA1,reviewed,DEKP3_ARATH,DEK domain-containing chromatin-associated pro...
3,RTEL1_MOUSE,Q0VGM9,reviewed,RTEL1_MOUSE,Regulator of telomere elongation helicase 1 (E...
4,ARM10_HUMAN,Q8N2F6,reviewed,ARM10_HUMAN,Armadillo repeat-containing protein 10 (Splici...


In [18]:
for idx, row in tp_map.iterrows():
    all_id_map[row['Entry Name']] = row['Entry']

In [None]:
pdb_points_info = {}
pdb_seq_info = {}
unseen_proteins = set()

for protein in tqdm(pid_list):
    uni_id = all_id_map[protein]
    pdb_file = f"./AF2DB/AF-{uni_id}-F1-model_v4.pdb.gz"
    if not os.path.exists(pdb_file):
        unseen_proteins.add(protein)
        continue
    
    sequence, coords_list = extract_single_chain(pdb_file, 'A')
    
    if sequence and coords_list:
        valid_coords = [coord for coord in coords_list if coord is not None]
        valid_sequence = ''.join([sequence[i] for i, coord in enumerate(coords_list) if coord is not None])
        
        if len(valid_coords)==0 or len(valid_sequence)==0:
            print(f"Empty {protein}, {uni_id}")
        
        pdb_points_info[protein] = valid_coords
        pdb_seq_info[protein] = valid_sequence

  0%|          | 0/59966 [00:00<?, ?it/s]

In [None]:
save_pkl('./processed_file/pdb_points.pkl', pdb_points_info)
save_pkl('./processed_file/pdb_seqs.pkl', pdb_seq_info)
save_pkl('./processed_file/unseen_proteins.pkl', unseen_proteins)

## 4.generate ESM - PDB

In [None]:
'''
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Please first run "process_esm.py" file to generate the esm data.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
'''

In [4]:
all_protein_list = []
for tag in tags:
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        all_protein_list+=pid_list
all_protein_list = list(set(all_protein_list))

In [5]:
pdb_points_info = read_pkl('./processed_file/pdb_points.pkl')
pdb_seqs = read_pkl('./processed_file/pdb_seqs.pkl')

In [6]:
assert len(set(pdb_points_info.keys())&set(all_protein_list))==len(all_protein_list)
assert len(set(pdb_seqs.keys())&set(all_protein_list))==len(all_protein_list)

In [7]:
import dgl

In [8]:
import math
import torch

In [9]:
def get_dis(point1, point2):
    dis_x = point1[0] - point2[0]
    dis_y = point1[1] - point2[1]
    dis_z = point1[2] - point2[2]
    return math.sqrt(dis_x*dis_x + dis_y*dis_y + dis_z*dis_z)

def process_input_pdb_file(tag, part, pid_list, pdb_points_info, pdb_seqs, thresholds=12):
    protein_map = read_pkl('./processed_file/protein_map.pkl')
    pdb_graphs = []
    p_cnt = 0
    file_idx = 0
    for pid in tqdm(pid_list):
        p_cnt += 1
        points = pdb_points_info[pid]
        
        u_list = []
        v_list = []
        dis_list = []
        for uid, amino_1 in enumerate(points):
            for vid, amino_2 in enumerate(points):
                if uid==vid:
                    continue
                dist = get_dis(amino_1, amino_2)
                if dist<=thresholds:
                    u_list.append(uid)
                    v_list.append(vid)
                    dis_list.append(dist)
        u_list, v_list = torch.tensor(u_list), torch.tensor(v_list)
        dis_list = torch.tensor(dis_list)

        graph = dgl.graph((u_list, v_list), num_nodes=len(points))
        graph.edata['dis'] = dis_list

        # graph node feature - esm
        esm_file_idx = protein_map[pid]
        esm_features = read_pkl(f"./processed_file/esm_emds/esm_part_{esm_file_idx}.pkl")
        node_features = esm_features[pid]
        assert node_features.shape[0]==graph.num_nodes()
        graph.ndata['x'] = torch.from_numpy(node_features)
        pdb_graphs.append(graph)

        if p_cnt%5000==0:
            save_pkl('./processed_file/graph_features/{}_{}_whole_pdb_part{}.pkl'.format(tag, part, file_idx), pdb_graphs)
            p_cnt = 0
            file_idx += 1
            pdb_graphs = []
    if len(pdb_graphs)>0:
        save_pkl('./processed_file/graph_features/{}_{}_whole_pdb_part{}.pkl'.format(tag, part, file_idx), pdb_graphs)
    return file_idx

In [10]:
for tag in tags:
    if tag=='mf':
        continue
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        max_cnt = process_input_pdb_file(tag, tp, pid_list, pdb_points_info, pdb_seqs)
        if tp=='train':
            print(f"{tag}-{tp}-train_file_count-{max_cnt}")

  0%|          | 0/41119 [00:00<?, ?it/s]

cc-train-train_file_count-8


  0%|          | 0/618 [00:00<?, ?it/s]

  0%|          | 0/990 [00:00<?, ?it/s]

  0%|          | 0/46642 [00:00<?, ?it/s]

bp-train-train_file_count-9


  0%|          | 0/707 [00:00<?, ?it/s]

  0%|          | 0/1280 [00:00<?, ?it/s]

## 5. generate Interpro

In [4]:
interpro_list = read_pkl('./data_dpfunc/interpro_list_26203.pkl')

In [6]:
len(interpro_list)

26203

In [7]:
inter_idx = {}
for idx, ipr in enumerate(interpro_list):
    inter_idx[ipr] = idx
save_pkl('./processed_file/inter_idx.pkl', inter_idx)

In [8]:
all_protein_list = []
for tag in tags:
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        all_protein_list+=pid_list
all_protein_list = list(set(all_protein_list))
len(all_protein_list)

59350

In [9]:
all_protein_interpro = read_pkl('./data_dpfunc/all_protein_interpros.pkl')

In [12]:
for pr in all_protein_list:
    inters = all_protein_interpro[pr]
    inter_matrix = np.zeros(len(interpro_list))
    for it in inters:
        inter_matrix[inter_idx[it]] += 1
    save_pkl(f"./processed_file/interpro/{pr}.pkl", inter_matrix)

## 6. Check configures

In [None]:
'''
name: mf
mlb: ./mlb/mf_go.mlb
results: ./results

base:
  interpro_whole: ./processed_file/interpro/{}.pkl

train:
  name: train
  pid_list_file: ./processed_file/mf_train_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_train_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_train_whole_pdb_part{}.pkl
  train_file_count: 7
  interpro_file: ./processed_file/mf_train_interpro.pkl

valid:
  name: valid
  pid_list_file: ./processed_file/mf_test1_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_test1_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_test1_whole_pdb_part0.pkl
  interpro_file: ./processed_file/mf_test1_interpro.pkl
  
test:
  name: test
  pid_list_file: ./processed_file/mf_test2_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_test2_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_test2_whole_pdb_part0.pkl
  interpro_file: ./processed_file/mf_test2_interpro.pkl
'''

In [None]:
'''
Run DPFunc_main.py / DPFunc_pred.py if you need:
python DPFunc_main.py -d mf -n 0 -e 15 -p DPFunc
python DPFunc_pred.py -d mf -n 0 -p DPFunc
'''