# mRNA Expression Optimization Tool

This notebook optimizes a protein-coding nucleic acid sequence for host-specific enhanced expression and facilitated cloning/synthesis.

## 1. Setup and Initialization

In [None]:
# Basic imports
import os
import sys
import warnings
import re

warnings.filterwarnings('ignore')

# Import project modules
from src.optimization import genetic_algorithm, FitnessWeights
from src.pre_optimization import optimize_codons
from src.utils import (
    validate_sequence, translate_dna_to_protein, back_translate_protein,
    format_sequence_comparison, calculate_gc_content,
    calculate_cai, load_codon_table, HOST_TARGET_GC
)
from src.gceh_module import (
    gceh_anal,
    plot_codon_usage_compare,
    plot_gc3_compare,
    plot_gc3_sliding_compare,
    plot_gc_sliding_compare
)

# Import widgets
from ipywidgets import (
    Text, Textarea, Button, FloatSlider, IntSlider, Dropdown, 
    VBox, HBox, Label, HTML, Output, Layout, Tab, Accordion,
    FloatProgress, IntProgress, IntText, Checkbox
)
from IPython.display import display, clear_output

# Import plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Import codon tables
codon_table_folder = 'data/codon_tables'
host_options = [
    os.path.splitext(f)[0]
    for f in os.listdir(codon_table_folder)
    if f.endswith('.csv')
]

## 2. Helper Functions

In [None]:
def validate_protein_sequence(sequence):
    '''Validate protein sequence and return cleaned version or error.'''
    # Remove spaces and line breaks
    sequence = re.sub(r'\s+', '', sequence.upper())
    
    if not sequence:
        return None, 'Sequence is empty'
    
    # Valid amino acids
    valid_aas = set('ACDEFGHIKLMNPQRSTVWY')
    
    # Check for invalid characters
    invalid_chars = set(sequence) - valid_aas
    if invalid_chars:
        return None, f'Invalid amino acids: {', '.join(invalid_chars)}'
    
    return sequence, None


def clean_nucleotide_sequence(sequence):
    '''Clean and validate nucleotide sequence.'''
    # Remove spaces, line breaks, and numbers
    sequence = re.sub(r'[\s\d]+', '', sequence.upper())
    
    if not sequence:
        return None, 'Sequence is empty'
    
    # Check for invalid characters
    valid_nucs = set('ATGC')
    invalid_chars = set(sequence) - valid_nucs
    if invalid_chars:
        if invalid_chars == {'U'}:
            sequence = sequence.replace('U', 'T')
        else:
            invalid_chars.discard('U')
            return None, f'Invalid nucleotides: {', '.join(invalid_chars)}'
    
    # Truncate to satisfy frame
    if len(sequence) % 3 != 0:
        rest = len(sequence) % 3
        sequence = sequence[:-rest]
    
    # Find and truncate at first stop codon
    stop_codons = ['TAA', 'TAG', 'TGA']
    for i in range(0, len(sequence) - 2, 3):
        codon = sequence[i:i+3]
        if codon in stop_codons:
            sequence = sequence[:i+3]
            print(f'Stop codon found at position {i+1}! Truncating sequence.\n')
            break
    
    return sequence, None

## 3. Create Interactive GUI

In [None]:
# Global variables for widgets
global weight_sliders, weight_labels, sum_label, optimize_button, progress_bar

# Default weight values
DEFAULT_WEIGHTS = {
    'cai': 0.40,
    'gc': 0.20,
    'folding': 0.15,
    'motifs': 0.10,
    'repeats': 0.07,
    'splice': 0.08
}

# Protein input section
protein_input = Textarea(
    value = '',
    placeholder = 'Enter AA sequence (single letter code)\n',
    disabled = False,
    layout = Layout(width = '600px', height = '80px')
)

backtranslate_button = Button(
    description = 'Back-translate to DNA',
    button_style = 'primary',
    tooltip = 'Convert protein to DNA sequence',
    layout = Layout(width = '200px')
)

protein_status = HTML(value = '<i>No protein specified</i>')

# DNA input section
dna_input = Textarea(
    value = '',
    placeholder = 'Enter DNA/RNA sequence\n',
    disabled = False,
    layout = Layout(width = '600px', height = '100px')
)

dna_status = HTML(value='<i>DNA sequence will appear here after back-translation or direct input</i>')

# Host organism selection
host_dropdown = Dropdown(
    options = host_options,
    disabled = False,
    layout = Layout(width = '300px')
)
if 'hsapiens' in host_options:
    host_dropdown.value = 'hsapiens'  # Default to human if available

# Codon optimization method
codon_method_dropdown = Dropdown(
    options = [
        'Simple best-codon',
        'Preserve pattern (percentile matching)',
        'Preserve distribution (full distribution matching)'
    ],
    value = 'Simple best-codon',
    disabled = False,
    layout = Layout(width = '450px')
)

# Genetic algorithm parameters
pop_size_slider = IntSlider(
    value = 20, min = 10, max = 200, step = 10,
    description = 'Population:', layout = Layout(width = '400px')
)

generations_slider = IntSlider(
    value = 25, min = 25, max = 300, step = 25,
    description = 'Generations:', layout = Layout(width = '400px')
)

mutation_slider = FloatSlider(
    value = 0.035, min = 0.005, max = 0.1, step = 0.005,
    description = 'Mutation:', readout_format = '.3f',
    layout = Layout(width = '400px')
)

# Fitness weight sliders
weights_label = HTML(value='<b>Fitness Function Weights:</b> (must sum to 1.0)')

weight_sliders = {
    'cai': FloatSlider(
        value = DEFAULT_WEIGHTS['cai'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    ),
    'gc': FloatSlider(
        value = DEFAULT_WEIGHTS['gc'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    ),
    'folding': FloatSlider(
        value = DEFAULT_WEIGHTS['folding'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    ),
    'motifs': FloatSlider(
        value = DEFAULT_WEIGHTS['motifs'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    ),
    'repeats': FloatSlider(
        value = DEFAULT_WEIGHTS['repeats'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    ),
    'splice': FloatSlider(
        value = DEFAULT_WEIGHTS['splice'], min = 0, max = 1, step = 0.01,
        description = '', readout_format = '.2f',
        layout = Layout(width = '300px')
    )
}

weight_label_names = {
    'cai': 'CAI:',
    'gc': 'GC Content:',
    'folding': 'Folding:',
    'motifs': 'Motifs:',
    'repeats': 'Repeats:',
    'splice': 'Splice Sites:'
}

weight_labels = {
    key: HTML(value = f'<span style = "color: green; width: 120px; display: inline-block;"><b>{weight_label_names[key]}</b></span>')
    for key in weight_sliders.keys()
}

sum_label = HTML(value = '<span style = "color: green;"><b>Total: 1.00</b></span>')

# Reset weights button
reset_button = Button(
    description = 'Reset Weights',
    button_style = 'warning',
    tooltip = 'Reset weights to default values',
    icon = 'refresh',
    layout = Layout(width = '150px')
)

# Optimize button
optimize_button = Button(
    description = 'Optimize Sequence',
    button_style = 'success',
    disabled = False,
    layout = Layout(width = '200px', height = '40px')
)

# Progress bar
progress_bar = IntProgress(
    value = 0,
    min = 0,
    max = 100,
    description = 'Progress:',
    bar_style = 'info',
    orientation = 'horizontal',
    layout = Layout(width = '600px', display = 'none')
)

progress_label = HTML(value = '', layout = Layout(display = 'none'))

# Output area
output_area = Output(layout = Layout(width = '100%', border = '1px solid #ddd', padding = '10px'))

# Avoid sequences
avoid_input = Text(
    value = '',
    placeholder = 'e.g., GAATTC,GGATCC',
    layout = Layout(width = '400px')
)

## 4. Event Handlers

In [None]:
def update_weight_labels():
    '''Update weight label colors based on sum.'''
    total = sum([weight_sliders[key].value for key in weight_sliders])
    is_valid = abs(total - 1.0) < 0.01
    
    color = 'green' if is_valid else 'red'
    for key, label in weight_labels.items():
        label.value = f'<span style = "color: {color}; width: 120px; display: inline-block;"><b>{weight_label_names[key]}</b></span>'
    
    sum_label.value = f'<span style = "color: {color};"><b>Total: {total:.2f}</b></span>'
    
    optimize_button.disabled = not is_valid
    optimize_button.button_style = 'success' if is_valid else 'danger'
    optimize_button.description = 'Optimize Sequence' if is_valid else 'Fix Weights First'

def on_backtranslate_click(b):
    '''Handle back-translation button click.'''
    clean_protein, error = validate_protein_sequence(protein_input.value)
    
    if error:
        protein_status.value = f'<span style = "color: red;">❌ {error}</span>'
        return
    
    try:
        dna_seq = back_translate_protein(clean_protein, 'human', optimize = False)
        dna_input.value = dna_seq
        protein_status.value = f'<span style = "color: green;">✓ Protein: {len(clean_protein)} aa</span>'
        dna_status.value = f'<span style = "color: green;">✓ DNA: {len(dna_seq)} bp ({len(dna_seq)//3} codons)</span>'
    except Exception as e:
        protein_status.value = f'<span style = "color: red;">❌ Error: {str(e)}</span>'

def on_weight_change(change):
    '''Handle weight slider changes.'''
    if change['type'] == 'change' and change['name'] == 'value':
        update_weight_labels()

def on_reset_click(b):
    '''Reset weights to default values.'''
    for key, slider in weight_sliders.items():
        slider.value = DEFAULT_WEIGHTS[key]
    update_weight_labels()

def update_progress(current, total):
    '''Update progress bar during optimization.'''
    progress = int((current / total) * 100)
    progress_bar.value = progress
    progress_label.value = f'<i>Generation {current}/{total}</i>'

def on_optimize_click(b):
    '''Handle optimize button click.'''
    with output_area:
        clear_output(wait = True)
        
        clean_seq, error = clean_nucleotide_sequence(dna_input.value)
        
        if error:
            print(f'❌ Error: {error}')
            return
        
        # Show progress bar
        progress_bar.layout.display = 'flex'
        progress_label.layout.display = 'block'
        progress_bar.value = 0
        
        print('=' * 70)
        print('OPTIMIZATION STARTED')
        print('=' * 70)
        print(f'Host: {host_dropdown.value}')
        print(f'Codon optimization method: {codon_method_dropdown.value}')
        print(f'Input sequence: {len(clean_seq)} bp')
        if clean_seq[-3:] in ['TAA', 'TAG', 'TGA']:
            print(f'Protein: {(len(clean_seq) - 3) / 3:.0f} aa')
        else:
            print(f'Protein: {len(clean_seq) / 3:.0f} aa')
        
        print(f'Population: {pop_size_slider.value}, Generations: {generations_slider.value}')
        
        # Phase 1: Initial codon optimization
        print('\n' + '=' * 70)
        print('PHASE 1: Initial Codon Optimization')
        print('=' * 70)
        
        try:
            pre_optimized = optimize_codons(clean_seq, host_dropdown.value, codon_method_dropdown.value)
            print(f'Method: {codon_method_dropdown.value}')
            print(f'Pre-optimized sequence ({len(pre_optimized)} bp):')
            
            # Display pre-optimized sequence
            display(Textarea(value = pre_optimized, layout = Layout(width = '600px', height = '400px')))
            
            # Calculate initial metrics
            codon_table = load_codon_table(host_dropdown.value)
            pre_cai = calculate_cai(pre_optimized, codon_table)
            pre_gc = calculate_gc_content(pre_optimized)
            print(f'Pre-optimized CAI: {pre_cai:.3f}')
            print(f'Pre-optimized GC: {pre_gc:.1f}%')
            
        except Exception as e:
            print(f'\n❌ Codon optimization failed: {str(e)}')
            progress_bar.layout.display = 'none'
            progress_label.layout.display = 'none'
            return
        
        # Phase 2: Genetic algorithm optimization
        print('\n' + '=' * 70)
        print('PHASE 2: Genetic Algorithm Optimization')
        print('=' * 70)
        print('Optimizing...\n')
        
        # Prepare weights
        weights = {
            'cai': weight_sliders['cai'].value,
            'gc_deviation': weight_sliders['gc'].value,
            'folding_energy': weight_sliders['folding'].value,
            'unwanted_motifs': weight_sliders['motifs'].value,
            'repeats': weight_sliders['repeats'].value,
            'cryptic_splice': weight_sliders['splice'].value
        }
        
        # Create progress callback
        def progress_callback(generation, total_generations):
            update_progress(generation, total_generations)
        
        # Run GA optimization on pre-optimized sequence
        try:
            optimized_seq, fitness, metrics = genetic_algorithm(
                initial_cds = pre_optimized,  # Use pre-optimized sequence
                host = host_dropdown.value,
                target_gc = HOST_TARGET_GC.get(host_dropdown.value),
                pop_size = pop_size_slider.value,
                generations = generations_slider.value,
                mutation_rate = mutation_slider.value,
                weights = weights,
                avoid_sequences = avoid_input.value,
                verbose = True,
                progress_callback = progress_callback
            )

            # Hide progress bar
            progress_bar.layout.display = 'none'
            progress_label.layout.display = 'none'
            
            print('\n' + '=' * 70)
            print('OPTIMIZATION COMPLETE')
            print('=' * 70)
            print(f'Final Fitness: {fitness:.4f}')
            print(f'CAI Score: {metrics['cai']:.3f}')
            print(f'GC Content: {metrics['gc_content']:.1f}%')
            
            # Display final optimized sequence
            print('\nFinal Optimized Sequence:')
            display(Textarea(value = optimized_seq, layout = Layout(width = '600px', height = '400px')))

            ################
            #  Genetic Code Exploration Helper module call for visual analysis
            data_clean = gceh_anal('Input:', clean_seq)
            data_opt = gceh_anal('Optimized:', optimized_seq)
            print('\n' + 'CODON ANALYSIS')
            print('\n' + 'Input sequence:')
            print(data_clean.table_str)
            print('\n' + 'Optimized sequence:')
            print(data_opt.table_str)  
            plot_codon_usage_compare(data_clean, data_opt, title = 'Codon usage: Input vs Optimized')

            # GC content sliding window analysis
            gc_window = len(optimized_seq) // 10
            plot_gc_sliding_compare(data_clean, data_opt, gc_window, step_nt = 1, title = 'Sliding GC%: Input vs Optimized')

            # Cumulative GC3 comparison
            plot_gc3_compare(data_clean, data_opt, title = 'Cumulative GC3 along sequence: Input vs Optimized')

            # Sliding-window GC3 comparison
            ten_percent = len(optimized_seq) // 10
            plot_gc3_sliding_compare(data_clean, data_opt, ten_percent, title = 'GC3 moving average: Input vs Optimized')
            #  End GCEH
            ################
            
        except Exception as e:
            progress_bar.layout.display = 'none'
            progress_label.layout.display = 'none'
            print(f'\n❌ Optimization failed: {str(e)}')

# Attach event handlers
backtranslate_button.on_click(on_backtranslate_click)
reset_button.on_click(on_reset_click)
optimize_button.on_click(on_optimize_click)

for slider in weight_sliders.values():
    slider.observe(on_weight_change)

## 5. Display GUI

In [None]:
# Assemble and display the GUI
gui = VBox([
    HTML('<h2>mRNA Sequence Optimization Tool</h2>'),
    HTML('<hr>'),
    
    # Protein input section
    HTML('<h3>1. Protein Sequence (Optional)</h3>'),
    protein_input,
    HBox([backtranslate_button, protein_status]),
    
    HTML('<hr>'),
    
    # DNA input section
    HTML('<h3>2. Nucleotide Sequence for Optimization</h3>'),
    dna_input,
    dna_status,
    
    HTML('<hr>'),
    
    # Parameters section
    HTML('<h3>3. Optimization Parameters</h3>'),
    VBox([
        HTML('<b>Host:</b>'),
        host_dropdown,
        HTML('<b>Avoid sequences (optional):</b>'),
        avoid_input,
        HTML('<b>Codon Optimization Method:</b>'),
        codon_method_dropdown
    ]),
    
    HTML('<h4>Genetic Algorithm Settings:</h4>'),
    pop_size_slider,
    generations_slider,
    mutation_slider,
    
    HTML('<hr>'),
    
    # Weights section
    HTML('<h3>4. Fitness Function Weights</h3>'),
    weights_label,
    VBox([
        HBox([weight_labels['cai'], weight_sliders['cai']]),
        HBox([weight_labels['gc'], weight_sliders['gc']]),
        HBox([weight_labels['folding'], weight_sliders['folding']]),
        HBox([weight_labels['motifs'], weight_sliders['motifs']]),
        HBox([weight_labels['repeats'], weight_sliders['repeats']]),
        HBox([weight_labels['splice'], weight_sliders['splice']]),
    ]),
    VBox([sum_label, reset_button]),
    
    HTML('<hr>'),
    
    # Optimize button and progress bar
    HTML('<h3>5. Run Optimization</h3>'),
    optimize_button,
    progress_bar,
    progress_label,
    
    HTML('<hr>'),
    
    # Output area
    HTML('<h3>Results</h3>'),
    HTML('''
    <style>
    .output_area pre, .output_area {
        user-select: text !important;
        -webkit-user-select: text !important;
    }
    </style>
    '''),
    output_area
])

display(gui)