In [1]:
import torch
import fm
import pandas as pd
import joblib
# 加载模型
rna_fm_model, alphabet = fm.pretrained.rna_fm_t12()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
rna_fm_model.to(device)
rna_fm_model.eval()  # disables dropout for deterministic results
file_path = "dataset/mRNA/mRNA_sublocation_TrainingSet.tsv"
df = pd.read_csv(
    file_path,
    sep="\t",
    header=None,
    names=["ensembl_transcript_id", "name", "cdna", "tag"],
    skiprows=1,
)
embeddings_list = []
tags_list = []
dtype = torch.bfloat16
for index, row in df.iterrows():
    cDNA = row["cdna"]

    # 根据序列长度确定分割策略
    if len(cDNA) > 5998:
        part1 = cDNA[:2999]
        part2 = cDNA[-2999:]

        # 对两部分分别按照1000个核苷酸的步长进行分割
        fragments_part1 = [part1[i : i + 1022] for i in range(0, len(part1), 1022)]
        fragments_part2 = [part2[i : i + 1022] for i in range(0, len(part2), 1022)]

        # 合并两部分的片段列表
        fragments = fragments_part1 + fragments_part2
    else:
        # 分割成多个1000个核苷酸的片段
        fragments = [cDNA[i : i + 1022] for i in range(0, len(cDNA), 1022)]

    # 初始化嵌入列表
    fragment_embeddings = []

    # 对每个片段进行处理

    for index_f, fragment in enumerate(fragments):
        # 创建数据元组
        data = (row["name"], fragment)
        # 准备数据
        batch_labels, batch_strs, batch_tokens = batch_converter([data])
        batch_tokens = batch_tokens.to(device)

        # 提取嵌入
        with torch.no_grad():
            results = rna_fm_model(batch_tokens, repr_layers=[12])
            fragment_embedding = results["representations"][12]
        if index_f == 0:
            # 将嵌入添加到列表
            fragment_embedding = fragment_embedding.squeeze(0)[0:-1]
            fragment_embeddings.append(fragment_embedding)
        elif index_f == len(fragments) - 1:
            fragment_embedding = fragment_embedding.squeeze(0)[1:]
            fragment_embeddings.append(fragment_embedding)
        else:
            fragment_embedding = fragment_embedding.squeeze(0)[1:-1]
            fragment_embeddings.append(fragment_embedding)

    # 拼接所有片段的嵌入
    concatenated_embeddings = torch.cat(fragment_embeddings, dim=0)
    embeddings_list.append(concatenated_embeddings.cpu())
    tags_list.append(row["tag"])
with open(f"dataset/mRNA/mRNA_train_embeddings_list.pkl", "wb") as f:
    joblib.dump(embeddings_list, f)

with open(f"dataset/mRNA/mRNA_train_tags_list.pkl", "wb") as f:
    joblib.dump(tags_list, f)

In [None]:
import torch
import fm
import pandas as pd
import joblib
# 加载模型
rna_fm_model, alphabet = fm.pretrained.rna_fm_t12()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
rna_fm_model.to(device)
rna_fm_model.eval()  # disables dropout for deterministic results
file_path = "dataset/mRNA/mRNA_sublocation_TestSet.tsv"
df = pd.read_csv(
    file_path,
    sep="\t",
    header=None,
    names=["ensembl_transcript_id", "name", "cdna", "tag"],
    skiprows=1,
)
embeddings_list = []
tags_list = []
dtype = torch.bfloat16
for index, row in df.iterrows():
    cDNA = row["cdna"]

    # 根据序列长度确定分割策略
    if len(cDNA) > 5998:
        part1 = cDNA[:2999]
        part2 = cDNA[-2999:]

        # 对两部分分别按照1000个核苷酸的步长进行分割
        fragments_part1 = [part1[i : i + 1022] for i in range(0, len(part1), 1022)]
        fragments_part2 = [part2[i : i + 1022] for i in range(0, len(part2), 1022)]

        # 合并两部分的片段列表
        fragments = fragments_part1 + fragments_part2
    else:
        # 分割成多个1000个核苷酸的片段
        fragments = [cDNA[i : i + 1022] for i in range(0, len(cDNA), 1022)]

    # 初始化嵌入列表
    fragment_embeddings = []

    # 对每个片段进行处理

    for index_f, fragment in enumerate(fragments):
        # 创建数据元组
        data = (row["name"], fragment)
        # 准备数据
        batch_labels, batch_strs, batch_tokens = batch_converter([data])
        batch_tokens = batch_tokens.to(device)

        # 提取嵌入
        with torch.no_grad():
            results = rna_fm_model(batch_tokens, repr_layers=[12])
            fragment_embedding = results["representations"][12]
        if index_f == 0:
            # 将嵌入添加到列表
            fragment_embedding = fragment_embedding.squeeze(0)[0:-1]
            fragment_embeddings.append(fragment_embedding)
        elif index_f == len(fragments) - 1:
            fragment_embedding = fragment_embedding.squeeze(0)[1:]
            fragment_embeddings.append(fragment_embedding)
        else:
            fragment_embedding = fragment_embedding.squeeze(0)[1:-1]
            fragment_embeddings.append(fragment_embedding)

    # 拼接所有片段的嵌入
    concatenated_embeddings = torch.cat(fragment_embeddings, dim=0)
    embeddings_list.append(concatenated_embeddings.cpu())
    tags_list.append(row["tag"])
with open(f"dataset/mRNA/mRNA_test_embeddings_list.pkl", "wb") as f:
    joblib.dump(embeddings_list, f)

with open(f"dataset/mRNA/mRNA_test_tags_list.pkl", "wb") as f:
    joblib.dump(tags_list, f)