# OmniGenome - A Demonstration based on RNA Secondary Structure Prediction
GitHub: https://github.com/yangheng95/OmniGenome
OmniGenome Hub: Huggingface Spaces

## Introduction
OmniGenome is a comprehensive package designed for pretrained genomic foundation models (FMs) development and FM benchmark. 
OmniGenome have the following key features:
- Automated genomic FM benchmarking on public genomic datasets
- Scalable genomic FM training and fine-tuning on genomic tasks
- Diversified genomic FMs implementation
- Easy-to-use pipeline for genomic FM development with no coding expertise required
- Accessible OmniGenome Hub for sharing FMs, datasets, and pipelines
- Extensive documentation and tutorials for genomic FM development

We begin to introduce OmniGenome by delivering a demonstration to train a model to predict RNA secondary structures. The dataset used in this demonstration is the bpRNA dataset which contains RNA sequences and their corresponding secondary structures. The secondary structure of an RNA sequence is a set of base pairs that describe the folding of the RNA molecule. The secondary structure of an RNA sequence is important for understanding the function of the RNA molecule. In this demonstration, we will train a model to predict the secondary structure of an RNA sequence given its primary sequence.

## Requirements
OmniGenome requires the following recommended dependencies:
- Python 3.9+
- PyTorch 2.0.0+
- Transformers 4.37.0+
- Pandas 1.3.3+
- Others in case of specific tasks

## Fine-tuning Genomic FMs for RNA Secondary Structure Prediction

### Step 1: Import Libraries

In [2]:
import autocuda
import torch
from metric_visualizer import MetricVisualizer

from omnigenome import OmniGenomeDatasetForTokenClassification
from omnigenome import ClassificationMetric
from omnigenome import OmniSingleNucleotideTokenizer, OmniKmersTokenizer
from omnigenome import OmniGenomeModelForTokenClassification
from omnigenome import Trainer


### Step 2: Define and Initialize the Tokenizer

In [3]:
# Predefined dataset label mapping
label2id = {"(": 0, ")": 1, ".": 2}

# The is FM is exclusively powered by the OmniGenome package
model_name_or_path = "anonymous8/OmniGenome-52M"

# Generally, we use the tokenizers from transformers library, such as AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# However, OmniGenome provides specialized tokenizers for genomic data, such as single nucleotide tokenizer and k-mers tokenizer
# we can force the tokenizer to be used in the model
tokenizer = OmniSingleNucleotideTokenizer.from_pretrained(model_name_or_path)

### Step 3: Define and Initialize the Model

In [4]:
# We have implemented a diverse set of genomic models in OmniGenome, please refer to the documentation for more details
ssp_model = OmniGenomeModelForTokenClassification(
    model_name_or_path,
    tokenizer=tokenizer,
    label2id=label2id,
)

Some weights of the model checkpoint at anonymous8/OmniGenome-52M were not used when initializing OmniGenomeModel: ['classifier.bias', 'classifier.weight', 'dense.bias', 'dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing OmniGenomeModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing OmniGenomeModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of OmniGenomeModel were not initialized from the model checkpoint at anonymous8/OmniGenome-52M and are newly initialized: ['OmniGenome.pooler.dense.bias', 'OmniGenome.pooler.dense.weight']
You

[2024-06-03 19:04:28] (0.0.6alpha) Model Name: OmniGenomeModelForTokenClassification
Model Metadata: {'library_name': 'OmniGenome', 'omnigenome_version': '0.0.6alpha', 'torch_version': '2.1.2+cu12.1+gita8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb', 'transformers_version': '4.41.0.dev0', 'model_cls': 'OmniGenomeModelForTokenClassification', 'tokenizer_cls': 'OmniSingleNucleotideTokenizer', 'model_name': 'OmniGenomeModelForTokenClassification'}
Base Model Name: anonymous8/OmniGenome-52M
Model Type: omnigenome
Model Architecture: ['OmniGenomeModel', 'OmniGenomeForTokenClassification', 'OmniGenomeForMaskedLM', 'OmniGenomeModelForSeq2SeqLM', 'OmniGenomeForTSequenceClassification', 'OmniGenomeForTokenClassification', 'OmniGenomeForSeq2SeqLM']
Model Parameters: 52.453345 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "_name_or_path": "anonymous8/OmniGenome-52M",
  "architectures": [
    "OmniGenomeModel",
    "OmniGenomeForTokenClassification",
    "OmniGenomeForMaskedLM",


### Step 4: Define and Load the Dataset

In [5]:
# necessary hyperparameters
epochs = 1
learning_rate = 2e-5
weight_decay = 1e-5
batch_size = 8
max_length = 512
seeds = [45]  # Each seed will be used for one run


# Load the dataset according to the path
train_file = "toy_datasets/train.json"
test_file = "toy_datasets/test.json"
valid_file = "toy_datasets/valid.json"

train_set = OmniGenomeDatasetForTokenClassification(
    data_source=train_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
test_set = OmniGenomeDatasetForTokenClassification(
    data_source=test_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
valid_set = OmniGenomeDatasetForTokenClassification(
    data_source=valid_file,
    tokenizer=tokenizer,
    label2id=label2id,
    max_length=max_length,
)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

[2024-06-03 19:04:28] (0.0.6alpha) Detected max_length=512 in the dataset, using it as the max_length.
[2024-06-03 19:04:28] (0.0.6alpha) Loading data from toy_datasets/train.json...
[2024-06-03 19:04:28] (0.0.6alpha) Loaded 2278 examples from toy_datasets/train.json


100%|██████████| 2278/2278 [00:00<00:00, 4146.42it/s]


[2024-06-03 19:04:29] (0.0.6alpha) {'avg': 138.3305531167691, 'max': 502, 'min': 32}
[2024-06-03 19:04:29] (0.0.6alpha) Preview of the first two samples in the dataset:
{'input_ids': tensor([0, 6, 6, 5, 4, 6, 5, 5, 6, 5, 4, 6, 6, 6, 6, 9, 5, 5, 5, 5, 9, 5, 6, 5,
        6, 4, 5, 6, 4, 4, 9, 9, 6, 5, 5, 6, 9, 6, 4, 4, 5, 5, 5, 5, 6, 5, 5, 4,
        6, 6, 5, 5, 5, 6, 6, 4, 4, 6, 6, 6, 4, 6, 5, 4, 4, 5, 6, 6, 9, 4, 6, 5,
        6, 6, 9, 6, 6, 4, 9, 9, 9, 6, 6, 6, 9, 6, 5, 5, 6, 4, 6, 6, 9, 6, 5, 6,
        6, 5, 9, 9, 9, 9, 6, 9, 6, 6, 5, 9, 6, 9, 5, 9, 9, 9, 9, 2, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 

100%|██████████| 285/285 [00:00<00:00, 5009.35it/s]


[2024-06-03 19:04:29] (0.0.6alpha) {'avg': 132.5157894736842, 'max': 495, 'min': 34}
[2024-06-03 19:04:29] (0.0.6alpha) Preview of the first two samples in the dataset:
{'input_ids': tensor([0, 6, 6, 6, 5, 5, 6, 6, 9, 6, 6, 5, 9, 5, 4, 6, 5, 5, 9, 6, 6, 9, 9, 4,
        6, 4, 6, 5, 6, 6, 5, 6, 6, 6, 5, 9, 5, 9, 9, 4, 4, 5, 5, 5, 6, 5, 4, 6,
        6, 9, 5, 5, 6, 6, 6, 6, 9, 9, 5, 6, 4, 4, 9, 5, 5, 5, 5, 6, 5, 5, 6, 6,
        5, 5, 5, 6, 5, 5, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 

100%|██████████| 285/285 [00:00<00:00, 5314.13it/s]


[2024-06-03 19:04:29] (0.0.6alpha) {'avg': 140.90877192982455, 'max': 495, 'min': 32}
[2024-06-03 19:04:29] (0.0.6alpha) Preview of the first two samples in the dataset:
{'input_ids': tensor([0, 5, 5, 9, 6, 6, 5, 6, 6, 5, 6, 4, 9, 4, 6, 9, 6, 5, 4, 6, 9, 6, 6, 9,
        5, 5, 5, 4, 5, 5, 9, 6, 4, 4, 9, 5, 5, 4, 9, 6, 5, 5, 6, 4, 4, 5, 9, 5,
        4, 6, 4, 4, 6, 9, 6, 4, 4, 4, 5, 6, 5, 9, 6, 9, 9, 4, 9, 6, 5, 5, 6, 4,
        9, 6, 6, 9, 4, 6, 9, 6, 9, 6, 6, 6, 6, 9, 4, 9, 5, 5, 5, 5, 4, 9, 6, 9,
        6, 4, 6, 4, 6, 9, 4, 6, 6, 9, 5, 4, 5, 5, 6, 5, 5, 4, 6, 6, 2, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1,

### Step 5: Define the Metrics
We have implemented a diverse set of genomic metrics in OmniGenome, please refer to the documentation for more details.
Users can also define their own metrics by inheriting the `OmniGenomeMetric` class. 
The `compute_metrics` can be a metric function list and each metric function should return a dictionary of metrics.

In [6]:
compute_metrics = [
    ClassificationMetric(ignore_y=-100).accuracy_score,
    ClassificationMetric(ignore_y=-100, average="macro").f1_score,
]

## Step 6: Define and Initialize the Trainer

In [7]:
# Initialize the MetricVisualizer for logging the metrics
mv = MetricVisualizer(name="OmniGenome-52M-SSP")

for seed in seeds:
    optimizer = torch.optim.AdamW(
        ssp_model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    trainer = Trainer(
        model=ssp_model,
        train_loader=train_loader,
        eval_loader=valid_loader,
        test_loader=test_loader,
        batch_size=batch_size,
        epochs=epochs,
        optimizer=optimizer,
        compute_metrics=compute_metrics,
        seeds=seed,
        device=autocuda.auto_cuda(),
    )

    metrics = trainer.train()
    test_metrics = metrics["test"][-1]
    mv.log(model_name_or_path.split("/")[-1], "F1", test_metrics["f1_score"])
    mv.log(
        model_name_or_path.split("/")[-1],
        "Accuracy",
        test_metrics["accuracy_score"],
    )
    print(metrics)
    mv.summary()

Evaluating: 100%|██████████| 36/36 [00:03<00:00,  9.77it/s]


[2024-06-03 19:04:33] (0.0.6alpha) {'accuracy_score': 0.2781833337543257, 'f1_score': 0.27604302848700996}


Epoch 1/1 Loss: 0.6855: 100%|██████████| 285/285 [01:14<00:00,  3.83it/s]
Evaluating: 100%|██████████| 36/36 [00:03<00:00, 10.57it/s]


[2024-06-03 19:05:51] (0.0.6alpha) {'accuracy_score': 0.9250044204198136, 'f1_score': 0.9265131989308374}


Testing: 100%|██████████| 36/36 [00:03<00:00, 10.60it/s]

[2024-06-03 19:05:55] (0.0.6alpha) {'accuracy_score': 0.9209344839637605, 'f1_score': 0.9225768824097732}
{'valid': [{'accuracy_score': 0.2781833337543257, 'f1_score': 0.27604302848700996}, {'accuracy_score': 0.9250044204198136, 'f1_score': 0.9265131989308374}], 'best_valid': {'accuracy_score': 0.9250044204198136, 'f1_score': 0.9265131989308374}, 'test': [{'accuracy_score': 0.9209344839637605, 'f1_score': 0.9225768824097732}]}

---------------------------------------------- Raw Metric Records ----------------------------------------------
╒══════════╤════════════════╤══════════════════════╤═══════════╤══════════╤═══════╤═══════╤══════════╤══════════╕
│ Metric   │ Trial          │ Values               │  Average  │  Median  │  Std  │  IQR  │   Min    │   Max    │
╞══════════╪════════════════╪══════════════════════╪═══════════╪══════════╪═══════╪═══════╪══════════╪══════════╡
│ F1       │ OmniGenome-52M │ [0.9225768824097732] │ 0.922577  │ 0.922577 │   0   │   0   │ 0.922577 │ 0.922577 │


  self.skewness = stats.skew(self.data, keepdims=True)


### Step 7. Experimental Results Visualization
The experimental results are visualized in the following plots. The plots show the F1 score and accuracy of the model on the test set for each run. The average F1 score and accuracy are also shown.

|### Step 8. Model Checkpoint for Sharing
The model checkpoint can be saved and shared with others for further use. The model checkpoint can be loaded using the following code:

**Regular checkpointing and resuming are good practices to save the model at different stages of training.**

In [8]:
path_to_save = "OmniGenome-52M-SSP"
ssp_model.save(path_to_save, overwrite=True)

# Load the model checkpoint
ssp_model = ssp_model.load(path_to_save)
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

[2024-06-03 19:05:56] (0.0.6alpha) The model is saved to OmniGenome-52M-SSP.
['.', '(', '(', '(', '(', '(', '.', '.', '.', '.', '(', '(', '(', '.', '(', '.', '.', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', '.', '.', ')', ')', ')', '.', '.', '.', ')', ')', ')', ')', ')', '.']
logits: tensor([[1.2754e-02, 8.4659e-03, 9.7878e-01],
        [8.2270e-04, 1.9768e-04, 9.9898e-01],
        [9.9559e-01, 4.2518e-04, 3.9845e-03],
        [9.9918e-01, 3.8828e-04, 4.2924e-04],
        [9.9916e-01, 3.7811e-04, 4.5945e-04],
        [9.9916e-01, 3.3816e-04, 4.9766e-04],
        [9.9905e-01, 3.7165e-04, 5.8111e-04],
        [2.4212e-03, 3.5749e-04, 9.9722e-01],
        [8.2577e-04, 2.3750e-04, 9.9894e-01],
        [2.7535e-04, 1.8864e-04, 9.9954e-01],
        [6.5934e-04, 2.1868e-04, 9.9912e-01],
        [9.9584e-01, 7.8992e-04, 3.3739e-03],
        [9.9885e-01, 3.4025e-04, 8.1102e-04],
        [9.9390e-01, 3.9648e-04, 5.7052e-03],
        [1.2942e-03, 4.4689e-04, 9.9826e-01],
        [




# What if someone doesn't know how to initialize the model?

In [9]:
# We can load the model checkpoint using the ModelHub
from omnigenome import ModelHub

ssp_model = ModelHub.load("OmniGenome-52M-SSP")
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

[2024-06-03 19:05:56] (0.0.6alpha) Model Name: OmniGenomeModelForTokenClassification
Model Metadata: {'library_name': 'OmniGenome', 'omnigenome_version': '0.0.6alpha', 'torch_version': '2.1.2+cu12.1+gita8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb', 'transformers_version': '4.41.0.dev0', 'model_cls': 'OmniGenomeModelForTokenClassification', 'tokenizer_cls': 'OmniSingleNucleotideTokenizer', 'model_name': 'OmniGenomeModelForTokenClassification'}
Base Model Name: OmniGenome-52M-SSP
Model Type: omnigenome
Model Architecture: ['OmniGenomeModel']
Model Parameters: 52.453345 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "_name_or_path": "OmniGenome-52M-SSP",
  "architectures": [
    "OmniGenomeModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0,
  "hidden_size": 480,
  "id2label": {
    "0": "(",
    "1": ")",
    "2": "."
  },
  "initializer_range": 0.02,
  "in

## Step 8. Model Inference

In [10]:
examples = [
    "GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC",
    "AUCUGUACUAGUUAGCUAACUAGAUCUGUAUCUGGCGGUUCCGUGGAAGAACUGACGUGUUCAUAUUCCCGACCGCAGCCCUGGGAGACGUCUCAGAGGC",
]

results = ssp_model.inference(examples)
structures = ["".join(prediction) for prediction in results["predictions"]]
print(results)
print(structures)

{'predictions': [['(', '(', '(', '(', '(', '.', '.', '.', '(', '(', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', ')', '.', '.', ')', ')', '.', '.', '.', '.', '.', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', '.', ')', ')', ')', ')', ')'], ['.', '(', '.', '.', '(', '.', '.', '(', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', ')', '.', '.', '.', '.', '.', '.', '.', '.', '(', '(', '.', '(', '(', '(', '(', '.', '.', '.', '(', '(', '.', '(', '(', '(', '.', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', '.', '.', '.', ')', '.', '.', ')', ')', ')', '.', '.', '.', ')', ')', '.', ')', '.', '.', '(', '(', '(', ')', '(', '(', '.', '.', '.', '.', ')', ')', ')', ')', ')', ')', '.', '.', '.', '.']], 'logits': tensor([[[1.1020e-02, 7.6820e-03, 9.8130e-01],
         [9.9856e-01, 4.4648e-04, 9.9442e-04],
         [9.9907e-01, 4.1033e-04, 5.2263e-04],
         [9.9866e-01, 4.9436e-04, 8.4

### Step 9. Pipeline Creation
The OmniGenome package provides pipelines for genomic FM development. The pipeline can be used to train, fine-tune, and evaluate genomic FMs. The pipeline can be used with a single command to train a genomic FM on a dataset. The pipeline can also be used to fine-tune a pre-trained genomic FM on a new dataset. The pipeline can be used to evaluate the performance of a genomic FM on a dataset. The pipeline can be used to generate predictions using a genomic FM.

In [11]:
from omnigenome import Pipeline, PipelineHub

pipeline = Pipeline(
    name="OmniGenome-52M-SSP-Pipeline",
    # model_name_or_path="OmniGenome-52M-SSP",  # The model name or path can be specified
    # tokenizer="OmniGenome-52M-SSP",  # The tokenizer can be specified
    model_name_or_path=ssp_model,
    tokenizer=ssp_model.tokenizer,
    datasets={
        "train": "toy_datasets/train.json",
        "test": "toy_datasets/test.json",
        "valid": "toy_datasets/valid.json",
    },
    trainer=trainer,
    device=ssp_model.model.device,
)

### Using the Pipeline

In [None]:
results = pipeline(examples[0])
print(results)

pipeline.train()

pipeline.save("OmniGenome-52M-SSP-Pipeline", overwrite=True)

pipeline = PipelineHub.load("OmniGenome-52M-SSP-Pipeline")
results = pipeline(examples)
print(results)

{'predictions': ['(', '(', '(', '(', '(', '.', '.', '.', '(', '(', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', ')', '.', '.', ')', ')', '.', '.', '.', '.', '.', '(', '(', '(', '(', '.', '.', '.', '.', '.', '.', '.', '.', ')', ')', ')', ')', '.', ')', ')', ')', ')', ')'], 'logits': tensor([[1.1020e-02, 7.6820e-03, 9.8130e-01],
        [9.9856e-01, 4.4648e-04, 9.9442e-04],
        [9.9907e-01, 4.1033e-04, 5.2263e-04],
        [9.9866e-01, 4.9436e-04, 8.4987e-04],
        [9.9889e-01, 3.5461e-04, 7.5468e-04],
        [9.9705e-01, 1.0346e-03, 1.9177e-03],
        [2.2339e-02, 4.6507e-04, 9.7720e-01],
        [7.1744e-02, 1.7932e-03, 9.2646e-01],
        [2.3055e-02, 1.0602e-03, 9.7588e-01],
        [9.7184e-01, 6.7886e-03, 2.1375e-02],
        [9.8951e-01, 3.5503e-03, 6.9399e-03],
        [9.9309e-01, 1.5381e-03, 5.3732e-03],
        [9.9490e-01, 1.7584e-03, 3.3442e-03],
        [9.9707e-01, 1.8680e-03, 1.0640e-03],
        [9.8953e-01, 2.1098e-03, 8.364

Evaluating: 100%|██████████| 36/36 [00:03<00:00, 10.61it/s]


[2024-06-03 19:06:00] (0.0.6alpha) {'accuracy_score': 0.9250549395033975, 'f1_score': 0.9265583277708882}


Epoch 1/1 Loss: 0.6112: 100%|██████████| 285/285 [01:14<00:00,  3.82it/s]
Evaluating: 100%|██████████| 36/36 [00:03<00:00, 10.52it/s]


[2024-06-03 19:07:18] (0.0.6alpha) {'accuracy_score': 0.9386950920710299, 'f1_score': 0.9400390719527772}


Testing: 100%|██████████| 36/36 [00:03<00:00, 10.57it/s]


[2024-06-03 19:07:22] (0.0.6alpha) {'accuracy_score': 0.9339462859908058, 'f1_score': 0.9353768554174792}
[2024-06-03 19:07:25] (0.0.6alpha) The model is saved to OmniGenome-52M-SSP-Pipeline.


## Web Demo for RNA Secondary Structure Prediction

In [2]:
import numpy as np
import json
import gradio as gr
import RNA


def ss_validity_loss(rna_strct):
    dotCount = 0
    leftCount = 0
    rightCount = 0
    unmatched_positions = []  # 用于记录未匹配括号的位置
    uncoherentCount = 0
    prev_char = ""
    for i, char in enumerate(rna_strct):
        if prev_char != char:
            uncoherentCount += 1
        prev_char = char

        if char == "(":
            leftCount += 1
            unmatched_positions.append(i)  # 记录左括号位置
        elif char == ")":
            if leftCount > 0:
                leftCount -= 1
                unmatched_positions.pop()  # 移除最近的左括号位置
            else:
                rightCount += 1
                unmatched_positions.append(i)  # 记录右括号位置
        elif char == ".":
            dotCount += 1
        else:
            raise ValueError(f"Invalid character {char} in RNA structure")
    match_loss = (leftCount + rightCount) / (len(rna_strct) - dotCount + 1e-5)
    return match_loss


def find_invalid_ss_positions(rna_strct):
    left_brackets = []  # 存储左括号的位置
    right_brackets = []  # 存储未匹配的右括号的位置
    for i, char in enumerate(rna_strct):
        if char == "(":
            left_brackets.append(i)
        elif char == ")":
            if left_brackets:
                left_brackets.pop()  # 找到匹配的左括号，从列表中移除
            else:
                right_brackets.append(i)  # 没有匹配的左括号，记录右括号的位置
    return left_brackets + right_brackets


def fold(rna_sequence):
    ref_struct = RNA.fold(rna_sequence)[0]
    RNA.svg_rna_plot(rna_sequence, ref_struct, f"real_structure.svg")

    pred_structure = "".join(ssp_model.inference(rna_sequence)["predictions"])
    print(pred_structure)
    if ss_validity_loss(pred_structure) == 0:
        RNA.svg_rna_plot(rna_sequence, pred_structure, f"predicted_structure.svg")
        return (
            ref_struct,
            pred_structure,
            "real_structure.svg",
            "predicted_structure.svg",
        )
    else:
        # return blank image of predicted structure
        # generate a blank svg image
        with open("predicted_structure.svg", "w") as f:
            f.write(
                '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
            )
        return (
            ref_struct,
            pred_structure,
            "real_structure.svg",
            "predicted_structure.svg",
        )


def repair_rna_structure(rna_sequence, invalid_struct):
    try:
        invalid_ss_positions = find_invalid_ss_positions(invalid_struct)
        for pos_idx in invalid_ss_positions:
            if invalid_struct[pos_idx] == "(":
                invalid_struct = (
                    invalid_struct[:pos_idx] + "." + invalid_struct[pos_idx + 1 :]
                )
            else:
                invalid_struct = (
                    invalid_struct[:pos_idx] + "." + invalid_struct[pos_idx + 1 :]
                )

        best_pred_struct = invalid_struct
        RNA.svg_rna_plot(rna_sequence, best_pred_struct, f"best_pred_struct.svg")
        return best_pred_struct, "best_pred_struct.svg"
    except Exception as e:
        with open("best_pred_struct.svg", "w") as f:
            f.write(
                '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100"></svg>'
            )
        return e, "best_pred_struct.svg"


def sample_rna_sequence():
    example = examples[np.random.randint(0, len(examples))]
    RNA.svg_rna_plot(example["seq"], example["label"], f"annotated_structure.svg")

    return example["seq"], example["label"], "annotated_structure.svg"


# 定义界面
with gr.Blocks() as demo:
    gr.Markdown("### RNA Secondary Structure Prediction")

    with gr.Row():
        with gr.Row():
            rna_input = gr.Textbox(
                label="RNA Sequence", placeholder="Enter RNA sequence here..."
            )
        with gr.Row():
            strcut_input = gr.Textbox(
                label="Annotated Secondary Structure",
                placeholder="Enter RNA secondary structure here...",
            )

    with gr.Row():
        #     examples = [
        #     ["GCGUCACACCGGUGAAGUCGCGCGUCACACCGGUGAAGUCGC"],
        #     ["GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGCGCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC"],
        #     ["GGCUGGUCCGAGUGCAGUGGUGUUUACAACUAAUUGAUCACAACCAGUUACAGAUUUCUUUGUUCCUUCUCCACUCCCACUGCUUCACUUGACUAGCCUU"],
        # ]
        #     gr.Examples(examples=examples, label="Examples", inputs=[rna_input])
        with open("toy_datasets/test.json", "r") as f:
            examples = []
            for line in f:
                examples.append(json.loads(line))

        sample_button = gr.Button("Sample a RNA Sequence from RNAStrand2 testset")

    with gr.Row():
        submit_button = gr.Button("Run Prediction")

    with gr.Row():
        ref_structure_output = gr.Textbox(
            label="Secondary Structure by ViennaRNA", interactive=False
        )

    with gr.Row():
        pred_structure_output = gr.Textbox(
            label="Secondary Structure by Model", interactive=False
        )

    with gr.Row():
        anno_structure_output = gr.Image(
            label="Annotated Secondary Structure", show_share_button=True
        )
        real_image = gr.Image(
            label="Secondary Structure by ViennaRNA", show_share_button=True
        )
        predicted_image = gr.Image(
            label="Secondary Structure by Model", show_share_button=True
        )

    with gr.Row():
        repair_button = gr.Button("Run Prediction Repair")

    submit_button.click(
        fn=fold,
        inputs=rna_input,
        outputs=[
            ref_structure_output,
            pred_structure_output,
            real_image,
            predicted_image,
        ],
    )

    repair_button.click(
        fn=repair_rna_structure,
        inputs=[rna_input, pred_structure_output],
        outputs=[pred_structure_output, predicted_image],
    )

    sample_button.click(
        fn=sample_rna_sequence, outputs=[rna_input, strcut_input, anno_structure_output]
    )
demo.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




### Conclusion
In this demonstration, we have shown how to fine-tune a genomic foundation model for RNA secondary structure prediction using the OmniGenome package. We have also shown how to use the trained model for inference and how to create a web demo for RNA secondary structure prediction. We hope this demonstration will help you get started with genomic foundation model development using OmniGenome.