# OmniGenome - A Demonstration based on RNA Secondary Structure Prediction
GitHub: https://github.com/yangheng95/OmniGenome
OmniGenome Hub: Huggingface Spaces

## Introduction
OmniGenome is a comprehensive package designed for pretrained genomic foundation models (FMs) development and FM benchmark. 
OmniGenome have the following key features:
- Automated genomic FM benchmarking on public genomic datasets
- Scalable genomic FM training and fine-tuning on genomic tasks
- Diversified genomic FMs implementation
- Easy-to-use pipeline for genomic FM development with no coding expertise required
- Accessible OmniGenome Hub for sharing FMs, datasets, and pipelines
- Extensive documentation and tutorials for genomic FM development

We begin to introduce OmniGenome by delivering a demonstration to train a model to predict RNA secondary structures. The dataset used in this demonstration is the bpRNA dataset which contains RNA sequences and their corresponding secondary structures. The secondary structure of an RNA sequence is a set of base pairs that describe the folding of the RNA molecule. The secondary structure of an RNA sequence is important for understanding the function of the RNA molecule. In this demonstration, we will train a model to predict the secondary structure of an RNA sequence given its primary sequence.

## Requirements
OmniGenome requires the following recommended dependencies:
- Python 3.9+
- PyTorch 2.0.0+
- Transformers 4.37.0+
- Pandas 1.3.3+
- Others in case of specific tasks

pip install OmniGenome


## Fine-tuning Genomic FMs for RNA Secondary Structure Prediction

### Step 1: Import Libraries

In [None]:
import os

import autocuda
import torch
from metric_visualizer import MetricVisualizer

from omnigenbench import OmniDatasetForTokenClassification
from omnigenbench import ClassificationMetric
from omnigenbench import OmniSingleNucleotideTokenizer, OmniKmersTokenizer
from omnigenbench import OmniModelForTokenClassification
from omnigenbench import Trainer

### Step 2: Define and Initialize the Tokenizer

In [None]:
# Predefined dataset label mapping
label2id = {"(": 0, ")": 1, ".": 2}

# The is FM is exclusively powered by the OmniGenome package
model_name_or_path = "anonymous8/OmniGenome-186M"

# Generally, we use the tokenizers from transformers library, such as AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# However, OmniGenome provides specialized tokenizers for genomic data, such as single nucleotide tokenizer and k-mers tokenizer
# we can force the tokenizer to be used in the model
tokenizer = OmniSingleNucleotideTokenizer.from_pretrained(model_name_or_path)

### Step 3: Define and Initialize the Model

In [None]:
# We have implemented a diverse set of genomic models in OmniGenome, please refer to the documentation for more details
ssp_model = OmniModelForTokenClassification(
    model_name_or_path,
    tokenizer=tokenizer,
    label2id=label2id,
)

### Step 4: Define and Load the Dataset

In [None]:
# necessary hyperparameters
epochs = 10
learning_rate = 2e-5
weight_decay = 1e-5
batch_size = 8
max_length = 512
seeds = [45]  # Each seed will be used for one run


# Load the dataset according to the path
train_file = "toy_datasets/Archive2/train.json"
test_file = "toy_datasets/Archive2/test.json"
valid_file = "toy_datasets/Archive2/valid.json"

train_set = OmniDatasetForTokenClassification(
    data_source=train_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
test_set = OmniDatasetForTokenClassification(
    data_source=test_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
valid_set = OmniDatasetForTokenClassification(
    data_source=valid_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

### Step 5: Define the Metrics
We have implemented a diverse set of genomic metrics in OmniGenome, please refer to the documentation for more details.
Users can also define their own metrics by inheriting the `OmniGenomeMetric` class. 
The `compute_metrics` can be a metric function list and each metric function should return a dictionary of metrics.

In [None]:
compute_metrics = [
    ClassificationMetric(ignore_y=-100).accuracy_score,
    ClassificationMetric(ignore_y=-100, average="macro").f1_score,
    ClassificationMetric(ignore_y=-100).matthews_corrcoef,
]


## Step 6: Define and Initialize the Trainer

In [None]:
# Initialize the MetricVisualizer for logging the metrics
mv = MetricVisualizer(name="OmniGenome-186M-SSP")

for seed in seeds:
    optimizer = torch.optim.AdamW(
        ssp_model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    trainer = Trainer(
        model=ssp_model,
        train_loader=train_loader,
        eval_loader=valid_loader,
        test_loader=test_loader,
        batch_size=batch_size,
        epochs=epochs,
        optimizer=optimizer,
        compute_metrics=compute_metrics,
        seeds=seed,
        device=autocuda.auto_cuda(),
    )

    metrics = trainer.train()
    test_metrics = metrics["test"][-1]
    mv.log(model_name_or_path.split("/")[-1], "F1", test_metrics["f1_score"])
    mv.log(
        model_name_or_path.split("/")[-1],
        "Accuracy",
        test_metrics["accuracy_score"],
    )
    print(metrics)
    mv.summary()

### Step 7. Experimental Results Visualization
The experimental results are visualized in the following plots. The plots show the F1 score and accuracy of the model on the test set for each run. The average F1 score and accuracy are also shown.

|### Step 8. Model Checkpoint for Sharing
The model checkpoint can be saved and shared with others for further use. The model checkpoint can be loaded using the following code:

**Regular checkpointing and resuming are good practices to save the model at different stages of training.**

In [None]:
path_to_save = "OmniGenome-186M-SSP"
ssp_model.save(path_to_save, overwrite=True)

# Load the model checkpoint
ssp_model = ssp_model.load(path_to_save)
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])


# What if someone doesn't know how to initialize the model?

In [None]:
# We can load the model checkpoint using the ModelHub
from omnigenbench import ModelHub

ssp_model = ModelHub.load("OmniGenome-186M-SSP")
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

## Step 8. Model Inference

In [None]:
examples = [
    "GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC",
    "AUCUGUACUAGUUAGCUAACUAGAUCUGUAUCUGGCGGUUCCGUGGAAGAACUGACGUGUUCAUAUUCCCGACCGCAGCCCUGGGAGACGUCUCAGAGGC",
]

results = ssp_model.inference(examples)
structures = ["".join(prediction) for prediction in results["predictions"]]
print(results)
print(structures)

### Step 9. Pipeline Creation
The OmniGenome package provides pipelines for genomic FM development. The pipeline can be used to train, fine-tune, and evaluate genomic FMs. The pipeline can be used with a single command to train a genomic FM on a dataset. The pipeline can also be used to fine-tune a pre-trained genomic FM on a new dataset. The pipeline can be used to evaluate the performance of a genomic FM on a dataset. The pipeline can be used to generate predictions using a genomic FM.

In [None]:
# from omnigenbench import Pipeline, PipelineHub
# 
# pipeline = Pipeline(
#     name="OmniGenome-186M-SSP-Pipeline",
#     # model_name_or_path="OmniGenome-186M-SSP",  # The model name or path can be specified
#     # tokenizer="OmniGenome-186M-SSP",  # The tokenizer can be specified
#     model_name_or_path=ssp_model,
#     tokenizer=ssp_model.tokenizer,
#     datasets={
#         "train": "toy_datasets/train.json",
#         "test": "toy_datasets/test.json",
#         "valid": "toy_datasets/valid.json",
#     },
#     trainer=trainer,
#     device=ssp_model.model.device,
# )

### Using the Pipeline

In [None]:
# results = pipeline(examples[0])
# print(results)
# 
# pipeline.train()
# 
# pipeline.save("OmniGenome-186M-SSP-Pipeline", overwrite=True)
# 
# pipeline = PipelineHub.load("OmniGenome-186M-SSP-Pipeline")
# results = pipeline(examples)
# print(results)

## Web Demo for RNA Secondary Structure Prediction

In [None]:
import os
import time
import base64
import tempfile
from pathlib import Path
import json
import numpy as np
import gradio as gr
import RNA
from omnigenbench import ModelHub

# 加载模型
ssp_model = ModelHub.load("OmniGenome-186M-SSP")

# 临时 SVG 存储目录
TEMP_DIR = Path(tempfile.mkdtemp())
print(f"Using temporary directory: {TEMP_DIR}")


def ss_validity_loss(rna_strct: str) -> float:
    left = right = 0
    dots = rna_strct.count('.')
    for c in rna_strct:
        if c == '(':
            left += 1
        elif c == ')':
            if left:
                left -= 1
            else:
                right += 1
        elif c != '.':
            raise ValueError(f"Invalid char {c}")
    return (left + right) / (len(rna_strct) - dots + 1e-8)


def find_invalid_positions(struct: str) -> list:
    stack, invalid = [], []
    for i, c in enumerate(struct):
        if c == '(': stack.append(i)
        elif c == ')':
            if stack:
                stack.pop()
            else:
                invalid.append(i)
    invalid.extend(stack)
    return invalid


def generate_svg_datauri(rna_seq: str, struct: str) -> str:
    """生成 SVG 并返回 Base64 URI"""
    try:
        path = TEMP_DIR / f"{hash(rna_seq+struct)}.svg"
        RNA.svg_rna_plot(rna_seq, struct, str(path))
        time.sleep(0.1)
        svg_bytes = path.read_bytes()
        b64 = base64.b64encode(svg_bytes).decode('utf-8')
    except Exception as e:
        err = ('<svg xmlns="http://www.w3.org/2000/svg" width="400" height="200">'
               f'<text x="50" y="100" fill="red">Error: {e}</text></svg>')
        b64 = base64.b64encode(err.encode()).decode('utf-8')
    return f"data:image/svg+xml;base64,{b64}"


def fold(rna_seq: str, gt_struct: str):
    """展示 Ground Truth、ViennaRNA 与模型预测的结构对比"""
    if not rna_seq.strip():
        return "", "", "", ""
    # Ground Truth: 用户输入优先
    ground = gt_struct.strip() if gt_struct and gt_struct.strip() else ""
    gt_uri = generate_svg_datauri(rna_seq, ground) if ground else ""

    # ViennaRNA 预测
    vienna_struct, vienna_energy = RNA.fold(rna_seq)
    vienna_uri = generate_svg_datauri(rna_seq, vienna_struct)

    # 模型预测
    result = ssp_model.inference(rna_seq)
    pred = "".join(result.get('predictions', []))
    if ss_validity_loss(pred):
        for i in find_invalid_positions(pred):
            pred = pred[:i] + '.' + pred[i+1:]
    pred_uri = generate_svg_datauri(rna_seq, pred)

    # 统计信息
    match_gt = (sum(a==b for a,b in zip(ground, pred)) / len(ground)) if ground else 0
    match_vienna = sum(a==b for a,b in zip(vienna_struct, pred)) / len(vienna_struct)
    stats = (
        f"GT↔Pred Match: {match_gt:.2%}" + (" | " if ground else "") +
        f"Vienna↔Pred Match: {match_vienna:.2%}"
    )

    # 合并 HTML：三图水平排列
    combined = (
        '<div style="display:flex;justify-content:space-around;">'
        f'{f"<div><h4>Ground Truth</h4><img src=\"{gt_uri}\" style=\"max-width:100%;height:auto;\"/></div>" if ground else ""}'
        f'<div><h4>ViennaRNA</h4><img src=\"{vienna_uri}\" style=\"max-width:100%;height:auto;\"/></div>'
        f'<div><h4>Prediction</h4><img src=\"{pred_uri}\" style=\"max-width:100%;height:auto;\"/></div>'
        '</div>'
    )
    return ground, vienna_struct, pred, stats, combined


def sample_rna_sequence():
    """从测试集中抽样，返回序列与 Ground Truth 结构"""
    try:
        exs = [json.loads(l) for l in open('toy_datasets/Archive2/test.json')]
        ex = exs[np.random.randint(len(exs))]
        return ex['seq'], ex.get('label','')
    except Exception as e:
        return f"加载样本出错: {e}", ""

# Gradio UI
with gr.Blocks(css="""
.heading {text-align:center;color:#2a4365;}
.controls {display:flex;gap:10px;margin:20px 0;}
.status {padding:10px;background:#f0f4f8;border-radius:4px;white-space:pre;}
""") as demo:
    gr.Markdown("# RNA 结构预测对比", elem_classes="heading")
    with gr.Row():
        rna_input = gr.Textbox(label="RNA 序列", lines=3)
        structure_input = gr.Textbox(label="Ground Truth 结构 (可选)", lines=3)
    with gr.Row(elem_classes="controls"):
        sample_btn = gr.Button("抽取样本")
        run_btn = gr.Button("预测并对比", variant="primary")
    stats_out    = gr.Textbox(label="统计信息", interactive=False, elem_classes="status")
    gt_out       = gr.Textbox(label="Ground Truth", interactive=False)
    vienna_out   = gr.Textbox(label="ViennaRNA 结构", interactive=False)
    pred_out     = gr.Textbox(label="Prediction 结构", interactive=False)
    combined_view= gr.HTML(label="三图对比视图")

    run_btn.click(
        fold,
        inputs=[rna_input, structure_input],
        outputs=[gt_out, vienna_out, pred_out, stats_out, combined_view]
    )
    sample_btn.click(
        sample_rna_sequence,
        outputs=[rna_input, structure_input]
    )

    demo.launch(share=True)


### Conclusion
In this demonstration, we have shown how to fine-tune a genomic foundation model for RNA secondary structure prediction using the OmniGenome package. We have also shown how to use the trained model for inference and how to create a web demo for RNA secondary structure prediction. We hope this demonstration will help you get started with genomic foundation model development using OmniGenome.