### Preparation

In [None]:
# !wget -c https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-62/fasta/arabidopsis_thaliana/cds/Arabidopsis_thaliana.TAIR10.cds.all.fa.gz

# !pip install pyfastx

In [2]:
from pyfastx import Fasta

genome = Fasta("Arabidopsis_thaliana.TAIR10.cds.all.fa.gz")
with open("ath_cds.csv", "w") as f:
    print("seq_id,sequence", file=f)
    for seq in genome:
        print(f"{seq.name},{seq.seq}", file=f)

In [3]:
import copy
from dnallm import load_config, load_model_and_tokenizer, DNADataset, DNATrainer

In [4]:
# Load the datasets
data_path = "ath_cds.csv"
datasets = DNADataset.load_local_data(data_path, seq_col="sequence", sep=",")

# Sampling the datasets
datasets.sampling(0.1, seed=42, overwrite=True)
datasets.split_data(seed=42)

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
seq = datasets.dataset["test"][10]["sequence"]
prompt = seq[:10]
print("Length:", len(seq))
print("Prompt sequence:", prompt)
print("Full sequence:  ", seq)

Length: 207
Prompt sequence: ATGACTTGCA
Full sequence:   ATGACTTGCACGACAGAGATAGATATTTTGAAGTGGACAGTGAGGTATTGTTCGAGTTTAGCTGCACACCTTCTAACTCCTACGAGATTGTTCAAATATGAAATTCAACAACAGAGCGATTTGAGAAATGCAACTGAAAACAAAACTGAAAAATATATTTCTGACGACGTCGGTCATTGTAGACATACATACATGCAAATCAGATAA


### DNAGPT

In [6]:
# Load the config file
configs = load_config("./finetune_config.yaml")
configs["finetune"].output_dir = "./outputs_dnagpt"

In [7]:
# Load the model and tokenizer
model_name = "zhangtaolab/plant-dnagpt-singlebase"
# from Hugging Face
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
tokenizer.model_max_length = 2048

Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-singlebase


2025-12-28 17:50:16,908 - modelscope - INFO - Got 1 files, start to download ...


Processing 1 items:   0%|          | 0.00/1.00 [00:00<?, ?it/s]

Downloading [model.safetensors]:   0%|          | 0.00/328M [00:00<?, ?B/s]

2025-12-28 17:50:28,037 - modelscope - INFO - Download model 'zhangtaolab/plant-dnagpt-singlebase' successfully.


17:50:28 - dnallm.utils.support - INFO - Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-singlebase


In [8]:
# Encode the datasets
data = copy.deepcopy(datasets)
data.encode_sequences(tokenizer=tokenizer)

Encoding inputs:   0%|          | 0/3382 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/966 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/484 [00:00<?, ? examples/s]

In [9]:
# Initialize the trainer
trainer = DNATrainer(
    model=model,
    config=configs,
    datasets=data
)

In [10]:
# Start training
metrics = trainer.train()
print(metrics)

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss,Validation Loss
200,1.2892,1.278937
400,1.2723,1.276405


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


{'train_runtime': 361.7112, 'train_samples_per_second': 18.7, 'train_steps_per_second': 1.172, 'total_flos': 1767379304448000.0, 'train_loss': 1.2800271870954982, 'epoch': 2.0}


In [11]:
model.eval()

tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=len(seq)+5, num_return_sequences=5, do_sample=True, top_k=50, top_p=0.95, temperature=1.0)

Setting `pad_token_id` to `eos_token_id`:9 for open-end generation.


In [12]:
print("Prompt:               ", prompt)
for i, out in enumerate(outputs):
    out_seq = tokenizer.decode(out, skip_special_tokens=True)
    print(f"Generated sequence {i}: ", out_seq.replace(" ", ""))
print("Raw sequence:         ", seq)

Prompt:                ATGACTTGCA
Generated sequence 0:  ATGACTTGCATGTGGTCGTTGGAATAGGAGGTCACTATGTGGTTTTGAACCCAAGATTCTCATTTGATGGCTTATCAATCTCCATGCTTCTAACATTAGGTTTTCTCTCGTTCTTCTTCTTCTTTTCCTCTTATGGCGGCGGCGGCTCTCCAGTAGCTTGTGTGGAATCTGGAAAGGCATATTGTAGACCAAGGAATCTATCTCCAGCTA
Generated sequence 1:  ATGACTTGCAGATATACGTCACGAAGAAACCAAAATTCGACCTGACAGAAGGGAATCAAGCTGGTGAGGTTGAAGAACTCGCTATCTTCAGGTCTAACAGTATACTCCTCCAAAGGAAAGAAACGCTCTTCTTCCACCCTTTCTCCGTCGATGATGGTGTCGTTGAGAAGGAAATCAGAGCAGTTAAAGAGGTTAGACCGAGGTTAGCGT
Generated sequence 2:  ATGACTTGCATTCTCTTTTCGAAGCTGTGTTTTATCTTCAGATCTCAAGATTTCGATGTGGGATTTTGAACCAGCTGATTGATAAAGCTGGTTCAGAGTCTGGTTCTGGCCAACAAGAGGATTCAGCTTTGATGATTTTGGGAGCAGATTGCTCTACCTCAAGAGTATGGTTACATCGGCTTTCCATGATTGATGTAAAAGTTCTTGACA
Generated sequence 3:  ATGACTTGCAATTACGGTCTGCGAGGACACACTCCTCGAAGCTCGTCTCCGACCCTAACGTCGAGTCCAACTCCGAGTACAACCCCTTCGATTTGCCCAAAAGCCGGATGGATGCTTTTGGAGCAATCAAAGGAGAGCAGACAGAGGCGTTACACCTAACACCCTACCTTACAGGCCAGTTCTGCGCCGAGCTTCAAGAATTGAAAAAAA
Generated sequence 4:  ATGACTT

### MegaDNA

In [13]:
# Load the config file
configs = load_config("./finetune_config.yaml")
configs["task"].task_type = "embedding"
configs["finetune"].output_dir = "./outputs_megadna"

In [None]:
# !git clone https://github.com/lingxusb/megaDNA.git
# !cd megaDNA
# !pip install .

In [15]:
# Load the model and tokenizer
model_name = "lingxusb/megaDNA_updated"
# from Hugging Face
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
tokenizer.model_max_length = 2048



Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

17:59:18 - dnallm.utils.support - INFO - Model files are stored in /Users/forrest/.cache/huggingface/hub/models--lingxusb--megaDNA_updated/snapshots/ed298be539e1667b52a1181a6472528a34dd2ef9


In [16]:
# Encode the datasets
data = copy.deepcopy(datasets)
data.encode_sequences(tokenizer=tokenizer)

Encoding inputs:   0%|          | 0/3382 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/966 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/484 [00:00<?, ? examples/s]

In [17]:
# Specific processing for MEGA-DNA
data.dataset = data.dataset.remove_columns(["seq_id", "sequence", "token_type_ids", "attention_mask"])
data.dataset = data.dataset.rename_column("input_ids", "ids")
data.dataset

DatasetDict({
    train: Dataset({
        features: ['ids'],
        num_rows: 3382
    })
    test: Dataset({
        features: ['ids'],
        num_rows: 966
    })
    val: Dataset({
        features: ['ids'],
        num_rows: 484
    })
})

In [18]:
# Initialize the trainer
trainer = DNATrainer(
    model=model,
    config=configs,
    datasets=data
)

In [19]:
# Define a custom trainer for MEGA-DNA
class MegaDNATrainer(type(trainer.trainer)):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        loss = model(**inputs, return_value = "loss")
        if return_outputs:
            logits = model(**inputs, return_value = "logits")
            return (loss, logits)
        
        return loss

trainer.customize_trainer(MegaDNATrainer)
trainer.trainer.can_return_loss = True

In [20]:
# Start training
metrics = trainer.train()
print(metrics)

  self.gen = func(*args, **kwds)


Step,Training Loss,Validation Loss
200,1.3221,1.306477
400,1.305,1.30198


  self.gen = func(*args, **kwds)
  self.gen = func(*args, **kwds)


AttributeError: 'MEGADNA' object has no attribute 'save_pretrained'

In [None]:
model.eval()

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = [model.generate(inputs["input_ids"], seq_len=len(seq)+5, temperature=0.95, filter_thres=0.0) for _ in range(5)]

  0%|          | 0/202 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


  0%|          | 0/202 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

In [None]:
print("Prompt:               ", prompt)
for i, out in enumerate(outputs):
    out_seq = tokenizer.decode(out[0], skip_special_tokens=True)
    print(f"Generated sequence {i}: ", out_seq.replace(" ", ""))
print("Raw sequence:         ", seq)

Prompt:                ATGACTTGCA
Generated sequence 0:  ATGACTTGCATGGCATCGAGCAATCACGAGTGCTCGAGTAGTTGGTGGCAGTCAGCCCATAGTGGATGCTCCACTAGTCTTGGGTTGACCTCCTCTGATTGGAAGTCTATGATTGTTGGACCATCCCCGTTTGGATCCCCATCTCTGGCTGGCTTTAGTACTAACTGGATCACTAGGACTCCTAATCATTCATCAGGTCTCGGGACCTGTGC
Generated sequence 1:  ATGACTTGCAAAAGGAGAGTATTTCTTGGCTGCCTCTCTGCCGAACCAAACATTCAAGAACCTCCCGAAATTGCTCGTGAAACTGTAACGCTCGGTATCAAAAACCCGAAATCAAGAAGGGAATATCTTACTCTCTACAAAAAACGAAGGGGAAAGATCTTTGTTCATCCGAGCGCTGATGTGCACATTATGGAACTCGAGATGGGTTTTCA
Generated sequence 2:  ATGACTTGCATGTTCTTCCATTCTTTCTCCTCACCTTGTCTTATCCGTAGCCCCCTGCTGCTTCAGGACTTTCGGTCTCTCCTGCTCTTTCTCCTGCTCCTGCTCTCTCTCACCGGGGATCTTCCCACATTTCTGACGCTGCCAGAAGTGGTGAAGCTGCTGGGCTTCCTCCCCTTCGTGGAGTTTCCTTTCTCCGCAGCCCGCCCATGTTG
Generated sequence 3:  ATGACTTGCATTTCCAGAGAAGACGAAATGCAAGCAATCCTCCACGAAGAGCGGGAAGAGATCAACGAGCTTCGCATTGAAGATGAAGAAGATGAAGGTGAACATGTTACCTCTTACAAGAAGAATGAATCGCTCACCACTCATGATGATCTGCTGGATATCGTTCTTGATGAGCTCAAGAAAGAGCGGATTGGTAATGAAGAAGCTGAGAT
Generated sequence 4: 