## cd-hit and split datasets to 10 folds
### # http://cd-hit.org; http://www.bioinformatics.org/cd-hit/; https://github.com/weizhongli/cdhit/archive/V4.6.2.tar.gz;
### tar -zxvf cdhit-4.6.2.tar.gz; cd cdhit-4.6.2; make; export PATH=$PATH:/nfs/my/Xu/jicm/WWDynoMTGBM/share/cdhit-4.6.2; source ~/.bashrc;
### cd-hit -i ./data_process/cd-hit/sequences_mixed.fasta -o ./data_process/cd-hit/clustered_sequences.fasta -c 0.6 -n 3 -T 8

In [1]:
import string
from itertools import product
import pandas as pd


df_all = pd.read_pickle("./dataset/df_all_log_transformed.pkl")
input_sequences = df_all['sequence'].drop_duplicates()
print(f'Number of sequences: {len(input_sequences)}')
# 混合大小写字符集
characters = string.ascii_letters  # A-Z + a-z (52 个字符)
# 生成所有可能的 3 位字符组合
ids = [''.join(comb) for comb in product(characters, repeat=3)]

# 分配 ID
ids = ids[:len(input_sequences)]

# 定义生成 FASTA 文件的函数
def generate_fasta(filename):
    with open(filename, 'w') as f:
        for id, sequence in zip(ids, input_sequences):
            f.write(f">{id}\n")
            f.write(f"{sequence}\n")

# 保存为 fasta 文件
generate_fasta('./cd-hit/sequences_mixed.fasta')
print("FASTA is saved to ./cd-hit/sequences_mixed.fasta")

Number of sequences: 16659
FASTA is saved to ./cd-hit/sequences_mixed.fasta


In [2]:
import random
import pandas as pd

# 解析 .clstr 文件，获取每个聚类中的序列 ID
def parse_clstr(clstr_file):
    clusters = []
    with open(clstr_file, 'r') as f:
        cluster = []
        for line in f:
            if line.startswith('>Cluster'):
                if cluster:
                    clusters.append(cluster)
                cluster = []
            else:
                seq_id = line.split('>')[1].split('...')[0]
                cluster.append(seq_id)
        if cluster:
            clusters.append(cluster)
    print(f"Total number of clusters: {len(clusters)}")
    return clusters


clstr_file = "./cd-hit/clustered_sequences.fasta.clstr"
fasta_file = "./cd-hit/sequences_mixed.fasta"

# 加载 sequences_mixed.fasta，建立序列和 ID 的映射
with open(fasta_file, 'r') as f:
    records = f.readlines()
ids, sequences = [], []
for i in range(0, len(records), 2):  # 每两个行为一条记录
    ids.append(records[i].strip().lstrip('>'))  # 去掉 '>'
    sequences.append(records[i + 1].strip())
fasta_df = pd.DataFrame({'id': ids, 'sequence': sequences})
print(f"Total sequences in FASTA file: {len(fasta_df)}")

# 加载 df_all 文件
print(f"Total sequences in df_all: {len(df_all)}")

# 将序列与 ID 合并
df_all_index = df_all.index
df_all = pd.merge(df_all, fasta_df, on='sequence', how='left')
df_all.index = df_all_index
print(f"Total sequences after merging with IDs: {len(df_all)}")
df_all.head()

Total sequences in FASTA file: 16659
Total sequences in df_all: 46534
Total sequences after merging with IDs: 46534


Unnamed: 0,ec,organism,uniprot,substrate,smiles,sequence,type,ph,t,esm2,...,prott5,prost5,molebert,transsmiles,logkm,logkcat,logkcatkm,logp,mw,id
0,3.5.5.1,Saccharolobus solfataricus,P95896,trichloroacetonitrile,C(#N)C(Cl)(Cl)Cl,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,wild,7.4,70.0,"[0.07309095, -0.085310504, 0.03223636, -0.0094...",...,"[0.0620778725, 0.0053198822, 0.026737025, -0.0...","[0.0229359791, -0.0089250933, -0.0310355425000...","[-0.0235576797, -0.1898318082, -0.005378013, 0...","[-0.1102231815, -0.2566757202, 0.2837018669, 0...",-21.416413,-4.60517,,1.88018,144.388,aaa
1,1.21.99.4,Homo sapiens,,L-thyroxine,C1=C(C=C(C(=C1I)OC2=CC(=C(C(=C2)I)O)I)I)CC(C(=...,MGLPQPGLWLKRLWVLLEVAVHVVVGKVLLILFPDRVKRNILAMGE...,mutant,7.5,37.0,"[0.054237492, -0.04574185, 0.008709021, 0.0387...",...,"[0.0137849757, 0.0193469338, 0.0360138603, 0.0...","[-0.0300230943, -0.0246387329, -0.0323411487, ...","[0.11776212600000001, 0.24018001560000002, -0....","[-0.049280483300000004, -0.2168657631, 0.30448...",-20.192638,,,4.5573,776.872,aab
2,1.21.99.4,Homo sapiens,,L-thyroxine,C1=C(C=C(C(=C1I)OC2=CC(=C(C(=C2)I)O)I)I)CC(C(=...,MGLPQPGLWLKRLWVLLEVAVHVVVGKVLLILFPDRVKRNILAMGE...,wild,7.5,37.0,"[0.054237492, -0.04574185, 0.008709021, 0.0387...",...,"[0.0137849757, 0.0193469338, 0.0360138603, 0.0...","[-0.0300230943, -0.0246387329, -0.0323411487, ...","[0.11776212600000001, 0.24018001560000002, -0....","[-0.049280483300000004, -0.2168657631, 0.30448...",-19.658555,,,4.5573,776.872,aab
3,3.5.5.1,Saccharolobus solfataricus,P95896,Cinnamonitrile,C1=CC=C(C=C1)C=CC#N,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,mutant,7.4,70.0,"[0.07308519, -0.08452837, 0.029972142, -0.0119...",...,"[0.0622685216, 0.0037881299, 0.027366610200000...","[0.022653413900000002, -0.008009411400000001, ...","[0.1382588446, -0.2295262814, -0.0154548008, 0...","[-0.07580477000000001, -0.26762363310000004, 0...",-18.45114,-4.135167,,2.22338,129.162,aac
4,3.5.5.1,Saccharolobus solfataricus,P95896,Malononitrile,C(C#N)C#N,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,wild,7.4,70.0,"[0.07309095, -0.085310504, 0.03223636, -0.0094...",...,"[0.0620778725, 0.0053198822, 0.026737025, -0.0...","[0.0229359791, -0.0089250933, -0.0310355425000...","[0.4168405533, -0.2187459022, -0.1292698383, 0...","[-0.0909005329, -0.2873343229, 0.2866819799, 0...",-16.821293,-1.966113,,0.42366,66.063,aaa


In [3]:
def split_by_sequence_proportion(clusters, df_all, n_fold=10):
    total_sequences = len(df_all)
    fold_target = int(total_sequences / n_fold)
    print(f"Target count per fold: {int(fold_target)}")

    # 初始化每个子集的信息
    fold_ids = [set() for _ in range(n_fold)]
    fold_counts = [0 for _ in range(n_fold)]

    random.shuffle(clusters)  # 打乱聚类顺序，增加随机性

    for cluster in clusters:
        # 获取该聚类中的所有样本
        cluster_ids = set(cluster)
        cluster_df = df_all[df_all['id'].isin(cluster_ids)]
        cluster_size = len(cluster_df)

        # 找出当前负载最小的子集，且添加后不会远超目标
        for i in range(n_fold):
            if fold_counts[i] + cluster_size <= fold_target:
                fold_ids[i].update(cluster)
                fold_counts[i] += cluster_size
                break

    print(f"Final counts: {fold_counts}")

    # 构建子集 DataFrame
    split_dfs = [df_all[df_all['id'].isin(fold_ids[i])].copy() for i in range(n_fold)]
    for i, df in enumerate(split_dfs): print(f"Fold {i+1}: {len(df)} samples")

    return split_dfs


# 分析数据集中涉及的聚类数量，以及序列最多和最少的聚类的序列数量
def analyze_clusters(clusters, df, dataset_name):
    cluster_counts = []
    for cluster in clusters:
        cluster_sequences = df[df['id'].isin(cluster)]
        cluster_size = len(cluster_sequences)
        if cluster_size > 0:
            cluster_counts.append(cluster_size)

    num_clusters = len(cluster_counts)
    max_cluster_size = max(cluster_counts) if cluster_counts else 0
    min_cluster_size = min(cluster_counts) if cluster_counts else 0

    print(f"{dataset_name} Cluster Analysis:")
    print(f"  - Total Clusters: {num_clusters}")
    print(f"  - Max Cluster Size: {max_cluster_size}")
    print(f"  - Min Cluster Size: {min_cluster_size}")
    print("-" * 40)

# 解析聚类文件
clusters = parse_clstr(clstr_file)

# split
fold_index = []
fold_dfs = split_by_sequence_proportion(clusters, df_all, n_fold=5)

for fold_idx, fold_df in enumerate(fold_dfs):
    analyze_clusters(clusters, fold_df, f"Fold {fold_idx}")

Total number of clusters: 6087
Target count per fold: 9306
Final counts: [9306, 9306, 9306, 9306, 9306]
Fold 1: 9306 samples
Fold 2: 9306 samples
Fold 3: 9306 samples
Fold 4: 9306 samples
Fold 5: 9306 samples
Fold 0 Cluster Analysis:
  - Total Clusters: 1273
  - Max Cluster Size: 158
  - Min Cluster Size: 1
----------------------------------------
Fold 1 Cluster Analysis:
  - Total Clusters: 1212
  - Max Cluster Size: 257
  - Min Cluster Size: 1
----------------------------------------
Fold 2 Cluster Analysis:
  - Total Clusters: 1276
  - Max Cluster Size: 156
  - Min Cluster Size: 1
----------------------------------------
Fold 3 Cluster Analysis:
  - Total Clusters: 1104
  - Max Cluster Size: 342
  - Min Cluster Size: 1
----------------------------------------
Fold 4 Cluster Analysis:
  - Total Clusters: 1222
  - Max Cluster Size: 186
  - Min Cluster Size: 1
----------------------------------------


In [4]:
print(len(fold_dfs[0]))
fold_dfs[0].head()

9306


Unnamed: 0,ec,organism,uniprot,substrate,smiles,sequence,type,ph,t,esm2,...,prott5,prost5,molebert,transsmiles,logkm,logkcat,logkcatkm,logp,mw,id
7,3.5.2.6,Enterobacter cloacae,P05364,Cloxacillin,CC1=C(C(=NO1)C2=CC=CC=C2Cl)C(=O)NC3C4N(C3=O)C(...,MMRKSLCCALLLGISCSALATPVSEKQLAEVVANTITPLMKAQSVP...,,7.0,30.0,"[0.02284837, -0.1021297, -0.008256997, -0.0174...",...,"[0.0337163433, 0.1072318107, 0.0364475697, 0.0...","[0.0198395196, -0.0072697727000000005, -0.0390...","[0.0014601951, 0.0383156165, -0.0392316021, -0...","[-0.056131124500000004, -0.2073351443, 0.29625...",-14.508658,-5.521461,,2.54872,435.889,aae
17,1.3.1.22,Homo sapiens,,testosterone,CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C,MATATGVAEERLLAALAYLQCAVGCAVFARNRQTNSVYGRHALPSH...,,7.0,30.0,"[0.006440393, -0.043844223, -0.010472644, 0.03...",...,"[-0.0024389753, 0.08504800500000001, 0.0047531...","[-0.0351138115, -0.0130013935, -0.0182109457, ...","[0.28372663260000003, -0.08476485310000001, 0....","[-0.0756168738, -0.2516127825, 0.3206992149000...",-13.553146,,,3.8792,288.431,aai
23,2.5.1.59,Rattus norvegicus,,geranylgeranyl diphosphate,CC(=CCCC(=CCCC(=CCCC(=CCOP(=O)(O)OP(=O)(O)O)C)...,MAATEDDRLAGSGEGERLDFLRDRHVRFFQRCLQVLPERYSSLETS...,,7.7,30.0,"[0.02027583, 0.020071417, 0.017161898, 0.03857...",...,"[0.048964381200000004, 0.0857807696, 0.0552370...","[0.030504426, 0.0137842335, 0.0333012007, 0.01...","[0.18105946480000001, -0.22527112070000002, -0...","[-0.0602552779, -0.2366906852, 0.3021515608, 0...",-12.716898,,,6.3584,450.449,aam
43,3.1.3.48,Saccharomyces cerevisiae,Q00684,"6,8-Difluoro-4-methylumbelliferyl phosphate",CC1=CC(=O)OC2=C(C(=C(C=C12)F)OP(=O)(O)O)F,MRRSVYLDNTIEFLRGRVYLGAYDYTPEDTDELVFFTVEDAIFYNS...,mutant,7.0,30.0,"[-0.024799412, -0.08170397, -0.012841493, 0.02...",...,"[0.0227221362, -0.0156827234, -0.0347927324, -...","[-0.041735515, -0.0003951836, -0.0228151977, -...","[-0.08841883390000001, 0.0610267036, -0.277456...","[-0.049670942100000004, -0.2183186412, 0.30093...",-11.748648,-4.961845,6.779922,1.85112,292.13,aau
52,2.5.1.59,Saccharomyces cerevisiae,,geranylgeranyl diphosphate,CC(=CCCC(=CCCC(=CCCC(=CCOP(=O)(O)OP(=O)(O)O)C)...,MCQATNGPSRVVTKKHRKFFERHLQLLPSSHQGHDVNRMAIIFYSI...,,7.5,37.0,"[0.061458018, 0.034990672, 0.048507832, 0.0237...",...,"[0.043895810800000004, -0.0132369939, 0.018462...","[-0.0121555841, -0.0175855793, 0.058201097, -0...","[0.18105946480000001, -0.22527112070000002, -0...","[-0.0602552779, -0.2366906852, 0.3021515608, 0...",-11.512925,,,6.3584,450.449,aay


In [5]:
fold_index = []
for fold_idx, fold_df in enumerate(fold_dfs):
    # 保存 test 文件
    test_df = fold_df[['smiles', 'sequence', 'organism', 'ph', 't', 'logkm', 'logkcat', 'logkcatkm']]
    test_df.to_csv(f'./dataset/cdhit/cdhit_fold{fold_idx}_test.csv', index=False)

    # 获取 train 数据（除了当前 fold 的所有折）
    train_dfs = [df for i, df in enumerate(fold_dfs) if i != fold_idx]
    train_df = pd.concat(train_dfs)[['smiles', 'sequence', 'organism', 'ph', 't', 'logkm', 'logkcat', 'logkcatkm']]
    train_df.to_csv(f'./dataset/cdhit/cdhit_fold{fold_idx}_train.csv', index=False)
    fold_index.append([train_df.index.to_list(), test_df.index.to_list()])

pd.DataFrame(fold_index).to_pickle('./dataset/cdhit/cdhit_fold_index.pkl')

In [6]:
df_ = pd.read_csv('./dataset/cdhit/cdhit_fold0_train.csv')
df_.head()

Unnamed: 0,smiles,sequence,organism,ph,t,logkm,logkcat,logkcatkm
0,C1=CC(=C[N+](=C1)C2C(C(C(O2)COP(=O)([O-])OP(=O...,MPGWSCLVTGAGGFVGQRIIRMLVQEKELQEVRALDKVFRPETKEE...,Rattus norvegicus,7.25,37.0,-12.267948,,
1,C=C(C(=O)O)OP(=O)(O)O,MEKFLVIAGPCAIESEELLLKVGEEIKRLSEKFKEVEFVFKSSFDK...,Aquifex aeolicus,5.0,40.0,-11.928441,-0.776529,
2,[O-]P(=O)([O-])OP(=O)([O-])[O-],MSLLNVPAGKDLPEDIYVVIEIPANADPIKYEIDKESGALFVDRFM...,Escherichia coli,9.1,25.0,-11.512925,3.73767,
3,C(CC(=O)O)C(C(=O)O)N,MTMASKSDSTHDESGDEAADSTEPESALETARRQLYHAASYLDIDQ...,Halobacterium salinarum,9.0,40.0,-11.417615,3.178054,-5.654992
4,C1=CC(=C[N+](=C1)C2C(C(C(O2)COP(=O)(O)OP(=O)(O...,MGTRASCGELQADSRGSGDTAQPQPRQQAARGSAAGAESAMAEQVA...,Ovis aries,8.0,25.0,-11.330604,,
