# OmniGenome - A Demonstration based on RNA Secondary Structure Prediction
GitHub: https://github.com/yangheng95/OmniGenome
OmniGenome Hub: Huggingface Spaces

## Introduction
OmniGenome is a comprehensive package designed for pretrained genomic foundation models (FMs) development and FM benchmark. 
OmniGenome have the following key features:
- Automated genomic FM benchmarking on public genomic datasets
- Scalable genomic FM training and fine-tuning on genomic tasks
- Diversified genomic FMs implementation
- Easy-to-use pipeline for genomic FM development with no coding expertise required
- Accessible OmniGenome Hub for sharing FMs, datasets, and pipelines
- Extensive documentation and tutorials for genomic FM development

We begin to introduce OmniGenome by delivering a demonstration to train a model to predict RNA secondary structures. The dataset used in this demonstration is the bpRNA dataset which contains RNA sequences and their corresponding secondary structures. The secondary structure of an RNA sequence is a set of base pairs that describe the folding of the RNA molecule. The secondary structure of an RNA sequence is important for understanding the function of the RNA molecule. In this demonstration, we will train a model to predict the secondary structure of an RNA sequence given its primary sequence.

## Requirements
OmniGenome requires the following recommended dependencies:
- Python 3.9+
- PyTorch 2.0.0+
- Transformers 4.37.0+
- Pandas 1.3.3+
- Others in case of specific tasks

pip install OmniGenome


## Fine-tuning Genomic FMs for RNA Secondary Structure Prediction

### Step 1: Import Libraries

In [1]:
import os

import autocuda
import torch
from metric_visualizer import MetricVisualizer

from omnigenome import OmniGenomeDatasetForTokenClassification
from omnigenome import ClassificationMetric
from omnigenome import OmniSingleNucleotideTokenizer, OmniKmersTokenizer
from omnigenome import OmniGenomeModelForTokenClassification
from omnigenome import Trainer

### Step 2: Define and Initialize the Tokenizer

In [2]:
# Predefined dataset label mapping
label2id = {"(": 0, ")": 1, ".": 2}

# The is FM is exclusively powered by the OmniGenome package
model_name_or_path = "anonymous8/OmniGenome-186M"

# Generally, we use the tokenizers from transformers library, such as AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# However, OmniGenome provides specialized tokenizers for genomic data, such as single nucleotide tokenizer and k-mers tokenizer
# we can force the tokenizer to be used in the model
tokenizer = OmniSingleNucleotideTokenizer.from_pretrained(model_name_or_path)

### Step 3: Define and Initialize the Model

In [3]:
# We have implemented a diverse set of genomic models in OmniGenome, please refer to the documentation for more details
ssp_model = OmniGenomeModelForTokenClassification(
    model_name_or_path,
    tokenizer=tokenizer,
    label2id=label2id,
)

You are using a model of type omnigenome to instantiate a model of type mprna. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at anonymous8/OmniGenome-186M were not used when initializing OmniGenomeModel: ['classifier.bias', 'classifier.weight', 'dense.bias', 'dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing OmniGenomeModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing OmniGenomeModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of OmniGenomeModel were not initial

[2024-08-11 17:43:16] (0.0.8alpha) Model Name: OmniGenomeModelForTokenClassification
Model Metadata: {'library_name': 'OmniGenome', 'omnigenome_version': '0.0.8alpha', 'torch_version': '2.1.2+cu12.1+gita8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb', 'transformers_version': '4.43.2', 'model_cls': 'OmniGenomeModelForTokenClassification', 'tokenizer_cls': 'OmniSingleNucleotideTokenizer', 'model_name': 'OmniGenomeModelForTokenClassification'}
Base Model Name: anonymous8/OmniGenome-186M
Model Type: omnigenome
Model Architecture: ['OmniGenomeModel', 'OmniGenomeForTokenClassification', 'OmniGenomeForMaskedLM', 'OmniGenomeModelForSeq2SeqLM', 'OmniGenomeForTSequenceClassification', 'OmniGenomeForTokenClassification', 'OmniGenomeForSeq2SeqLM']
Model Parameters: 185.886801 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "_name_or_path": "anonymous8/OmniGenome-186M",
  "architectures": [
    "OmniGenomeModel",
    "OmniGenomeForTokenClassification",
    "OmniGenomeForMaskedLM",
  

### Step 4: Define and Load the Dataset

In [4]:
# necessary hyperparameters
epochs = 10
learning_rate = 2e-5
weight_decay = 1e-5
batch_size = 8
max_length = 512
seeds = [45]  # Each seed will be used for one run


# Load the dataset according to the path
train_file = "toy_datasets/Archive2/train.json"
test_file = "toy_datasets/Archive2/test.json"
valid_file = "toy_datasets/Archive2/valid.json"

train_set = OmniGenomeDatasetForTokenClassification(
    data_source=train_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
test_set = OmniGenomeDatasetForTokenClassification(
    data_source=test_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
valid_set = OmniGenomeDatasetForTokenClassification(
    data_source=valid_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

[2024-08-11 17:43:16] (0.0.8alpha) Detected max_length=512 in the dataset, using it as the max_length.
[2024-08-11 17:43:16] (0.0.8alpha) Loading data from toy_datasets/Archive2/train.json...
[2024-08-11 17:43:16] (0.0.8alpha) Loaded 608 examples from toy_datasets/Archive2/train.json
[2024-08-11 17:43:16] (0.0.8alpha) Detected shuffle=True, shuffling the examples...


100%|██████████████████████████████████████████████████████████████████████████████| 608/608 [00:00<00:00, 5249.80it/s]


[2024-08-11 17:43:17] (0.0.8alpha) {'avg_seq_len': 130.54276315789474, 'max_seq_len': 501, 'min_seq_len': 56, 'avg_label_len': 501.0, 'max_label_len': 501, 'min_label_len': 501}
[2024-08-11 17:43:17] (0.0.8alpha) Detected max_length=512 in the dataset, using it as the max_length.
[2024-08-11 17:43:17] (0.0.8alpha) Loading data from toy_datasets/Archive2/test.json...
[2024-08-11 17:43:17] (0.0.8alpha) Loaded 82 examples from toy_datasets/Archive2/test.json
[2024-08-11 17:43:17] (0.0.8alpha) Detected shuffle=True, shuffling the examples...


100%|████████████████████████████████████████████████████████████████████████████████| 82/82 [00:00<00:00, 3625.84it/s]


[2024-08-11 17:43:17] (0.0.8alpha) {'avg_seq_len': 131.23170731707316, 'max_seq_len': 321, 'min_seq_len': 67, 'avg_label_len': 321.0, 'max_label_len': 321, 'min_label_len': 321}
[2024-08-11 17:43:17] (0.0.8alpha) Detected max_length=512 in the dataset, using it as the max_length.
[2024-08-11 17:43:17] (0.0.8alpha) Loading data from toy_datasets/Archive2/valid.json...
[2024-08-11 17:43:17] (0.0.8alpha) Loaded 76 examples from toy_datasets/Archive2/valid.json
[2024-08-11 17:43:17] (0.0.8alpha) Detected shuffle=True, shuffling the examples...


100%|████████████████████████████████████████████████████████████████████████████████| 76/76 [00:00<00:00, 5782.41it/s]

[2024-08-11 17:43:17] (0.0.8alpha) {'avg_seq_len': 117.39473684210526, 'max_seq_len': 308, 'min_seq_len': 60, 'avg_label_len': 308.0, 'max_label_len': 308, 'min_label_len': 308}





### Step 5: Define the Metrics
We have implemented a diverse set of genomic metrics in OmniGenome, please refer to the documentation for more details.
Users can also define their own metrics by inheriting the `OmniGenomeMetric` class. 
The `compute_metrics` can be a metric function list and each metric function should return a dictionary of metrics.

In [5]:
compute_metrics = [
    ClassificationMetric(ignore_y=-100).accuracy_score,
    ClassificationMetric(ignore_y=-100, average="macro").f1_score,
    ClassificationMetric(ignore_y=-100).matthews_corrcoef,
]


## Step 6: Define and Initialize the Trainer

In [6]:
# Initialize the MetricVisualizer for logging the metrics
mv = MetricVisualizer(name="OmniGenome-186M-SSP")

for seed in seeds:
    optimizer = torch.optim.AdamW(
        ssp_model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    trainer = Trainer(
        model=ssp_model,
        train_loader=train_loader,
        eval_loader=valid_loader,
        test_loader=test_loader,
        batch_size=batch_size,
        epochs=epochs,
        optimizer=optimizer,
        compute_metrics=compute_metrics,
        seeds=seed,
        device=autocuda.auto_cuda(),
    )

    metrics = trainer.train()
    test_metrics = metrics["test"][-1]
    mv.log(model_name_or_path.split("/")[-1], "F1", test_metrics["f1_score"])
    mv.log(
        model_name_or_path.split("/")[-1],
        "Accuracy",
        test_metrics["accuracy_score"],
    )
    print(metrics)
    mv.summary()

Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.43it/s]


[2024-08-11 17:43:19] (0.0.8alpha) {'accuracy_score': 0.2790193842645382, 'f1_score': 0.28151975296578563, 'matthews_corrcoef': -0.09291127922709266}


Epoch 1/10 Loss: 0.7989: 100%|█████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.54it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.99it/s]


[2024-08-11 17:44:11] (0.0.8alpha) {'accuracy_score': 0.8913340935005701, 'f1_score': 0.8935400779001638, 'matthews_corrcoef': 0.8353253240117546}


Epoch 2/10 Loss: 0.6545: 100%|█████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.54it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.16it/s]


[2024-08-11 17:45:02] (0.0.8alpha) {'accuracy_score': 0.9076396807297605, 'f1_score': 0.9095038559875431, 'matthews_corrcoef': 0.8604032983011348}


Epoch 3/10 Loss: 0.6302: 100%|█████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.18it/s]


[2024-08-11 17:45:54] (0.0.8alpha) {'accuracy_score': 0.9148232611174458, 'f1_score': 0.9163503175903402, 'matthews_corrcoef': 0.86969111358666}


Epoch 4/10 Loss: 0.6151: 100%|█████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.19it/s]


[2024-08-11 17:46:45] (0.0.8alpha) {'accuracy_score': 0.9169897377423033, 'f1_score': 0.9185686268915924, 'matthews_corrcoef': 0.8725737867525207}


Epoch 5/10 Loss: 0.6071: 100%|█████████████████████████████████████████████████████████| 76/76 [00:48<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.22it/s]


[2024-08-11 17:47:36] (0.0.8alpha) {'accuracy_score': 0.9189281641961231, 'f1_score': 0.9205276415383489, 'matthews_corrcoef': 0.875436812852734}


Epoch 6/10 Loss: 0.6013: 100%|█████████████████████████████████████████████████████████| 76/76 [00:48<00:00,  1.56it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.21it/s]


[2024-08-11 17:48:28] (0.0.8alpha) {'accuracy_score': 0.9210946408209806, 'f1_score': 0.9226092911100953, 'matthews_corrcoef': 0.879263171602823}


Epoch 7/10 Loss: 0.5989: 100%|█████████████████████████████████████████████████████████| 76/76 [00:48<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.24it/s]


[2024-08-11 17:49:19] (0.0.8alpha) {'accuracy_score': 0.9238312428734321, 'f1_score': 0.9253576750498466, 'matthews_corrcoef': 0.8831977559814651}


Epoch 8/10 Loss: 0.5979: 100%|█████████████████████████████████████████████████████████| 76/76 [00:48<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.23it/s]


[2024-08-11 17:50:10] (0.0.8alpha) {'accuracy_score': 0.9234891676168757, 'f1_score': 0.9250099970359921, 'matthews_corrcoef': 0.8820785908253933}


Epoch 9/10 Loss: 0.5955: 100%|█████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.17it/s]


[2024-08-11 17:51:00] (0.0.8alpha) {'accuracy_score': 0.9240592930444698, 'f1_score': 0.9255602479349917, 'matthews_corrcoef': 0.883211983456326}


Epoch 10/10 Loss: 0.5913: 100%|████████████████████████████████████████████████████████| 76/76 [00:49<00:00,  1.55it/s]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.24it/s]


[2024-08-11 17:51:51] (0.0.8alpha) {'accuracy_score': 0.9225769669327252, 'f1_score': 0.9241115922227455, 'matthews_corrcoef': 0.8821062314790764}


Testing: 100%|█████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  8.05it/s]


[2024-08-11 17:51:53] (0.0.8alpha) {'accuracy_score': 0.902897046333868, 'f1_score': 0.9044334792769698, 'matthews_corrcoef': 0.8503789642989459}
{'valid': [{'accuracy_score': 0.2790193842645382, 'f1_score': 0.28151975296578563, 'matthews_corrcoef': -0.09291127922709266}, {'accuracy_score': 0.8913340935005701, 'f1_score': 0.8935400779001638, 'matthews_corrcoef': 0.8353253240117546}], 'best_valid': {'accuracy_score': 0.9240592930444698, 'f1_score': 0.9255602479349917, 'matthews_corrcoef': 0.883211983456326}, 'test': [{'accuracy_score': 0.902897046333868, 'f1_score': 0.9044334792769698, 'matthews_corrcoef': 0.8503789642989459}]}

----------------------------------------------- Raw Metric Records -----------------------------------------------
╒══════════╤═════════════════╤══════════════════════╤═══════════╤══════════╤═══════╤═══════╤══════════╤══════════╕
│ Metric   │ Trial           │ Values               │  Average  │  Median  │  Std  │  IQR  │   Min    │   Max    │
╞══════════╪═══════

  self.skewness = stats.skew(self.data, keepdims=True)


### Step 7. Experimental Results Visualization
The experimental results are visualized in the following plots. The plots show the F1 score and accuracy of the model on the test set for each run. The average F1 score and accuracy are also shown.

|### Step 8. Model Checkpoint for Sharing
The model checkpoint can be saved and shared with others for further use. The model checkpoint can be loaded using the following code:

**Regular checkpointing and resuming are good practices to save the model at different stages of training.**

In [7]:
path_to_save = "OmniGenome-186M-SSP"
ssp_model.save(path_to_save, overwrite=True)

# Load the model checkpoint
ssp_model = ssp_model.load(path_to_save)
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

[2024-08-11 17:51:55] (0.0.8alpha) The model is saved to OmniGenome-186M-SSP.
['.', '(', '(', '(', '(', '(', '.', '.', '.', '.', '(', '(', '(', '.', '(', '.', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', '.', ')', '.', ')', '.', '.', '.', '.', ')', ')', ')', ')', ')', '.']
logits: tensor([[8.0241e-04, 6.8535e-04, 9.9851e-01],
        [1.8072e-03, 2.7458e-04, 9.9792e-01],
        [9.9968e-01, 1.4969e-04, 1.7153e-04],
        [9.9977e-01, 1.2595e-04, 1.0330e-04],
        [9.9973e-01, 1.5334e-04, 1.1417e-04],
        [9.9977e-01, 1.1016e-04, 1.1670e-04],
        [9.9974e-01, 1.4174e-04, 1.1885e-04],
        [1.6035e-04, 8.9402e-05, 9.9975e-01],
        [1.2057e-04, 1.2549e-04, 9.9975e-01],
        [1.0425e-04, 1.2844e-04, 9.9977e-01],
        [1.0099e-04, 1.1066e-04, 9.9979e-01],
        [9.9936e-01, 2.3561e-04, 4.0091e-04],
        [9.9964e-01, 1.5549e-04, 2.0940e-04],
        [9.9949e-01, 1.4136e-04, 3.7019e-04],
        [3.0048e-04, 1.4218e-04, 9.9956e-01],
        




# What if someone doesn't know how to initialize the model?

In [8]:
# We can load the model checkpoint using the ModelHub
from omnigenome import ModelHub

ssp_model = ModelHub.load("OmniGenome-186M-SSP")
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

[2024-08-11 17:52:00] (0.0.8alpha) Model Name: OmniGenomeModelForTokenClassification
Model Metadata: {'library_name': 'OmniGenome', 'omnigenome_version': '0.0.8alpha', 'torch_version': '2.1.2+cu12.1+gita8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb', 'transformers_version': '4.43.2', 'model_cls': 'OmniGenomeModelForTokenClassification', 'tokenizer_cls': 'OmniSingleNucleotideTokenizer', 'model_name': 'OmniGenomeModelForTokenClassification'}
Base Model Name: OmniGenome-186M-SSP
Model Type: mprna
Model Architecture: ['OmniGenomeModel']
Model Parameters: 185.886801 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "_name_or_path": "OmniGenome-186M-SSP",
  "architectures": [
    "OmniGenomeModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "auto_map": {
    "AutoConfig": "anonymous8/OmniGenome-186M--configuration_omnigenome.OmniGenomeConfig",
    "AutoModel": "anonymous8/OmniGenome-186M--modeling_omnigenome.OmniGenomeModel",
    "AutoModelForMaskedLM": "anonymous8/OmniGenome

## Step 8. Model Inference

In [9]:
examples = [
    "GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC",
    "AUCUGUACUAGUUAGCUAACUAGAUCUGUAUCUGGCGGUUCCGUGGAAGAACUGACGUGUUCAUAUUCCCGACCGCAGCCCUGGGAGACGUCUCAGAGGC",
]

results = ssp_model.inference(examples)
structures = ["".join(prediction) for prediction in results["predictions"]]
print(results)
print(structures)

{'predictions': [['(', '(', '(', '(', '(', '.', '(', '(', '(', '.', '(', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', '.', ')', ')', ')', '.', '.', '.', '.', '.', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', '.', ')', ')', ')', ')', ')'], ['.', '.', '.', '.', '.', '.', '.', '(', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', ')', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '(', '(', '(', '.', '.', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', '.', '.', '.', ')', ')', ')', '.', '.', ')', ')', ')', ')', ')', ')', '.', '(', '(', '(', '(', '(', '(', '(', '(', '(', '.', '.', '.', '.', ')', ')', ')', ')', ')', ')', '.', ')', ')', ')']], 'logits': tensor([[[2.4458e-04, 2.9453e-04, 9.9946e-01],
         [9.9968e-01, 1.8715e-04, 1.3058e-04],
         [9.9971e-01, 1.7857e-04, 1.1603e-04],
         [9.9969e-01, 1.9235e-04, 1.2

### Step 9. Pipeline Creation
The OmniGenome package provides pipelines for genomic FM development. The pipeline can be used to train, fine-tune, and evaluate genomic FMs. The pipeline can be used with a single command to train a genomic FM on a dataset. The pipeline can also be used to fine-tune a pre-trained genomic FM on a new dataset. The pipeline can be used to evaluate the performance of a genomic FM on a dataset. The pipeline can be used to generate predictions using a genomic FM.

In [10]:
# from omnigenome import Pipeline, PipelineHub
# 
# pipeline = Pipeline(
#     name="OmniGenome-186M-SSP-Pipeline",
#     # model_name_or_path="OmniGenome-186M-SSP",  # The model name or path can be specified
#     # tokenizer="OmniGenome-186M-SSP",  # The tokenizer can be specified
#     model_name_or_path=ssp_model,
#     tokenizer=ssp_model.tokenizer,
#     datasets={
#         "train": "toy_datasets/train.json",
#         "test": "toy_datasets/test.json",
#         "valid": "toy_datasets/valid.json",
#     },
#     trainer=trainer,
#     device=ssp_model.model.device,
# )

### Using the Pipeline

In [11]:
# results = pipeline(examples[0])
# print(results)
# 
# pipeline.train()
# 
# pipeline.save("OmniGenome-186M-SSP-Pipeline", overwrite=True)
# 
# pipeline = PipelineHub.load("OmniGenome-186M-SSP-Pipeline")
# results = pipeline(examples)
# print(results)

## Web Demo for RNA Secondary Structure Prediction

In [1]:
import numpy as np
import json
import gradio as gr
import RNA
from omnigenome import ModelHub
import os
print(os.listdir('.'))
ssp_model = ModelHub.load("OmniGenome-186M-SSP")

def ss_validity_loss(rna_strct):
    dotCount = 0
    leftCount = 0
    rightCount = 0
    unmatched_positions = []  # 用于记录未匹配括号的位置
    uncoherentCount = 0
    prev_char = ""
    for i, char in enumerate(rna_strct):
        if prev_char != char:
            uncoherentCount += 1
        prev_char = char

        if char == "(":
            leftCount += 1
            unmatched_positions.append(i)  # 记录左括号位置
        elif char == ")":
            if leftCount > 0:
                leftCount -= 1
                unmatched_positions.pop()  # 移除最近的左括号位置
            else:
                rightCount += 1
                unmatched_positions.append(i)  # 记录右括号位置
        elif char == ".":
            dotCount += 1
        else:
            raise ValueError(f"Invalid character {char} in RNA structure")
    match_loss = (leftCount + rightCount) / (len(rna_strct) - dotCount + 1e-5)
    return match_loss


def find_invalid_ss_positions(rna_strct):
    left_brackets = []  # 存储左括号的位置
    right_brackets = []  # 存储未匹配的右括号的位置
    for i, char in enumerate(rna_strct):
        if char == "(":
            left_brackets.append(i)
        elif char == ")":
            if left_brackets:
                left_brackets.pop()  # 找到匹配的左括号，从列表中移除
            else:
                right_brackets.append(i)  # 没有匹配的左括号，记录右括号的位置
    return left_brackets + right_brackets


def fold(rna_sequence):
    try:
        ref_struct = RNA.fold(rna_sequence)[0]
        RNA.svg_rna_plot(rna_sequence, ref_struct, f"real_structure.svg")
    
        pred_structure = "".join(ssp_model.inference(rna_sequence)["predictions"])
        print(pred_structure)
        if ss_validity_loss(pred_structure) == 0:
            RNA.svg_rna_plot(rna_sequence, pred_structure, f"predicted_structure.svg")
            pred_structure, _ = repair_rna_structure(
                rna_sequence, pred_structure
            )
            return (
                ref_struct,
                pred_structure,
                "real_structure.svg",
                "predicted_structure.svg",
            )
        else:
            # return blank image of predicted structure
            # generate a blank svg image
            with open("predicted_structure.svg", "w") as f:
                f.write(
                    '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
                )
            return (
                ref_struct,
                pred_structure,
                "real_structure.svg",
                "best_pred_struct.svg",
            )
    except Exception as e:
        with open("real_structure.svg", "w") as f:
            f.write(
                '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
            )
        with open("predicted_structure.svg", "w") as f:
            f.write(
                '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
            )
        return e, e, "real_structure.svg", "predicted_structure.svg"


def repair_rna_structure(rna_sequence, invalid_struct):
    try:
        invalid_ss_positions = find_invalid_ss_positions(invalid_struct)
        for pos_idx in invalid_ss_positions:
            if invalid_struct[pos_idx] == "(":
                invalid_struct = (
                    invalid_struct[:pos_idx] + "." + invalid_struct[pos_idx + 1 :]
                )
            else:
                invalid_struct = (
                    invalid_struct[:pos_idx] + "." + invalid_struct[pos_idx + 1 :]
                )

        best_pred_struct = invalid_struct
        RNA.svg_rna_plot(rna_sequence, best_pred_struct, f"best_pred_struct.svg")
        return best_pred_struct, "best_pred_struct.svg"
    except Exception as e:
        with open("best_pred_struct.svg", "w") as f:
            f.write(
                '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
            )
        return e, "best_pred_struct.svg"


def sample_rna_sequence():
    example = examples[np.random.randint(0, len(examples))]
    RNA.svg_rna_plot(example["seq"], example["label"], f"annotated_structure.svg")

    return example["seq"], example["label"], "annotated_structure.svg"


# 定义界面
with gr.Blocks() as demo:
    gr.Markdown("### RNA Secondary Structure Prediction")

    with gr.Row():
        with gr.Row():
            rna_input = gr.Textbox(
                label="RNA Sequence", placeholder="Enter RNA sequence here..."
            )
        with gr.Row():
            strcut_input = gr.Textbox(
                label="Annotated Secondary Structure",
                placeholder="Enter RNA secondary structure here...",
            )

    with gr.Row():
        #     examples = [
        #     ["GCGUCACACCGGUGAAGUCGCGCGUCACACCGGUGAAGUCGC"],
        #     ["GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGCGCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC"],
        #     ["GGCUGGUCCGAGUGCAGUGGUGUUUACAACUAAUUGAUCACAACCAGUUACAGAUUUCUUUGUUCCUUCUCCACUCCCACUGCUUCACUUGACUAGCCUU"],
        # ]
        #     gr.Examples(examples=examples, label="Examples", inputs=[rna_input])
        with open("toy_datasets/Archive2/test.json", "r") as f:
            examples = []
            for line in f:
                examples.append(json.loads(line))

        sample_button = gr.Button("Sample a RNA Sequence from RNAStrand2 testset")

    with gr.Row():
        submit_button = gr.Button("Run Prediction")

    with gr.Row():
        ref_structure_output = gr.Textbox(
            label="Secondary Structure by ViennaRNA", interactive=False
        )

    with gr.Row():
        pred_structure_output = gr.Textbox(
            label="Secondary Structure by Model", interactive=False
        )

    with gr.Row():
        anno_structure_output = gr.Image(
            label="Annotated Secondary Structure", show_share_button=True
        )
        real_image = gr.Image(
            label="Secondary Structure by ViennaRNA", show_share_button=True
        )
        predicted_image = gr.Image(
            label="Secondary Structure by Model", show_share_button=True
        )

    with gr.Row():
        repair_button = gr.Button("Run Prediction Repair")

    submit_button.click(
        fn=fold,
        inputs=rna_input,
        outputs=[
            ref_structure_output,
            pred_structure_output,
            real_image,
            predicted_image,
        ],
    )

    repair_button.click(
        fn=repair_rna_structure,
        inputs=[rna_input, pred_structure_output],
        outputs=[pred_structure_output, predicted_image],
    )

    sample_button.click(
        fn=sample_rna_sequence, outputs=[rna_input, strcut_input, anno_structure_output]
    )
demo.launch(share=True)

['.ipynb_checkpoints', 'annotated_structure.svg', 'auto_benchmark.py', 'benchmark', 'benchmarks_info.json', 'best_pred_struct.svg', 'easy_rna_design.py', 'eterna100_contrafold.txt', 'eterna100_vienna2.txt', 'eterna100_vienna2.txt.result', 'EternaV2_RNA_design_demo.py', 'mlm_augmentation.py', 'OmniGenome-186M-SSP', 'OmniGenome-186M-SSP-Pipeline', 'OmniGenome_RNA_design.ipynb', 'predicted_structure.svg', 'readme.md', 'real_structure.svg', 'rna_modeling_using_omnigenome.py', 'secondary_structure_prediction_demo.ipynb', 'ssp_inference.py', 'test.py', 'toy_datasets', 'true_struct.svg', 'zero_shot_secondary_structure_prediction.py']
[2024-08-14 22:47:21] (0.0.8alpha) Model Name: OmniGenomeModelForTokenClassification
Model Metadata: {'library_name': 'OmniGenome', 'omnigenome_version': '0.0.8alpha', 'torch_version': '2.1.2+cu12.1+gita8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb', 'transformers_version': '4.42.0.dev0', 'model_cls': 'OmniGenomeModelForTokenClassification', 'tokenizer_cls': 'OmniSingl

        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
        
  @app.on_event("startup")
        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
        
  return self.router.on_event(event_type)


Running on local URL:  http://127.0.0.1:7860
IMPORTANT: You are using gradio version 4.25.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://092094b2837cbc5f03.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


ConnectTimeout: _ssl.c:1112: The handshake operation timed out

### Conclusion
In this demonstration, we have shown how to fine-tune a genomic foundation model for RNA secondary structure prediction using the OmniGenome package. We have also shown how to use the trained model for inference and how to create a web demo for RNA secondary structure prediction. We hope this demonstration will help you get started with genomic foundation model development using OmniGenome.