In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

import torch
print("PyTorch:", torch.__version__)

from tqdm import tqdm
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

# from esm.models.esmc import ESMC
# from esm.sdk.api import ESMProtein, LogitsConfig

import torch
print(torch.version.cuda)   # shows the CUDA version PyTorch was built against
print(torch.backends.cudnn.version())  # cuDNN version
print(torch.cuda.is_available())
import training_utils.partitioning_utils as pat_utils

PyTorch: 2.9.1+cu128
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts
12.8
91002
True


In [3]:
path_to_mmseqs_clustering = "/work3/s232958/data/PPint_DB/3_å_dataset5_singlefasta/clusterRes40"
all_seqs, clust, clust_keys = pat_utils.mmseqs_parser(path_to_mmseqs_clustering)

path_to_interaction_df = "/work3/s232958/data/PPint_DB/disordered_interfaces_no_cutoff_filtered_nonredundant80_3å_5.csv.gz"
disordered_interfaces_df = pd.read_csv(path_to_interaction_df,index_col=0).reset_index(drop=True)
disordered_interfaces_df["PDB_chain_name"] = (disordered_interfaces_df["PDB"] + "_" + disordered_interfaces_df["chainname"]).tolist()
disordered_interfaces_df["index_num"] = np.arange(len(disordered_interfaces_df))
disordered_interfaces_df["chain_name_index"] = [row["PDB_chain_name"] + "_" + str(row["index_num"]) for index, row in disordered_interfaces_df.iterrows()]
disordered_interfaces_df = disordered_interfaces_df.set_index("PDB_interface_name")
disordered_interfaces_df["interface_residues"] = disordered_interfaces_df["interface_residues"].apply(lambda x: ast.literal_eval(x))
# disordered_interfaces_df["inter_chain_hamming"] = [1 - (Ldistance(seq.split("-")[0], seq.split("-")[1]))/np.max([len(seq.split("-")[0]), len(seq.split("-")[1])]) for seq in disordered_interfaces_df["protien_interface_sequences"]]
disordered_interfaces_df["dimer"] = disordered_interfaces_df["inter_chain_hamming"] > 0.60
disordered_interfaces_df["clust_keys"] = [clust_keys.get(row["chain_name_index"]) for index, row in disordered_interfaces_df.iterrows()] 
print(disordered_interfaces_df["clust_keys"].head())

pdb_interface_and_clust_keys = {index:disordered_interfaces_df.loc[index,"clust_keys"].values.tolist() for index in tqdm(disordered_interfaces_df.index.drop_duplicates(), total=len(disordered_interfaces_df)/2)}
new_clusters, new_clusters_clustkeys = pat_utils.recluster_mmseqs_keys_to_non_overlapping_groups(pdb_interface_and_clust_keys)

### Creating train and test datasets based on train and test-idexes
train_indexes, test_indexes = pat_utils.run_train_test_partition(interaction_df=disordered_interfaces_df,
                                                    clustering=new_clusters, # Clusters from Bidentate-graphs
                                                    train_ratio=0.8, 
                                                    test_ratio=0.2, 
                                                    v=True, 
                                                    seed=0)

disordered_interfaces_df["ID"] = [row["PDB"]+"_"+str(row["interface_index"])+"_"+row["chainname"] for __, row in disordered_interfaces_df.iterrows()]
disordered_interfaces_df["PDB_interface_name"] = disordered_interfaces_df.index

PDB_interface_name
6NZA_0    6NZA_B_1
6NZA_0    6NZA_B_1
9JKA_1    9JKA_C_3
9JKA_1    9JKA_C_3
8DQ6_1    8DQ6_B_4
Name: clust_keys, dtype: object


100%|██████████████████████████████████████████████████████████████████████████| 24725/24725.0 [00:31<00:00, 779.27it/s]
100%|█████████████████████████████████████████████████████████████████████████| 27834/27834 [00:00<00:00, 813985.61it/s]


0.8
0.2


In [3]:
print(f"Train indeces number: {len(train_indexes)}")
print(f"Test indeces number: {len(test_indexes)}")

Train indeces number: 19781
Test indeces number: 4944


In [4]:
# Creating new dataframe with pairs of proteins (PPints)
grouped = {}
for _, row in disordered_interfaces_df.iterrows():
    iface = row["PDB_interface_name"]
    seq = row["sequence"]
    rid = row["ID"]
    dimer = row["dimer"]
    
    if iface not in grouped:
        grouped[iface] = {
            "sequences": [],
            "IDs": [],
            "dimer": dimer,        # keep the dimer value for this interface
        }
    else:
        # Optional: sanity-check it's consistent per interface
        if grouped[iface]["dimer"] != dimer:
            print(f"Warning: multiple dimers for interface {iface}:",
                  grouped[iface]['dimer'], "vs", dimer)

    grouped[iface]["sequences"].append(seq)
    grouped[iface]["IDs"].append(rid)

records = []
for iface, vals in grouped.items():
    seqs = vals["sequences"]
    ids = vals["IDs"]
    if len(seqs) >= 2 and len(ids) >= 2:
        records.append({
            "interface_id": iface,
            "seq_target": seqs[0],
            "seq_binder": seqs[1],
            "ID1": ids[0],
            "ID2": ids[1],
            "dimer": vals["dimer"],   # <- add dimer to final record
        })

PPint_interactions_NEW = pd.DataFrame(records)
PPint_interactions_NEW["seq_target_len"] = [len(row.seq_target) for __, row in PPint_interactions_NEW.iterrows()]
PPint_interactions_NEW["seq_binder_len"] = [len(row.seq_binder) for __, row in PPint_interactions_NEW.iterrows()]
PPint_interactions_NEW["target_binder_id"] = PPint_interactions_NEW["ID1"] + "_" + PPint_interactions_NEW["ID2"]

PPint_interactions_NEW

Unnamed: 0,interface_id,seq_target,seq_binder,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
0,6NZA_0,MNTVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIH...,TVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIHAL...,6NZA_0_A,6NZA_0_B,True,461,459,6NZA_0_A_6NZA_0_B
1,9JKA_1,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,9JKA_1_B,9JKA_1_C,True,362,362,9JKA_1_B_9JKA_1_C
2,8DQ6_1,PTLNLFTNIPVDAVTCSDILKDATKAVAKIIGKPESYVMILLNSGV...,PTLNLFTNIPVDAVTCSDILKDATKAVAKIIGKPESYVMILLNSGV...,8DQ6_1_B,8DQ6_1_C,True,109,97,8DQ6_1_B_8DQ6_1_C
3,2YMZ_0,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,2YMZ_0_A,2YMZ_0_B,True,130,130,2YMZ_0_A_2YMZ_0_B
4,6IDB_0,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,6IDB_0_A,6IDB_0_B,False,317,172,6IDB_0_A_6IDB_0_B
...,...,...,...,...,...,...,...,...,...
24720,6O42_0,EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLE...,6O42_0_L,6O42_0_H,False,214,220,6O42_0_L_6O42_0_H
24721,2QIB_0,GVEERRQQLIGVALDLFSRRSPDEVSIDEIASAAGISRPLVYHYFP...,RRQQLIGVALDLFSRRSPDEVSIDEIASAAGISRPLVYHYFPGKLS...,2QIB_0_A,2QIB_0_B,True,217,204,2QIB_0_A_2QIB_0_B
24722,7OXG_0,MKVGQDKVVTIRYTLQVEGEVLDQGELSYLHGHRNLIPGLEEALEG...,TRYWNAKALPFAFG,7OXG_0_A,7OXG_0_C,False,99,14,7OXG_0_A_7OXG_0_C
24723,6UUM_0,EIVLTQSPVTLSLSSGETGTLSCRASQNISSSWIAWYQQRRGQVPR...,QVQLVQSGAEVRKPGSSVTISCKPVGGTFTNFAIHWVRQAPGQGLE...,6UUM_0_F,6UUM_0_H,False,215,227,6UUM_0_F_6UUM_0_H


In [7]:
# sample random 10%
random.seed(0)
train_indexes_sample = random.sample(train_indexes, int(len(train_indexes) * 0.1))
test_indexes_sample = random.sample(test_indexes, int(len(test_indexes) * 0.1))

Df_train = PPint_interactions_NEW[PPint_interactions_NEW.interface_id.isin(train_indexes_sample)]
Df_test = PPint_interactions_NEW[PPint_interactions_NEW.interface_id.isin(train_indexes_sample)]

In [10]:
Df_train.to_csv('/work3/s232958/data/PPint_DB/PPint_train.csv', index=False)
Df_test.to_csv('/work3/s232958/data/PPint_DB/PPint_test.csv', index=False)

In [11]:
TRAIN_interaction_df = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
TEST_interaction_df = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)
TRAIN_interaction_df

Unnamed: 0,seq_target,seq_binder,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
0,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,6IDB_0_A,6IDB_0_B,False,317,172,6IDB_0_A_6IDB_0_B
1,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,2WZP_3_D,2WZP_3_G,False,122,266,2WZP_3_D_2WZP_3_G
2,LYFQSNAKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLA...,AKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQKYI...,1ZKP_0_A,1ZKP_0_C,True,246,240,1ZKP_0_A_1ZKP_0_C
3,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,6GRH_3_C,6GRH_3_D,False,266,396,6GRH_3_C_6GRH_3_D
4,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,8R57_1_M,8R57_1_f,False,118,64,8R57_1_M_8R57_1_f
...,...,...,...,...,...,...,...,...
1973,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,4YO8_0_A,4YO8_0_B,True,238,242,4YO8_0_A_4YO8_0_B
1974,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,3CKI_0_A,3CKI_0_B,False,256,121,3CKI_0_A_3CKI_0_B
1975,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,7MHY_1_M,7MHY_1_N,False,118,109,7MHY_1_M_7MHY_1_N
1976,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,7MHY_2_O,7MHY_2_P,False,100,94,7MHY_2_O_7MHY_2_P


In [12]:
interaction_Dict = {}

for _, row in TRAIN_interaction_df.iterrows():
    key1 = row['ID1'].split("_")[0]+"_"+row['ID1'].split("_")[-1]
    key2 = row['ID2'].split("_")[0]+"_"+row['ID2'].split("_")[-1]
    seq1 = row['seq_target']
    seq2 = row['seq_binder']       
    interaction_Dict[key1] = seq1
    interaction_Dict[key2] = seq2

for _, row in TEST_interaction_df.iterrows():
    key1 = row['ID1'].split("_")[0]+"_"+row['ID1'].split("_")[-1]
    key2 = row['ID2'].split("_")[0]+"_"+row['ID2'].split("_")[-1]
    seq1 = row['seq_target']
    seq2 = row['seq_binder']       
    interaction_Dict[key1] = seq1
    interaction_Dict[key2] = seq2
    
list(interaction_Dict.items())[:2]

[('6IDB_A',
  'DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGKRTVDLGQCGLLGTITGPPQCDQFLEFSADLIIERREGSDVCYPGKFVNEEALRQILRESGGIDKEAMGFTYSGIRTNGATSSCRRSGSSFYAEMKWLLSNTDNAAFPQMTKSYKNTRKSPALIVWGIHHSVSTAEQTKLYGSGNKLVTVGSSNYQQSFVPSPGARTQVNGQSGRIDFHWLMLNPNDTVTFSFNGAFIAPDRASFLRGKSMGIQSGVQVDANCEGDCYHSGGTIISNLPFQNIDSRAVGKCPRYVKQRSLLLATGMKNVPEI'),
 ('6IDB_B',
  'GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAIDQITGKLNRLIEKTNQQFELIDNEFNEVEKQIGNVINWTRDSITEVWSYNAELLVAMENQHTIDLADSEMDKLYERVKRQLRENAEEDGTGCFEIFHKCDDDCMASIRNNTYDHSKYREEAMQNRIQ')]

In [None]:
### ESM-C embeddings calculated running cal_esmC.py dataset = "PPint"

# client = ESMC.from_pretrained("esmc_600m", device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# def calculate_ESM_pr_res_embeddings(sequence):
#     protein = ESMProtein(sequence=sequence)
#     protein_tensor = client.encode(protein)
#     logits_output = client.logits(
#     protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
#     )
#     return logits_output.embeddings.detach().cpu().numpy()

# path_to_output_embeddings = "/work3/s232958/data/PPint_DB/embeddings_esmC"
# os.makedirs(path_to_output_embeddings, exist_ok=True)

# # helper: convert torch tensor to numpy
# def to_numpy(x):
#     try:
#         return x.detach().cpu().numpy()
#     except AttributeError:
#         return np.asarray(x)
    
# for name, seq in tqdm(interaction_Dict.items(), total=len(interaction_Dict.items()), desc="Embedding targets"):
#     emb = calculate_ESM_pr_res_embeddings(seq)
#     emb_np = to_numpy(emb)
#     out_path = os.path.join(path_to_output_embeddings, f"{name}.npy")
#     np.save(out_path, emb_np)
#     print(f"Protein {name} embedded and saved to {out_path}")