In [1]:
import torch
from sklearn.model_selection import train_test_split
import os
import pandas as pd
import numpy as np
import joblib
pd.set_option('display.max_columns', None)
def convert_to_float64(arr):
    # Convert list or array to a NumPy array of type float64
    return np.array(arr, dtype=np.float64)

In [13]:
path = "../data/km_data_with_features.joblib"
df = joblib.load(path).reset_index(drop=True) 
df.head()

Unnamed: 0,ECNumber,organism,Substrate,Sequence,Smiles,value,type,source,metabolite_features,protein_features,label
0,1.1.1.1,Haloferax volcanii,NADH,MRAAVLREHGEPLDVTEVPDPTCDADGVVVEVEACGICRSDWHSWM...,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)(O)OP(=O)(O)O...,0.071,wild,custom,"[0.24308163, 0.20997117, -0.3391949, -0.368485...","[0.063692555, 0.0023747068, 0.16711928, -0.258...",-1.148742
1,1.1.1.1,Geobacillus stearothermophilus,NAD+,MKAAVVEQFKEPLKIKEVEKPTISYGEVLVRIKACGVCHTDLHAAH...,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)(O)OP(=O)...,0.2,mutant,custom,"[0.24613404, 0.28360435, -0.31403223, -0.48450...","[-0.052756682, -0.07354015, 0.28844666, 0.1124...",-0.69897
2,1.1.1.1,Picrophilus torridus,ethanol,MRAIVLERFGIENIKIEDIDDESPGIPVKITMAGLNPVDYSTVNGN...,CCO,0.056,wild,custom,"[1.1508174, 2.0531795, -0.729387, -2.5030358, ...","[-0.0042622425, -0.09458677, 0.15183483, 0.133...",-1.251812
3,1.1.1.1,Saccharolobus solfataricus,ethanol,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,CCO,66.0,mutant,custom,"[1.1508174, 2.0531795, -0.729387, -2.5030358, ...","[-0.03386792, -0.06237053, 0.24611998, 0.03784...",1.819544
4,1.1.1.1,Saccharolobus solfataricus,1-Pentanol,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,CCCCCO,0.26,mutant,custom,"[0.99664927, 2.2531178, -0.6534443, -2.62829, ...","[-0.030971242, -0.06459375, 0.24837981, 0.0344...",-0.585027


In [14]:
df.rename(columns={'ECNumber':'ID', 'label':'regression_label', 'Sequence':'Protein', 'Smiles':'Ligand'}, inplace=True)
cols = ['ID',  'Ligand', 'Protein', 'metabolite_features', 'protein_features', 'regression_label']
df = df[cols]


In [15]:
df.head()

Unnamed: 0,ID,Ligand,Protein,metabolite_features,protein_features,regression_label
0,1.1.1.1,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)(O)OP(=O)(O)O...,MRAAVLREHGEPLDVTEVPDPTCDADGVVVEVEACGICRSDWHSWM...,"[0.24308163, 0.20997117, -0.3391949, -0.368485...","[0.063692555, 0.0023747068, 0.16711928, -0.258...",-1.148742
1,1.1.1.1,NC(=O)c1ccc[n+]([C@@H]2O[C@H](COP(=O)(O)OP(=O)...,MKAAVVEQFKEPLKIKEVEKPTISYGEVLVRIKACGVCHTDLHAAH...,"[0.24613404, 0.28360435, -0.31403223, -0.48450...","[-0.052756682, -0.07354015, 0.28844666, 0.1124...",-0.69897
2,1.1.1.1,CCO,MRAIVLERFGIENIKIEDIDDESPGIPVKITMAGLNPVDYSTVNGN...,"[1.1508174, 2.0531795, -0.729387, -2.5030358, ...","[-0.0042622425, -0.09458677, 0.15183483, 0.133...",-1.251812
3,1.1.1.1,CCO,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,"[1.1508174, 2.0531795, -0.729387, -2.5030358, ...","[-0.03386792, -0.06237053, 0.24611998, 0.03784...",1.819544
4,1.1.1.1,CCCCCO,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,"[0.99664927, 2.2531178, -0.6534443, -2.62829, ...","[-0.030971242, -0.06459375, 0.24837981, 0.0344...",-0.585027


In [20]:
# Identify drug and protein columns (handle both naming conventions)
drug_col = 'Ligand'
protein_col = 'Protein'

print(f"\nUsing drug column: {drug_col}")
print(f"Using protein column: {protein_col}")



Using drug column: Ligand
Using protein column: Protein


In [21]:
# Function to perform cold split
def cold_split(unique_items, test_size=0.2, val_size=0.1, random_state=42):
    """
    Perform cold split on unique items.
    Returns train, val, test items.
    """
    train_items, test_items = train_test_split(
        unique_items, 
        test_size=test_size, 
        random_state=random_state
    )
    # Adjust val_size to account for test_size
    val_ratio = val_size / (1 - test_size)
    print(f"val_ratio: {val_ratio}")
    train_items, val_items = train_test_split(
        train_items, 
        test_size=val_ratio, 
        random_state=random_state
    )
    return train_items, val_items, test_items



In [22]:
df.columns

Index(['ID', 'Ligand', 'Protein', 'metabolite_features', 'protein_features',
       'regression_label'],
      dtype='object')

In [23]:
# Get unique drugs and proteins
unique_drugs = df[drug_col].unique()
unique_proteins = df[protein_col].unique()

print(f"Total samples: {len(df)}")
print(f"Unique drugs: {len(unique_drugs)}")
print(f"Unique proteins: {len(unique_proteins)}")

Total samples: 29321
Unique drugs: 5247
Unique proteins: 14058


In [24]:
train_drugs, val_drugs, test_drugs = cold_split(unique_drugs)
            
train_df = df[df[drug_col].isin(train_drugs)].copy()
val_df = df[df[drug_col].isin(val_drugs)].copy()
test_df = df[df[drug_col].isin(test_drugs)].copy()

# Verify no overlap
train_drugs_set = set(train_drugs)
val_drugs_set = set(val_drugs)
test_drugs_set = set(test_drugs)
assert len(train_drugs_set & val_drugs_set) == 0, "Overlap between train and val drugs!"
assert len(train_drugs_set & test_drugs_set) == 0, "Overlap between train and test drugs!"
assert len(val_drugs_set & test_drugs_set) == 0, "Overlap between val and test drugs!"
            

val_ratio: 0.125


In [25]:
# Reset indices
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# Print split statistics
print(f"    Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"    Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"    Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
        

    Train: 21374 samples (72.9%)
    Val: 2718 samples (9.3%)
    Test: 5229 samples (17.8%)


In [28]:
path = "../output/data/km_data_with_features_train.csv"
train_df.head(100).to_csv(path, index=False)

In [29]:
path = "../output/data/km_data_with_features_val.csv"
val_df.head(100).to_csv(path, index=False)
path = "../output/data/km_data_with_features_test.csv"
test_df.head(100).to_csv(path, index=False)

In [None]:

df = catpred_km.copy()

# Define embedding columns
embedding_cols = ['ESMv1_embedding',
                  'ESM2_embedding',
                  'MUTAPLM_embedding',
                  'ProteinCLIP_embedding']

# Identify drug and protein columns (handle both naming conventions)
if 'Smiles' in df.columns:
    drug_col = 'Smiles'
elif 'Drug' in df.columns:
    drug_col = 'Drug'
else:
    raise ValueError("Could not find drug column. Expected 'Smiles' or 'Drug'")

if 'Sequence' in df.columns:
    protein_col = 'Sequence'
elif 'Target' in df.columns:
    protein_col = 'Target'
else:
    raise ValueError("Could not find protein column. Expected 'Sequence' or 'Target'")

print(f"\nUsing drug column: {drug_col}")
print(f"Using protein column: {protein_col}")

# Define split types and dataset splits
split_types = ['cold_drug', 'cold_protein', 'random']
dataset_splits = ['train', 'test', 'val']

# Function to perform cold split
def cold_split(unique_items, test_size=0.2, val_size=0.1, random_state=42):
    """
    Perform cold split on unique items.
    Returns train, val, test items.
    """
    train_items, test_items = train_test_split(
        unique_items, 
        test_size=test_size, 
        random_state=random_state
    )
    # Adjust val_size to account for test_size
    val_ratio = val_size / (1 - test_size)
    train_items, val_items = train_test_split(
        train_items, 
        test_size=val_ratio, 
        random_state=random_state
    )
    return train_items, val_items, test_items


# Process each embedding type
for ebd_col in embedding_cols:
    print(f"\n{'='*60}")
    print(f"Processing {ebd_col}")
    print(f"{'='*60}")
    
    # Create a copy of the dataframe for this embedding type
    df['protein_features'] = df[ebd_col]

    # Get unique drugs and proteins
    unique_drugs = df[drug_col].unique()
    unique_proteins = df[protein_col].unique()
    
    print(f"Total samples: {len(df)}")
    print(f"Unique drugs: {len(unique_drugs)}")
    print(f"Unique proteins: {len(unique_proteins)}")
    
    # Process each split type
    for split_type in split_types:
        print(f"\n  Processing {split_type} split...")
        
        if split_type == 'cold_drug':
            # Cold drug split: split by unique drugs
            train_drugs, val_drugs, test_drugs = cold_split(unique_drugs)
            
            train_df = df[df[drug_col].isin(train_drugs)].copy()
            val_df = df[df[drug_col].isin(val_drugs)].copy()
            test_df = df[df[drug_col].isin(test_drugs)].copy()
            
            # Verify no overlap
            train_drugs_set = set(train_drugs)
            val_drugs_set = set(val_drugs)
            test_drugs_set = set(test_drugs)
            assert len(train_drugs_set & val_drugs_set) == 0, "Overlap between train and val drugs!"
            assert len(train_drugs_set & test_drugs_set) == 0, "Overlap between train and test drugs!"
            assert len(val_drugs_set & test_drugs_set) == 0, "Overlap between val and test drugs!"
            
        elif split_type == 'cold_protein':
            # Cold protein split: split by unique proteins
            train_proteins, val_proteins, test_proteins = cold_split(unique_proteins)
            
            train_df = df[df[protein_col].isin(train_proteins)].copy()
            val_df = df[df[protein_col].isin(val_proteins)].copy()
            test_df = df[df[protein_col].isin(test_proteins)].copy()
            
            # Verify no overlap
            train_proteins_set = set(train_proteins)
            val_proteins_set = set(val_proteins)
            test_proteins_set = set(test_proteins)
            assert len(train_proteins_set & val_proteins_set) == 0, "Overlap between train and val proteins!"
            assert len(train_proteins_set & test_proteins_set) == 0, "Overlap between train and test proteins!"
            assert len(val_proteins_set & test_proteins_set) == 0, "Overlap between val and test proteins!"
            
        elif split_type == 'random':
            # Random split: split randomly without considering drugs/proteins
            train_df, temp_df = train_test_split(
                df, 
                test_size=0.2, 
                random_state=42
            )
            val_df, test_df = train_test_split(
                temp_df, 
                test_size=0.5,  # 0.5 of 0.2 = 0.1 total
                random_state=42
            )
            train_df = train_df.copy()
            val_df = val_df.copy()
            test_df = test_df.copy()
        
        # Reset indices
        train_df = train_df.reset_index(drop=True)
        val_df = val_df.reset_index(drop=True)
        test_df = test_df.reset_index(drop=True)
        
        # Print split statistics
        print(f"    Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
        print(f"    Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
        print(f"    Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
        
        if split_type in ['cold_drug', 'cold_protein']:
            if split_type == 'cold_drug':
                print(f"    Train drugs: {len(train_drugs)}, Val drugs: {len(val_drugs)}, Test drugs: {len(test_drugs)}")
            else:
                print(f"    Train proteins: {len(train_proteins)}, Val proteins: {len(val_proteins)}, Test proteins: {len(test_proteins)}")
        
        # Create output directory with naming pattern: embedding_split_xx
        ebd_prefix = ebd_col.replace('_Embedding', '')
        folder_name = f"catpred_km_{ebd_prefix}_{split_type}"
        output_base = Path('/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/enzyme_embeddings_dataset')
        output_dir = output_base / folder_name
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Save splits
        train_df.to_parquet(output_dir / 'train.parquet')
        val_df.to_parquet(output_dir / 'val.parquet')
        test_df.to_parquet(output_dir / 'test.parquet')
        
        print(f"    Saved to: {output_dir}")
        
print("\n" + "="*60)
print("All splits generated successfully!")
print("="*60)

In [2]:
embeddings = torch.load('/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/embeddings_dataset/20251203_data_km_with_features_and_embeddings.pt')
embeddings.head()

Unnamed: 0,ECNumber,Type,Organism,Substrate,Kcat,Km,pH,Temp,Smiles,Sequence,log10_Kcat,log10_Km,Kcat_Unit,Km_Unit,metabolite_features,protein_features,ESMv1_embedding,ESM2_embedding,MUTAPLM_embedding,ProteinCLIP_embedding
0,1.1.99.31,mutant,Pseudomonas putida,2-Hydroxy-3-butynoate,14.8,4.3,7.5,20.0,C#CC(C(=O)[O-])O,MSQNLFNVEDYRKLRQKRLPKMVYDYLEGGAEDEYGVKHNRDVFQQ...,1.170262,0.633468,s^(-1),mM,"[0.1228977, 0.7496512, -0.24066311, -0.4157082...","[-0.016876074, -0.05564796, 0.076313846, 0.080...","[0.029985197, 0.09301364, -0.0493219, 0.039276...","[0.037578337, -0.08526868, -0.006642793, -0.01...","[0.36402643, 0.0801322, 0.040295407, 1.8510873...","[-0.04857145, -0.0003646736, 0.005266936, 0.00..."
1,1.1.99.31,wildtype,Pseudomonas putida,2-Hydroxy-3-butynoate,3.9,22.0,7.5,20.0,C#CC(C(=O)[O-])O,MSQNLFNVEDYRKLRQKRLPKMVYDYLEGGAEDEYGVKHNRDVFQQ...,0.591065,1.342423,s^(-1),mM,"[0.1228977, 0.7496512, -0.24066311, -0.4157082...","[-0.016387729, -0.055583365, 0.07538496, 0.082...","[0.030519469, 0.09706152, -0.05022144, 0.03885...","[0.036715094, -0.08720552, -0.0068982467, -0.0...","[0.43954754, 0.12984525, -0.09636725, 1.915048...","[-0.04991224, 0.0006942024, 0.005273169, 0.003..."
2,1.4.3.2,wildtype,Crotalus atrox,L-propargylglycine,7.75,11.78,7.5,28.0,C#CCC(C(=O)O)N,MNVFFMFSLLFLAALGSCAHDRNPLEECFRETDYEEFLEIAKNGLT...,0.889302,1.071145,s^(-1),mM,"[0.08531234, 0.018572615, 0.079737425, -0.7695...","[-0.0658514, -0.14893617, 0.073952675, 0.22143...","[0.061326716, 0.10479735, -0.117684826, 0.0421...","[0.005299273, -0.0010562497, -0.015425095, 0.0...","[-1.3927803, 2.2993002, -0.4493971, -0.2907000...","[0.16714047, 0.07128104, -0.14348145, -0.12102..."
3,1.1.1.169,mutant,Escherichia coli,Propargylamine,,65.0,7.5,25.0,C#CCN,MKITVLGCGALGQLWLTALCKQGHEVQGWLRVPQPYCSVNLVETDG...,,1.812913,,mM,"[-0.061515212, 0.5467514, 0.27831262, -1.00274...","[-0.013440645, -0.18082525, 0.061067406, -0.07...","[-0.069296174, 0.09819967, -0.13210024, 0.0269...","[0.022534946, -0.13141964, 0.021545902, 0.0157...","[-0.21310124, -0.6200448, 0.31904882, 1.065778...","[0.067515224, -0.052900158, -0.14412864, -0.12..."
4,1.6.5.2,wildtype,Kluyveromyces marxianus,"2,3-Dichloro-5,6-dicyano-1,4-benzoquinone",8.666667,298.2,6.5,25.0,C(#N)C1=C(C(=O)C(=C(C1=O)Cl)Cl)C#N,MSSFLSKRFISTTQRAMSQLPKAKSLIYSSHDQDVSKILKVHTYQP...,0.937852,2.474508,s^(-1),mM,"[0.30274612, 0.08312379, 0.14933385, -0.178018...","[-0.066582374, 0.018061409, 0.21740742, 0.0066...","[-0.10812851, 0.102599956, -0.08162622, -0.101...","[0.012352087, -0.05722501, -0.03693495, 0.0513...","[-1.1678288, 0.4664118, 0.15785684, 3.8421142,...","[0.027986268, 0.06020311, 0.059641764, -0.0419..."


In [14]:
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from pathlib import Path
import os
import glob
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

embedding_files = sorted(glob.glob('/Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/*.pt'))
embeddings = []
for file in embedding_files:
    dataset_name = file.split('/')[-1].split('_')[2]
    label_name = file.split('/')[-1].split('_')[3]
    data = torch.load(file)
    embeddings.append(data)
    print(dataset_name, label_name, file, list(data.columns))
embeddings = pd.concat(embeddings)

EITLEM kcat /Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/20251204_dataset_EITLEM_kcat_kcat_data_with_features_and_embeddings.pt ['Sequence', 'ESMv1_embedding', 'ESM2_embedding', 'MUTAPLM_embedding', 'ProteinCLIP_embedding']
EITLEM kkm /Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/20251204_dataset_EITLEM_kkm_kkm_data_with_features_and_embeddings.pt ['Sequence', 'ESMv1_embedding', 'ESM2_embedding', 'MUTAPLM_embedding', 'ProteinCLIP_embedding']
EITLEM km /Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/20251204_dataset_EITLEM_km_km_data_with_features_and_embeddings.pt ['Sequence', 'ESMv1_embedding', 'ESM2_embedding', 'MUTAPLM_embedding', 'ProteinCLIP_embedding']
MPEK kcat /Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/20251204_dataset_MPEK_kcat_data_kcat_with_features_and_embeddings.pt ['Sequence', 'ESMv1_embedding', 'ESM2_embedding', 'MUTAPLM_embedding', 'ProteinCLIP_embedding']
MPEK kcat /Users/cheng.wang/Documents/mpi-web/enzyme_activity_data/20

In [16]:
embeddings.head()

Unnamed: 0,Sequence,ESMv1_embedding,ESM2_embedding,MUTAPLM_embedding,ProteinCLIP_embedding
0,MRAIVLERFGIENIKIEDIDDESPGIPVKITMAGLNPVDYSTVNGN...,"[0.04086296, 0.15116347, -0.07425522, 0.174843...","[0.03968187, -0.008488885, -0.0144894235, 0.08...","[0.020446844, 0.3811567, -2.2385237, 1.6300449...","[-0.040493883, 0.06432688, -0.0041583185, -0.0..."
1,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,"[-0.10295879, 0.20300558, 0.06610628, 0.066524...","[0.041890863, -0.053846564, -0.0022026624, 0.0...","[-0.13251053, 0.7809831, -1.9364418, 1.4279354...","[0.00016483433, 0.15033893, -0.05325313, -0.08..."
2,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,"[-0.11003746, 0.19600387, 0.05785173, 0.071878...","[0.04101067, -0.055868216, -0.004563694, 0.036...","[-0.1467754, 0.47730213, -1.7648451, 1.7315975...","[0.004419966, 0.14977796, -0.051655266, -0.080..."
3,MRAVRLVEIGKPLSLQEIGVPKPKGPQVLIKVEAAGVCHSDVHMRQ...,"[-0.09918236, 0.20210914, 0.061877523, 0.07223...","[0.044095144, -0.05485441, -0.0029434075, 0.03...","[-0.055095762, 0.68348485, -1.7683418, 1.44411...","[0.00045073644, 0.14759801, -0.05338094, -0.07..."
4,MSVARIVFPPLSHVGWGALDQLVPEVKRLGAKHILVITDPMLVKIG...,"[0.023023674, 0.049577378, 0.015320542, 0.0299...","[0.029797172, -0.11233459, 0.019906148, 0.0228...","[0.017170891, -0.23267624, -1.2182617, 1.85922...","[0.06639592, 0.0630165, -0.009678955, -0.08530..."


In [19]:
embeddings_de = embeddings.drop_duplicates(subset=["Sequence"]).reset_index(drop=True)
embeddings_de.shape

(26240, 5)

In [42]:
joblib_files = sorted(glob.glob('/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/*/A01_dataset/*with_features.joblib'))
joblib_files

['/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_EITLEM_kcat/A01_dataset/kcat_data_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_EITLEM_kkm/A01_dataset/kkm_data_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_EITLEM_km/A01_dataset/km_data_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_MPEK_kcat/A01_dataset/data_kcat_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_MPEK_kcat/A01_dataset/data_km_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/experiments/dataset_MPEK_km/A01_dataset/data_km_with_features.joblib',
 '/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhou

In [72]:
EITLEM_kcat = joblib.load(joblib_files[0]).reset_index(drop=True)   
EITLEM_kkm = joblib.load(joblib_files[1]).reset_index(drop=True)
EITLEM_km = joblib.load(joblib_files[2]).reset_index(drop=True)
MPEK_kcat = joblib.load(joblib_files[3]).reset_index(drop=True)
MPEK_km = joblib.load(joblib_files[5]).reset_index(drop=True)
catpred_kcat = joblib.load(joblib_files[6]).reset_index(drop=True)
catpred_ki = joblib.load(joblib_files[7]).reset_index(drop=True)
catpred_km = joblib.load(joblib_files[8]).reset_index(drop=True)

In [73]:
import numpy as np
import pandas as pd

def add_id(prefix, df):
    """为每行生成全局唯一 ID，格式 prefix_000001"""
    df = df.copy()
    df["ID"] = [f"{prefix}_{i:06d}" for i in range(len(df))]
    return df

def add_label_column(df, col):
    df = df.copy()
    eps = 1e-9

    # 1) 强制转换为数值，非法的变成 NaN
    df[col] = pd.to_numeric(df[col], errors="coerce")

    # 2) 把负数裁成 0（理论上你这些指标不应该有负值）
    df[col] = df[col].clip(lower=0)

    # 3) 做 log10 变换
    df["Label"] = np.log10(df[col].to_numpy(dtype=float) + eps)

    return df

In [74]:
EITLEM_kcat = add_id('EITLEM_kcat', EITLEM_kcat)
EITLEM_kcat = add_label_column(EITLEM_kcat, 'value')
EITLEM_kcat = pd.merge(EITLEM_kcat, embeddings_de, on="Sequence", how="left")
EITLEM_kkm = add_id('EITLEM_kkm', EITLEM_kkm)
EITLEM_kkm = add_label_column(EITLEM_kkm, 'value')
EITLEM_kkm = pd.merge(EITLEM_kkm, embeddings_de, on="Sequence", how="left")
EITLEM_km = add_id('EITLEM_km', EITLEM_km)
EITLEM_km = add_label_column(EITLEM_km, 'value')
EITLEM_km = pd.merge(EITLEM_km, embeddings_de, on="Sequence", how="left")

In [75]:
MPEK_km = add_id('MPEK_km', MPEK_km)
MPEK_km = add_label_column(MPEK_km, 'Km')
MPEK_km = pd.merge(MPEK_km, embeddings_de, on="Sequence", how="left")
MPEK_kcat = add_id('MPEK_kcat', MPEK_kcat)
MPEK_kcat = add_label_column(MPEK_kcat, 'Kcat')
MPEK_kcat = pd.merge(MPEK_kcat, embeddings_de, on="Sequence", how="left")

In [76]:
catpred_kcat = add_id('catpred_kcat', catpred_kcat)
catpred_kcat = add_label_column(catpred_kcat, 'value')
catpred_kcat = pd.merge(catpred_kcat, embeddings_de, on="Sequence", how="left")
catpred_ki = add_id('catpred_ki', catpred_ki)
catpred_ki = add_label_column(catpred_ki, 'value')
catpred_ki = pd.merge(catpred_ki, embeddings_de, on="Sequence", how="left")
catpred_km = add_id('catpred_km', catpred_km)
catpred_km = add_label_column(catpred_km, 'value')
catpred_km = pd.merge(catpred_km, embeddings_de, on="Sequence", how="left")

In [78]:
print(EITLEM_kcat.isna().sum())
print(EITLEM_kkm.isna().sum())
print(EITLEM_km.isna().sum())

ECNumber                 0
organism                 0
Substrate                0
Sequence                 0
Smiles                   0
value                    0
type                     0
source                   0
metabolite_features      0
protein_features         0
label                    0
ID                       0
Label                    0
ESMv1_embedding          0
ESM2_embedding           0
MUTAPLM_embedding        0
ProteinCLIP_embedding    0
dtype: int64
ECNumber                 0
organism                 0
Substrate                0
Sequence                 0
Smiles                   0
value                    0
type                     0
source                   0
metabolite_features      0
protein_features         0
label                    0
ID                       0
Label                    0
ESMv1_embedding          0
ESM2_embedding           0
MUTAPLM_embedding        0
ProteinCLIP_embedding    0
dtype: int64
ECNumber                 0
organism                 0
Su

In [85]:
print(MPEK_kcat.isna().sum())
print(MPEK_km.isna().sum())

original_index              0
ECNumber                    0
Type                        0
Organism                    0
Substrate                   0
Kcat                        0
Km                       3656
pH                          0
Temp                        0
Smiles                      0
Sequence                    0
log10_Kcat                  0
log10_Km                 3656
Kcat_Unit                   0
Km_Unit                  3656
metabolite_features         0
protein_features            0
ID                          0
Label                       0
ESMv1_embedding             0
ESM2_embedding              0
MUTAPLM_embedding           0
ProteinCLIP_embedding       0
dtype: int64
ECNumber                     0
Type                         0
Organism                     0
Substrate                    0
Kcat                     10348
Km                           0
pH                           0
Temp                         0
Smiles                       0
Sequence          

In [None]:
print(catpred_kcat.isna().sum())
print(catpred_ki.isna().sum())
print(catpred_km.isna().sum())

Sequence                    0
sequence_source             0
uniprot                     0
Smiles                      0
value                       0
ec                          0
log10_value                 0
reactant_smiles             0
product_smiles           2335
log10kcat_max               0
metabolite_features         0
protein_features            0
ID                          0
Label                       0
ESMv1_embedding             0
ESM2_embedding              0
MUTAPLM_embedding           0
ProteinCLIP_embedding       0
dtype: int64


In [None]:
# dataset_na.to_csv('/Users/cheng.wang/Documents/mpi-web/inhouse_dataset_with_embeddings_dataset_na_1206.csv', index=False)

In [92]:
def check_cols(df,cols):
    for col in cols:
        if col not in df.columns:
            print(f"{col} not in {df.columns}")
    print(sum(df[cols].isna().sum()))
    
check_cols(EITLEM_kcat,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(EITLEM_km,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(EITLEM_kkm,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(MPEK_kcat,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(MPEK_km,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(catpred_kcat,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(catpred_ki,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])
check_cols(catpred_km,['ID','Label','Sequence','Smiles','metabolite_features','protein_features','ESMv1_embedding','ESM2_embedding','MUTAPLM_embedding','ProteinCLIP_embedding'])



0
0
0
0
0
0
0
0


In [101]:
# print(EITLEM_kcat.head(1))
# print(EITLEM_km.head(1))
# print(EITLEM_kkm.head(1))
# print(MPEK_kcat.head(1))
# print(MPEK_km.head(1))
# print(catpred_kcat.head(1))
# print(catpred_ki.head(1))
# print(catpred_km.head(1))

df = catpred_km.copy()

# Define embedding columns
embedding_cols = ['ESMv1_embedding',
                  'ESM2_embedding',
                  'MUTAPLM_embedding',
                  'ProteinCLIP_embedding']

# Identify drug and protein columns (handle both naming conventions)
if 'Smiles' in df.columns:
    drug_col = 'Smiles'
elif 'Drug' in df.columns:
    drug_col = 'Drug'
else:
    raise ValueError("Could not find drug column. Expected 'Smiles' or 'Drug'")

if 'Sequence' in df.columns:
    protein_col = 'Sequence'
elif 'Target' in df.columns:
    protein_col = 'Target'
else:
    raise ValueError("Could not find protein column. Expected 'Sequence' or 'Target'")

print(f"\nUsing drug column: {drug_col}")
print(f"Using protein column: {protein_col}")

# Define split types and dataset splits
split_types = ['cold_drug', 'cold_protein', 'random']
dataset_splits = ['train', 'test', 'val']

# Function to perform cold split
def cold_split(unique_items, test_size=0.2, val_size=0.1, random_state=42):
    """
    Perform cold split on unique items.
    Returns train, val, test items.
    """
    train_items, test_items = train_test_split(
        unique_items, 
        test_size=test_size, 
        random_state=random_state
    )
    # Adjust val_size to account for test_size
    val_ratio = val_size / (1 - test_size)
    train_items, val_items = train_test_split(
        train_items, 
        test_size=val_ratio, 
        random_state=random_state
    )
    return train_items, val_items, test_items


# Process each embedding type
for ebd_col in embedding_cols:
    print(f"\n{'='*60}")
    print(f"Processing {ebd_col}")
    print(f"{'='*60}")
    
    # Create a copy of the dataframe for this embedding type
    df['protein_features'] = df[ebd_col]

    # Get unique drugs and proteins
    unique_drugs = df[drug_col].unique()
    unique_proteins = df[protein_col].unique()
    
    print(f"Total samples: {len(df)}")
    print(f"Unique drugs: {len(unique_drugs)}")
    print(f"Unique proteins: {len(unique_proteins)}")
    
    # Process each split type
    for split_type in split_types:
        print(f"\n  Processing {split_type} split...")
        
        if split_type == 'cold_drug':
            # Cold drug split: split by unique drugs
            train_drugs, val_drugs, test_drugs = cold_split(unique_drugs)
            
            train_df = df[df[drug_col].isin(train_drugs)].copy()
            val_df = df[df[drug_col].isin(val_drugs)].copy()
            test_df = df[df[drug_col].isin(test_drugs)].copy()
            
            # Verify no overlap
            train_drugs_set = set(train_drugs)
            val_drugs_set = set(val_drugs)
            test_drugs_set = set(test_drugs)
            assert len(train_drugs_set & val_drugs_set) == 0, "Overlap between train and val drugs!"
            assert len(train_drugs_set & test_drugs_set) == 0, "Overlap between train and test drugs!"
            assert len(val_drugs_set & test_drugs_set) == 0, "Overlap between val and test drugs!"
            
        elif split_type == 'cold_protein':
            # Cold protein split: split by unique proteins
            train_proteins, val_proteins, test_proteins = cold_split(unique_proteins)
            
            train_df = df[df[protein_col].isin(train_proteins)].copy()
            val_df = df[df[protein_col].isin(val_proteins)].copy()
            test_df = df[df[protein_col].isin(test_proteins)].copy()
            
            # Verify no overlap
            train_proteins_set = set(train_proteins)
            val_proteins_set = set(val_proteins)
            test_proteins_set = set(test_proteins)
            assert len(train_proteins_set & val_proteins_set) == 0, "Overlap between train and val proteins!"
            assert len(train_proteins_set & test_proteins_set) == 0, "Overlap between train and test proteins!"
            assert len(val_proteins_set & test_proteins_set) == 0, "Overlap between val and test proteins!"
            
        elif split_type == 'random':
            # Random split: split randomly without considering drugs/proteins
            train_df, temp_df = train_test_split(
                df, 
                test_size=0.2, 
                random_state=42
            )
            val_df, test_df = train_test_split(
                temp_df, 
                test_size=0.5,  # 0.5 of 0.2 = 0.1 total
                random_state=42
            )
            train_df = train_df.copy()
            val_df = val_df.copy()
            test_df = test_df.copy()
        
        # Reset indices
        train_df = train_df.reset_index(drop=True)
        val_df = val_df.reset_index(drop=True)
        test_df = test_df.reset_index(drop=True)
        
        # Print split statistics
        print(f"    Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
        print(f"    Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
        print(f"    Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
        
        if split_type in ['cold_drug', 'cold_protein']:
            if split_type == 'cold_drug':
                print(f"    Train drugs: {len(train_drugs)}, Val drugs: {len(val_drugs)}, Test drugs: {len(test_drugs)}")
            else:
                print(f"    Train proteins: {len(train_proteins)}, Val proteins: {len(val_proteins)}, Test proteins: {len(test_proteins)}")
        
        # Create output directory with naming pattern: embedding_split_xx
        ebd_prefix = ebd_col.replace('_Embedding', '')
        folder_name = f"catpred_km_{ebd_prefix}_{split_type}"
        output_base = Path('/Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/enzyme_embeddings_dataset')
        output_dir = output_base / folder_name
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Save splits
        train_df.to_parquet(output_dir / 'train.parquet')
        val_df.to_parquet(output_dir / 'val.parquet')
        test_df.to_parquet(output_dir / 'test.parquet')
        
        print(f"    Saved to: {output_dir}")
        
print("\n" + "="*60)
print("All splits generated successfully!")
print("="*60)


Using drug column: Smiles
Using protein column: Sequence

Processing ESMv1_embedding
Total samples: 41174
Unique drugs: 10535
Unique proteins: 12355

  Processing cold_drug split...
    Train: 28489 samples (69.2%)
    Val: 4262 samples (10.4%)
    Test: 8423 samples (20.5%)
    Train drugs: 7374, Val drugs: 1054, Test drugs: 2107
    Saved to: /Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/enzyme_embeddings_dataset/catpred_km_ESMv1_embedding_cold_drug

  Processing cold_protein split...
    Train: 28449 samples (69.1%)
    Val: 4007 samples (9.7%)
    Test: 8718 samples (21.2%)
    Train proteins: 8648, Val proteins: 1236, Test proteins: 2471
    Saved to: /Users/cheng.wang/Documents/mpi-web/exp_of_catpred_MPEK_EITLEM_inhouse_dataset/enzyme_embeddings_dataset/catpred_km_ESMv1_embedding_cold_protein

  Processing random split...
    Train: 32939 samples (80.0%)
    Val: 4117 samples (10.0%)
    Test: 4118 samples (10.0%)
    Saved to: /Users/cheng.wang/

In [None]:
import numpy as np
import pandas as pd

def add_id(prefix, df):
    """为每行生成全局唯一 ID，格式 prefix_000001"""
    df = df.copy()
    df["ID"] = [f"{prefix}_{i:06d}" for i in range(len(df))]
    return df

def add_label_column(df, col):
    df = df.copy()
    eps = 1e-9

    # 1) 强制转换为数值，非法的变成 NaN
    df[col] = pd.to_numeric(df[col], errors="coerce")

    # 2) 把负数裁成 0（理论上你这些指标不应该有负值）
    df[col] = df[col].clip(lower=0)

    # 3) 做 log10 变换
    df["Label"] = np.log10(df[col].to_numpy(dtype=float) + eps)

    return df