# Advanced Hybrid ORF Detection - Stage 2

## Our Algorithm Architecture:
1. **Universal ORF Detection** - Find all potential genes
2. **Self-Training** - Learn species-specific patterns  
3. **Multi-Evidence Scoring** - Combine traditional methods
4. **Conflict Resolution** - Optimize overlapping predictions

## Goal:
Build a unsupervosed gene predictor that adapts to bacterial genome (maybe archeal as well)

1. Interpolated Markov Models (IMMs)

GLIMMER uses interpolated Markov model to identify coding regions, typically finding 98-99% of all relatively long protein coding genes GLIMMER - Wikipedia
Variable-length context models that adapt based on data availability

2. Ribosome Binding Site (RBS) Detection

Position weight matrix (PWM) that scores any potential RBS, using Gibbs-sampler to find RBS motifs in iterative fashion Identifying bacterial genes and endosymbiont DNA with Glimmer | Bioinformatics | Oxford Academic
Critical for accurate start site prediction

3. Reverse ORF Scanning

Scoring of ORF in GLIMMER 3.0 is done in reverse order starting from stop codon and moves back towards the start codon GLIMMER - Wikipedia
More accurate identification of coding regions


In [71]:
# IMPORTS AND GLOBAL VARIABLES
import sys
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from Bio.Seq import Seq
import matplotlib.pyplot as plt
import math
import time
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
import seaborn as sns
from typing import Dict, List



sys.path.insert(0, '..')  


from src.data_management import (
    load_genome_sequence,
    load_reference_genes_from_gff,
    get_fasta_path,
    get_gff_path,
    TEST_GENOMES  
)
from src.config import MIN_ORF_LENGTH,START_SELECTION_WEIGHTS,SCORE_WEIGHTS,LENGTH_REFERENCE_BP,MIN_ORF_LENGTH, START_CODON_WEIGHTS
from src.cache import load_cache
from src.comparative_analysis import compare_orfs_to_reference


print("Advanced ORF detection ready")


Advanced ORF detection ready


In [72]:
genome_id = "NC_000913.3"
cached_data=load_cache()

Loading cached data from C:\Users\User\Desktop\Bacterial Gene Prediction & Comparison Project\data\processed\cached_orfs.pkl...
Loaded 15 cached genomes


# stage 1- Universal ORF Detection
- ATG starts: High confidence
- GTG starts: Medium confidence 
- TTG starts: Lower confidence
- ribosome binding site detection
**Integration Strategy:**
Weight final gene scores by start codon reliability to reduce false positives from rare start codons while maintaining sensitivity for real genes with unusual starts.


In [73]:
# Known RBS motifs - just use these directly
KNOWN_RBS_MOTIFS = [
    "AGGAGG",
    "GGAGG", 
    "AGGAG",
    "GAGG",
    "AGGA",
    "GGAG"
]

def find_purine_rich_regions(sequence, min_length=4, min_purine_content=0.6):
    """
    Find purine-rich regions (A and G rich) in sequence.
    """
    purine_regions = []
    
    for start in range(len(sequence)):
        for length in range(min_length, min(9, len(sequence) - start + 1)):
            subseq = sequence[start:start + length]
            
            purines = subseq.count('A') + subseq.count('G')
            purine_fraction = purines / length
            
            if purine_fraction >= min_purine_content:
                purine_regions.append({
                    'sequence': subseq,
                    'start': start,
                    'end': start + length,
                    'purine_content': purine_fraction,
                    'length': length
                })
    
    return purine_regions

def evaluate_spacing_score(spacing):
    """
    Evaluate spacing between SD sequence and start codon.
    Optimal spacing is 6-10 nucleotides.
    """
    if 6 <= spacing <= 8:
        return 3.0  # Optimal
    elif 5 <= spacing <= 10:
        return 2.5  # Very good
    elif 4 <= spacing <= 12:
        return 1.5  # Good
    elif 3 <= spacing <= 14:
        return 1.0  # Acceptable
    else:
        return 0.2  # Poor

def score_motif_similarity(sequence):
    """
    Score how similar a sequence is to known RBS motifs.
    No padding, no PWM - just direct comparison.
    """
    best_score = 0.0
    best_motif = None
    
    for motif in KNOWN_RBS_MOTIFS:
        # Try different alignments
        for offset in range(max(len(sequence), len(motif))):
            matches = 0
            total_positions = 0
            
            # Compare overlapping region
            for i in range(len(sequence)):
                motif_pos = i + offset
                if 0 <= motif_pos < len(motif):
                    total_positions += 1
                    if sequence[i] == motif[motif_pos]:
                        matches += 1
            
            if total_positions > 0:
                similarity = matches / total_positions
                
                # Weight by overlap length and motif quality
                overlap_length = total_positions
                motif_weight = len(motif) / 6.0  # AGGAGG gets weight 1.0
                
                score = similarity * overlap_length * motif_weight
                
                if score > best_score:
                    best_score = score
                    best_motif = motif
    
    return best_score, best_motif

def predict_rbs_simple(sequence, orf, upstream_length=20):
    """
    Simple RBS prediction: purine-rich + spacing + direct motif comparison.
    """
    start_pos = orf['start']
    
    if start_pos < upstream_length:
        return {
            'rbs_score': -5.0,
            'spacing_score': 0.0,
            'motif_score': 0.0,
            'best_sequence': None,
            'best_motif': None,
            'spacing': 0,
            'position': 0
        }

    # Translation initiation region (20 bases upstream)
    upstream_start = start_pos - upstream_length
    upstream_seq = sequence[upstream_start:start_pos]
    
    # Find purine-rich regions
    purine_regions = find_purine_rich_regions(upstream_seq, min_length=4, min_purine_content=0.6)
    
    best_score = -5.0
    best_prediction = None
    
    # Evaluate each purine-rich region
    for region in purine_regions:
        sd_candidate = region['sequence']
        
        spacing = len(upstream_seq) - region['end']
        
        if spacing < 4 or spacing > 12:
            continue
        

        spacing_score = evaluate_spacing_score(spacing)
        
        motif_score, best_motif = score_motif_similarity(sd_candidate)
        
        purine_bonus = (region['purine_content'] - 0.6) * 2.0
        
        # Combined score
        combined_score = (
            spacing_score * 2.0 +    # Spacing is critical
            motif_score * 1.5 +      # Direct motif similarity  
            purine_bonus             # Higher purine content
        )
        
        if combined_score > best_score:
            best_score = combined_score
            best_prediction = {
                'rbs_score': combined_score,
                'spacing_score': spacing_score,
                'motif_score': motif_score,
                'best_sequence': sd_candidate,
                'best_motif': best_motif,
                'spacing': spacing,
                'position': region['start'],
                'purine_content': region['purine_content'],
                'length': region['length']
            }
    
    return best_prediction or {
        'rbs_score': -5.0,
        'spacing_score': 0.0,
        'motif_score': 0.0,
        'best_sequence': None,
        'best_motif': None,
        'spacing': 0,
        'position': 0
    }

In [74]:
def find_orfs_candidates(sequence, min_length):
    """
    ORF detection with dual coordinate system and immediate RBS calculation.
    
    Returns ORFs with:
    - 'start', 'end': coordinates in the RC sequence (for scoring)
    - 'genome_start', 'genome_end': forward-strand coordinates (for validation)
    - 'sequence': RC sequence (for scoring)
    - 'rbs_score', 'rbs_motif', 'rbs_spacing', 'rbs_sequence': RBS features
    """
    start_codons = {'ATG', 'GTG', 'TTG'}
    stop_codons = {'TAA', 'TAG', 'TGA'}
    orfs = []
    
    sequences = [
        ('forward', sequence),
        ('reverse', str(Seq(sequence).reverse_complement()))
    ]
    seq_len = len(sequence)

    print("Detecting ORFs and calculating RBS...")

    for strand_name, seq in sequences:
        for frame in range(3):
            active_starts = [] 
            for i in range(frame, len(seq) - 2, 3):
                codon = seq[i:i+3]
                
                if len(codon) != 3:  
                    break
                
                if codon in start_codons:
                    active_starts.append((i, codon))
                    
                elif codon in stop_codons and active_starts:
                    for start_pos, start_codon in active_starts:
                        orf_length = i + 3 - start_pos
                        if orf_length >= min_length:
                            # Create ORF
                            if strand_name == 'forward':
                                orf = {
                                    'start': start_pos + 1,
                                    'end': i + 3,
                                    'genome_start': start_pos + 1,
                                    'genome_end': i + 3,
                                    'length': orf_length,
                                    'frame': frame,
                                    'strand': 'forward',
                                    'start_codon': start_codon,
                                    'sequence': seq[start_pos:i+3]
                                }
                            else:  # reverse strand
                                orf = {
                                    'start': start_pos + 1,
                                    'end': i + 3,
                                    'genome_start': seq_len - (i + 3) + 1,
                                    'genome_end': seq_len - start_pos,
                                    'length': orf_length,
                                    'frame': frame,
                                    'strand': 'reverse',
                                    'start_codon': start_codon,
                                    'sequence': seq[start_pos:i+3]  
                                }
                            
                            # Calculate RBS
                            rbs_result = predict_rbs_simple(seq, orf, upstream_length=20)
                            orf['rbs_score'] = rbs_result['rbs_score']
                            orf['rbs_motif'] = rbs_result.get('best_motif')
                            orf['rbs_spacing'] = rbs_result.get('spacing', 0)
                            orf['rbs_sequence'] = rbs_result.get('best_sequence')
                            
                            orfs.append(orf)
                    active_starts = []
    
    print(f"Complete: {len(orfs):,} ORFs detected with RBS scores")
    return orfs

## Phase 2: Self-Training and Advanced Scoring

**Challenge:** 176,315 ORF candidates → ~4,300 real genes  
**Solution:** Implement Glimmer-style advanced methods

### Methods to Implement:
1. **Self-training on long ORFs** - Build species-specific models
2. **Interpolated Markov Models** - Score coding potential  
3. **RBS (Ribosome Binding Site) detection** - Find translation signals
4. **Codon usage bias analysis** - Species-specific patterns

### Strategy:
Use the longest, most confident ORFs to train models, then score all candidates.

In [75]:
#TRAINING SET SELECTION: possible strategies
def calculate_amino_acid_entropy(orf):
    """
    Calculate Shannon entropy of amino acid composition.
    Uses cached value if already calculated.
    
    Lower entropy = biased distribution = likely real gene
    Higher entropy = uniform distribution = likely random ORF
    
    Args:
        orf (dict): ORF dictionary with 'sequence' key
        
    Returns:
        float: Entropy value
               Returns 999.0 if translation fails
               
    Side effects:
        Caches result in orf['aa_entropy'] for reuse
    """
    # Return cached value if available
    if 'aa_entropy' in orf:
        return orf['aa_entropy']
    
    try:
        protein = str(Seq(orf['sequence']).translate(to_stop=True))
        
        # Need sufficient length for reliable statistics
        if len(protein) < 10:
            orf['aa_entropy'] = 999.0
            return 999.0
        
        # Count amino acid frequencies
        aa_counts = Counter(protein)
        total = len(protein)
        
        # Calculate Shannon entropy: H = -Σ(p_i × log₂(p_i))
        entropy = 0.0
        for count in aa_counts.values():
            if count > 0:
                freq = count / total
                entropy -= freq * math.log2(freq)
        
        orf['aa_entropy'] = entropy
        return entropy
        
    except Exception:
        orf['aa_entropy'] = 999.0
        return 999.0

def select_training_glimmer_3(all_orfs, min_length=300, 
                               max_entropy=4.2, max_training_size=2000):
    """
    GLIMMER 3.0 strategy with entropy filtering.
    """
    long_orfs = [orf for orf in all_orfs if orf['length'] >= min_length]
    
    # Calculate entropy for all ORFs (function handles caching internally)
    for orf in long_orfs:
        calculate_amino_acid_entropy(orf)
    
    # Filter by entropy
    entropy_filtered = [orf for orf in long_orfs if orf['aa_entropy'] <= max_entropy]
    entropy_filtered.sort(key=lambda x: x['length'], reverse=True)
    
    # Select non-overlapping
    training_set = []
    covered_intervals = []
    
    for orf in entropy_filtered:
        start = orf.get('genome_start', orf['start'])
        end = orf.get('genome_end', orf['end'])
        if start > end:
            start, end = end, start
        
        overlaps = any(not (end < cov_start or start > cov_end) 
                      for cov_start, cov_end in covered_intervals)
        
        if not overlaps:
            training_set.append(orf)
            covered_intervals.append((start, end))
            if max_training_size is not None and len(training_set) >= max_training_size:
                break
    
    return training_set

def select_training_glimmer(all_orfs, min_length=300, max_training_size=2000):
    """
    GLIMMER Pure - optimized (no entropy needed).
    Robust max_training_size enforcement (stop + final truncation).
    """
    long_orfs = [orf for orf in all_orfs if orf['length'] >= min_length]
    long_orfs.sort(key=lambda x: x['length'], reverse=True)
    
    training_set = []
    covered_intervals = []
    
    for orf in long_orfs:
        start = orf.get('genome_start', orf['start'])
        end = orf.get('genome_end', orf['end'])
        if start > end:
            start, end = end, start
        
        overlaps = False
        for cov_start, cov_end in covered_intervals:
            if not (end < cov_start or start > cov_end):
                overlaps = True
                break
        
        if not overlaps:
            training_set.append(orf)
            covered_intervals.append((start, end))
            if max_training_size is not None and len(training_set) >= max_training_size:
                break
    
    return training_set

def select_training_flexible(all_orfs, 
                             target_size=500, 
                             min_length=300, 
                             max_length=2400, 
                             max_overlap_fraction=0.3,
                             use_entropy_filter=True, 
                             max_entropy=4.2,
                             prefer_atg=True):
    """
    Flexible training set selection with entropy filtering and controlled overlap.
    
    This strategy allows some overlap between ORFs (unlike pure Glimmer) and
    prioritizes ATG start codons when available.
    
    Args:
        all_orfs (list): All detected ORFs
        target_size (int): Target number of training ORFs (default: 500)
        min_length (int): Minimum ORF length in bp (default: 300)
        max_length (int): Maximum ORF length in bp (default: 2400)
        max_overlap_fraction (float): Maximum allowed overlap as fraction of ORF length (default: 0.3)
        use_entropy_filter (bool): Whether to filter by amino acid entropy (default: True)
        max_entropy (float): Maximum amino acid entropy if filtering (default: 4.2)
        prefer_atg (bool): Prioritize ATG start codons (default: True)
        
    Returns:
        list: Selected training ORFs
    """
    # Filter by length
    filtered = [orf for orf in all_orfs 
                if min_length <= orf['length'] <= max_length]
    
    # Apply entropy filter if requested
    if use_entropy_filter:
        for orf in filtered:
            calculate_amino_acid_entropy(orf)
        filtered = [orf for orf in filtered if orf['aa_entropy'] <= max_entropy]
    
    # Sort candidates: ATG first if preferred, then by length
    if prefer_atg:
        atg_orfs = [orf for orf in filtered if orf.get('start_codon') == 'ATG']
        non_atg_orfs = [orf for orf in filtered if orf.get('start_codon') != 'ATG']
        
        atg_orfs.sort(key=lambda x: x['length'], reverse=True)
        non_atg_orfs.sort(key=lambda x: x['length'], reverse=True)
        
        candidates = atg_orfs + non_atg_orfs
    else:
        candidates = sorted(filtered, key=lambda x: x['length'], reverse=True)
    
    selected = []
    
    for orf in candidates:
        # Normalize coordinates
        orf_start = orf.get('genome_start', orf['start'])
        orf_end = orf.get('genome_end', orf['end'])
        if orf_start > orf_end:
            orf_start, orf_end = orf_end, orf_start
        
        orf_strand = orf.get('strand', 'forward')
        orf_length = orf['length']
        
        # Check overlap with already selected ORFs on same strand
        max_overlap = 0.0
        for sel in selected:
            # Only check same-strand overlaps
            if sel.get('strand', 'forward') != orf_strand:
                continue
            
            # Normalize selected ORF coordinates
            sel_start = sel.get('genome_start', sel['start'])
            sel_end = sel.get('genome_end', sel['end'])
            if sel_start > sel_end:
                sel_start, sel_end = sel_end, sel_start
            
            # Calculate overlap
            overlap_bp = max(0, min(orf_end, sel_end) - max(orf_start, sel_start) + 1)
            overlap_frac = overlap_bp / orf_length
            max_overlap = max(max_overlap, overlap_frac)
        
        # Accept ORF if overlap is acceptable
        if max_overlap <= max_overlap_fraction:
            selected.append(orf)
            
            # Stop when target size reached
            if len(selected) >= target_size:
                break
    
    return selected

def extract_intergenic_regions(sequence, training_orfs, buffer=50, min_length=150):
    """
    Extract intergenic regions using high-confidence genes (Glimmer approach).
    Returns both concatenated sequence and coordinates.
    """
    gene_regions = []
    for orf in training_orfs:
        start = orf.get('genome_start', orf['start'])
        end = orf.get('genome_end', orf['end'])
        if start > end:
            start, end = end, start
        gene_regions.append((max(1, start-buffer), min(len(sequence), end+buffer)))
    
    # Merge overlapping regions
    merged = []
    for s, e in sorted(gene_regions):
        if merged and s <= merged[-1][1]:
            merged[-1] = (merged[-1][0], max(merged[-1][1], e))
        else:
            merged.append((s, e))
    
    # Extract gaps
    intergenic_seqs = []
    intergenic_coords = []
    last_end = 1
    for s, e in merged:
        if s - last_end >= min_length:
            intergenic_coords.append((last_end, s-1))
            intergenic_seqs.append(sequence[last_end-1:s-1])
        last_end = e + 1
    if len(sequence) - last_end + 1 >= min_length:
        intergenic_coords.append((last_end, len(sequence)))
        intergenic_seqs.append(sequence[last_end-1:])
    
    concatenated = ''.join(intergenic_seqs)
    return concatenated, intergenic_coords

def extract_non_orf_regions_conservative(sequence, all_orfs, min_rbs_threshold=3.0, min_length=150):
    """
    Extract non-ORF regions using RBS-filtering.
    Returns both concatenated sequence and coordinates.
    """
    filtered = [orf for orf in all_orfs if orf.get('rbs_score', 0) >= min_rbs_threshold]
    occupied = []
    for orf in filtered:
        start = orf.get('genome_start', orf['start'])
        end = orf.get('genome_end', orf['end'])
        if start > end:
            start, end = end, start
        occupied.append((start, end))
    
    merged = []
    for s, e in sorted(occupied):
        if merged and s <= merged[-1][1]:
            merged[-1] = (merged[-1][0], max(merged[-1][1], e))
        else:
            merged.append((s, e))
    
    non_orf_seqs = []
    non_orf_coords = []
    last_end = 1
    for s, e in merged:
        if s - last_end >= min_length:
            non_orf_coords.append((last_end, s-1))
            non_orf_seqs.append(sequence[last_end-1:s-1])
        last_end = e + 1
    if len(sequence) - last_end + 1 >= min_length:
        non_orf_coords.append((last_end, len(sequence)))
        non_orf_seqs.append(sequence[last_end-1:])
    
    concatenated = ''.join(non_orf_seqs)
    return concatenated, non_orf_coords


def extract_all_non_orf_regions(sequence, all_orfs, min_length=150):
    """
    Extract all non-ORF regions (no filtering).
    Returns both concatenated sequence and coordinates.
    """
    occupied = []
    for orf in all_orfs:
        start = orf.get('genome_start', orf['start'])
        end = orf.get('genome_end', orf['end'])
        if start > end:
            start, end = end, start
        occupied.append((start, end))
    
    merged = []
    for s, e in sorted(occupied):
        if merged and s <= merged[-1][1]:
            merged[-1] = (merged[-1][0], max(merged[-1][1], e))
        else:
            merged.append((s, e))
    
    non_orf_seqs = []
    non_orf_coords = []
    last_end = 1
    for s, e in merged:
        if s - last_end >= min_length:
            non_orf_coords.append((last_end, s-1))
            non_orf_seqs.append(sequence[last_end-1:s-1])
        last_end = e + 1
    if len(sequence) - last_end + 1 >= min_length:
        non_orf_coords.append((last_end, len(sequence)))
        non_orf_seqs.append(sequence[last_end-1:])
    
    concatenated = ''.join(non_orf_seqs)
    return concatenated, non_orf_coords

def merge_intervals(intervals):
    """Merge overlapping intervals."""
    if not intervals:
        return []
    intervals = sorted(intervals)
    merged = [intervals[0]]
    for s, e in intervals[1:]:
        if s <= merged[-1][1]:
            merged[-1] = (merged[-1][0], max(merged[-1][1], e))
        else:
            merged.append((s, e))
    return merged


In [76]:
def check_intergenic_purity_fast(intergenic_regions, ref_genes_df, genome_length, method_name):
    """Fast purity check using merged gene intervals for linear scan."""
    if not intergenic_regions:
        return {
            'method': method_name,
            'total_regions': 0,
            'pure_regions': 0,
            'contaminated_regions': 0,
            'total_bp': 0,
            'pure_bp': 0,
            'contaminated_bp': 0,
            'region_purity': 0.0,
            'bp_purity': 0.0,
            'pct_of_genome': 0.0
        }

    # 1. Merge overlapping gene intervals
    gene_intervals = sorted([(int(row['start']), int(row['end'])) for _, row in ref_genes_df.iterrows()])
    merged_genes = []
    for s, e in gene_intervals:
        if merged_genes and s <= merged_genes[-1][1]:
            merged_genes[-1] = (merged_genes[-1][0], max(merged_genes[-1][1], e))
        else:
            merged_genes.append((s, e))

    # 2. Scan intergenic regions against merged genes
    pure_regions = 0
    contaminated_regions = 0
    pure_bp = 0
    contaminated_bp = 0

    g_idx = 0
    for inter_start, inter_end in intergenic_regions:
        region_length = inter_end - inter_start + 1
        total_bp_in_region = region_length
        region_pure_bp = 0
        overlap_found = False

        region_start = inter_start
        while g_idx < len(merged_genes) and merged_genes[g_idx][1] < inter_start:
            g_idx += 1
        temp_idx = g_idx
        while temp_idx < len(merged_genes) and merged_genes[temp_idx][0] <= inter_end:
            overlap_found = True
            overlap_start = max(region_start, merged_genes[temp_idx][0])
            overlap_end = min(inter_end, merged_genes[temp_idx][1])
            region_pure_bp += overlap_start - region_start  # add pure part before overlap
            region_start = overlap_end + 1
            temp_idx += 1
        if region_start <= inter_end:
            region_pure_bp += inter_end - region_start + 1

        pure_bp += region_pure_bp
        contaminated_bp += total_bp_in_region - region_pure_bp
        if overlap_found:
            contaminated_regions += 1
        else:
            pure_regions += 1

    total_regions = len(intergenic_regions)
    total_bp = sum(end - start + 1 for start, end in intergenic_regions)

    region_purity = pure_regions / total_regions * 100 if total_regions else 0
    bp_purity = pure_bp / total_bp * 100 if total_bp else 0
    pct_of_genome = total_bp / genome_length * 100 if genome_length else 0

    return {
        'method': method_name,
        'total_regions': total_regions,
        'pure_regions': pure_regions,
        'contaminated_regions': contaminated_regions,
        'total_bp': total_bp,
        'pure_bp': pure_bp,
        'contaminated_bp': contaminated_bp,
        'region_purity': region_purity,
        'bp_purity': bp_purity,
        'pct_of_genome': pct_of_genome
    }

def process_single_genome_intergenic(genome_id, cached_data):
    start_time = time.time()
    
    try:
        genome_data = cached_data.get(genome_id)
        if genome_data is None:
            raise ValueError(f"No precomputed ORFs found for {genome_id}")

        all_orfs = genome_data['orfs']

        all_orfs = genome_data['orfs']
        gff_path = get_gff_path(genome_id)
        fasta_path = get_fasta_path(genome_id)
        genome_sequence=load_genome_sequence(fasta_path)
        genome_sequence = genome_sequence['sequence'] 

        # Load reference genes from cached GFF
        ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
        if (ref[2] == "CDS").sum() > 0:
            ref_genes = ref[ref[2] == "CDS"][[3, 4]]
        else:
            ref_genes = ref[ref[2] == "gene"][[3, 4]]
        ref_genes.columns = ["start", "end"]

        genome_length = len(genome_sequence)
        intergenic_results = {}

        # -------------------
        # Glimmer-style
        # -------------------
        likely_genes = [orf for orf in all_orfs if orf['length'] >= 200]
        intergenic_seq_1, intergenic_coords_1 = extract_intergenic_regions(
            genome_sequence, likely_genes, buffer=50, min_length=150
        )
        intergenic_results['glimmer'] = check_intergenic_purity_fast(intergenic_coords_1, ref_genes, genome_length, "Glimmer")

        # -------------------
        # RBS-filtered
        # -------------------
        intergenic_seq_2, intergenic_coords_2 = extract_non_orf_regions_conservative(
            genome_sequence, all_orfs, min_rbs_threshold=3.0, min_length=150
        )
        intergenic_results['rbs_filtered'] = check_intergenic_purity_fast(intergenic_coords_2, ref_genes, genome_length, "RBS-Filtered")

        # -------------------
        # All non-ORF
        # -------------------
        intergenic_seq_3, intergenic_coords_3 = extract_all_non_orf_regions(
            genome_sequence, all_orfs, min_length=150
        )
        intergenic_results['all_nonorf'] = check_intergenic_purity_fast(intergenic_coords_3, ref_genes, genome_length, "All-NonORF")

        # -------------------
        # Union of all three
        # -------------------
        all_union_coords = merge_intervals(intergenic_coords_1 + intergenic_coords_2 + intergenic_coords_3)
        intergenic_results['union'] = check_intergenic_purity_fast(all_union_coords, ref_genes, genome_length, "Union")

        # -------------------
        # Flatten results
        # -------------------
        results = {
            'genome_id': genome_id,
            'total_orfs': len(all_orfs),
            'likely_genes': len(likely_genes),
            'processing_time': time.time() - start_time,
            'status': 'success'
        }

        for name, stats in intergenic_results.items():
            for key, value in stats.items():
                if key != 'method':
                    results[f'{name}_{key}'] = value

        return results

    except Exception as e:
        return {
            'genome_id': genome_id,
            'status': 'failed',
            'error': str(e),
            'processing_time': time.time() - start_time
        }

def run_intergenic_test_precomputed_sequential(genomes, cached_candidates):
    test_start = time.time()
    
    print("="*80)
    print("SEQUENTIAL INTERGENIC PURITY TEST (PRECOMPUTED ORFs)")
    print("="*80)
    print(f"Started: {datetime.now().strftime('%H:%M:%S')}")
    print(f"Processing {len(genomes)} genome(s) sequentially")
    print("="*80)
    
    all_results = []

    for id in genomes:
        result = process_single_genome_intergenic(id, cached_candidates)
        all_results.append(result)
        status = "✓" if result.get('status') == 'success' else "✗"
        print(f"  {status} {id} ({result.get('processing_time',0):.1f}s) [{len(all_results)}/{len(genomes)}]")

    successful_results = [r for r in all_results if r.get('status') == 'success']
    if not successful_results:
        print("\n✗ ERROR: All genomes failed to process!")
        return None, None

    results_df = pd.DataFrame(successful_results)

    summary_data = []
    for _, row in results_df.iterrows():
        for strategy in ['glimmer', 'rbs_filtered', 'all_nonorf', 'union']:
            summary_data.append({
                'Strategy': strategy.replace('_',' ').title(),
                'Regions': row[f'{strategy}_total_regions'],
                'Pure BP': row[f'{strategy}_pure_bp'],
                'BP Purity (%)': row[f'{strategy}_bp_purity'],
                'Pct of Genome (%)': row[f'{strategy}_pct_of_genome']
            })
    summary_df = pd.DataFrame(summary_data)

    print(f"\n{'='*80}")
    print("FINAL INTERGENIC PURITY RESULTS")
    print(f"{'='*80}\n")
    print(summary_df.to_string(index=False))


    return summary_df, results_df

def process_single_genome_precomputed(genome_id, cached_data, use_entropy=False):
    """
    Process a single genome with precomputed ORFs.
    
    Parameters:
    -----------
    genome_id : str
        Genome identifier
    cached_data : dict
        Cached ORF data
    use_entropy : bool, default=False
        If True, includes entropy-based filtering strategies
        If False, only uses strategies without entropy
    """
    start_time = time.time()
    
    try:
        # Use precomputed ORFs and GFF path
        genome_data = cached_data.get(genome_id)
        if genome_data is None:
            raise ValueError(f"No precomputed ORFs found for {genome_id}")
        
        all_orfs = genome_data['orfs']
        gff_path = get_gff_path(genome_id)

        # Define strategies based on use_entropy flag
        strategies = {}
        sets_to_intersect = []
        
        # Always include glimmer_pure
        glimmer_pure = select_training_glimmer(all_orfs, min_length=300, max_training_size=2000)
        strategies['glimmer_pure'] = glimmer_pure
        sets_to_intersect.append(
            set((orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end'])) 
                for orf in glimmer_pure)
        )
        
        if use_entropy:
            # Pre-calculate entropy for long ORFs (>=300)
            long_orfs = [orf for orf in all_orfs if orf['length'] >= 300]
            for orf in long_orfs:
                calculate_amino_acid_entropy(orf)
            
            # Include entropy-based strategies
            glimmer_3_0 = select_training_glimmer_3(
                all_orfs, min_length=300, max_entropy=4.2, max_training_size=2000
            )
            enhanced_flexible = select_training_flexible(
                all_orfs, target_size=2000, min_length=300, max_length=20000,
                max_overlap_fraction=0.3, use_entropy_filter=True, max_entropy=4.2, prefer_atg=True
            )
            
            strategies['glimmer_3_0'] = glimmer_3_0
            strategies['enhanced_flexible'] = enhanced_flexible
            
            sets_to_intersect.append(
                set((orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end'])) 
                    for orf in glimmer_3_0)
            )
            sets_to_intersect.append(
                set((orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end'])) 
                    for orf in enhanced_flexible)
            )
        
        # Always include enhanced_flexible_no_entropy
        enhanced_flexible_no_entropy = select_training_flexible(
            all_orfs, target_size=2000, min_length=300, max_length=20000,
            max_overlap_fraction=0.3, use_entropy_filter=False, prefer_atg=True
        )
        strategies['enhanced_flexible_no_entropy'] = enhanced_flexible_no_entropy
        sets_to_intersect.append(
            set((orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end'])) 
                for orf in enhanced_flexible_no_entropy)
        )
        
        # Intersection of all strategies
        intersection_coords = set.intersection(*sets_to_intersect)
        intersection_all = [
            orf for orf in all_orfs 
            if (orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end'])) 
            in intersection_coords
        ]
        strategies['intersection'] = intersection_all
        
        # Load reference genes from cached GFF
        ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
        if (ref[2] == "CDS").sum() > 0:
            ref_genes = ref[ref[2] == "CDS"][[3, 4]]
        else:
            ref_genes = ref[ref[2] == "gene"][[3, 4]]
        ref_genes.columns = ["start", "end"]

        # Calculate purity
        results = {
            'genome_id': genome_id,
            'total_orfs': len(all_orfs),
            'ref_genes': len(ref_genes),
            'processing_time': time.time() - start_time,
            'status': 'success',
            'use_entropy': use_entropy
        }

        for strategy_name, training_set in strategies.items():
            training_coords = pd.DataFrame({
                'start': [orf.get('genome_start', orf['start']) for orf in training_set],
                'end': [orf.get('genome_end', orf['end']) for orf in training_set]
            })
            matches = pd.merge(training_coords, ref_genes, on=['start', 'end'])
            true_genes = len(matches)
            purity = (true_genes / len(training_set) * 100) if len(training_set) > 0 else 0

            results[f'{strategy_name}_size'] = len(training_set)
            results[f'{strategy_name}_true_genes'] = true_genes
            results[f'{strategy_name}_purity'] = purity

        return results

    except Exception as e:
        return {
            'genome_id': genome_id,
            'status': 'failed',
            'error': str(e),
            'processing_time': time.time() - start_time
        }

def run_genome_test_precomputed_sequential(genomes, cached_data, use_entropy=False):
    """
    Run sequential genome-level purity test with precomputed ORFs.
    
    Parameters:
    -----------
    genomes : list
        List of genome IDs to process
    cached_data : dict
        Cached ORF data
    use_entropy : bool, default=False
        If True, includes entropy-based filtering strategies
        If False, only uses strategies without entropy
    """
    test_start = time.time()
    
    entropy_status = "WITH ENTROPY" if use_entropy else "WITHOUT ENTROPY"
    print("="*80)
    print(f"SEQUENTIAL GENOME-LEVEL PURITY TEST (PRECOMPUTED ORFs) - {entropy_status}")
    print("="*80)
    print(f"Started: {datetime.now().strftime('%H:%M:%S')}")
    print(f"Processing {len(genomes)} genome(s) sequentially")
    print("="*80)
    
    all_results = []

    # Sequential execution
    for genome_id in genomes:
        result = process_single_genome_precomputed(genome_id, cached_data, use_entropy=use_entropy)
        all_results.append(result)
        status = "✓" if result.get('status') == 'success' else "✗"
        print(f"  {status} {genome_id} ({result.get('processing_time',0):.1f}s) [{len(all_results)}/{len(genomes)}]")

    total_time = time.time() - test_start

    successful_results = [r for r in all_results if r.get('status') == 'success']
    if not successful_results:
        print("\n✗ ERROR: All genomes failed to process!")
        return None, None

    results_df = pd.DataFrame(successful_results)

    # Determine which strategies to show
    if use_entropy:
        strategy_names = ['glimmer_pure', 'glimmer_3_0', 'enhanced_flexible', 
                         'enhanced_flexible_no_entropy', 'intersection']
    else:
        strategy_names = ['glimmer_pure', 'enhanced_flexible_no_entropy', 'intersection']

    # Summary table
    summary_data = []
    for _, row in results_df.iterrows():
        for strategy in strategy_names:
            if f'{strategy}_size' in row:
                summary_data.append({
                    'Strategy': strategy.replace('_', ' ').title(),
                    'Training Size': row[f'{strategy}_size'],
                    'True Genes': row[f'{strategy}_true_genes'],
                    'Purity (%)': row[f'{strategy}_purity']
                })
    summary_df = pd.DataFrame(summary_data)

    # Display results
    print(f"\n{'='*80}")
    print("FINAL RESULTS TABLE")
    print(f"{'='*80}\n")
    print(summary_df.to_string(index=False))

    # Average purity
    print(f"\n{'='*80}")
    print("AVERAGE PURITY BY STRATEGY")
    print(f"{'='*80}")
    avg_purity = summary_df.groupby('Strategy')['Purity (%)'].mean().sort_values(ascending=False)
    for strategy, purity in avg_purity.items():
        symbol = "🏆" if purity == avg_purity.max() else "✓"
        print(f"  {symbol} {strategy:<25} {purity:>6.2f}%")

    # Performance metrics
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    avg_processing_time = results_df['processing_time'].mean()
    print(f"  Average per-genome time: {avg_processing_time:.1f}s")
    print(f"  Total wall-clock time: {total_time:.1f}s ({total_time/60:.1f} min)")
    print(f"  Genomes processed: {len(successful_results)}/{len(genomes)}")
    print(f"  Entropy filtering: {'Enabled' if use_entropy else 'Disabled'}")

    return summary_df, results_df
'''
# Test WITHOUT entropy
summary_no_entropy, results_no_entropy = run_genome_test_precomputed_sequential(TEST_GENOMES, cached_data, use_entropy=False)
# Test WITH entropy
summary_with_entropy, results_with_entropy = run_genome_test_precomputed_sequential(TEST_GENOMES, cached_data, use_entropy=True)
# Test INTERGENIC
summary_df, results_df = run_intergenic_test_precomputed_sequential(TEST_GENOMES, cached_data)
'''

'\n# Test WITHOUT entropy\nsummary_no_entropy, results_no_entropy = run_genome_test_precomputed_sequential(TEST_GENOMES, cached_data, use_entropy=False)\n# Test WITH entropy\nsummary_with_entropy, results_with_entropy = run_genome_test_precomputed_sequential(TEST_GENOMES, cached_data, use_entropy=True)\n# Test INTERGENIC\nsummary_df, results_df = run_intergenic_test_precomputed_sequential(TEST_GENOMES, cached_data)\n'

In [77]:
def create_training_set(genome_id, cached_candidates, glimmer_max_size=2000, flexible_target_size=2000):
    """
    Create a training set by intersecting Glimmer and Flexible (no-entropy) selections
    for a given genome using precomputed ORFs from cached_candidates.

    Args:
        genome_id (str): Genome identifier
        cached_candidates (dict): Precomputed ORFs for all genomes
        glimmer_max_size (int): Maximum size for Glimmer training set
        flexible_target_size (int): Target size for Flexible training set

    Returns:
        list: ORFs belonging to the intersection training set
    """
    genome_data = cached_candidates.get(genome_id)
    if genome_data is None:
        raise ValueError(f"No precomputed ORFs found for genome_id {genome_id}")

    all_orfs = genome_data['orfs']

    # --- Step 1: Glimmer selection (pure, no entropy) ---
    glimmer_set = select_training_glimmer(all_orfs, min_length=300, max_training_size=glimmer_max_size)

    # --- Step 2: Flexible selection (no entropy) ---
    flexible_set = select_training_flexible(
        all_orfs, target_size=flexible_target_size, min_length=300, max_length=20000,
        max_overlap_fraction=0.3, use_entropy_filter=False, prefer_atg=True
    )

    # --- Step 3: Intersection based on coordinates ---
    glimmer_coords = set((orf.get('genome_start', orf['start']),
                          orf.get('genome_end', orf['end'])) for orf in glimmer_set)
    flexible_coords = set((orf.get('genome_start', orf['start']),
                           orf.get('genome_end', orf['end'])) for orf in flexible_set)

    intersection_coords = glimmer_coords & flexible_coords

    intersection_orfs = [orf for orf in all_orfs
                         if (orf.get('genome_start', orf['start']),
                             orf.get('genome_end', orf['end'])) in intersection_coords]

    return intersection_orfs

def create_intergenic_set(genome_id, cached_candidates, buffer=50, min_length=150, min_rbs_threshold=3.0):
    """
    Create intergenic regions by taking the union of multiple strategies:
    - Glimmer-like ORFs (length >= 200)
    - RBS-filtered non-ORF regions
    - All non-ORF regions

    Args:
        genome_id (str): Genome identifier
        cached_candidates (dict): Precomputed ORFs and genome sequences
        buffer (int): Buffer around ORFs for Glimmer-like extraction
        min_length (int): Minimum intergenic region length
        min_rbs_threshold (float): Minimum RBS score for conservative extraction

    Returns:
        list of dicts: intergenic regions, each with start, end, length, and sequence
    """
    genome_data = cached_candidates.get(genome_id)
    if genome_data is None:
        raise ValueError(f"No precomputed ORFs found for genome {genome_id}")

    genome_sequence = genome_data['sequence']
    all_orfs = genome_data['orfs']

    # Step 1: Identify likely coding regions (for exclusion)
    likely_genes = [orf for orf in all_orfs if orf['length'] >= 200]

    # Step 2: Extract intergenic regions from multiple strategies
    _, intergenic_coords_1 = extract_intergenic_regions(
        genome_sequence, likely_genes, buffer=buffer, min_length=min_length
    )
    _, intergenic_coords_2 = extract_non_orf_regions_conservative(
        genome_sequence, all_orfs, min_rbs_threshold=min_rbs_threshold, min_length=min_length
    )
    _, intergenic_coords_3 = extract_all_non_orf_regions(
        genome_sequence, all_orfs, min_length=min_length
    )

    # Step 3: Merge all coordinates into a single union
    all_union_coords = merge_intervals(intergenic_coords_1 + intergenic_coords_2 + intergenic_coords_3)

    # Step 4: Build dictionary objects similar to ORFs
    intergenic_regions = []
    for start, end in all_union_coords:
        seq = genome_sequence[start-1:end]
        intergenic_regions.append({
            'start': start,
            'end': end,
            'length': len(seq),
            'sequence': seq,
            'type': 'intergenic'
        })

    return intergenic_regions


In [78]:

def build_codon_model(sequences):
    """
    Build species-specific codon model from given sequence.
    
    Args:
        sequences (list): High-confidence sequences (either training or intergenic)
        
    Returns:
        dict: codon -> frequency mapping
    """
    codon_counts = Counter()
    total_codons = 0

    for seq in sequences:
        sequence = seq['sequence']
        for i in range(0, len(sequence) - 2, 3):
            codon = sequence[i:i+3]
            if len(codon) == 3 and 'N' not in codon:
                codon_counts[codon] += 1
                total_codons += 1

    if total_codons == 0:
        return {}

    frequencies = {codon: count / total_codons for codon, count in codon_counts.items()}

    return frequencies

def prepare_models(training_orfs, intergenic_regions):
    """
    Prepare both coding and background (noncoding) codon usage models.

    Args:
        training_orfs (list of dict): High-confidence ORFs with 'sequence' key
        intergenic_regions (list of dict): Intergenic regions with 'sequence' key

    Returns:
        tuple: (codon_model, background_codon_model)
            codon_model     -> dict: codon -> frequency
            background_codon_model -> dict: codon -> frequency
    """
    codon_model = build_codon_model(training_orfs)
    background_codon_model = build_codon_model(intergenic_regions)

    return codon_model, background_codon_model

def score_codon_bias_ratio(orf_sequence, codon_model, background_codon_model):
    """
    Score ORF by comparing coding vs background codon usage.
    
    Args:
        orf_sequence (str): Sequence to score
        codon_model (dict): Codon frequencies from genes
        background_codon_model (dict): Codon frequencies from intergenic regions
        
    Returns:
        float: Log ratio score (positive = more gene-like)
    """
    if len(orf_sequence) < 3:
        return 0.0
    
    coding_score = 0.0
    background_score = 0.0
    codon_count = 0
    
    for i in range(0, len(orf_sequence) - 2, 3):
        codon = orf_sequence[i:i+3]
        if len(codon) == 3 and 'N' not in codon:
            # Get frequencies (with small pseudocount for unseen codons)
            coding_freq = codon_model.get(codon, 0.0001)
            background_freq = background_codon_model.get(codon, 0.0001)
            
            coding_score += math.log(coding_freq)
            background_score += math.log(background_freq)
            codon_count += 1
    
    if codon_count == 0:
        return 0.0
    
    # Return log likelihood ratio
    return (coding_score - background_score) / codon_count


In [79]:
#attemp new imm with frame detection

def build_interpolated_markov_model(training_sequences, max_order, min_observations=10):
    """
    Build frame-aware IMM from training sequences.
    Internally creates 3 position-specific models but API stays the same!
    
    Args:
        training_sequences (list): List of DNA sequences (must start at codon position 0)
        max_order (int): Maximum context length
        min_observations (int): Minimum observations for reliable probability
        
    Returns:
        list: [model_pos0, model_pos1, model_pos2] - 3 position-specific models
              (internally frame-aware, but returned as list to maintain compatibility)
    """
    
    # Initialize 3 models, one for each codon position
    position_models = [
        defaultdict(lambda: defaultdict(int)),  # Position 0 (1st base of codon)
        defaultdict(lambda: defaultdict(int)),  # Position 1 (2nd base of codon)
        defaultdict(lambda: defaultdict(int))   # Position 2 (3rd base of codon)
    ]
    
    # Build counts for each position
    for sequence in training_sequences:
        for i in range(len(sequence)):
            nucleotide = sequence[i]
            codon_position = i % 3  # Which position in the codon (0, 1, or 2)
            
            # Build context with different orders
            for order in range(min(i + 1, max_order + 1)):
                if order == 0:
                    context = ""  # 0th order - no context
                else:
                    context = sequence[i-order:i]  # Previous 'order' nucleotides
                
                # Add to the model for this codon position
                position_models[codon_position][context][nucleotide] += 1
    
    # Convert counts to probabilities for each position
    position_probabilities = []
    
    for pos in range(3):
        probabilities = {}
        for context, nucleotide_counts in position_models[pos].items():
            total_count = sum(nucleotide_counts.values())
            
            if total_count >= min_observations:  # Only if enough data
                probabilities[context] = {}
                for nucleotide, count in nucleotide_counts.items():
                    probabilities[context][nucleotide] = count / total_count
        
        position_probabilities.append(probabilities)
    
    return position_probabilities

def get_interpolated_probability(nucleotide, context, probabilities, fallback_prob=0.25):
    """
    Get probability using longest reliable context.
    Works with both single model (dict) or position-specific model (dict).
    
    Args:
        nucleotide (str): Nucleotide to predict (A, T, G, C)
        context (str): Context sequence (previous nucleotides)
        probabilities (dict): Probability model for this position
        fallback_prob (float): Default probability if no context found
        
    Returns:
        float: Probability of nucleotide given context
    """
    # Try longest context first, fall back to shorter contexts
    for order in range(len(context), -1, -1):
        current_context = context[-order:] if order > 0 else ""
        
        if current_context in probabilities:
            if nucleotide in probabilities[current_context]:
                return probabilities[current_context][nucleotide]
    
    # Ultimate fallback - uniform probability
    return fallback_prob

def score_imm_ratio(sequence, coding_imm, noncoding_imm, max_order):
    """
    Score sequence using frame-aware IMM log-likelihood ratio.
    
    This function automatically detects if models are frame-aware (list of 3)
    or frame-agnostic (single dict) and handles both!
    
    Args:
        sequence (str): DNA sequence to score (should start at codon position 0)
        coding_imm: Coding model (list of 3 dicts or single dict)
        noncoding_imm: Noncoding model (list of 3 dicts or single dict)
        max_order (int): Maximum context length
        
    Returns:
        float: Log likelihood ratio (positive = more coding-like)
    """
    if len(sequence) < 3:
        return 0.0
    
    # Auto-detect if frame-aware (list of 3) or frame-agnostic (single dict)
    is_frame_aware = isinstance(coding_imm, list) and len(coding_imm) == 3
    
    coding_log_prob = 0.0
    noncoding_log_prob = 0.0
    
    # Pseudocount to prevent log(0)
    EPSILON = 1e-10
    
    for i in range(len(sequence)):
        nucleotide = sequence[i]
        
        # Get context (up to max_order previous nucleotides)
        context_start = max(0, i - max_order)
        context = sequence[context_start:i]
        
        if is_frame_aware:
            # Frame-aware: use position-specific model
            codon_position = i % 3  # Which position in codon (0, 1, or 2)
            coding_prob = get_interpolated_probability(
                nucleotide, context, coding_imm[codon_position]
            )
            noncoding_prob = get_interpolated_probability(
                nucleotide, context, noncoding_imm[codon_position]
            )
        else:
            # Frame-agnostic: use single model
            coding_prob = get_interpolated_probability(nucleotide, context, coding_imm)
            noncoding_prob = get_interpolated_probability(nucleotide, context, noncoding_imm)
        
        # Guard against log(0)
        coding_prob = max(coding_prob, EPSILON)
        noncoding_prob = max(noncoding_prob, EPSILON)
        
        # Add to log probabilities
        coding_log_prob += math.log(coding_prob)
        noncoding_log_prob += math.log(noncoding_prob)
    
    # Return log likelihood ratio normalized by sequence length
    return (coding_log_prob - noncoding_log_prob) / len(sequence)

def prepare_imm_models(training_set, intergenic_set, min_observations=10):
    """
    Prepare frame-aware coding and noncoding IMM models.
    Assumes both sets contain dicts with 'sequence' fields.
    """
    # Extract sequences directly
    training_seqs = [orf['sequence'] for orf in training_set]
    intergenic_seqs = [orf['sequence'] for orf in intergenic_set]
    
    # Compute total sequence lengths
    n_training = sum(len(seq) for seq in training_seqs)
    n_intergenic = sum(len(seq) for seq in intergenic_seqs)
    effective_n = min(n_training, n_intergenic)
    
    # Estimate max order based on available data
    if effective_n < min_observations:
        estimated_order = 0
    else:
        estimated_order = math.floor(math.log2(effective_n / min_observations) / 2)
    
    # Clip to reasonable bacterial range
    estimated_order = min(estimated_order, 8)
    estimated_order = max(estimated_order, 3)
    
    # Build frame-aware IMM models
    coding_imm = build_interpolated_markov_model(training_seqs, estimated_order, min_observations)
    noncoding_imm = build_interpolated_markov_model(intergenic_seqs, estimated_order, min_observations)
    
    return coding_imm, noncoding_imm, estimated_order



In [80]:
def build_all_scoring_models(training_set, intergenic_set, min_observations=10):
    """
    Build all scoring models needed for ORF evaluation.
    
    Args:
        training_set (list): List of training ORFs or sequences
        intergenic_set (list): List of intergenic sequences
        min_observations (int): Minimum observations for IMM contexts
        
    Returns:
        dict: All models needed for scoring:
            - codon_model
            - background_codon_model
            - coding_imm
            - noncoding_imm
            - max_order
    """
    print("Building scoring models...")
    start_time = time.time()
    
    print("  Building codon usage models...")
    codon_model,background_codon_model = prepare_models(training_set,intergenic_set)

    print("  Building IMM models...")
    coding_imm, noncoding_imm, max_order = prepare_imm_models(training_set, intergenic_set, min_observations=min_observations)
    
    print(f"✓ All models built in {time.time() - start_time:.1f}s")
    
    return {
        'codon_model': codon_model,
        'background_codon_model': background_codon_model,
        'coding_imm': coding_imm,
        'noncoding_imm': noncoding_imm,
        'max_order': max_order
    }


ORF Scoring System

## Goal
Score all 176,315 ORF candidates using multiple evidence-based methods and rank them to identify real genes.

## Scoring Methods
1. **Codon Usage Bias** - Compare to coding vs non-coding patterns
2. **Interpolated Markov Model** - Nucleotide context patterns 
3. **Ribosome Binding Site** - Translation initiation signals 
4. **ORF Length** 
5. **Start Codon** - Frequency-based weighting

## Process
1. Score each ORF with all methods (raw scores)
2. Normalize scores to prevent range dominance
3. Combine with weighted sum
4. Rank all ORFs

## Key Improvement
**Normalization** prevents RBS (range -5 to 15) from dominating IMM (range -0.2 to 0.13)

In [81]:
def score_orf_length(orf_length):
    """
    Score ORF length using log scale.
    
    Args:
        orf_length (int): Length of ORF in base pairs
        
    Returns:
        float: Length score
    """
    return math.log(max(orf_length, MIN_ORF_LENGTH) / LENGTH_REFERENCE_BP)

def score_start_codon(start_codon):
    """
    Score start codon based on frequency in bacterial genomes.
    """
    return START_CODON_WEIGHTS.get(start_codon, 0.4)


In [82]:
def score_all_orfs(all_orfs, models):
    """
    Score all ORFs with all methods using pre-built models.
    Adds all scores and RBS details to each ORF dict.
    
    Args:
        all_orfs (list): List of ORF dictionaries
        models (dict): Pre-built models from build_all_scoring_models()
        
    Returns:
        list: ORFs with scores added
    """
    print(f"Scoring {len(all_orfs):,} ORFs...")
    start_time = time.time()
    
    codon_model = models['codon_model']
    background_codon_model = models['background_codon_model']
    coding_imm = models['coding_imm']
    noncoding_imm = models['noncoding_imm']
    max_order = models['max_order']
    
    for i, orf in enumerate(all_orfs):
        if i % 25000 == 0 and i > 0:
            print(f"  {i:,}...")
        
        # Score each component
        orf['codon_score'] = score_codon_bias_ratio(
            orf['sequence'], 
            codon_model, 
            background_codon_model
        )
        
        orf['imm_score'] = score_imm_ratio(
            orf['sequence'], 
            coding_imm, 
            noncoding_imm,
            max_order
        )
        # Length score
        orf['length_score'] = score_orf_length(orf['length'])
        
        # Start codon score
        orf['start_score'] = score_start_codon(orf.get('start_codon', 'ATG'))
    
    print(f"Done in {(time.time() - start_time)/60:.1f} minutes")
    return all_orfs

# normalization function

In [83]:
def normalize_scores_zscore(scores):
    """
    Normalize scores using z-score (standard score).
    
    Formula: z = (x - mean) / std
    
    Args:
        scores (list or array): Raw scores
        
    Returns:
        numpy.array: Normalized scores (mean=0, std=1)
    """
    scores = np.array(scores)
    mean = np.mean(scores)
    std = np.std(scores)
    
    if std == 0:  # All scores identical
        return np.zeros_like(scores)
    
    return (scores - mean) / std

def normalize_all_orf_scores(scored_orfs):
    """
    Normalize all score components across all ORFs.
    
    Adds normalized versions to each ORF:
    - codon_score_norm
    - imm_score_norm
    - rbs_score_norm
    - length_score_norm
    - start_score_norm
    
    Args:
        scored_orfs (list): ORFs with raw scores
        
    Returns:
        list: Same ORFs with normalized scores added
    """
    
    print(f"\nNormalizing ORFs...")
    
    codon_scores = np.array([orf['codon_score'] for orf in scored_orfs])
    imm_scores = np.array([orf['imm_score'] for orf in scored_orfs])
    rbs_scores = np.array([orf['rbs_score'] for orf in scored_orfs])
    length_scores = np.array([orf['length_score'] for orf in scored_orfs])
    start_scores = np.array([orf['start_score'] for orf in scored_orfs])
    
    # Normalize each component
    codon_norm = normalize_scores_zscore(codon_scores)
    imm_norm = normalize_scores_zscore(imm_scores)
    rbs_norm = normalize_scores_zscore(rbs_scores)
    length_norm = normalize_scores_zscore(length_scores)
    start_norm = normalize_scores_zscore(start_scores)
    
    # Add back to ORFs
    for i, orf in enumerate(scored_orfs):
        orf['codon_score_norm'] = codon_norm[i]
        orf['imm_score_norm'] = imm_norm[i]
        orf['rbs_score_norm'] = rbs_norm[i]
        orf['length_score_norm'] = length_norm[i]
        orf['start_score_norm'] = start_norm[i]
    
    components = [
        ('Codon', codon_scores, codon_norm),
        ('IMM', imm_scores, imm_norm),
        ('RBS', rbs_scores, rbs_norm),
        ('Length', length_scores, length_norm),
        ('Start', start_scores, start_norm)
    ]
    
    return scored_orfs

def calculate_combined_score(orf, weights=None):
    """
    Calculate weighted combined score from normalized components.
    
    Args:
        orf (dict): ORF with normalized scores
        weights (dict): Optional custom weights (uses global SCORE_WEIGHTS if None)
        
    Returns:
        float: Combined score
    """
    if weights is None:
        weights = SCORE_WEIGHTS  
    
    combined = (
        orf['codon_score_norm'] * weights['codon'] +
        orf['imm_score_norm'] * weights['imm'] +
        orf['rbs_score_norm'] * weights['rbs'] +
        orf['length_score_norm'] * weights['length'] +
        orf['start_score_norm'] * weights['start']
    )
    
    return combined

def add_combined_scores(scored_orfs, weights=None):
    """
    Add combined score to all ORFs.
    
    Args:
        scored_orfs (list): ORFs with normalized scores
        weights (dict): Optional custom weights (uses global SCORE_WEIGHTS if None)
        
    Returns:
        list: ORFs with combined_score added
    """

    if weights is None:
        weights = SCORE_WEIGHTS 
    
    print(f"\nCalculating weighted scores...")
    
    for orf in scored_orfs:
        orf['combined_score'] = calculate_combined_score(orf, weights)
    
    combined_scores = np.array([orf['combined_score'] for orf in scored_orfs])
     
    return scored_orfs


# filtering and analysis

In [84]:
def diagnose_rbs_implementation_fixed(all_orfs, reference_genes, genome_sequence=None):
    """
    Fixed diagnostic that correctly displays RBS sequences from ORF data.
    
    Key fix: Instead of trying to re-extract sequences from genome,
    use the 'rbs_sequence' field that was already calculated correctly.
    """
    
    print("="*80)
    print("RBS SCORE DIAGNOSTIC (FIXED)")
    print("="*80)
    
    # Step 1: Classify ORFs
    print("\n1. CLASSIFYING ORFs AS TP/FP...")
    tp_orfs = []
    fp_orfs = []
    
    for orf in all_orfs:
        orf_start = orf.get('genome_start', orf['start'])
        orf_end = orf.get('genome_end', orf['end'])
        is_match = (orf_start, orf_end) in reference_genes
        
        if is_match:
            tp_orfs.append(orf)
        else:
            fp_orfs.append(orf)
    
    print(f"   TP ORFs: {len(tp_orfs):,}")
    print(f"   FP ORFs: {len(fp_orfs):,}")
    
    # Step 2: Check RBS scores
    print("\n2. CHECKING RBS SCORE AVAILABILITY...")
    if not all_orfs or 'rbs_score' not in all_orfs[0]:
        print("   ❌ ERROR: 'rbs_score' field not found!")
        return
    print("   ✓ RBS scores found in ORFs")
    
    # Step 3: Score distributions
    print("\n3. ANALYZING RBS SCORE DISTRIBUTIONS...")
    tp_scores = [orf['rbs_score'] for orf in tp_orfs if 'rbs_score' in orf]
    fp_scores = [orf['rbs_score'] for orf in fp_orfs if 'rbs_score' in orf]
    
    if not tp_scores or not fp_scores:
        print("   ❌ ERROR: No RBS scores available!")
        return
    
    tp_mean = np.mean(tp_scores)
    fp_mean = np.mean(fp_scores)
    tp_std = np.std(tp_scores)
    fp_std = np.std(fp_scores)
    tp_median = np.median(tp_scores)
    fp_median = np.median(fp_scores)
    
    print(f"   TP: mean={tp_mean:.4f}, std={tp_std:.4f}, median={tp_median:.4f}")
    print(f"   FP: mean={fp_mean:.4f}, std={fp_std:.4f}, median={fp_median:.4f}")
    print(f"   Mean difference: {tp_mean - fp_mean:.4f}")
    
    pooled_std = np.sqrt((tp_std**2 + fp_std**2) / 2)
    effect_size = abs(tp_mean - fp_mean) / pooled_std if pooled_std > 0 else 0
    
    print(f"   Effect size (Cohen's d): {effect_size:.4f}")
    if effect_size < 0.5:
        print("   ⚠️  Weak separation")
    elif effect_size < 0.8:
        print("   ✓ Moderate separation")
    else:
        print("   ✓✓ Good separation!")
    
    # Step 4: FIXED - Analyze RBS sequences from ORF data
    print("\n4. ANALYZING RBS SEQUENCES FROM ORF DATA...")
    
    tp_with_rbs = [o for o in tp_orfs if o.get('rbs_sequence')][:50]
    fp_with_rbs = [o for o in fp_orfs if o.get('rbs_sequence')][:50]
    
    # Check for SD motifs in detected RBS sequences
    sd_motifs = ['AGGAGG', 'GGAGG', 'AGGAG', 'GAGG', 'AGGA']
    
    print(f"\n   Shine-Dalgarno motif presence in detected RBS sequences:")
    print(f"   (Analyzing {len(tp_with_rbs)} TP and {len(fp_with_rbs)} FP samples)")
    
    for motif in sd_motifs:
        tp_count = sum(1 for o in tp_with_rbs if motif in str(o.get('rbs_sequence', '')))
        fp_count = sum(1 for o in fp_with_rbs if motif in str(o.get('rbs_sequence', '')))
        tp_pct = tp_count / len(tp_with_rbs) * 100 if tp_with_rbs else 0
        fp_pct = fp_count / len(fp_with_rbs) * 100 if fp_with_rbs else 0
        print(f"   {motif:8s}: TP={tp_pct:5.1f}%  FP={fp_pct:5.1f}%  (diff={tp_pct-fp_pct:+6.1f}%)")
    
    print("\n5. RBS SPACING DISTRIBUTION...")
    
    tp_spacing = [o.get('rbs_spacing') for o in tp_orfs if o.get('rbs_spacing')]
    fp_spacing = [o.get('rbs_spacing') for o in fp_orfs if o.get('rbs_spacing')]
    
    if tp_spacing:
        from collections import Counter
        tp_spacing_counts = Counter(tp_spacing)
        fp_spacing_counts = Counter(fp_spacing)
        
        print(f"\n   Spacing distribution (nucleotides between SD and start codon):")
        print(f"   {'Spacing':>8} {'TP Count':>10} {'TP %':>8} {'FP Count':>10} {'FP %':>8}")
        print(f"   {'-'*8} {'-'*10} {'-'*8} {'-'*10} {'-'*8}")
        
        all_spacings = sorted(set(tp_spacing_counts.keys()) | set(fp_spacing_counts.keys()))
        for spacing in all_spacings[:15]:
            tp_count = tp_spacing_counts.get(spacing, 0)
            fp_count = fp_spacing_counts.get(spacing, 0)
            tp_pct = tp_count / len(tp_spacing) * 100 if tp_spacing else 0
            fp_pct = fp_count / len(fp_spacing) * 100 if fp_spacing else 0
            print(f"   {spacing:8d} {tp_count:10d} {tp_pct:7.1f}% {fp_count:10d} {fp_pct:7.1f}%")
    
    print("\n" + "="*80)
    print("SUMMARY")
    print("="*80)
    
    if effect_size >= 0.8:
        print("\n✓✓ RBS scoring is working well!")
        print(f"   - Strong separation: Cohen's d = {effect_size:.3f}")
        print(f"   - Mean difference: {tp_mean - fp_mean:.2f}")
    elif effect_size >= 0.5:
        print("\n✓ RBS scoring is working moderately well")
        print(f"   - Moderate separation: Cohen's d = {effect_size:.3f}")
    else:
        print("\n  RBS scoring needs improvement")
        print(f"   - Weak separation: Cohen's d = {effect_size:.3f}")
    
    print("="*80)

def diagnose_imm_implementation(
    genome_id,
    all_orfs,
    cached_data,
    models,
    training_set=None,
    intergenic_set=None
):
    """
    Comprehensive diagnostic for frame-aware IMM score problems.
    
    Args:
        genome_id: Genome identifier
        all_orfs: List of ORF dictionaries with 'imm_score'
        cached_data: Cached data with genome info
        models: Model dictionary from build_all_scoring_models()
        training_set: Optional training sequences
        intergenic_set: Optional intergenic sequences
    """

    print("="*80)
    print("FRAME-AWARE IMM DIAGNOSTIC")
    print("="*80)
    
    # Load reference genes
    print("\n0. LOADING REFERENCE GENES...")
    try:
        gff_path = get_gff_path(genome_id)
        reference_genes = load_reference_genes_from_gff(gff_path)
        print(f"   Loaded {len(reference_genes):,} reference genes")
    except Exception as e:
        print(f"   [ERROR] {e}")
        return
    
    # Extract models
    coding_imm = models['coding_imm']
    noncoding_imm = models['noncoding_imm']
    max_order = models['max_order']
    
    # ========================================================================
    # STEP 1: Classify ORFs as TP/FP
    # ========================================================================
    print("\n1. CLASSIFYING ORFs AS TP/FP...")
    tp_orfs = []
    fp_orfs = []
    
    for orf in all_orfs:
        orf_start = orf.get('genome_start', orf['start'])
        orf_end = orf.get('genome_end', orf['end'])
        is_match = (orf_start, orf_end) in reference_genes
        
        if is_match:
            tp_orfs.append(orf)
        else:
            fp_orfs.append(orf)
    
    print(f"   TP ORFs: {len(tp_orfs):,}")
    print(f"   FP ORFs: {len(fp_orfs):,}")
    
    if len(tp_orfs) == 0:
        print("   [CRITICAL] No true positives found!")
        return
    
    # ========================================================================
    # STEP 2: Check IMM scores
    # ========================================================================
    print("\n2. CHECKING IMM SCORES...")
    if not all_orfs or 'imm_score' not in all_orfs[0]:
        print("   [ERROR] No IMM scores found!")
        return
    print("   IMM scores present")
    
    # ========================================================================
    # STEP 3: Analyze score distributions
    # ========================================================================
    print("\n3. ANALYZING IMM SCORE DISTRIBUTIONS...")
    tp_scores = [orf['imm_score'] for orf in tp_orfs]
    fp_scores = [orf['imm_score'] for orf in fp_orfs]
    
    tp_mean = np.mean(tp_scores)
    fp_mean = np.mean(fp_scores)
    tp_std = np.std(tp_scores)
    fp_std = np.std(fp_scores)
    tp_median = np.median(tp_scores)
    fp_median = np.median(fp_scores)
    
    print(f"   TP: mean={tp_mean:.6f}, std={tp_std:.6f}, median={tp_median:.6f}")
    print(f"   FP: mean={fp_mean:.6f}, std={fp_std:.6f}, median={fp_median:.6f}")
    print(f"   Mean difference: {tp_mean - fp_mean:.6f}")
    
    # Calculate effect size (Cohen's d)
    pooled_std = np.sqrt((tp_std**2 + fp_std**2) / 2)
    effect_size = abs(tp_mean - fp_mean) / pooled_std if pooled_std > 0 else 0
    
    print(f"   Effect size (Cohen's d): {effect_size:.4f}")
    if effect_size < 0.2:
        print("   [WARNING] Very weak separation")
    elif effect_size < 0.5:
        print("   [WARNING] Weak separation")
    elif effect_size < 0.8:
        print("   Moderate separation")
    else:
        print("   Strong separation")
    
    # ========================================================================
    # STEP 4: Check frame-aware model structure
    # ========================================================================
    print("\n4. CHECKING FRAME-AWARE MODEL STRUCTURE...")
    print(f"   Max order: {max_order}")
    
    # Verify frame-aware structure (must be list of 3)
    if not (isinstance(coding_imm, list) and len(coding_imm) == 3):
        print("   [ERROR] Not a frame-aware IMM! Expected list of 3 models.")
        return
    
    print("   Frame-aware IMM confirmed (3 position-specific models)")
    
    for pos in range(3):
        print(f"   Position {pos}: coding={len(coding_imm[pos])} contexts, "
              f"noncoding={len(noncoding_imm[pos])} contexts")
    
    # Sample from position 0 model
    print(f"\n   Sampling contexts from Position 0:")
    sample_contexts = list(coding_imm[0].keys())[:5]
    
    models_differ = False
    for context in sample_contexts:
        if context in coding_imm[0] and context in noncoding_imm[0]:
            coding_probs = coding_imm[0][context]
            noncoding_probs = noncoding_imm[0][context]
            
            for nuc in ['A', 'T', 'G', 'C']:
                c_prob = coding_probs.get(nuc, 0)
                nc_prob = noncoding_probs.get(nuc, 0)
                diff = abs(c_prob - nc_prob)
                
                if diff > 0.01:
                    models_differ = True
                
                if diff > 0.05:
                    print(f"   Context '{context}'+{nuc}: "
                          f"coding={c_prob:.3f}, noncoding={nc_prob:.3f} (diff={diff:.3f})")
    
    if not models_differ:
        print("   [CRITICAL] Models nearly identical")
    else:
        print("   Models show differences")
    
    # ========================================================================
    # STEP 5: Check training data quality
    # ========================================================================
    if training_set and intergenic_set:
        print("\n5. CHECKING TRAINING DATA...")
        
        train_seqs = [s['sequence'] for s in training_set]
        inter_seqs = [s['sequence'] for s in intergenic_set]
        
        total_coding_bp = sum(len(s) for s in train_seqs)
        total_noncoding_bp = sum(len(s) for s in inter_seqs)
        
        print(f"   Coding: {len(train_seqs)} sequences, {total_coding_bp:,} bp")
        print(f"   Noncoding: {len(inter_seqs)} sequences, {total_noncoding_bp:,} bp")
        
        # GC content
        coding_gc = sum(s.count('G') + s.count('C') for s in train_seqs) / total_coding_bp
        noncoding_gc = sum(s.count('G') + s.count('C') for s in inter_seqs) / total_noncoding_bp
        
        print(f"   Coding GC%: {coding_gc*100:.2f}%")
        print(f"   Noncoding GC%: {noncoding_gc*100:.2f}%")
        print(f"   GC difference: {abs(coding_gc - noncoding_gc)*100:.2f}%")
        
        if abs(coding_gc - noncoding_gc) < 0.02:
            print("   [WARNING] Very similar GC content")
    
    # ========================================================================
    # STEP 6: Diagnostic plots
    # ========================================================================
    print("\n6. GENERATING PLOTS...")
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Score distributions
    ax1 = axes[0]
    ax1.hist(fp_scores, bins=50, alpha=0.6, color='red', 
             label=f'FP (n={len(fp_scores):,})', density=True)
    ax1.hist(tp_scores, bins=50, alpha=0.6, color='green', 
             label=f'TP (n={len(tp_scores):,})', density=True)
    ax1.axvline(tp_mean, color='darkgreen', linestyle='--', linewidth=2, 
                label=f'TP mean: {tp_mean:.4f}')
    ax1.axvline(fp_mean, color='darkred', linestyle='--', linewidth=2, 
                label=f'FP mean: {fp_mean:.4f}')
    ax1.set_xlabel('IMM Score')
    ax1.set_ylabel('Density')
    ax1.set_title('Frame-Aware IMM: TP vs FP Distribution')
    ax1.legend(fontsize=9)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Score vs Length
    ax2 = axes[1]
    tp_lengths = [len(orf.get('sequence', '')) for orf in tp_orfs[:1000]]
    fp_lengths = [len(orf.get('sequence', '')) for orf in fp_orfs[:1000]]
    ax2.scatter(tp_lengths, tp_scores[:1000], alpha=0.3, s=10, color='green', label='TP')
    ax2.scatter(fp_lengths, fp_scores[:1000], alpha=0.3, s=10, color='red', label='FP')
    ax2.set_xlabel('ORF Length (bp)')
    ax2.set_ylabel('IMM Score')
    ax2.set_title('IMM Score vs ORF Length')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # ========================================================================
    # STEP 7: Summary
    # ========================================================================
    print("\n" + "="*80)
    print("DIAGNOSTIC SUMMARY")
    print("="*80)
    
    issues = []
    
    if effect_size < 0.2:
        issues.append("CRITICAL: Very weak TP/FP separation")
    
    if not models_differ:
        issues.append("CRITICAL: Models nearly identical")
    
    total_contexts = sum(len(coding_imm[i]) for i in range(3))
    if total_contexts < 30:
        issues.append("WARNING: Insufficient contexts (<30)")
    
    if max_order < 3:
        issues.append("WARNING: Low order model (<3)")
    
    if issues:
        print("\n[ISSUES FOUND]")
        for i, issue in enumerate(issues, 1):
            print(f"   {i}. {issue}")
        
        print("\n[RECOMMENDATIONS]")
        print("   1. Verify training sequences are true coding regions (CDS)")
        print("   2. Verify intergenic sequences are non-coding")
        print("   3. Increase training data (aim for >100kb per class)")
        print("   4. Check sequences are in correct reading frame")
    else:
        print("\n[STATUS] Frame-aware IMM working well")
        print(f"   Effect size: {effect_size:.3f}")
        print(f"   Total contexts: {total_contexts}")
    
    print("="*80)


In [85]:

def filter_candidates(all_orfs, codon_threshold=0, imm_threshold=0, length_threshold=0, combined_threshold=0):
    """
    Removes ORFs if:
    - ALL THREE scores (codon, imm, length) are below their thresholds, OR
    - combined_score is below its threshold

    """
    filtered_orfs = []
    
    for orf in all_orfs:
        length_score = orf.get('length_score', 0)
        codon_score = orf.get('codon_score', 0)
        imm_score = orf.get('imm_score', 0)
        combined_score = orf.get('combined_score', 0)
        
        # Remove if ALL THREE are below thresholds OR if combined_score is below threshold
        all_three_below = (length_score < length_threshold and 
                          codon_score < codon_threshold and 
                          imm_score < imm_threshold)
        combined_below = combined_score < combined_threshold
        
        if all_three_below or combined_below:
            continue
        else:
            filtered_orfs.append(orf)
    
    return filtered_orfs


## Known Limitation: Start Codon Selection

Example: ORF at 2,420,617-2,421,960 (1344 bp)
- BLAST: 100% match to 215 AA protein (645 bp)
- Issue: Algorithm selected upstream ATG
- Real start: ~2,421,317 (700 bp downstream)

Challenge: Multiple ATGs in same frame
- ORF detection finds longest possible ORF???
- Need better start codon discrimination???
- RBS positioning is key???

Possible improvements:
- Weight RBS more heavily for start selection
- Penalize unusually long ORFs without strong evidence
- Check alternative internal starts

In [86]:
def organize_nested_orfs(all_orfs):
    """
    Group ORFs by stop codon, sort by start position within each group.
    
    Returns:
        dict: {(strand, end_position): [list of ORFs sorted by start]}
    """
    groups = defaultdict(list)
    
    for orf in all_orfs:
        key = (orf['strand'], orf['end'])
        groups[key].append(orf)
    
    for key in groups:
        groups[key].sort(key=lambda x: x['start'])
    
    return groups

def select_best_starts(nested_groups, weights=None):
    """
    For each stop codon, select the best start position using multi-factor scoring.
    Uses START_SELECTION_WEIGHTS by default to prioritize biological signals
    over length when choosing between nested ORFs.
    
    Args:
        nested_groups (dict): {(strand, end): [list of ORFs with same stop]}
        weights (dict): Optional custom weights for start selection
        
    Returns:
        list: Selected ORFs (one per stop position)
    """
    if weights is None:
        weights = START_SELECTION_WEIGHTS
    
    print(f"\nSelecting best start for {len(nested_groups):,}")
    
    selected_orfs = []
    single_option = 0
    multiple_options = 0
    
    selection_reasons = {
        'rbs_winner': 0,
        'imm_winner': 0,
        'codon_winner': 0,
        'length_winner': 0
    }
    
    for (strand, end), orfs in nested_groups.items():
        if len(orfs) == 1:
            selected_orfs.append(orfs[0])
            single_option += 1
        else:
            # RECALCULATE score using start-selaection weights
            for orf in orfs:
                orf['start_selection_score'] = (
                    orf['codon_score_norm'] * weights['codon'] +
                    orf['imm_score_norm'] * weights['imm'] +
                    orf['rbs_score_norm'] * weights['rbs'] +
                    orf['length_score_norm'] * weights['length'] +
                    orf['start_score_norm'] * weights['start']
                )
            
            # Select based on the NEW score
            best_orf = max(orfs, key=lambda x: x['start_selection_score'])
            selected_orfs.append(best_orf)
            multiple_options += 1
            
            # Track selection reasons
            components = ['rbs_score_norm', 'imm_score_norm', 'codon_score_norm', 'length_score_norm']
            component_names = ['rbs_winner', 'imm_winner', 'codon_winner', 'length_winner']
            
            best_component_value = -999
            best_component = None
            for comp, name in zip(components, component_names):
                if best_orf[comp] > best_component_value:
                    best_component_value = best_orf[comp]
                    best_component = name
            
            if best_component:
                selection_reasons[best_component] += 1
    
    return selected_orfs

In [87]:
def analyze_all_groups_correct_vs_selected(nested_groups, genome_id, weights=None):
    """
    Compare, for all reference-containing groups, the real gene vs the algorithm's selected ORF.
    Performs component-wise comparisons and reports:
      - % correct selections
      - Which ORF wins more scoring components per group
      - Optional per-score preference summary
    
    Args:
        nested_groups: Dictionary of grouped ORFs
        genome_id: Genome accession ID (e.g., "NC_000913.3")
        weights: Optional dict of score weights
    
    Returns:
        list: Analyzed groups with comparison results
    """

    if weights is None:
        weights = START_SELECTION_WEIGHTS
    
    # Load reference coordinates using proper function
    gff_path = get_gff_path(genome_id)
    ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
    ref_genes = ref[ref[2] == "CDS"][[3, 4]]
    ref_genes.columns = ["start", "end"]
    ref_genes = ref_genes.drop_duplicates()
    ref_set = set(zip(ref_genes['start'], ref_genes['end']))
    
    analyzed_groups = []

    # Helper for weighted selection
    def weighted_score(orf):
        return (
            orf['codon_score_norm'] * weights.get('codon', 0) +
            orf['imm_score_norm'] * weights.get('imm', 0) +
            orf['rbs_score_norm'] * weights.get('rbs', 0) +
            orf['length_score_norm'] * weights.get('length', 0) +
            orf['start_score_norm'] * weights.get('start', 0)
        )

    # Iterate over groups
    for orfs_at_stop in nested_groups.values():
        if not orfs_at_stop:
            continue

        # Find reference orf
        correct_orf = None
        for orf in orfs_at_stop:
            coord = (orf.get('genome_start', orf['start']), orf.get('genome_end', orf['end']))
            if coord in ref_set:
                correct_orf = orf
                break
        if correct_orf is None:
            continue

        # Skip singletons
        if len(orfs_at_stop) == 1:
            continue

        selected_orf = max(orfs_at_stop, key=weighted_score)

        # Component-wise comparison
        score_keys = ['codon_score_norm', 'imm_score_norm', 'rbs_score_norm', 
                      'length_score_norm', 'start_score_norm']
        correct_points = 0
        selected_points = 0
        per_component = {}

        for key in score_keys:
            c_val = correct_orf.get(key, 0)
            s_val = selected_orf.get(key, 0)
            if c_val > s_val:
                correct_points += 1
                per_component[key] = "correct"
            elif s_val > c_val:
                selected_points += 1
                per_component[key] = "selected"
            else:
                per_component[key] = "tie"

        # Exact match test
        selected_coord = (selected_orf.get('genome_start', selected_orf['start']),
                         selected_orf.get('genome_end', selected_orf['end']))
        correct_coord = (correct_orf.get('genome_start', correct_orf['start']),
                        correct_orf.get('genome_end', correct_orf['end']))
        
        analyzed_groups.append({
            'correct_orf': correct_orf,
            'selected_orf': selected_orf,
            'correct_points': correct_points,
            'selected_points': selected_points,
            'is_correct': selected_coord == correct_coord,
            'group_size': len(orfs_at_stop),
            **per_component
        })

    # Reporting
    print("="*80)
    print(f"CORRECT vs SELECTED ANALYSIS: {genome_id}")
    print("="*80)
    print(f"Total groups with reference genes: {len(analyzed_groups)}")
    print(f"Selection weights: {weights}")
    print("="*80)

    correct_selections = [g for g in analyzed_groups if g['is_correct']]
    wrong_selections = [g for g in analyzed_groups if not g['is_correct']]

    print(f"\nSelection results:")
    print(f"  Correct selections: {len(correct_selections)} ({len(correct_selections)/len(analyzed_groups)*100:.1f}%)")
    print(f"  Wrong selections:   {len(wrong_selections)} ({len(wrong_selections)/len(analyzed_groups)*100:.1f}%)")

    # Point comparison
    correct_won_more = sum(1 for g in analyzed_groups if g['correct_points'] > g['selected_points'])
    selected_won_more = sum(1 for g in analyzed_groups if g['selected_points'] > g['correct_points'])
    tied = len(analyzed_groups) - correct_won_more - selected_won_more

    print("\n" + "="*80)
    print("POINT COMPARISON (Correct vs Selected)")
    print("="*80)
    print(f"  Correct gene won more components:   {correct_won_more:>4} ({correct_won_more/len(analyzed_groups)*100:>5.1f}%)")
    print(f"  Selected gene won more components:  {selected_won_more:>4} ({selected_won_more/len(analyzed_groups)*100:>5.1f}%)")
    print(f"  Tied:                               {tied:>4} ({tied/len(analyzed_groups)*100:>5.1f}%)")
    print(f"\nNote: {len(correct_selections)} groups where correct=selected contribute to 'Tied'")

    # Per-score summary
    print("\n" + "="*80)
    print("SCORE COMPONENT PREFERENCE SUMMARY")
    print("="*80)
    for key in ['codon_score_norm', 'imm_score_norm', 'rbs_score_norm', 
                'length_score_norm', 'start_score_norm']:
        correct_favored = sum(1 for g in analyzed_groups if g[key] == 'correct')
        selected_favored = sum(1 for g in analyzed_groups if g[key] == 'selected')
        tied_comp = sum(1 for g in analyzed_groups if g[key] == 'tie')
        print(f"{key:20s} -> correct: {correct_favored:5}  selected: {selected_favored:5}  tied: {tied_comp:5}")

    return analyzed_groups

def comprehensive_fn_analysis(nested_groups, genome_id, weights=None):
    """
    Complete analysis of false negatives from start selection.
    Shows raw scores, weighted contributions, and point-based comparison.
    
    Args:
        nested_groups: Dictionary of grouped ORFs
        genome_id: Genome accession ID (e.g., "NC_000913.3")
        weights: Optional dict of score weights
    
    Returns:
        dict: FN statistics and analysis results
    """
    if weights is None:
        weights = weights = START_SELECTION_WEIGHTS
        
    # Load reference
    gff_path = get_gff_path(genome_id)
    ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
    ref_genes = ref[ref[2] == "CDS"][[3, 4]]
    ref_genes.columns = ["start", "end"]
    ref_genes = ref_genes.drop_duplicates()
    ref_set = set(zip(ref_genes['start'], ref_genes['end']))
    
    lost_genes = []
    
    def weighted_score(orf):
        return (
            orf['codon_score_norm'] * weights.get('codon', 0) +
            orf['imm_score_norm'] * weights.get('imm', 0) +
            orf['rbs_score_norm'] * weights.get('rbs', 0) +
            orf['length_score_norm'] * weights.get('length', 0) +
            orf['start_score_norm'] * weights.get('start', 0)
        )
    
    for orfs_at_stop in nested_groups.values():
        if len(orfs_at_stop) == 1:
            continue
        
        correct_orf = None
        for orf in orfs_at_stop:
            coord = (orf.get('genome_start', orf['start']), 
                    orf.get('genome_end', orf['end']))
            if coord in ref_set:
                correct_orf = orf
                break
        
        if correct_orf is None:
            continue
        
        selected_orf = max(orfs_at_stop, key=weighted_score)
        
        selected_coord = (selected_orf.get('genome_start', selected_orf['start']),
                         selected_orf.get('genome_end', selected_orf['end']))
        correct_coord = (correct_orf.get('genome_start', correct_orf['start']),
                        correct_orf.get('genome_end', correct_orf['end']))
        
        if selected_coord != correct_coord:
            # Calculate points
            score_keys = ['codon_score_norm', 'imm_score_norm', 'rbs_score_norm',
                         'length_score_norm', 'start_score_norm']
            correct_points = 0
            selected_points = 0
            
            for key in score_keys:
                if correct_orf[key] > selected_orf[key]:
                    correct_points += 1
                elif selected_orf[key] > correct_orf[key]:
                    selected_points += 1
            
            lost_genes.append({
                'correct': correct_orf,
                'selected': selected_orf,
                'correct_points': correct_points,
                'selected_points': selected_points
            })
    
    print("="*80)
    print(f"COMPREHENSIVE FALSE NEGATIVE ANALYSIS: {genome_id}")
    print(f"Total genes lost: {len(lost_genes)}")
    print(f"Selection weights: {weights}")
    print("="*80)
    
    if not lost_genes:
        print("\nNo false negatives with these weights")
        return {'fn_count': 0}
    
    # Calculate all statistics
    imm_raw = []
    codon_raw = []
    rbs_raw = []
    length_raw = []
    start_raw = []
    
    imm_weighted = []
    codon_weighted = []
    rbs_weighted = []
    length_weighted = []
    start_weighted = []
    
    for case in lost_genes:
        sel = case['selected']
        cor = case['correct']
        
        # Raw differences
        imm_raw.append(sel['imm_score_norm'] - cor['imm_score_norm'])
        codon_raw.append(sel['codon_score_norm'] - cor['codon_score_norm'])
        rbs_raw.append(sel['rbs_score_norm'] - cor['rbs_score_norm'])
        length_raw.append(sel['length_score_norm'] - cor['length_score_norm'])
        start_raw.append(sel['start_score_norm'] - cor['start_score_norm'])
        
        # Weighted differences
        imm_weighted.append((sel['imm_score_norm'] - cor['imm_score_norm']) * weights.get('imm', 0))
        codon_weighted.append((sel['codon_score_norm'] - cor['codon_score_norm']) * weights.get('codon', 0))
        rbs_weighted.append((sel['rbs_score_norm'] - cor['rbs_score_norm']) * weights.get('rbs', 0))
        length_weighted.append((sel['length_score_norm'] - cor['length_score_norm']) * weights.get('length', 0))
        start_weighted.append((sel['start_score_norm'] - cor['start_score_norm']) * weights.get('start', 0))
    
    # Build summary table
    print("\n" + "="*120)
    print(f"{'Component':<10} | {'Weight':<6} | {'Raw Diff':<20} | {'Weighted Contrib':<20} | {'Favored Correct':<15} | {'Favored Wrong':<15}")
    print("="*120)
    
    components = [
        ('Codon', codon_raw, codon_weighted, weights.get('codon', 0)),
        ('IMM', imm_raw, imm_weighted, weights.get('imm', 0)),
        ('RBS', rbs_raw, rbs_weighted, weights.get('rbs', 0)),
        ('Length', length_raw, length_weighted, weights.get('length', 0)),
        ('Start', start_raw, start_weighted, weights.get('start', 0))
    ]
    
    for name, raw, weighted, weight in components:
        favored_correct = sum(1 for r in raw if r < 0)
        favored_wrong = sum(1 for r in raw if r > 0)
        
        raw_mean = np.mean(raw)
        weighted_mean = np.mean(weighted) if weight != 0 else 0
        
        if weight == 0:
            print(f"{name:<10} | {weight:>6.1f} | {raw_mean:>8.3f} (mean)    | {'DISABLED':>20} | {favored_correct:>4} ({favored_correct/len(lost_genes)*100:>5.1f}%) | {favored_wrong:>4} ({favored_wrong/len(lost_genes)*100:>5.1f}%)")
        else:
            print(f"{name:<10} | {weight:>6.1f} | {raw_mean:>8.3f} (mean)    | {weighted_mean:>8.3f} (mean)    | {favored_correct:>4} ({favored_correct/len(lost_genes)*100:>5.1f}%) | {favored_wrong:>4} ({favored_wrong/len(lost_genes)*100:>5.1f}%)")
    
    print("="*120)
    
    # Total weighted contribution
    total_weighted = [sum([imm_weighted[i], codon_weighted[i], rbs_weighted[i], 
                           length_weighted[i], start_weighted[i]]) 
                      for i in range(len(lost_genes))]
    print(f"\nTotal weighted advantage for wrong start: mean={np.mean(total_weighted):.3f}, median={np.median(total_weighted):.3f}")
    
    # Point-based summary
    correct_won_more = sum(1 for c in lost_genes if c['correct_points'] > c['selected_points'])
    selected_won_more = sum(1 for c in lost_genes if c['selected_points'] > c['correct_points'])
    tied = sum(1 for c in lost_genes if c['correct_points'] == c['selected_points'])
    
    print(f"\n{'='*80}")
    print("POINT-BASED COMPARISON (1 point per component won)")
    print(f"{'='*80}")
    print(f"  Correct gene won MORE components:   {correct_won_more:>4} ({correct_won_more/len(lost_genes)*100:>5.1f}%)")
    print(f"  Selected gene won MORE components:  {selected_won_more:>4} ({selected_won_more/len(lost_genes)*100:>5.1f}%)")
    print(f"  Tied:                                {tied:>4} ({tied/len(lost_genes)*100:>5.1f}%)")
    
    return {
        'fn_count': len(lost_genes),
        'raw_diffs': {
            'imm': np.mean(imm_raw),
            'codon': np.mean(codon_raw),
            'rbs': np.mean(rbs_raw),
            'length': np.mean(length_raw),
            'start': np.mean(start_raw)
        },
        'weighted_contributions': {
            'imm': np.mean(imm_weighted),
            'codon': np.mean(codon_weighted),
            'rbs': np.mean(rbs_weighted),
            'length': np.mean(length_weighted),
            'start': np.mean(start_weighted)
        },
        'point_comparison': {
            'correct_won_more': correct_won_more,
            'selected_won_more': selected_won_more,
            'tied': tied
        }
    }

def analyze_negative_scores(all_orfs, reference_genes):
    """
    Analyze how many TP and FP ORFs have negative scores,
    both individually and in all possible combinations.
    
    Args:
        all_orfs: All ORF candidates with scores
        reference_genes: Set of (start, end) tuples for reference genes
    """
    from itertools import combinations
    
    # Classify ORFs as TP or FP
    print("="*80)
    print("NEGATIVE SCORE ANALYSIS")
    print("="*80)
    
    tp_orfs = []
    fp_orfs = []
    
    for orf in all_orfs:
        orf_start = orf.get('genome_start', orf['start'])
        orf_end = orf.get('genome_end', orf['end'])
        
        if (orf_start, orf_end) in reference_genes:
            tp_orfs.append(orf)
        else:
            fp_orfs.append(orf)
    
    print(f"Total ORFs: {len(all_orfs):,}")
    print(f"  True Positives: {len(tp_orfs):,}")
    print(f"  False Positives: {len(fp_orfs):,}")
    
    # Individual score analysis
    print("\n" + "="*80)
    print("INDIVIDUAL SCORES WITH NEGATIVE VALUES")
    print("="*80)
    
    scores_to_check = ['codon_score', 'imm_score', 'length_score', 'rbs_score', 'start_score']
    
    print(f"\n{'Score':<20} {'TP Negative':>15} {'TP %':>10} {'FP Negative':>15} {'FP %':>10}")
    print("-" * 80)
    
    for score_name in scores_to_check:
        tp_negative = sum(1 for orf in tp_orfs if orf.get(score_name, 0) < 0)
        fp_negative = sum(1 for orf in fp_orfs if orf.get(score_name, 0) < 0)
        
        tp_pct = tp_negative / len(tp_orfs) * 100 if tp_orfs else 0
        fp_pct = fp_negative / len(fp_orfs) * 100 if fp_orfs else 0
        
        print(f"{score_name:<20} {tp_negative:>15,} {tp_pct:>9.2f}% {fp_negative:>15,} {fp_pct:>9.2f}%")
    
    # Two-score combinations
    print("\n" + "="*80)
    print("TWO-SCORE COMBINATIONS (BOTH NEGATIVE)")
    print("="*80)
    
    main_scores = ['codon_score', 'imm_score', 'length_score']
    
    print(f"\n{'Combination':<40} {'TP Count':>15} {'TP %':>10} {'FP Count':>15} {'FP %':>10}")
    print("-" * 90)
    
    for score1, score2 in combinations(main_scores, 2):
        tp_count = sum(1 for orf in tp_orfs 
                      if orf.get(score1, 0) < 0 and orf.get(score2, 0) < 0)
        fp_count = sum(1 for orf in fp_orfs 
                      if orf.get(score1, 0) < 0 and orf.get(score2, 0) < 0)
        
        tp_pct = tp_count / len(tp_orfs) * 100 if tp_orfs else 0
        fp_pct = fp_count / len(fp_orfs) * 100 if fp_orfs else 0
        
        combo_name = f"{score1} & {score2}"
        print(f"{combo_name:<40} {tp_count:>15,} {tp_pct:>9.2f}% {fp_count:>15,} {fp_pct:>9.2f}%")
    
    # Three-score combination
    print("\n" + "="*80)
    print("THREE-SCORE COMBINATION (ALL VERY LOW)")
    print("="*80)
    
    tp_all_three = sum(1 for orf in tp_orfs 
                      if orf.get('codon_score', 0) < 0.2 
                      and orf.get('imm_score', 0) < 0.2 
                      and orf.get('length_score', 0) < 0.2)
    fp_all_three = sum(1 for orf in fp_orfs 
                      if orf.get('codon_score', 0) < 0.2
                      and orf.get('imm_score', 0) < 0.2 
                      and orf.get('length_score', 0) < 0.2)
    
    tp_pct = tp_all_three / len(tp_orfs) * 100 if tp_orfs else 0
    fp_pct = fp_all_three / len(fp_orfs) * 100 if fp_orfs else 0
    
    print(f"\n{'codon_score & imm_score & length_score':<40} {'TP Count':>15} {'TP %':>10} {'FP Count':>15} {'FP %':>10}")
    print("-" * 90)
    print(f"{'ALL THREE VERY LOW (<0.2)':<40} {tp_all_three:>15,} {tp_pct:>9.2f}% {fp_all_three:>15,} {fp_pct:>9.2f}%")
    
    # Combined score analysis
    if 'combined_score' in all_orfs[0]:
        print("\n" + "="*80)
        print("COMBINED SCORE ANALYSIS")
        print("="*80)
        
        tp_combined_neg = sum(1 for orf in tp_orfs if orf.get('combined_score', 0) < 0.2)
        fp_combined_neg = sum(1 for orf in fp_orfs if orf.get('combined_score', 0) < 0.2)
        
        tp_pct = tp_combined_neg / len(tp_orfs) * 100 if tp_orfs else 0
        fp_pct = fp_combined_neg / len(fp_orfs) * 100 if fp_orfs else 0
        
        print(f"\n{'Combined Score < 0.2':<40} {'TP Count':>15} {'TP %':>10} {'FP Count':>15} {'FP %':>10}")
        print("-" * 90)
        print(f"{'VERY LOW combined_score':<40} {tp_combined_neg:>15,} {tp_pct:>9.2f}% {fp_combined_neg:>15,} {fp_pct:>9.2f}%")

def diagnose_false_negatives(genome_id, cached_data, all_orfs, reference_genes, min_length=100):
    """
    Analyze why reference genes are not being detected as ORFs.
    
    Args:
        genome_id: Genome accession ID
        cached_data: Cached genome data
        all_orfs: All detected ORF candidates
        reference_genes: Set of (start, end) tuples for reference genes
        min_length: Minimum ORF length used in detection
    """
    
    def reverse_complement(seq):
        """Simple reverse complement function."""
        complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', 'N': 'N'}
        return ''.join(complement.get(base, 'N') for base in reversed(seq))
    
    print("="*100)
    print("FALSE NEGATIVE ANALYSIS - Why are genes missing?")
    print("="*100)
    
    # Get genome sequence
    genome_sequence = cached_data[genome_id]['sequence']
    
    # Identify which genes are missing
    detected_positions = set()
    for orf in all_orfs:
        orf_start = orf.get('genome_start', orf['start'])
        orf_end = orf.get('genome_end', orf['end'])
        detected_positions.add((orf_start, orf_end))
    
    missing_genes = reference_genes - detected_positions
    found_genes = reference_genes & detected_positions
    
    print(f"\nTotal reference genes: {len(reference_genes):,}")
    print(f"Detected as ORFs: {len(found_genes):,} ({len(found_genes)/len(reference_genes)*100:.1f}%)")
    print(f"NOT detected (False Negatives): {len(missing_genes):,} ({len(missing_genes)/len(reference_genes)*100:.1f}%)")
    
    if not missing_genes:
        print("\nAll reference genes were detected as ORFs")
        return
    
    # Load GFF for detailed analysis
    gff_path = get_gff_path(genome_id)
    print(f"\nLoading gene details from GFF...")
    
    gff_df = pd.read_csv(gff_path, sep="\t", comment="#", header=None,
                         names=['seqid', 'source', 'type', 'start', 'end', 'score', 'strand', 'phase', 'attributes'])
    
    gff_df = gff_df[gff_df['type'] == 'CDS']
    
    # Create gene_details dictionary
    gene_details = {}
    for _, row in gff_df.iterrows():
        start = int(row['start'])
        end = int(row['end'])
        
        gene_details[(start, end)] = {
            'length': end - start + 1,
            'strand': row['strand'],
            'start_pos': start,
            'end_pos': end,
            'attributes': row['attributes']
        }
    
    # Categorize missing genes by reason
    reasons = {
        'too_short': [],
        'no_start_codon': [],
        'no_stop_codon': [],
        'overlapping': [],
        'frameshifted': [],
        'other': []
    }
    
    start_codons = {'ATG', 'GTG', 'TTG'}
    stop_codons = {'TAA', 'TAG', 'TGA'}
    
    print(f"\nAnalyzing {len(missing_genes):,} missing genes...")
    
    for gene_pos in missing_genes:
        if gene_pos not in gene_details:
            reasons['other'].append({'position': gene_pos, 'reason': 'Not in GFF details'})
            continue
        
        details = gene_details[gene_pos]
        start = details['start_pos']
        end = details['end_pos']
        length = details['length']
        strand = details['strand']
        
        # Check length
        if length < min_length:
            reasons['too_short'].append({
                'position': gene_pos,
                'length': length,
                'strand': strand
            })
            continue
        
        # Extract sequence
        if strand == '+':
            gene_seq = genome_sequence[start-1:end]
        else:
            gene_seq = reverse_complement(genome_sequence[start-1:end])
        
        # Check for start codon
        start_codon = gene_seq[:3]
        has_start = start_codon in start_codons
        
        # Check for stop codon
        stop_codon = gene_seq[-3:]
        has_stop = stop_codon in stop_codons
        
        # Check if it's in frame
        in_frame = (length % 3 == 0)
        
        if not has_start:
            reasons['no_start_codon'].append({
                'position': gene_pos,
                'length': length,
                'strand': strand,
                'start_codon': start_codon
            })
        elif not has_stop:
            reasons['no_stop_codon'].append({
                'position': gene_pos,
                'length': length,
                'strand': strand,
                'stop_codon': stop_codon
            })
        elif not in_frame:
            reasons['frameshifted'].append({
                'position': gene_pos,
                'length': length,
                'strand': strand
            })
        else:
            # Check for overlaps
            is_overlapping = False
            for other_gene_pos in reference_genes:
                if other_gene_pos == gene_pos:
                    continue
                other_start, other_end = other_gene_pos
                if not (end < other_start or start > other_end):
                    is_overlapping = True
                    break
            
            if is_overlapping:
                reasons['overlapping'].append({
                    'position': gene_pos,
                    'length': length,
                    'strand': strand
                })
            else:
                reasons['other'].append({
                    'position': gene_pos,
                    'length': length,
                    'strand': strand,
                    'start_codon': start_codon,
                    'stop_codon': stop_codon
                })
    
    # Report results
    print("\n" + "="*100)
    print("CATEGORIZATION OF FALSE NEGATIVES")
    print("="*100)
    
    print(f"\n{'Category':<25} {'Count':>10} {'% of FN':>10} {'% of Total':>12}")
    print("-" * 100)
    
    for category, genes in reasons.items():
        count = len(genes)
        pct_fn = count / len(missing_genes) * 100 if missing_genes else 0
        pct_total = count / len(reference_genes) * 100
        
        category_name = category.replace('_', ' ').title()
        print(f"{category_name:<25} {count:>10,} {pct_fn:>9.1f}% {pct_total:>11.1f}%")
    
    # Detailed breakdown
    print("\n" + "="*100)
    print("DETAILED BREAKDOWN")
    print("="*100)
    
    if reasons['too_short']:
        print(f"\n1. TOO SHORT (< {min_length} bp): {len(reasons['too_short']):,} genes")
        lengths = [g['length'] for g in reasons['too_short']]
        print(f"   Length range: {min(lengths)}-{max(lengths)} bp")
        print(f"   Mean length: {sum(lengths)/len(lengths):.1f} bp")
        print(f"   [INFO] Solution: Consider lowering min_length threshold (currently {min_length})")
        
        short_bins = [0, 50, 75, 100]
        for i in range(len(short_bins)):
            if i < len(short_bins) - 1:
                count = sum(1 for l in lengths if short_bins[i] <= l < short_bins[i+1])
                print(f"      {short_bins[i]}-{short_bins[i+1]-1} bp: {count:,} genes")
            else:
                count = sum(1 for l in lengths if l >= short_bins[i])
                print(f"      {short_bins[i]}+ bp: {count:,} genes")
    
    if reasons['no_start_codon']:
        print(f"\n2. NO STANDARD START CODON: {len(reasons['no_start_codon']):,} genes")
        start_codons_found = {}
        for g in reasons['no_start_codon']:
            codon = g['start_codon']
            start_codons_found[codon] = start_codons_found.get(codon, 0) + 1
        print("   Start codons found:")
        for codon, count in sorted(start_codons_found.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"      {codon}: {count:,}")
        print(f"   [INFO] Solution: These genes use non-standard start codons")
        print(f"      Consider adding alternative start codons to detection")
    
    if reasons['no_stop_codon']:
        print(f"\n3. NO STANDARD STOP CODON: {len(reasons['no_stop_codon']):,} genes")
        print(f"   [INFO] Solution: These genes may extend beyond expected boundaries")
        print(f"      or have annotation errors")
    
    if reasons['frameshifted']:
        print(f"\n4. NOT IN READING FRAME: {len(reasons['frameshifted']):,} genes")
        print(f"   [INFO] Solution: These genes may have programmed frameshifts")
        print(f"      or annotation errors")
    
    if reasons['overlapping']:
        print(f"\n5. OVERLAPPING WITH OTHER GENES: {len(reasons['overlapping']):,} genes")
        print(f"   [INFO] Solution: These genes overlap with others on the same or opposite strand")
        print(f"      Your ORF detection may be choosing one over the other")
    
    if reasons['other']:
        print(f"\n6. OTHER REASONS: {len(reasons['other']):,} genes")
        print(f"   These genes have standard start/stop codons and correct length")
        print(f"   but are still not detected. Requires manual investigation.")
        
        print(f"\n   Sample cases:")
        for i, g in enumerate(reasons['other'][:5], 1):
            pos = g['position']
            print(f"      {i}. Position {pos[0]:,}-{pos[1]:,}, {g['length']} bp, strand {g['strand']}")
            print(f"         Start: {g.get('start_codon', 'N/A')}, Stop: {g.get('stop_codon', 'N/A')}")
    

In [88]:
def process_genome(genome_id, cached_data):
    """Process a single genome through the ORF prediction pipeline."""
    genome_data = cached_data[genome_id]
    all_orfs = genome_data['orfs']
    
    training_set = create_training_set(genome_id, cached_data)
    intergenic_set = create_intergenic_set(genome_id, cached_data)
    
    models = build_all_scoring_models(training_set, intergenic_set)
    
    scored_orfs = score_all_orfs(all_orfs, models)
    scored_orfs = normalize_all_orf_scores(scored_orfs)
    scored_orfs = add_combined_scores(scored_orfs)
    
    candidates = filter_candidates(scored_orfs)
    grouped_orfs = organize_nested_orfs(candidates)
    top_candidates = select_best_starts(grouped_orfs)
    top_candidates = filter_candidates(top_candidates, 0.2, 0.2, 0.2, 0)
    
    compare_orfs_to_reference(top_candidates, genome_id)
    
    return top_candidates

#TODO attempt to improve scoring efficieny

In [89]:
from functools import lru_cache

# Global cache variables - MUST match traditional_methods.py
_GLOBAL_CODING_IMM = None
_GLOBAL_NONCODING_IMM = None
def clear_imm_cache():
    """Clear the LRU cache for IMM scoring."""
    get_interpolated_probability_cached.cache_clear()

def build_all_scoring_models_cached(
    training_set: List[Dict], 
    intergenic_set: List[Dict], 
    min_observations: int = 10
) -> Dict:
    """Build all traditional scoring models from training data."""
    print("Building traditional scoring models...")
    start_time = time.time()
    
    clear_imm_cache()
    
    print("  Building codon usage models...")
    codon_model = build_codon_model(training_set)
    background_codon_model = build_codon_model(intergenic_set)

    print("  Building IMM models...")
    training_seqs = [orf['sequence'] for orf in training_set]
    intergenic_seqs = [orf['sequence'] for orf in intergenic_set]
    
    n_training = sum(len(seq) for seq in training_seqs)
    n_intergenic = sum(len(seq) for seq in intergenic_seqs)
    effective_n = min(n_training, n_intergenic)
    
    if effective_n < min_observations:
        estimated_order = 0
    else:
        estimated_order = math.floor(math.log2(effective_n / min_observations) / 2)
    
    estimated_order = min(estimated_order, 8)
    estimated_order = max(estimated_order, 3)
    
    coding_imm = build_interpolated_markov_model(training_seqs, estimated_order, min_observations)
    noncoding_imm = build_interpolated_markov_model(intergenic_seqs, estimated_order, min_observations)
    
    print(f"✓ All models built in {time.time() - start_time:.1f}s")
    print(f"  IMM order: {estimated_order}")
    print(f"  Training sequences: {len(training_seqs)} ({n_training:,} bp)")
    print(f"  Intergenic sequences: {len(intergenic_seqs)} ({n_intergenic:,} bp)")
    
    return {
        'codon_model': codon_model,
        'background_codon_model': background_codon_model,
        'coding_imm': coding_imm,
        'noncoding_imm': noncoding_imm,
        'max_order': estimated_order
    }


@lru_cache(maxsize=200000)  
def get_interpolated_probability_cached(
    nucleotide: str,
    context: str,
    codon_pos: int,  
    imm_type: str,   
    fallback_prob: float = 0.25
) -> float:

    global _GLOBAL_CODING_IMM, _GLOBAL_NONCODING_IMM  
    
    probabilities = _GLOBAL_CODING_IMM if imm_type == 'coding' else _GLOBAL_NONCODING_IMM
    
    
    for order in range(len(context), -1, -1):
        current_context = context[-order:] if order > 0 else ""
        
        
        if current_context in probabilities[codon_pos]:
            if nucleotide in probabilities[codon_pos][current_context]:
                return probabilities[codon_pos][current_context][nucleotide]

    return fallback_prob

def score_imm_ratio_cached(
    sequence: str, 
    coding_imm: List[Dict], 
    noncoding_imm: List[Dict], 
    max_order: int
) -> float:
    """Score sequence using frame-aware IMM log-likelihood ratio."""
    global _GLOBAL_CODING_IMM, _GLOBAL_NONCODING_IMM 
    
    _GLOBAL_CODING_IMM = coding_imm
    _GLOBAL_NONCODING_IMM = noncoding_imm
    
    if len(sequence) < 3:
        return 0.0
    
    is_frame_aware = isinstance(coding_imm, list) and len(coding_imm) == 3
    
    coding_log_prob = 0.0
    noncoding_log_prob = 0.0
    EPSILON = 1e-10
    
    for i in range(len(sequence)):
        nucleotide = sequence[i]
        
        context_start = max(0, i - max_order)
        context = sequence[context_start:i]
        
        if is_frame_aware:
            codon_position = i % 3
            coding_prob = get_interpolated_probability_cached(
                nucleotide, context, codon_position, 'coding'
            )
            noncoding_prob = get_interpolated_probability_cached(
                nucleotide, context, codon_position, 'noncoding'
            )
        else:
            coding_prob = get_interpolated_probability_cached(
                nucleotide, context, 0, 'coding'
            )
            noncoding_prob = get_interpolated_probability_cached(
                nucleotide, context, 0, 'noncoding'
            )
        
        coding_prob = max(coding_prob, EPSILON)
        noncoding_prob = max(noncoding_prob, EPSILON)
        
        coding_log_prob += math.log(coding_prob)
        noncoding_log_prob += math.log(noncoding_prob)
    
    return (coding_log_prob - noncoding_log_prob) / len(sequence)

def score_all_orfs_cached(all_orfs, models):
    """
    Score all ORFs with all methods using pre-built models.
    Adds all scores and RBS details to each ORF dict.
    
    Args:
        all_orfs (list): List of ORF dictionaries
        models (dict): Pre-built models from build_all_scoring_models()
        
    Returns:
        list: ORFs with scores added
    """
    print(f"Scoring {len(all_orfs):,} ORFs...")
    start_time = time.time()
    
    codon_model = models['codon_model']
    background_codon_model = models['background_codon_model']
    coding_imm = models['coding_imm']
    noncoding_imm = models['noncoding_imm']
    max_order = models['max_order']
    
    for i, orf in enumerate(all_orfs):
        if i % 25000 == 0 and i > 0:
            print(f"  {i:,}...")
        
        # Score each component
        orf['codon_score'] = score_codon_bias_ratio(
            orf['sequence'], 
            codon_model, 
            background_codon_model
        )
        
        orf['imm_score'] = score_imm_ratio_cached(
            orf['sequence'], 
            coding_imm, 
            noncoding_imm,
            max_order
        )
        # Length score
        orf['length_score'] = score_orf_length(orf['length'])
        
        # Start codon score
        orf['start_score'] = score_start_codon(orf.get('start_codon', 'ATG'))
    
    print(f"Done in {(time.time() - start_time)/60:.1f} minutes")
    return all_orfs

def process_genome_cached(genome_id, cached_data):
    """
    Process genome with LRU cached IMM scoring.
    Drop-in replacement that works with your loop.
    """
    genome_data = cached_data[genome_id]
    all_orfs = genome_data['orfs']
    
    # Create training sets
    training_set = create_training_set(genome_id, cached_data)
    intergenic_set = create_intergenic_set(genome_id, cached_data)
    
    # Build models
    models = build_all_scoring_models_cached(training_set, intergenic_set)
    
    scored_orfs=score_all_orfs_cached(all_orfs, models)
    # Normalize and combine - use all_orfs since we scored in-place
    scored_orfs = normalize_all_orf_scores(scored_orfs)
    scored_orfs = add_combined_scores(scored_orfs)
    
    # Filter and select
    candidates = filter_candidates(scored_orfs)
    grouped_orfs = organize_nested_orfs(candidates)
    top_candidates = select_best_starts(grouped_orfs)
    final_predictions = filter_candidates(top_candidates, 0.2, 0.2, 0.2, 0)
    
    # Get metrics
    metrics = compare_orfs_to_reference(final_predictions, genome_id)
    
    return {
        'predictions': final_predictions,
        'metrics': metrics,
        'count': len(final_predictions)
    }


In [90]:
predictions = process_genome_cached(genome_id, cached_data)

Building traditional scoring models...
  Building codon usage models...
  Building IMM models...
✓ All models built in 2.6s
  IMM order: 6
  Training sequences: 1220 (1,809,003 bp)
  Intergenic sequences: 490 (148,082 bp)
Scoring 176,315 ORFs...
  25,000...
  50,000...
  75,000...
  100,000...
  125,000...
  150,000...
  175,000...
Done in 1.1 minutes

Normalizing ORFs...

Calculating weighted scores...

Selecting best start for 9,972
VALIDATION SUMMARY: NC_000913.3
Predicted ORFs:              5,687
Reference CDS (proteins):    4,340
True positives (exact):      3,427
False negatives (missed):    913
False positives (spurious):  2,260

Sensitivity (Recall):        78.96%
Precision:                   60.26%
F1 Score:                    68.36


WEIGHTS OPTIMIZATION- 2 hours run 

In [None]:
from scipy.optimize import differential_evolution

def objective_function(weight_array, grouped_orfs, ref_set):
    """
    Objective function: Returns NEGATIVE accuracy (to minimize).
    
    Args:
        weight_array: [codon, imm, rbs, length, start]
        grouped_orfs: The grouped ORFs
        ref_set: Set of reference gene coordinates
    
    Returns:
        -accuracy (negative so we can minimize)
    """
    weights = {
        'codon': weight_array[0],
        'imm': weight_array[1],
        'rbs': weight_array[2],
        'length': weight_array[3],
        'start': weight_array[4]
    }
    
    def weighted_score(orf):
        return (
            orf['codon_score_norm'] * weights['codon'] +
            orf['imm_score_norm'] * weights['imm'] +
            orf['rbs_score_norm'] * weights['rbs'] +
            orf['length_score_norm'] * weights['length'] +
            orf['start_score_norm'] * weights['start']
        )
    
    correct = 0
    total = 0
    
    for stop_pos, orf_group in grouped_orfs.items():
        if len(orf_group) < 2:
            continue
        
        correct_orf = None
        for orf in orf_group:
            coord = (orf.get('genome_start', orf['start']), 
                    orf.get('genome_end', orf['end']))
            if coord in ref_set:
                correct_orf = orf
                break
        
        if correct_orf is None:
            continue
        
        total += 1
        
        selected = max(orf_group, key=weighted_score)
        selected_coord = (selected.get('genome_start', selected['start']),
                         selected.get('genome_end', selected['end']))
        correct_coord = (correct_orf.get('genome_start', correct_orf['start']),
                        correct_orf.get('genome_end', correct_orf['end']))
        
        if selected_coord == correct_coord:
            correct += 1
    
    accuracy = correct / total if total > 0 else 0
    return -accuracy  # Negative because we minimize



In [None]:
def prepare_all_genome_data(TEST_GENOMES, cached_data):
    """
    Prepare grouped ORFs and reference sets for all genomes.
    Returns dict with genome_id -> (grouped_orfs, ref_set)
    """
    from src.comparative_analysis import get_gff_path
    
    all_genome_data = {}
    
    print("=" * 80)
    print("PREPARING ALL GENOME DATA")
    print("=" * 80)
    
    for gid in TEST_GENOMES:
        print(f"\nProcessing {gid}...")
        
        # Run pipeline to get grouped ORFs
        genome_data = cached_data[gid]
        all_orfs = genome_data['orfs']
        
        training_set = create_training_set(gid, cached_data)
        intergenic_set = create_intergenic_set(gid, cached_data)
        models = build_all_scoring_models_cached(training_set, intergenic_set)
        
        scored_orfs = score_all_orfs_cached(all_orfs, models)
        scored_orfs = normalize_all_orf_scores(scored_orfs)
        scored_orfs = add_combined_scores(scored_orfs)
        
        candidates = filter_candidates(scored_orfs)
        grouped_orfs = organize_nested_orfs(candidates)
        
        # Load reference
        gff_path = get_gff_path(gid)
        ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
        ref_genes = ref[ref[2] == "CDS"][[3, 4]]
        ref_genes.columns = ["start", "end"]
        ref_set = set(zip(ref_genes['start'], ref_genes['end']))
        
        all_genome_data[gid] = {
            'grouped_orfs': grouped_orfs,
            'ref_set': ref_set
        }
        
        print(f"  ✓ {len(grouped_orfs)} ORF groups prepared")
    
    print("\n" + "=" * 80)
    print(f"Total genomes prepared: {len(all_genome_data)}")
    print("=" * 80)
    
    return all_genome_data


def objective_function_multi_genome(weight_array, genome_data_dict, genome_ids):
    """
    Objective function for multiple genomes.
    Returns average negative accuracy across all genomes.
    
    Args:
        weight_array: [codon, imm, rbs, length, start]
        genome_data_dict: Dict of genome_id -> {grouped_orfs, ref_set}
        genome_ids: List of genome IDs to evaluate on
    
    Returns:
        -average_accuracy (negative for minimization)
    """
    weights = {
        'codon': weight_array[0],
        'imm': weight_array[1],
        'rbs': weight_array[2],
        'length': weight_array[3],
        'start': weight_array[4]
    }
    
    def weighted_score(orf):
        return (
            orf['codon_score_norm'] * weights['codon'] +
            orf['imm_score_norm'] * weights['imm'] +
            orf['rbs_score_norm'] * weights['rbs'] +
            orf['length_score_norm'] * weights['length'] +
            orf['start_score_norm'] * weights['start']
        )
    
    total_correct = 0
    total_count = 0
    
    for gid in genome_ids:
        grouped_orfs = genome_data_dict[gid]['grouped_orfs']
        ref_set = genome_data_dict[gid]['ref_set']
        
        correct = 0
        count = 0
        
        for stop_pos, orf_group in grouped_orfs.items():
            if len(orf_group) < 2:
                continue
            
            correct_orf = None
            for orf in orf_group:
                coord = (orf.get('genome_start', orf['start']), 
                        orf.get('genome_end', orf['end']))
                if coord in ref_set:
                    correct_orf = orf
                    break
            
            if correct_orf is None:
                continue
            
            count += 1
            
            selected = max(orf_group, key=weighted_score)
            selected_coord = (selected.get('genome_start', selected['start']),
                             selected.get('genome_end', selected['end']))
            correct_coord = (correct_orf.get('genome_start', correct_orf['start']),
                            correct_orf.get('genome_end', correct_orf['end']))
            
            if selected_coord == correct_coord:
                correct += 1
        
        total_correct += correct
        total_count += count
    
    average_accuracy = total_correct / total_count if total_count > 0 else 0
    return -average_accuracy  # Negative for minimization


def optimize_with_train_val_split(all_genome_data, train_ratio=0.7):
    """
    OPTION C: Train/Validation Split with proper ML methodology.
    
    Splits genomes into train/validation sets, optimizes on train,
    evaluates on validation to check generalization.
    """
    genome_ids = list(all_genome_data.keys())
    n_train = int(len(genome_ids) * train_ratio)
    
    # Shuffle and split
    import random
    random.seed(42)
    shuffled_ids = genome_ids.copy()
    random.shuffle(shuffled_ids)
    
    train_ids = shuffled_ids[:n_train]
    val_ids = shuffled_ids[n_train:]
    
    print("=" * 80)
    print("OPTION C: TRAIN/VALIDATION SPLIT OPTIMIZATION")
    print("=" * 80)
    print(f"Total genomes: {len(genome_ids)}")
    print(f"Training genomes ({len(train_ids)}): {train_ids}")
    print(f"Validation genomes ({len(val_ids)}): {val_ids}")
    print("=" * 80)
    print()
    
    # Optimize on training set
    print("PHASE 1: OPTIMIZING ON TRAINING SET")
    print("-" * 80)
    
    bounds = [
        (0.1, 5.0),  # codon
        (0.1, 10.0), # imm
        (0.1, 5.0),  # rbs
        (0.1, 10.0), # length
        (0.1, 3.0),  # start
    ]
    
    iteration = [0]
    best_train = [float('inf')]
    
    def callback_train(xk, convergence):
        iteration[0] += 1
        current_score = objective_function_multi_genome(xk, all_genome_data, train_ids)
        if current_score < best_train[0]:
            best_train[0] = current_score
            train_acc = -current_score * 100
            print(f"Iter {iteration[0]}: Train accuracy = {train_acc:.2f}% "
                  f"[codon={xk[0]:.2f}, imm={xk[1]:.2f}, rbs={xk[2]:.2f}, length={xk[3]:.2f}, start={xk[4]:.2f}]")
    
    result_train = differential_evolution(
        objective_function_multi_genome,
        bounds,
        args=(all_genome_data, train_ids),
        strategy='best1bin',
        maxiter=100,
        popsize=15,
        tol=0.0001,
        mutation=(0.5, 1),
        recombination=0.7,
        callback=callback_train,
        polish=True,
        workers=1,
        updating='deferred'
    )
    
    optimal_weights = {
        'codon': result_train.x[0],
        'imm': result_train.x[1],
        'rbs': result_train.x[2],
        'length': result_train.x[3],
        'start': result_train.x[4]
    }
    
    train_accuracy = -result_train.fun * 100
    
    print("\n" + "=" * 80)
    print("PHASE 2: EVALUATING ON VALIDATION SET")
    print("-" * 80)
    
    # Evaluate on validation set
    val_score = objective_function_multi_genome(result_train.x, all_genome_data, val_ids)
    val_accuracy = -val_score * 100
    
    print(f"Validation accuracy: {val_accuracy:.2f}%")
    
    # Evaluate on individual validation genomes
    print("\nPer-genome validation results:")
    for gid in val_ids:
        score = objective_function_multi_genome(result_train.x, all_genome_data, [gid])
        acc = -score * 100
        print(f"  {gid}: {acc:.2f}%")
    
    print("\n" + "=" * 80)
    print("TRAIN/VALIDATION RESULTS")
    print("=" * 80)
    print(f"Training accuracy:   {train_accuracy:.2f}%")
    print(f"Validation accuracy: {val_accuracy:.2f}%")
    print(f"Generalization gap:  {train_accuracy - val_accuracy:.2f}%")
    
    if abs(train_accuracy - val_accuracy) < 2.0:
        print("✓ Good generalization! Weights work across genomes.")
    else:
        print("⚠ Possible overfitting - weights may be too genome-specific.")
    
    print("\nOptimal weights (from training):")
    for key, val in optimal_weights.items():
        print(f"  {key}: {val:.4f}")
    
    print("=" * 80)
    
    return optimal_weights, train_accuracy, val_accuracy


def optimize_on_all_genomes(all_genome_data):
    """
    OPTION B: Optimize on ALL genomes (no split).
    Use this for final production weights after validation.
    """
    genome_ids = list(all_genome_data.keys())
    
    print("=" * 80)
    print("OPTION B: OPTIMIZING ON ALL GENOMES")
    print("=" * 80)
    print(f"Optimizing across {len(genome_ids)} genomes: {genome_ids}")
    print("This finds weights that work best on average across all organisms.")
    print("=" * 80)
    print()
    
    bounds = [
        (0.1, 5.0),  # codon
        (0.1, 10.0), # imm
        (0.1, 5.0),  # rbs
        (0.1, 10.0), # length
        (0.1, 3.0),  # start
    ]
    
    iteration = [0]
    best_so_far = [float('inf')]
    
    def callback(xk, convergence):
        iteration[0] += 1
        current_score = objective_function_multi_genome(xk, all_genome_data, genome_ids)
        if current_score < best_so_far[0]:
            best_so_far[0] = current_score
            avg_acc = -current_score * 100
            print(f"Iter {iteration[0]}: Avg accuracy = {avg_acc:.2f}% "
                  f"[codon={xk[0]:.2f}, imm={xk[1]:.2f}, rbs={xk[2]:.2f}, length={xk[3]:.2f}, start={xk[4]:.2f}]")
    
    result = differential_evolution(
        objective_function_multi_genome,
        bounds,
        args=(all_genome_data, genome_ids),
        strategy='best1bin',
        maxiter=100,
        popsize=15,
        tol=0.0001,
        mutation=(0.5, 1),
        recombination=0.7,
        callback=callback,
        polish=True,
        workers=1,
        updating='deferred'
    )
    
    optimal_weights = {
        'codon': result.x[0],
        'imm': result.x[1],
        'rbs': result.x[2],
        'length': result.x[3],
        'start': result.x[4]
    }
    
    avg_accuracy = -result.fun * 100
    
    print("\n" + "=" * 80)
    print("ALL-GENOME OPTIMIZATION COMPLETE")
    print("=" * 80)
    print(f"Average accuracy: {avg_accuracy:.2f}%")
    
    # Show per-genome breakdown
    print("\nPer-genome accuracy:")
    for gid in genome_ids:
        score = objective_function_multi_genome(result.x, all_genome_data, [gid])
        acc = -score * 100
        print(f"  {gid}: {acc:.2f}%")
    
    print("\nOptimal weights:")
    for key, val in optimal_weights.items():
        print(f"  {key}: {val:.4f}")
    
    print("=" * 80)
    
    return optimal_weights, avg_accuracy


def combined_best_practice_optimization(TEST_GENOMES, cached_data):
    """
    COMBINED BEST PRACTICE: 
    1. Prepare all genome data
    2. Train/Val split to check generalization
    3. If good generalization, train on all data for final weights
    """
    print("\n" + "=" * 100)
    print(" " * 30 + "BEST PRACTICE WEIGHT OPTIMIZATION")
    print("=" * 100)
    
    # Step 1: Prepare data
    print("\nSTEP 1: PREPARING DATA")
    all_genome_data = prepare_all_genome_data(TEST_GENOMES, cached_data)
    
    # Step 2: Train/Val split
    print("\n" + "=" * 100)
    print("\nSTEP 2: TRAIN/VALIDATION SPLIT")
    weights_trainval, train_acc, val_acc = optimize_with_train_val_split(all_genome_data)
    
    # Step 3: Decision point
    generalization_gap = abs(train_acc - val_acc)
    
    print("\n" + "=" * 100)
    print("\nSTEP 3: FINAL WEIGHT SELECTION")
    print("-" * 100)
    
    if generalization_gap < 2.0:
        print("✓ Good generalization detected!")
        print("  Training on ALL genomes for final production weights...")
        print()
        weights_final, avg_acc = optimize_on_all_genomes(all_genome_data)
        
        print("\n" + "=" * 100)
        print("FINAL RECOMMENDATION: Use ALL-GENOME weights")
        print("=" * 100)
        print(f"Average accuracy: {avg_acc:.2f}%")
        print("\nFinal weights:")
        for key, val in weights_final.items():
            print(f"  {key}: {val:.4f}")
        
        return weights_final, avg_acc
    else:
        print("⚠ Generalization gap is significant")
        print("  Recommend using train/val weights or investigating overfitting")
        print()
        
        print("\n" + "=" * 100)
        print("FINAL RECOMMENDATION: Use TRAIN/VAL weights (more conservative)")
        print("=" * 100)
        print(f"Training accuracy: {train_acc:.2f}%")
        print(f"Validation accuracy: {val_acc:.2f}%")
        print("\nWeights:")
        for key, val in weights_trainval.items():
            print(f"  {key}: {val:.4f}")
        
        return weights_trainval, val_acc


final_weights, final_accuracy = combined_best_practice_optimization(TEST_GENOMES, cached_data)


                              BEST PRACTICE WEIGHT OPTIMIZATION

STEP 1: PREPARING DATA
PREPARING ALL GENOME DATA

Processing NC_000913.3...
Building traditional scoring models...
  Building codon usage models...
  Building IMM models...
✓ All models built in 2.5s
  IMM order: 6
  Training sequences: 1220 (1,809,003 bp)
  Intergenic sequences: 490 (148,082 bp)
Scoring 176,315 ORFs...
  25,000...
  50,000...
  75,000...
  100,000...
  125,000...
  150,000...
  175,000...
Done in 1.0 minutes

Normalizing ORFs...

Calculating weighted scores...
  ✓ 9972 ORF groups prepared

Processing NC_000964.3...
Building traditional scoring models...
  Building codon usage models...
  Building IMM models...
✓ All models built in 4.0s
  IMM order: 6
  Training sequences: 1187 (1,605,582 bp)
  Intergenic sequences: 437 (133,496 bp)
Scoring 139,046 ORFs...
  25,000...
  50,000...
  75,000...
  100,000...
  125,000...
Done in 1.7 minutes

Normalizing ORFs...

Calculating weighted scores...
  ✓ 8545 ORF g

KeyboardInterrupt: 

# FILTERING THRESHOLDS OPTIMIZATION

In [93]:
def prepare_genome_data_with_scores(genome_list, cached_data, weights):
    """
    Pre-compute scored ORFs (we defined this earlier).
    """
    from src.comparative_analysis import get_gff_path
    
    all_genome_data = {}
    
    print("=" * 80)
    print("PREPARING SCORED GENOME DATA")
    print("=" * 80)
    
    for gid in genome_list:
        print(f"\nProcessing {gid}...")
        
        genome_data = cached_data[gid]
        all_orfs = genome_data['orfs']
        
        training_set = create_training_set(gid, cached_data)
        intergenic_set = create_intergenic_set(gid, cached_data)
        models = build_all_scoring_models_cached(training_set, intergenic_set)
        
        scored_orfs = score_all_orfs_cached(all_orfs, models)
        scored_orfs = normalize_all_orf_scores(scored_orfs)
        scored_orfs = add_combined_scores(scored_orfs)
        
        gff_path = get_gff_path(gid)
        ref = pd.read_csv(gff_path, sep="\t", comment="#", header=None)
        ref_genes = ref[ref[2] == "CDS"][[3, 4]]
        ref_genes.columns = ["start", "end"]
        ref_set = set(zip(ref_genes['start'], ref_genes['end']))
        
        all_genome_data[gid] = {
            'scored_orfs': scored_orfs,
            'ref_set': ref_set
        }
        
        print(f"  ✓ {len(scored_orfs)} ORFs scored")
    
    print("\n" + "=" * 80)
    print(f"Total genomes prepared: {len(all_genome_data)}")
    print("=" * 80)
    
    return all_genome_data

def optimize_ALL_FOUR_THRESHOLDS(TEST_GENOMES, cached_data, fixed_weights):
    """
    Optimize ALL 4 thresholds: codon, imm, length, combined (8 parameters total).
    """
    print("\n" + "=" * 100)
    print(" " * 15 + "OPTIMIZING ALL 4 THRESHOLDS (8 parameters)")
    print("=" * 100)
    
    genome_subset = ['NC_000913.3', 'NC_000964.3', 'NC_002505.1']
    
    print("\nSTEP 1: PREPARING DATA")
    genome_data = prepare_genome_data_with_scores(genome_subset, cached_data, fixed_weights)
    genome_ids = list(genome_data.keys())
    
    def objective_function(threshold_array, genome_data_dict, genome_ids, weights):
        """
        8 parameters: [init_codon, init_imm, init_length, init_combined,
                       final_codon, final_imm, final_length, final_combined]
        """
        init_codon, init_imm, init_length, init_combined = threshold_array[:4]
        final_codon, final_imm, final_length, final_combined = threshold_array[4:]
        
        def weighted_score(orf):
            return (
                orf['codon_score_norm'] * weights['codon'] +
                orf['imm_score_norm'] * weights['imm'] +
                orf['rbs_score_norm'] * weights['rbs'] +
                orf['length_score_norm'] * weights['length'] +
                orf['start_score_norm'] * weights['start']
            )
        
        total_tp = 0
        total_fp = 0
        total_fn = 0
        
        for gid in genome_ids:
            scored_orfs = genome_data_dict[gid]['scored_orfs']
            ref_set = genome_data_dict[gid]['ref_set']
            
            # Initial filter with ALL 4 thresholds
            candidates = []
            for orf in scored_orfs:
                length_score = orf.get('length_score', 0)
                codon_score = orf.get('codon_score', 0)
                imm_score = orf.get('imm_score', 0)
                combined_score = orf.get('combined_score', 0)
                
                all_three_below = (length_score < init_length and 
                                  codon_score < init_codon and 
                                  imm_score < init_imm)
                combined_below = combined_score < init_combined
                
                if not (all_three_below or combined_below):
                    candidates.append(orf)
            
            # Group and select starts
            grouped = organize_nested_orfs(candidates)
            top_candidates = []
            for stop_pos, orf_group in grouped.items():
                if len(orf_group) == 1:
                    top_candidates.append(orf_group[0])
                else:
                    best = max(orf_group, key=weighted_score)
                    top_candidates.append(best)
            
            # Final filter with ALL 4 thresholds
            final_predictions = []
            for orf in top_candidates:
                length_score = orf.get('length_score', 0)
                codon_score = orf.get('codon_score', 0)
                imm_score = orf.get('imm_score', 0)
                combined_score = orf.get('combined_score', 0)
                
                all_three_below = (length_score < final_length and 
                                  codon_score < final_codon and 
                                  imm_score < final_imm)
                combined_below = combined_score < final_combined
                
                if not (all_three_below or combined_below):
                    final_predictions.append(orf)
            
            # Calculate metrics
            for orf in final_predictions:
                coord = (orf.get('genome_start', orf['start']), 
                        orf.get('genome_end', orf['end']))
                if coord in ref_set:
                    total_tp += 1
                else:
                    total_fp += 1
            
            predicted_set = {(orf.get('genome_start', orf['start']), 
                             orf.get('genome_end', orf['end']))
                            for orf in final_predictions}
            total_fn += len(ref_set - predicted_set)
        
        sensitivity = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        
        # Geometric mean - requires BOTH high
        geometric_mean = (sensitivity * precision) ** 0.5
        
        # Penalties for low values
        penalty = 0
        if sensitivity < 0.78:
            penalty += 10 * (0.78 - sensitivity)
        if precision < 0.60:
            penalty += 10 * (0.60 - precision)
        
        return -(geometric_mean - penalty)
    
    print("\n" + "=" * 100)
    print("STEP 2: OPTIMIZING ALL 8 THRESHOLD PARAMETERS")
    print("=" * 100)
    print("Parameters: [init_codon, init_imm, init_length, init_combined,")
    print("             final_codon, final_imm, final_length, final_combined]")
    print()
    
    # Bounds for ALL 4 thresholds × 2 filters = 8 parameters
    bounds = [
        (-0.3, 0.3),   # init_codon_threshold
        (-0.3, 0.3),   # init_imm_threshold
        (-0.3, 0.3),   # init_length_threshold
        (-0.3, 0.2),   # init_combined_threshold
        (0.0, 0.5),    # final_codon_threshold
        (0.0, 0.5),    # final_imm_threshold
        (0.0, 0.5),    # final_length_threshold
        (0.1, 0.5),    # final_combined_threshold
    ]
    
    iteration = [0]
    best_geom = [0]
    
    def callback(xk, convergence):
        iteration[0] += 1
        current_score = -objective_function(xk, genome_data, genome_ids, fixed_weights)
        if current_score > best_geom[0]:
            best_geom[0] = current_score
            print(f"Iter {iteration[0]}: Score={current_score:.4f}")
            print(f"  Init:  codon={xk[0]:.3f}, imm={xk[1]:.3f}, length={xk[2]:.3f}, comb={xk[3]:.3f}")
            print(f"  Final: codon={xk[4]:.3f}, imm={xk[5]:.3f}, length={xk[6]:.3f}, comb={xk[7]:.3f}")
    
    from scipy.optimize import differential_evolution
    
    print("Starting optimization...")
    print()
    
    result = differential_evolution(
        objective_function,
        bounds,
        args=(genome_data, genome_ids, fixed_weights),
        strategy='best1bin',
        maxiter=50,      # More iterations for 8 params
        popsize=15,      # Larger population for 8 params
        tol=0.001,
        mutation=(0.5, 1),
        recombination=0.7,
        callback=callback,
        polish=True,
        workers=1,
        updating='deferred'
    )
    
    optimal_thresholds = {
        'initial_filter': {
            'codon_threshold': result.x[0],
            'imm_threshold': result.x[1],
            'length_threshold': result.x[2],
            'combined_threshold': result.x[3]
        },
        'final_filter': {
            'codon_threshold': result.x[4],
            'imm_threshold': result.x[5],
            'length_threshold': result.x[6],
            'combined_threshold': result.x[7]
        }
    }
    
    print("\n" + "=" * 100)
    print("ALL-THRESHOLD OPTIMIZATION COMPLETE!")
    print("=" * 100)
    print(f"Best Score (geometric mean): {best_geom[0]:.4f}")
    print("\nOptimal Initial Filter:")
    print(f"  codon_threshold    = {optimal_thresholds['initial_filter']['codon_threshold']:.4f}")
    print(f"  imm_threshold      = {optimal_thresholds['initial_filter']['imm_threshold']:.4f}")
    print(f"  length_threshold   = {optimal_thresholds['initial_filter']['length_threshold']:.4f}")
    print(f"  combined_threshold = {optimal_thresholds['initial_filter']['combined_threshold']:.4f}")
    print("\nOptimal Final Filter:")
    print(f"  codon_threshold    = {optimal_thresholds['final_filter']['codon_threshold']:.4f}")
    print(f"  imm_threshold      = {optimal_thresholds['final_filter']['imm_threshold']:.4f}")
    print(f"  length_threshold   = {optimal_thresholds['final_filter']['length_threshold']:.4f}")
    print(f"  combined_threshold = {optimal_thresholds['final_filter']['combined_threshold']:.4f}")
    print("=" * 100)
    
    return optimal_thresholds


In [94]:

# RUN IT!
optimal_weights = {
    'codon': 4.8562,
    'imm': 1.0107,
    'rbs': 0.6383,
    'length': 7.4367,
    'start': 0.2755
}

# Optimize ALL 4 thresholds (8 parameters)
complete_thresholds = optimize_ALL_FOUR_THRESHOLDS(TEST_GENOMES, cached_data, optimal_weights)


               OPTIMIZING ALL 4 THRESHOLDS (8 parameters)

STEP 1: PREPARING DATA
PREPARING SCORED GENOME DATA

Processing NC_000913.3...
Building traditional scoring models...
  Building codon usage models...
  Building IMM models...
✓ All models built in 2.5s
  IMM order: 6
  Training sequences: 1220 (1,809,003 bp)
  Intergenic sequences: 490 (148,082 bp)
Scoring 176,315 ORFs...


KeyboardInterrupt: 