# 使用CLEAN模型对模型输出的结果进行预测

CLEAN github：https://github.com/tttianhao/CLEAN?tab=readme-ov-file \
Paper：https://www.science.org/doi/10.1126/science.adf2465

In [1]:
import pandas as pd

## 1. 处理输入数据

读取数据

In [91]:
# EC生成模型输出的文件为txt格式
# 放在EC_training/output_samples/文件夹下

label = '2'
p = 0.75
version = '6'
input_file_path = '../../EC_training/output_samples/'+version+'_samples_EC_label_'+label+'_'+str(p)+'.txt'
train_data = 'split100'


# 读取数据
df = pd.read_csv(input_file_path, header=None, names=['sequence', 'score'])
print(df.head())
print(df.shape)

                                            sequence     score
0  MSDNQAPKLTFVSLGCPKALVDSERILTQLRSEGYDLVAKYDDADV...  0.303411
1  MITIGITGGIGSGKTTVARYFREHGVPVVDADIIAREVVRPGSECL...  1.017966
2  MNTPQIIKNESEAIKIAILSGSFNLTNYGAMKDVFANNAEIIKSIG...  1.883462
3  MANKYIVSWDMLQIHARKLASRLMPSEQWKGIIAVSRGGLVPGALL...  0.114259
4  MKVFLDTANVDEIKKANALGVISGVTTNPSLIAKEGRNFEEVINEI...  0.475771
(5120, 2)


除去不合格的序列

In [92]:
# 去除df中长度大于等于508的序列
df = df[df['sequence'].apply(lambda x: len(x) < 508)]
print(df.head())
print(df.shape)


# # 选择score排行前80%的数据作为评估集合
# df = df.sort_values(by='score', ascending=True) # 从小到大排序
# df = df.head(int((df.shape[0])*0.8))        # 选择前80%
# print(df.head())
# print(df.shape)

                                            sequence     score
0  MSDNQAPKLTFVSLGCPKALVDSERILTQLRSEGYDLVAKYDDADV...  0.303411
1  MITIGITGGIGSGKTTVARYFREHGVPVVDADIIAREVVRPGSECL...  1.017966
2  MNTPQIIKNESEAIKIAILSGSFNLTNYGAMKDVFANNAEIIKSIG...  1.883462
3  MANKYIVSWDMLQIHARKLASRLMPSEQWKGIIAVSRGGLVPGALL...  0.114259
4  MKVFLDTANVDEIKKANALGVISGVTTNPSLIAKEGRNFEEVINEI...  0.475771
(4264, 2)


将txt转化为fasta格式

In [93]:
# 将df输出到inputs文件夹下，以fasta格式
output_file_path = 'data/inputs/1_samples_EC_label_'+label+'_'+str(p)+'.fasta'
with open(output_file_path, 'w') as f:
    for i in range(df.shape[0]):
        f.write('>sample'+str(i)+'\n')
        f.write(df.iloc[i]['sequence']+'\n')

## 2. 将数据输入模型中，进行预测

In [94]:
# 将需要预测的fasta文件放在outputs文件夹下
# 输出的ESM嵌入文件放在outputs/esm_data文件夹下
import csv
import subprocess
import os
import torch


def retrive_esm1b_embedding(fasta_name):
    esm_script = "esm/scripts/extract.py"   # 用于提取ESM（Evolutionary Scale Modeling）嵌入的Python脚本的路径
    esm_out = "data/esm_data"   # ESM嵌入的输出目录
    esm_type = "esm1b_t33_650M_UR50S"   # ESM模型的类型
    fasta_name = "data/" + fasta_name + ".fasta"
    command = ["python", esm_script, esm_type, 
              fasta_name, esm_out, "--include", "mean"]
    result = subprocess.run(command)
    return result.returncode


def prepare_infer_fasta(fasta_name):
    # e.g. fasta_name = 'inputs/1_samples_EC_label_3.6.4.12_0.5'
    # retrive_esm1b_embedding(fasta_name) 
    returncode = retrive_esm1b_embedding(fasta_name)    # 提取ESM嵌入
    if returncode == 0:
        print("Command ran successfully")
    else:
        print("Command failed with return code:", returncode)
    csvfile = open('data/' + fasta_name +'.csv', 'w', newline='')
    csvwriter = csv.writer(csvfile, delimiter = '\t')
    csvwriter.writerow(['Entry', 'EC number', 'Sequence'])  # 写入列名
    fastafile = open('data/' + fasta_name +'.fasta', 'r')
    for i in fastafile.readlines():
        if i[0] == '>':
            csvwriter.writerow([i.strip()[1:], ' ', ' '])

fasta_data = '1_samples_EC_label_'+label+'_'+str(p)
test_data = 'inputs/' + fasta_data
prepare_infer_fasta(test_data)

Transferred model to GPU
Read data/inputs/1_samples_EC_label_2_0.75.fasta with 4264 sequences
Processing 1 of 271 batches (60 sequences)
Processing 2 of 271 batches (49 sequences)
Processing 3 of 271 batches (44 sequences)
Processing 4 of 271 batches (44 sequences)
Processing 5 of 271 batches (43 sequences)
Processing 6 of 271 batches (39 sequences)
Processing 7 of 271 batches (35 sequences)
Processing 8 of 271 batches (34 sequences)
Processing 9 of 271 batches (32 sequences)
Processing 10 of 271 batches (32 sequences)
Processing 11 of 271 batches (32 sequences)
Processing 12 of 271 batches (31 sequences)
Processing 13 of 271 batches (30 sequences)
Processing 14 of 271 batches (29 sequences)
Processing 15 of 271 batches (29 sequences)
Processing 16 of 271 batches (29 sequences)
Processing 17 of 271 batches (28 sequences)
Processing 18 of 271 batches (28 sequences)
Processing 19 of 271 batches (28 sequences)
Processing 20 of 271 batches (28 sequences)
Processing 21 of 271 batches (28 se

In [95]:
from src.CLEAN.model import LayerNormNet
from src.CLEAN.utils import * 
from src.CLEAN.evaluate import *

def infer_maxsep(train_data, test_data, report_metrics = False, 
                 pretrained=True, model_name=None, gmm = None):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    dtype = torch.float32
    id_ec_train, ec_id_dict_train = get_ec_id_dict('data/' + train_data + '.csv')
    id_ec_test, _ = get_ec_id_dict('data/' + test_data + '.csv')
    # load checkpoints
    # NOTE: change this to LayerNormNet(512, 256, device, dtype) 
    # and rebuild with [python build.py install]
    # if inferencing on model trained with supconH loss
    model = LayerNormNet(512, 128, device, dtype)
    
    if pretrained:
        try:
            checkpoint = torch.load('data/pretrained/'+ train_data +'.pth', map_location=device)
        except FileNotFoundError as error:
            raise Exception('No pretrained weights for this training data')
    else:
        try:
            checkpoint = torch.load('data/model/'+ model_name +'.pth', map_location=device)
        except FileNotFoundError as error:
            raise Exception('No model found!')
            
    model.load_state_dict(checkpoint)
    model.eval()
    # load precomputed EC cluster center embeddings if possible
    if train_data == "split70":
        emb_train = torch.load('data/pretrained/70.pt', map_location=device)
    elif train_data == "split100":
        emb_train = torch.load('data/pretrained/100.pt', map_location=device)
    else:
        emb_train = model(esm_embedding(ec_id_dict_train, device, dtype))
        
    emb_test = model_embedding_test(id_ec_test, model, device, dtype)
    eval_dist = get_dist_map_test(emb_train, emb_test, ec_id_dict_train, id_ec_test, device, dtype)
    seed_everything()
    eval_df = pd.DataFrame.from_dict(eval_dist)
    ensure_dirs("results")
    out_filename = "results/" +  test_data
    write_max_sep_choices(eval_df, out_filename, gmm=gmm)
    if report_metrics:
        pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
        pred_probs = get_pred_probs(out_filename, pred_type='_maxsep')
        true_label, all_label = get_true_labels('data/' + test_data)
        pre, rec, f1, roc, acc = get_eval_metrics(
            pred_label, pred_probs, true_label, all_label)
        print("############ EC calling results using maximum separation ############")
        print('-' * 75)
        print(f'>>> total samples: {len(true_label)} | total ec: {len(all_label)} \n'
            f'>>> precision: {pre:.3} | recall: {rec:.3}'
            f'| F1: {f1:.3} | AUC: {roc:.3} ')
        print('-' * 75)


infer_maxsep(train_data, test_data, report_metrics=False, pretrained=True, gmm = 'data/pretrained/gmm_ensumble.pkl')
# removing dummy csv file
os.remove("data/"+ test_data +'.csv')

The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([4264, 128])


100%|██████████| 5242/5242 [00:00<00:00, 39568.12it/s]


Calculating eval distance map, between 4264 test ids and 5242 train EC cluster centers


4264it [00:04, 1037.69it/s]


## 3. 对生成的结果文件进行分析

In [96]:
# 读取上面的csv文件，得到预测结果
df = pd.read_csv('results/'+test_data+'_maxsep.csv', header=None, names=['sequence', 'EC_number'], usecols=[0, 1])    # 只读取第0列和第1列，防止异常数据
print(df.head())

# 处理EC_number列，得到纯EC编号的string格式
label_split = label.split('.')
if len(label_split) == 4:
    df['EC_number'] = df['EC_number'].apply(lambda x: x.split('.')[0][3:]+'.'+x.split('.')[1]+'.'+x.split('.')[2]+'.'+x.split('.')[3][:-2])
elif len(label_split) == 3:
    df['EC_number'] = df['EC_number'].apply(lambda x: x.split('.')[0][3:]+'.'+x.split('.')[1]+'.'+x.split('.')[2])
elif len(label_split) == 2:
    df['EC_number'] = df['EC_number'].apply(lambda x: x.split('.')[0][3:]+'.'+x.split('.')[1])
else:
    df['EC_number'] = df['EC_number'].apply(lambda x: x.split('.')[0][3:])
print(df.head())

  sequence           EC_number
0  sample0   EC:2.8.4.4/0.9980
1  sample1  EC:2.7.1.24/0.9974
2  sample2  EC:6.2.1.14/0.0000
3  sample3  EC:2.4.2.22/0.9953
4  sample4   EC:2.2.1.2/0.9886
  sequence EC_number
0  sample0         2
1  sample1         2
2  sample2         6
3  sample3         2
4  sample4         2


In [97]:
# 输出预测结果中每个EC编号的数量
print(df['EC_number'].value_counts())

EC_number
2    3634
6     215
1     146
3     140
4      84
5      37
7       8
Name: count, dtype: int64


## 4.输出最终准确率

In [98]:
# 输出最终的生成准确率
print(df[df['EC_number'] == label].shape[0]/df.shape[0])

0.8522514071294559
