# üß¨ RNA SSP Tutorial 4/4: Model Inference - Predicting on New Sequences

Welcome to the final tutorial in our series! We've come a long way:
1. Prepared RNA structure data ([01_data_preparation.ipynb](01_data_preparation.ipynb))
2. Initialized the OmniGenome model ([02_model_initialization.ipynb](02_model_initialization.ipynb))
3. Trained the model on bpRNA dataset ([03_model_training.ipynb](03_model_training.ipynb))

Now it's time to put our trained model to work! This tutorial covers:
1. **Loading Trained Models**: How to load saved models from disk or hub
2. **Making Predictions**: Running inference on new RNA sequences
3. **Structure Validation**: Ensuring predictions are biologically valid
4. **Results Interpretation**: Understanding and visualizing output
5. **Batch Processing**: Efficiently predicting multiple sequences

By the end, you'll be able to predict RNA secondary structures for any sequence!

### 1. Understanding Inference

**Inference** (also called prediction or evaluation) is the process of using a trained model to make predictions on new, unseen data. Unlike training:
- No gradient computation (faster, less memory)
- No weight updates (model is frozen)
- Focus on prediction quality and speed

```mermaid
graph LR
    A["New RNA Sequence<br/>AUGCCGUGC"] --> B["Tokenization<br/>[CLS] A U G C... [SEP]"]
    B --> C["Trained Model<br/>OmniGenome-52M"]
    C --> D["Logits<br/>Per-token scores"]
    D --> E["Predictions<br/>.(((...))). "]
    E --> F["Validation<br/>Check brackets balance"]
    
    style A fill:#e1f5fe
    style C fill:#f3e5f5
    style E fill:#e8f5e8
    style F fill:#fff3e0
```

## üõ†Ô∏è Practical Implementation

Let's run inference on new RNA sequences!

### Step 1: Environment Setup

In [None]:
# Install if needed
# !pip install omnigenbench -U

In [None]:
import torch
from omnigenbench import ModelHub

print("‚úÖ Libraries imported successfully!")
print(f"üî• PyTorch version: {torch.__version__}")
print(f"üéØ CUDA available: {torch.cuda.is_available()}")

### Step 2: Configuration

In [None]:
# Model path (local or hub)
# Option 1: Use locally trained model
model_path = "ogb_rna_structure_finetuned"

# Option 2: Use pre-trained model from hub (if available)
# model_path = "yangheng/ogb_rna_structure_finetuned"

# Label mapping (must match training)
label2id = {"(": 0, ")": 1, ".": 2}
id2label = {v: k for k, v in label2id.items()}

print(f"‚úÖ Configuration complete!")
print(f"üìä Model path: {model_path}")
print(f"üìä Label mapping: {label2id}")

### Step 3: Load Trained Model

OmniGenBench's `ModelHub` makes loading models effortless - it handles all the complexity behind the scenes.

In [None]:
print("üîÑ Loading trained model...")

# Load model using ModelHub
model = ModelHub.load(model_path)

# Set to evaluation mode
model.eval()

print(f"‚úÖ Model loaded successfully!")
print(f"üìä Model device: {model.device}")
print(f"üìä Model type: {type(model).__name__}")

### Step 4: Prepare Sample Sequences

Let's create a diverse set of RNA sequences to test our model:
- Simple hairpin structures
- Complex multi-stem structures  
- Real biological sequences (if available)

In [None]:
# Sample RNA sequences with varying complexity
sample_sequences = {
    "Simple Hairpin": "GCUGGGAUGUUGGCUUAGAAGCAGCCAUCAUUUAAAGAGUGCGUAACAGCUCACCAGC",
    "Complex Structure": "AUCUGUACUAGUUAGCUAACUAGAUCUGUAUCUGGCGGUUCCGUGGAAGAACUGACGU",
    "Short Sequence": "AUGCCGUGCAUUAA",
    "GC-Rich": "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
    "AU-Rich": "AUAUAUAUAUAUAUAUAUAUAUAUAUAUAU",
}

print(f"üìù Prepared {len(sample_sequences)} test sequences:")
for name, seq in sample_sequences.items():
    print(f"  - {name}: {len(seq)} nucleotides")

### Step 5: Single Sequence Prediction

Let's start with predicting structure for a single sequence.

In [None]:
# Select a test sequence
test_name = "Simple Hairpin"
test_sequence = sample_sequences[test_name]

print(f"üîÆ Predicting structure for: {test_name}")
print(f"üìù Sequence: {test_sequence}")
print(f"üìè Length: {len(test_sequence)} nucleotides\n")

# Run inference
with torch.no_grad():
    outputs = model.inference(test_sequence)

# Extract predictions
predictions = outputs.get('predictions', None)

if predictions is not None:
    # Convert predictions to structure notation
    predicted_structure = "".join([id2label[pred] for pred in predictions])
    
    print("üìä Prediction Results:")
    print("=" * 70)
    print(f"Sequence:  {test_sequence}")
    print(f"Structure: {predicted_structure}")
    print("=" * 70)
    
    # Count structure elements
    num_left = predicted_structure.count('(')
    num_right = predicted_structure.count(')')
    num_unpaired = predicted_structure.count('.')
    
    print(f"\nüìà Structure Statistics:")
    print(f"  - Opening brackets '(': {num_left}")
    print(f"  - Closing brackets ')': {num_right}")
    print(f"  - Unpaired '.': {num_unpaired}")
    print(f"  - Balanced: {'‚úÖ Yes' if num_left == num_right else '‚ùå No'}")
else:
    print("‚ùå No predictions found in output")

### Step 6: Batch Prediction

Process multiple sequences efficiently.

In [None]:
print("üîÆ Running batch predictions...\n")

# Store results
results = {}

# Predict for all sample sequences
with torch.no_grad():
    for name, sequence in sample_sequences.items():
        outputs = model.inference(sequence)
        predictions = outputs.get('predictions', None)
        
        if predictions is not None:
            structure = "".join([id2label[pred] for pred in predictions])
            results[name] = {
                'sequence': sequence,
                'structure': structure,
                'length': len(sequence),
                'num_paired': structure.count('(') + structure.count(')'),
                'num_unpaired': structure.count('.'),
                'balanced': structure.count('(') == structure.count(')')
            }

# Display results
print("üìä Batch Prediction Results:")
print("=" * 80)
for name, result in results.items():
    print(f"\nüìå {name}:")
    print(f"  Sequence:  {result['sequence']}")
    print(f"  Structure: {result['structure']}")
    print(f"  Length: {result['length']} nt | Paired: {result['num_paired']} | Unpaired: {result['num_unpaired']} | Balanced: {'‚úÖ' if result['balanced'] else '‚ùå'}")
print("=" * 80)

### Step 7: Structure Validation

Validate that predicted structures are biologically plausible.

In [None]:
def validate_structure(structure):
    """
    Validate RNA secondary structure.
    
    Checks:
    1. Brackets are balanced
    2. Brackets are properly nested (no pseudoknots in simple notation)
    3. Contains only valid characters
    """
    # Check valid characters
    valid_chars = set('().')
    if not set(structure).issubset(valid_chars):
        return False, "Invalid characters found"
    
    # Check balance
    if structure.count('(') != structure.count(')'):
        return False, "Unbalanced brackets"
    
    # Check proper nesting
    stack = []
    for char in structure:
        if char == '(':
            stack.append(char)
        elif char == ')':
            if not stack:
                return False, "Closing bracket without opening"
            stack.pop()
    
    if stack:
        return False, "Unclosed opening brackets"
    
    return True, "Valid structure"

# Validate all predictions
print("‚úÖ Structure Validation Results:\n")
for name, result in results.items():
    is_valid, message = validate_structure(result['structure'])
    status = "‚úÖ" if is_valid else "‚ùå"
    print(f"{status} {name}: {message}")

### Step 8: Simple Visualization

Create a simple text-based visualization of the structure.

In [None]:
def visualize_structure(sequence, structure, name=""):
    """
    Create a simple text visualization of RNA structure.
    """
    print(f"\n{'='*70}")
    if name:
        print(f"üìä {name}")
    print(f"{'='*70}")
    
    # Print in chunks for readability
    chunk_size = 60
    for i in range(0, len(sequence), chunk_size):
        seq_chunk = sequence[i:i+chunk_size]
        struct_chunk = structure[i:i+chunk_size]
        
        print(f"\nPosition {i+1}-{i+len(seq_chunk)}:")
        print(f"  5' {seq_chunk} 3'")
        print(f"     {struct_chunk}")
    
    print(f"\n{'='*70}")

# Visualize one example
example_name = "Simple Hairpin"
if example_name in results:
    result = results[example_name]
    visualize_structure(result['sequence'], result['structure'], example_name)

### Step 9: Interactive Prediction

Try predicting structure for your own custom sequence!

In [None]:
# Define your custom RNA sequence here
custom_sequence = "GGGGCCCAUUUUGGGCC"  # Replace with your sequence

print(f"üîÆ Predicting structure for custom sequence...\n")
print(f"Input: {custom_sequence}")
print(f"Length: {len(custom_sequence)} nucleotides\n")

# Validate input
valid_bases = set('AUGC')
if not set(custom_sequence.upper()).issubset(valid_bases):
    print("‚ùå Error: Sequence contains invalid characters. Use only A, U, G, C")
else:
    # Run prediction
    with torch.no_grad():
        outputs = model.inference(custom_sequence)
    
    predictions = outputs.get('predictions', None)
    if predictions is not None:
        structure = "".join([id2label[pred] for pred in predictions])
        
        # Validate and display
        is_valid, message = validate_structure(structure)
        
        visualize_structure(custom_sequence, structure, "Custom Sequence")
        
        print(f"\nValidation: {message} {'‚úÖ' if is_valid else '‚ùå'}")
    else:
        print("‚ùå Prediction failed")

### Step 10: Performance Analysis (Optional)

Measure inference speed and efficiency.

In [None]:
import time

# Performance test
test_sequence = sample_sequences["Simple Hairpin"]
num_iterations = 100

print(f"‚ö° Performance Test ({num_iterations} iterations)...")

# Warm-up
with torch.no_grad():
    _ = model.inference(test_sequence)

# Time the inference
start_time = time.time()
with torch.no_grad():
    for _ in range(num_iterations):
        _ = model.inference(test_sequence)
end_time = time.time()

# Calculate metrics
total_time = end_time - start_time
avg_time = total_time / num_iterations
throughput = num_iterations / total_time

print(f"\nüìä Performance Metrics:")
print(f"  - Total time: {total_time:.2f} seconds")
print(f"  - Average time per sequence: {avg_time*1000:.2f} ms")
print(f"  - Throughput: {throughput:.2f} sequences/second")
print(f"  - Sequence length: {len(test_sequence)} nucleotides")

## üí° Inference Tips and Best Practices

### Improving Prediction Quality
1. **Ensemble Predictions**: Average predictions from multiple models
2. **Post-processing**: Apply structure constraints (e.g., minimum stem length)
3. **Confidence Thresholding**: Filter low-confidence predictions

### Improving Speed
1. **Batch Processing**: Process multiple sequences together
2. **GPU Utilization**: Ensure model is on GPU for faster inference
3. **Mixed Precision**: Use FP16 for faster computation
4. **ONNX Export**: Convert to ONNX for optimized inference

### Common Issues

| Issue | Solution |
|-------|----------|
| **Unbalanced brackets** | Apply post-processing to balance |
| **Slow inference** | Use batch processing or GPU |
| **Memory errors** | Reduce sequence length or batch size |
| **Inconsistent predictions** | Use ensemble or increase training data |

## üöÄ Advanced Topics

### A. Confidence Scores
```python
# Get prediction probabilities
logits = outputs['logits']
probs = torch.softmax(logits, dim=-1)
confidence = torch.max(probs, dim=-1)[0]
```

### B. Ensemble Predictions
```python
# Load multiple models and average predictions
models = [ModelHub.load(path) for path in model_paths]
predictions = [model.inference(seq) for model in models]
# Average or vote on predictions
```

### C. Structure Constraints
```python
# Apply biological constraints
def apply_constraints(structure):
    # Minimum stem length: 3 base pairs
    # Maximum loop size: 30 nucleotides
    # etc.
    pass
```

## üìö Summary and Congratulations!

üéâ **Congratulations!** You've completed the entire RNA Secondary Structure Prediction tutorial series!

### What We've Accomplished in This Tutorial
1. ‚úÖ Loaded a trained model from disk
2. ‚úÖ Made predictions on single sequences
3. ‚úÖ Performed batch predictions efficiently
4. ‚úÖ Validated structure predictions
5. ‚úÖ Visualized results
6. ‚úÖ Analyzed inference performance

### Complete Journey (4 Tutorials)
```python
# Tutorial 1: Data Preparation
datasets = OmniDatasetForTokenClassification.from_hub(...)

# Tutorial 2: Model Initialization
model = OmniModelForTokenClassification(...)

# Tutorial 3: Training
trainer = AccelerateTrainer(...)
trainer.train()

# Tutorial 4: Inference (this tutorial)
model = ModelHub.load("trained_model")
predictions = model.inference(sequence)
```

### Key Takeaways
- **Inference is fast**: No gradient computation needed
- **Validation is important**: Check structure validity
- **Batch processing**: More efficient for multiple sequences
- **Post-processing**: Can improve prediction quality

### Next Steps
1. **Apply to your data**: Use the model on your RNA sequences
2. **Experiment with parameters**: Try different configurations
3. **Explore advanced features**: Ensemble, constraints, etc.
4. **Share your results**: Contribute back to the community!

### Resources
- üìö [OmniGenBench Documentation](../../docs/GETTING_STARTED.md)
- üî¨ [Other Examples](../../examples/)
- üí¨ [GitHub Issues](https://github.com/yangheng95/OmniGenBench/issues)
- üìß [Contact](mailto:hy345@exeter.ac.uk)

Thank you for following this tutorial series! We hope you found it helpful. Happy predicting! üß¨üöÄ