In [None]:
from dnallm import load_config
from dnallm import load_model_and_tokenizer, DNAInference
from dnallm import Benchmark

In [None]:
# Load configurations
configs = load_config("./inference_config.yaml")

### Model inference

In [None]:
# Load model and tokenizer
model_name = "zhangtaolab/plant-dnagpt-BPE"
# from Hugging Face
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")

In [None]:
# Create predictor
predictor = DNAInference(
    model=model,
    tokenizer=tokenizer,
    config=configs
)

In [None]:
# Check device
print(model.device)

In [None]:
# Predict sequences
seqs = ["GCACTTTACTTAAAGTAAAAAGAAAAAAACTGTGCGCTCTCCAACTACCGCAGCAACGTGTCGAGCACAGGAACACGTGTCACTTCAGTTCTTCCAATTGCTGGGGCCCACCACTGTTTACTTCTGTACAGGCAGGTGGCCATGCTGATGACACTCCACACTCCTCGACTTTCGTAGCAGCAAGCCACGCGTGACCGAGAAGCCTCGCG",
        "TTGTCATCACATTTGATCAACTACGATTTATGTTGTACTATTCATCTGTTTTCTCCTTTTTTTTTCCCTTATTGACAGGTTGTGGAGGTTCACAACGAACAGAATACAAGAAATTTTGGTAATCATTTGAGGACTTTCATGGGGTATGAATTGTGTGCTATAATAAATTAA"]
results = predictor.predict_seqs(seqs)
print(results)

In [None]:
# Predict from file
seq_file = './test.csv'
results, metrics = predictor.predict_file(seq_file, label_col='label', evaluate=True)
print(metrics)

### Models benchmark

In [None]:
# Initialize benchmark
benchmark = Benchmark(config=configs)

In [None]:
# Get dataset
dataset = benchmark.get_dataset("./test.csv", seq_col="sequence", label_col="label")

In [None]:
# Define models to benchmark
model_names = {
    "Plant DNABERT": "zhangtaolab/plant-dnabert-BPE-promoter",
    "Plant DNAGPT": "zhangtaolab/plant-dnagpt-BPE-promoter",
    "Nucleotide Transformer": "zhangtaolab/nucleotide-transformer-v2-100m-promoter",
}

In [None]:
# Run benchmark
metrics = benchmark.run(model_names, source="modelscope")

In [None]:
# Plotting（pbar：bar chart for all the scores；pline：ROC curve）
pbar, pline = benchmark.plot(metrics, save_path='plot.pdf')

In [None]:
# Show plots
pbar

In [None]:
pline