In [1]:
from dnallm import load_config
from dnallm import load_model_and_tokenizer, DNAPredictor
from dnallm import Benchmark

In [2]:
# 读取配置文件
configs = load_config("./inference_config.yaml")

### 模型推理

In [3]:
# 读取模型和分词器
model_name = "zhangtaolab/plant-dnagpt-BPE"
# from Hugging Face
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")

Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE
Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# 创建预测器
predictor = DNAPredictor(
    model=model,
    tokenizer=tokenizer,
    config=configs
)

Use device: mps


In [5]:
# 检测是否使用GPU
print(model.device)

mps:0


In [6]:
# 输入序列进行预测
seqs = ["AAGCAAAGCTAATTATGGGTCAAAAGACTCAAAGGCCAGAATTGACGCAGCCGTTTATGAGAAGTGAGAACACAATTTCGGAGTCACTTCCTTTATTTCCTCTCTTCCTTCACTCTCTCCTATATAAACCTTCCTCTCCTCTTCCTCTCTTCTCATCTCTTCAAACCATT",
        "TGCGGGTGCTTGTCTCCGAGGCCATCGACGAGCGAGTGGCGGAGGGTGAGGAAGGGGATGGCGCGATGAGGCTATTCGTGGGCCTCCCGTGGACGCGGTGGACTCTGGCAGCTAGCATCACCCTTCCTCCTTCCTGTTGGATTGGTTTCGCTTGCACTCACCAGGACACG"]
results = predictor.predict_seqs(seqs)
print(results)

Encoding inputs:   0%|          | 0/2 [00:00<?, ? examples/s]

Predicting: 100%|██████████| 1/1 [00:18<00:00, 18.41s/it]

{0: {'sequence': 'AAGCAAAGCTAATTATGGGTCAAAAGACTCAAAGGCCAGAATTGACGCAGCCGTTTATGAGAAGTGAGAACACAATTTCGGAGTCACTTCCTTTATTTCCTCTCTTCCTTCACTCTCTCCTATATAAACCTTCCTCTCCTCTTCCTCTCTTCTCATCTCTTCAAACCATT', 'label': 'Not promoter', 'scores': {'Not promoter': 0.6557487845420837, 'Core promoter': 0.34425121545791626}}, 1: {'sequence': 'TGCGGGTGCTTGTCTCCGAGGCCATCGACGAGCGAGTGGCGGAGGGTGAGGAAGGGGATGGCGCGATGAGGCTATTCGTGGGCCTCCCGTGGACGCGGTGGACTCTGGCAGCTAGCATCACCCTTCCTCCTTCCTGTTGGATTGGTTTCGCTTGCACTCACCAGGACACG', 'label': 'Not promoter', 'scores': {'Not promoter': 0.5779901742935181, 'Core promoter': 0.42200979590415955}}}





In [7]:
# 读取文件进行预测
seq_file = './test.csv'
results, metrics = predictor.predict_file(seq_file, label_col='label', evaluate=True)
print(metrics)

Format labels:   0%|          | 0/500 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/500 [00:00<?, ? examples/s]

Predicting: 100%|██████████| 32/32 [00:41<00:00,  1.29s/it]

{'accuracy': 0.516, 'precision': 0.5304347826086957, 'recall': 0.4765625, 'f1': 0.5020576131687243, 'mcc': 0.034038872412899164, 'AUROC': 0.5080686475409836, 'AUPRC': 0.4901074319158247, 'TPR': 0.4765625, 'TNR': 0.5573770491803278, 'FPR': 0.4426229508196721, 'FNR': 0.5234375}





### 模型基准测试

In [8]:
# 初始化基准测试
benchmark = Benchmark(config=configs)

In [9]:
# 获取数据集
dataset = benchmark.get_dataset("./test.csv", seq_col="sequence", label_col="label")

In [10]:
# 指定模型
model_names = {
    "Plant DNABERT": "zhangtaolab/plant-dnabert-BPE-promoter",
    "Plant DNAGPT": "zhangtaolab/plant-dnagpt-BPE-promoter",
    "Nucleotide Transformer": "zhangtaolab/nucleotide-transformer-v2-100m-promoter",
}

In [11]:
# 运行基准测试
metrics = benchmark.run(model_names, source="modelscope")

Dataset name: custom
Model name: Plant DNABERT
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnabert-BPE-promoter
Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnabert-BPE-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnabert-BPE-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnabert-BPE-promoter


Encoding inputs:   0%|          | 0/500 [00:00<?, ? examples/s]

Use device: mps


Predicting: 100%|██████████| 32/32 [00:39<00:00,  1.25s/it]


Model name: Plant DNAGPT
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE-promoter
Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE-promoter


Encoding inputs:   0%|          | 0/500 [00:00<?, ? examples/s]

Use device: mps


Predicting: 100%|██████████| 32/32 [00:41<00:00,  1.31s/it]


Model name: Nucleotide Transformer
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/nucleotide-transformer-v2-100m-promoter
Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/nucleotide-transformer-v2-100m-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/nucleotide-transformer-v2-100m-promoter
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/nucleotide-transformer-v2-100m-promoter


Encoding inputs:   0%|          | 0/500 [00:00<?, ? examples/s]

Use device: mps


Predicting: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


In [12]:
# 画图（pbar：各种得分柱状图；pline：ROC曲线）
pbar, pline = benchmark.plot(metrics, save_path='plot.pdf')

Metrics bar charts saved to plot_metrics.pdf
ROC curves saved to plot_roc.pdf


In [13]:
# 在Notebook中展示图
pbar

In [14]:
pline