In [1]:
# First install two dependencies for generating NER datasets, and make sure you have installed bedtools in your system
# please make sure you have installed bedtools in your system
# conda install -c bioconda bedtools
# or in macos, you can use brew install bedtools
# brew install bedtools

# Omit it if you have already installed these two packages
# !uv pip install pyfastx pybedtools

In [2]:
import gzip
import random
import numpy as np
from pyfastx import Fasta
from pybedtools import BedTool
from tqdm import tqdm
from collections import defaultdict
import pickle

from dnallm import load_config, load_model_and_tokenizer

In [3]:
# Set random seed
random.seed(42)

In [4]:
# Set minimum and maximum extend length around a gene
min_ext = 50
max_ext = 100
ext_list = [[random.randint(min_ext, max_ext), random.randint(min_ext, max_ext)] for x in range(60000)]

In [5]:
# Define Named Entity Recognition (NER) tags and corresponding id
# NER includes IO, IOB, IOE, IOBES, BI, IE and BIES schemes, here we use IOB scheme
# Example:
# ..........[ exon1 ]-----[ exon2 ]-------[ exon3 ]........
# 000000000012222222234444122222222344444412222222200000000
named_entities = {
    'intergenic': 'O',
    'exon0': 'B-EXON',
    'exon1': 'I-EXON',
    'intron0': 'B-INTRON',
    'intron1': 'I-INTRON',
}
tags_id = {
    'O': 0,
    'B-EXON': 1,
    'I-EXON': 2,
    'B-INTRON': 3,
    'I-INTRON': 4,
}

In [6]:
def get_gene_annotation(gene_anno):
    cnt = 0
    gene_info = {}
    for gene in gene_anno:
        gene_info[gene] = []
        chrom = gene_anno[gene]["chrom"]
        start = gene_anno[gene]["start"]
        end = gene_anno[gene]["end"]
        strand = gene_anno[gene]["strand"]
        isoforms = gene_anno[gene]["isoform"]
        # Get representative isoform（longest）
        if not isoforms:
            continue
        lso_lens = [(iso, sum([(x[2]-x[1]) for x in isoforms[iso]])) for iso in isoforms]
        representative = sorted(lso_lens, key=lambda x:x[1])[-1][0]
        isoform_info = isoforms[representative]
        iso_start = min([x[1] for x in isoform_info])
        iso_end = max([x[2] for x in isoform_info])

        if iso_start == start and iso_end == end:
            is_reverse = False if strand == "+" else True
            # Get intron annotation
            last = 0
            for region in sorted(isoform_info, key=lambda x:x[1], reverse=is_reverse):
                if strand == "+":
                    if last:
                        intron = [chrom, last, region[1], "intron", strand]
                        if intron[1] < intron[2]:
                            gene_info[gene].append(intron)
                    last = region[2]
                else:
                    if last:
                        intron = [chrom, region[2], last, "intron", strand]
                        if intron[1] < intron[2]:
                            gene_info[gene].append(intron)
                    last = region[1]
                gene_info[gene].append([chrom, region[1], region[2], region[0], strand])
        cnt += 1

    return gene_info

In [7]:
def tokenization(genome, gene_anno, gene_info, tokenizer, outfile, ext_list, sampling=1e7):
    """
    For each gene in `gene_anno`, extract the annotated exonic (and flanking) DNA subsequences,
    run the tokenizer once per subsequence with offset_mapping enabled, 
    and record the (genomic_start, genomic_end, token) tuples for all non-special tokens.

    - genome: dict mapping chromosome → SeqRecord (so that genome[chrom][start:end].seq gives a Seq)
    - gene_anno: dict mapping gene_name → { "chrom": str, "strand": "+" or "-", ... }
    - gene_info: dict mapping gene_name → list of (feature_id, exon_start, exon_end) or similar
    - tokenizer: a HuggingFace‐style tokenizer that supports return_offsets_mapping
    - outfile: (unused here, but you can write token_pos to it later)
    - ext_list: list of (left_extension, right_extension) tuples parallel to gene_anno order
    - sampling: random sampling the given number of genes for tokenization
    """
    # 1) Precompute special-tokens set for O(1) membership checks
    sp_tokens = set(tokenizer.special_tokens_map.values())

    token_pos = {}
    # Since gene_anno is likely a dict, we need a stable way of iterating + indexing ext_list.
    # We'll assume ext_list[i] corresponds to the i-th gene in `list(gene_anno.keys())`.
    gene_list = list(gene_anno.keys())
    if len(gene_list) > sampling:
        gene_list = random.sample(gene_list, int(sampling))

    for num, gene in enumerate(tqdm(gene_list, desc="Genes")):
        chrom = gene_anno[gene]["chrom"]
        strand = gene_anno[gene]["strand"]

        # Skip genes not in gene_info or with empty annotation
        if gene not in gene_info or not gene_info[gene]:
            continue

        # Determine exon‐range and extended boundaries
        exon_coords = gene_info[gene]
        # start = minimum exon_start; end = maximum exon_end
        start = min(exon[1] for exon in exon_coords)
        end   = max(exon[2] for exon in exon_coords)

        left_ext, right_ext = ext_list[num]
        ext_start = max(0, start - left_ext)
        ext_end   = end + right_ext

        # Shortcut: grab the full chromosome record once
        chrom_record = genome[chrom]

        # Build a list of (genomic_anchor, seq_string) for "+" or "-" strand
        seqinfo = []
        if strand == "+":
            #  1) upstream flank
            try:
                upstream_seq = chrom_record[ext_start:start].seq
            except Exception:
                # If slicing fails, log and skip
                print(f"ERROR: {chrom}\t{ext_start}\t{start}")
                upstream_seq = ""
            seqinfo.append((ext_start, str(upstream_seq)))

            #  2) each exon
            for feature in exon_coords:
                exon_start = feature[1]
                exon_end   = feature[2]
                if exon_start >= exon_end:
                    continue
                seq = chrom_record[exon_start:exon_end].seq
                seqinfo.append((exon_start, str(seq)))

            #  3) downstream flank
            downstream_seq = chrom_record[end:ext_end].seq
            seqinfo.append((end, str(downstream_seq)))

        else:  # strand == "-"
            # On the reverse‐strand, we want the reverse complement ("antisense").
            # Note: .antisense == .reverse_complement() for most SeqRecord slicing.
            # We still record the genomic anchor as if it were the left index on the + strand.
            # But because the sequence is reversed, offset_mapping will need to be mapped differently.

            #  1) “upstream” on reverse strand = (end → ext_end) in forward coords, but take antisense
            try:
                flank_seq = chrom_record[end:ext_end].antisense
            except Exception:
                print(f"ERROR (rev): {chrom}\t{end}\t{ext_end}")
                flank_seq = ""
            seqinfo.append((ext_end, str(flank_seq)))

            #  2) each exon (reverse‐complement)
            for feature in exon_coords:
                exon_start = feature[1]
                exon_end   = feature[2]
                if exon_start >= exon_end:
                    continue
                seq = chrom_record[exon_start:exon_end].antisense
                # For mapping, we’ll anchor each token by the 5′-most position on the minus strand,
                # but because the sequence is reversed, the “first character” of seq actually corresponds
                # to genomic position = exon_end - 1 in forward coordinates, and the “last character” ↦ exon_start.
                seqinfo.append((exon_end, str(seq)))

            #  3) downstream on reverse strand = (ext_start → start) in forward coords, but antisense
            flank_seq = chrom_record[ext_start:start].antisense
            seqinfo.append((start, str(flank_seq)))

        # Initialize the list for this gene
        token_pos[gene] = []

        # For each (anchor, raw_seq), run a single tokenizer(...) call
        for anchor, raw_seq in seqinfo:
            if not raw_seq:
                continue

            # 1) Tokenize with offsets (add_special_tokens=False so we skip [CLS], [SEP], etc.)
            #    “offset_mapping” is a list of (char_start, char_end) for each token in raw_seq.
            # encoding = tokenizer(
            #     raw_seq,
            #     return_offsets_mapping=True,
            #     add_special_tokens=False
            # )
            # offsets = encoding["offset_mapping"]
            # token_ids = encoding["input_ids"]
            token_ids = tokenizer.encode(raw_seq, add_special_tokens=False)
            tok_strs = tokenizer.convert_ids_to_tokens(token_ids)
            offsets = []
            cursor  = 0
            for tok in tok_strs:
                char_start = cursor
                char_end   = cursor + len(tok)
                offsets.append((char_start, char_end))
                cursor = char_end
            if len(offsets) != len(token_ids):
                # This should never happen in a well‐formed tokenizer, but just in case:
                raise RuntimeError("Offset mapping length ≠ token_ids length")

            # 2) Iterate through each token + offset, skip special tokens, then map back to genome coords
            for idx, (token_id, (char_start, char_end)) in enumerate(zip(token_ids, offsets)):
                token_str = tokenizer.convert_ids_to_tokens(token_id)

                # Skip if it’s one of the special tokens (“[PAD]”, “[CLS]”, etc.)
                if token_str in sp_tokens:
                    continue

                if strand == "+":
                    # On the forward strand, raw_seq[0] ↦ genomic position “anchor”.
                    # So any token covering raw_seq[char_start:char_end] ↦ genome positions [anchor+char_start : anchor+char_end]
                    g_start = anchor + char_start
                    g_end   = anchor + char_end

                else:
                    # On the reverse strand, raw_seq was already antisense (reverse), and “anchor” is the forward‐strand coordinate
                    # of the first character in raw_seq.  That first character of raw_seq is actually genome position (anchor-1),
                    # and the last character of raw_seq is genome position (anchor - len(raw_seq)).
                    # More generally, for raw_seq index i, the corresponding forward‐strand position is:
                    #     g_pos = anchor - 1 - i
                    #
                    # Thus, if the token covers raw_seq[char_start:char_end] (i.e. from i = char_start to i = char_end-1),
                    # its genomic coordinates (inclusive‐exclusive) on the forward strand are:
                    #   g_end = (anchor - 1 - char_start) + 1  = anchor - char_start
                    #   g_start = (anchor - 1 - (char_end - 1))  = anchor - char_end
                    #
                    # We want to store them as [g_start, g_end] with g_start < g_end.  So:
                    g_start = anchor - char_end
                    g_end   = anchor - char_start

                token_pos[gene].append([g_start, g_end, token_str])

    # save sequences and tokens
    with open(outfile, "w") as outf:
        for gene in tqdm(token_pos, desc="Save token positions"):
            chrom = gene_anno[gene]["chrom"]
            strand = gene_anno[gene]["strand"]
            for token in token_pos[gene]:
                print(chrom, token[0], token[1], token[2], gene, strand, sep="\t", file=outf)

    return token_pos

In [8]:
def tokens_to_nerdata(tokens_bed, annotation_bed, outfile, named_entities, tags_id):
    """
    Build a token‐level NER dataset by intersecting `tokens_bed` with `annotation_bed`.
    Returns a dict: { 'id': [...geneIDs...], 'sequence': [[token1, token2, …], …],
                     'labels': [[label1, label2, …], …] } 
    and also writes two files:
      1) “outfile” as a pickle of ner_info,
      2) “<outfile>.token_sizes” containing “gene<TAB>token_count” for each gene.
    """

    ne = named_entities
    # Build a map from “baseName + '0' → named_entities[...] → tags_id[...]”
    zero_map = {}
    one_map  = {}
    for base_name, ner_label in ne.items():
        # “intergenic” maps to 'O' no matter whether we’re at a “start” or “inside” —
        # so we do it for both 'intergenic0' and 'intergenic1'.
        if base_name == "intergenic":
            zero_map["intergenic0"] = ner_label
            one_map["intergenic1"] = ner_label
            continue

        # base_name will be something like “exon0” or “exon1”, “intron0”, “intron1”
        # We want to know, whenever the token’s name is exactly “exon” and we’re at a “start” boundary,
        # pick the B-EXON label.  If the name is “exon” but it matched the previous gene-level “name”,
        # then we call named_entities["exon1"] to get “I-EXON”.
        if base_name.endswith("0"):
            zero_map[base_name] = ner_label
        else:
            one_map[base_name]  = ner_label

    # 2) Perform the intersection once (Loj = “left outer join”) so we keep every token
    intersection = BedTool(tokens_bed).intersect(annotation_bed, loj=True)

    # 3) Prepare our output containers
    ner_info = {
        "id":       [],  # list of gene IDs (in the same order as we append)
        "sequence": [],  # each element is a list-of-strings (tokens)
        "labels":   []   # each element is a list-of-ints (NER tags)
    }

    # We'll accumulate (gene, token_count) pairs in-memory, then write them in bulk
    sizes_buffer = []

    # 4) Use defaultdict(set) to track “which token‐IDs we’ve already seen for each gene”
    token_seen = defaultdict(set)

    current_gene = None
    tokens_list  = []
    labels_list  = []
    last_name    = None  # to know if “name == last_name” (inside vs start)

    # 5) Iterate through every interval from the intersection
    #    We rely on the fact that BedTool.intersect(...) returns results in ascending
    #    genomic order, and within each gene, that will appear “in order of token positions.”
    for iv in intersection:
        # Instead of “str(iv).split('\t')”, do:
        chrom   = iv.chrom
        start   = iv.start   # integer
        end     = iv.end     # integer
        token   = iv.name    # 4th column of tokens_bed
        gene    = iv.fields[4]   # 5th column of tokens_bed (original gene ID)
        gene2   = iv.fields[9]   # 10th field (unused here, but was in your code)
        name    = iv.fields[10]  # 11th field = the annotation “name”
        # Build a unique‐ID for this token instance
        token_id = (start, end)

        # 6) When we see a new gene (i.e. “gene != current_gene”), we flush the previous gene’s data
        if gene != current_gene:
            # flush old gene if it exists
            if current_gene is not None:
                # Only append if we actually collected ≥1 token for current_gene
                if tokens_list:
                    sizes_buffer.append((current_gene, len(tokens_list)))
                    ner_info["id"].append(current_gene)
                    ner_info["sequence"].append(tokens_list)
                    ner_info["labels"].append(labels_list)
                    count = len(ner_info["id"])
                    if count % 100 == 0:
                        print(count)
            # Reset for the new gene
            current_gene = gene
            tokens_list  = []
            labels_list  = []
            last_name    = None

        # 7) If we’ve already seen this exact (start,end) “token_id” for this gene, skip
        if token_id in token_seen[gene]:
            continue
        token_seen[gene].add(token_id)

        # 8) Determine the correct NER‐tag (integer) for this token
        #    - If name == "-1" → treat as “intergenic”
        #    - If name == last_name → we pick “inside” (use one_map[name + "1"])
        #    - else → we pick “start”   (use zero_map[name + "0"])
        if name == "-1":
            base_name = "intergenic"
            ner_label = ne[base_name]          # always “O”
        else:
            # If it matched the previous token’s annotation name, choose inside
            if name == last_name:
                lookup_key = name + "1"       # e.g. “exon1” → I-EXON
                ner_label  = one_map.get(lookup_key)
                # If somehow it’s missing, fall back to “start” logic
                if ner_label is None:
                    ner_label = zero_map[name + "0"]
            else:
                # new annotation segment → start
                lookup_key = name + "0"       # e.g. “exon0” → B-EXON
                ner_label  = zero_map.get(lookup_key)
                # If it’s missing, fall back to “intergenic”
                if ner_label is None:
                    ner_label = ne["intergenic"]

        ner_tag = tags_id[ner_label]
        last_name = name

        # 9) Append the token string + numeric label
        tokens_list.append(token)
        labels_list.append(ner_tag)

    # 10) Don’t forget to flush the final gene once the loop ends
    if current_gene is not None and tokens_list:
        sizes_buffer.append((current_gene, len(tokens_list)))
        ner_info["id"].append(current_gene)
        ner_info["sequence"].append(tokens_list)
        ner_info["labels"].append(labels_list)
        print(".", end="")

    # 11) Write out the token_sizes file in one go
    sizes_file = outfile.rsplit(".", 1)[0] + ".token_sizes"
    with open(sizes_file, "w") as tsf:
        for gene_name, count in sizes_buffer:
            tsf.write(f"{gene_name}\t{count}\n")

    # 12) Finally, pickle‐dump ner_info
    with open(outfile, "wb") as handle:
        pickle.dump(ner_info, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return ner_info

In [9]:
# Download genome and gene annotation (make sure you have wget command in your path)
!wget -c https://rice.uga.edu/osa1r7_download/osa1_r7.asm.fa.gz
!wget -c https://rice.uga.edu/osa1r7_download/osa1_r7.all_models.gff3.gz

--2025-08-23 09:40:33--  https://rice.uga.edu/osa1r7_download/osa1_r7.asm.fa.gz
正在解析主机 rice.uga.edu (rice.uga.edu)... 128.192.162.131
正在连接 rice.uga.edu (rice.uga.edu)|128.192.162.131|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 416 Requested Range Not Satisfiable

    文件已下载完成；不会进行任何操作。

--2025-08-23 09:40:35--  https://rice.uga.edu/osa1r7_download/osa1_r7.all_models.gff3.gz
正在解析主机 rice.uga.edu (rice.uga.edu)... 128.192.162.131
正在连接 rice.uga.edu (rice.uga.edu)|128.192.162.131|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 416 Requested Range Not Satisfiable

    文件已下载完成；不会进行任何操作。



In [10]:
# Load genome sequence
genome_file = "osa1_r7.asm.fa.gz"
genome = Fasta(genome_file)
# Load annotation
gene_anno = {}
with gzip.open("osa1_r7.all_models.gff3.gz", "rt") as infile:
    for line in tqdm(infile):
        if line.startswith("#") or line.startswith("\n"):
            continue
        info = line.strip().split("\t")
        chrom = info[0]
        datatype = info[2]
        start = int(info[3]) - 1
        end = int(info[4])
        strand = info[6]
        description = info[8].split(";")
        if datatype == "gene":
            for item in description:
                if item.startswith("Name="):
                    gene = item[5:]
            if gene not in gene_anno:
                gene_anno[gene] = {}
                gene_anno[gene]["chrom"] = chrom
                gene_anno[gene]["start"] = start
                gene_anno[gene]["end"] = end
                gene_anno[gene]["strand"] = strand
                gene_anno[gene]["isoform"] = {}
        elif datatype in ["exon"]:
            for item in description:
                if item.startswith("Parent="):
                    isoform = item[7:].split(',')[0]
            if isoform not in gene_anno[gene]["isoform"]:
                gene_anno[gene]["isoform"][isoform] = []
            gene_anno[gene]["isoform"][isoform].append([datatype, start, end])

# Get full gene annotation information and save
gene_info = get_gene_annotation(gene_anno)
annotation_bed = "rice_annotation.bed"
with open(annotation_bed, "w") as outf:
    for gene in sorted(gene_anno, key=lambda x: (gene_anno[x]["chrom"], gene_anno[x]["start"])):
        chrom = gene_anno[gene]["chrom"]
        strand = gene_anno[gene]["strand"]
        if strand == "+":
            for item in gene_info[gene]:
                print(item[0], item[1], item[2], gene, item[3], item[4], sep="\t", file=outf)
        else:
            for item in gene_info[gene][::-1]:
                print(item[0], item[1], item[2], gene, item[3], item[4], sep="\t", file=outf)

813791it [00:01, 651827.12it/s]


In [11]:
# Load configs, model and tokenizer
configs = load_config("./ner_task_config.yaml")
model_name = "zhangtaolab/plant-dnagpt-6mer"
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")

Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer
Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer


Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-6mer and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# 将基因序列tokenization并生成命名体识别所需格式数据集
print("# Performing sequence tokenization...")
tokens_bed = "rice_genes_tokens.bed"

token_pos = tokenization(genome, gene_anno, gene_info, tokenizer, tokens_bed, ext_list, sampling=2000)

# Performing sequence tokenization...


Genes:   0%|          | 8/2000 [00:00<00:26, 74.41it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
Genes: 100%|██████████| 2000/2000 [00:20<00:00, 98.11it/s] 
Save token positions: 100%|██████████| 1926/1926 [00:00<00:00, 2858.01it/s]


In [13]:



print("# Generate NER dataset...")

dataset = 'rice_gene_ner.pkl'
ner_info = tokens_to_nerdata(tokens_bed, annotation_bed, dataset, named_entities, tags_id)


# Generate NER dataset...
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
.

In [14]:
from dnallm import DNADataset, DNATrainer

In [15]:
# Load the datasets
datasets = DNADataset.load_local_data("./rice_gene_ner.pkl", seq_col="sequence", label_col="labels", tokenizer=tokenizer, max_length=1024)

# Encode the sequences with given task's data collator
datasets.encode_sequences(task=configs['task'].task_type, remove_unused_columns=True)

# Split the dataset into train, test, and validation sets
datasets.split_data()

Format labels:   0%|          | 0/1926 [00:00<?, ? examples/s]

Encoding inputs:   0%|          | 0/1926 [00:00<?, ? examples/s]

In [None]:
# check the dataset
if hasattr(datasets.dataset, 'keys'):
    for split_name in datasets.dataset.keys():
        print(f"{split_name}: {len(datasets.dataset[split_name])} samples")

train: 1348 samples
test: 385 samples
val: 193 samples


In [16]:
# Initialize the trainer
trainer = DNATrainer(
    model=model,
    config=configs,
    datasets=datasets
)

In [17]:
# Start training
metrics = trainer.train()
print(metrics)

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
500,0.5122,0.237552,0.918376,0.572056,0.516667,0.542952
1000,0.175,0.195971,0.937131,0.639966,0.602381,0.620605


{'train_runtime': 761.5615, 'train_samples_per_second': 5.31, 'train_steps_per_second': 1.328, 'total_flos': 2113463702642688.0, 'train_loss': 0.3411682097777415, 'epoch': 3.0}


In [18]:
# Do prediction on the test set
predictions = trainer.predict()
print(predictions.metrics)

{'test_loss': 0.21138842403888702, 'test_accuracy': 0.9275411185150473, 'test_precision': 0.6377063423110338, 'test_recall': 0.5921742638160549, 'test_f1': 0.6140974691487138, 'test_runtime': 23.6895, 'test_samples_per_second': 16.252, 'test_steps_per_second': 1.055}
