In [None]:
"""
Script Name: ChainLAMP_coverage_analysis_v3.4.ipynb
Version: v3.4
Author: Shane Gilligan-Steinberg
Date: 240822

Description: ******** START HERE PLEASE ********
Each of the cells of this jupyter notebook can operated separately.

Process:
[1] Coverage Software - For initial testing you can skip extension and go straight here as there are test files already set up. Make sure to run [1A]-[1B]-[1C] in order
    [1A] Setup for all neccesary inputs - please go through the folder files to understand how assays and primers are denoted.
    [1B] Running the actual code to determine coverage

Additional software: Script Name: ChainLAMP_coverage_analysis_v3.4_extension.ipynb
This can be used to generate a sets of input sequences organized by subtype for use in pipeline
[1] Generate library of sequences to be inputted into the pipeline (from LANL alignment - need to remove gaps and organize by subtype).
Other capability of splitting by year is not available in this version
    [1B] Generate alignments with additional split by year of sequeneces
    [1C] Another option is to gather sequences from GenBank IDs
"""
# Install all neccesary libraries
!pip install biopython pandas matplotlib tqdm

In [16]:
"""
Script Name: ChainLAMP_coverage_analysis_v2.7.ipynb [1A]
Version: v2.11
Author: Shane Gilligan-Steinberg
Date: 240715

Description:
Run analysis of primer sets agains libraries of HIV sequences.
Aligns primers within assays to libraries of sequences (organizes by subtypes or other metric). Finds best alignment (no gaps) and all mismatches.
Then identifies coverage in a few ways (1) Perfect coverage (2) Coverage with single mismatch 
(3) Coverage with single mismatch (not in last 3 bp of critical termini) (4) Coverage with single mismatch (not in last 3 bp of critical termini)
(5) Coverage based on ROSALIND Scoring
Could add extra readout of coverage by assay and primer (for development)

Input:
1. List of assays (.csv) [~/Assays]
2. List of primer sequences for each assay (.fasta) [~/Assays]
- This software enables assessment of original and reverse complement.
- Report primers as 5'-3' with 3' end as critical termini (e.x. Report RC of F1 as the 5' is the critical termini)
- For LAMP, use primer regions. Make sure to be careful about where you split F1/F2 and B1/B2
- Probes: add NNNNN if direction of alignment is not important
3. .csv list of subtypes (can use script 230704_Split_FASTA_v2_CRF.py to perform separation from LANL alignmnents) [~/Targets]
4. .fasta library of HIV sequences separated by subtype (can use script 230704_Split_FASTA_v2_CRF.py to perform separation
from LANL alignmnents) [~/Alignments]

Outputs: In folder ("/Outputs")
1. NAME_organized_data.json: alignment for all primers
2. NAME_aggregate.json: alignmeng by assay
3. NAME_coverage.json: coverage by subtype
4. NAME_coverage.csv: coverage
5. NAME_organized_data_primer.json: coverage
6. NAME_by_base.json: coverage at the base pair level
7. NAME_case_info.json: information about the software run
"""

# Import libraries
import time
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqUtils import nt_search
from Bio.SeqRecord import SeqRecord
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import re
from Bio import Align
from Bio.Seq import Seq
import json
import tqdm
import csv

# Parameters
testing = False # True disables the global coverage weighting
globalAnalysis = False # True enables different coverage by year or test case
singleBaseAnalysis = True # True enables a readout of the coverage at the base pair level
enableGaps = False # The code is not able to parse gaps at the moment

# Primer order check for coverage
primer_order_check = True
primer_proximity_check = True
proximity_max_footprint = 500 ## 500 for LAMP to give buffer

# Primer Tm Calculation (Na/Mg = mM) (DNA = nM)
na_concentration = 60  # 60 mM
primer_concentration = 1000  # 1 µM = 1000 nM
mg_concentration = 8  # 8 mM
dntp_concentration = 1.4 # 1.4 mM
corrected_mg_concentration = max(0, mg_concentration - dntp_concentration)

## Expected primer orders (allows both backwards and forwards). This is the first word of the primer name (ex. F1 T7 - "F1" is primer name)
expected_orders = [
    ['F2', 'FL', 'F1', 'B1', 'BL', 'B2'],
    ['F3', 'F2', 'FL', 'F1', 'B1', 'BL', 'B2', 'B3'],
    ['F3', 'F2', 'FL', 'F1', 'B1', 'B2', 'B3'],
    ['FP', 'Probe', 'RP'],
    ['FP', 'RP']
]

# Below are the three test cases included in this tutorial
# 1. Chain LAMP with a small set of sequences
test_cases = [
    {
        "output_name_base": "240716_Chain_TEST", # Output name
        "assays": "Targets/targets_Chain_LAMP.csv", # list of assays. see sample file.  This points towards .fasta files organized by assay with primers
        "alignments": "Alignments/Subtypes_Pol_TEST.csv", # List of subtypes
        "rosalind_threshold": 3,
        "testing": True # True disables global coverage weighting
    }]
# # 2. Chain LAMP with a all pol sequences (takes ~ 30 minutes to run)
# test_cases = [
#     {
#         "output_name_base": "240716_Chain", # Output name
#         "assays": "Targets/targets_Chain_LAMP.csv", # list of assays. see sample file.  This points towards .fasta files organized by assay with primers
#         "alignments": "Alignments/Subtypes_Pol.csv", # List of subtypes
#         "rosalind_threshold": 3,
#         "testing": False
#     }]
# # 3. Mock validation set with known outputs
# test_cases = [
#     {
#         "output_name_base": "Validation", # Output name
#         "assays": "Targets/240710_Test.csv", # list of assays. see sample file.  This points towards .fasta files organized by assay with primers
#         "alignments": "Alignments/Subtypes_240710_Test.csv", # List of subtypes
#         "rosalind_threshold": 3,
#         "testing": True
#     }]
# ### Validation - expected outcome should align with Validation_ExpectedOutput.csv
# Please set the following to false for the validation
# ### primer_order_check = False
# ### primer_proximity_check = False

print(test_cases)

[{'output_name_base': 'Validation', 'assays': 'Targets/240710_Test.csv', 'alignments': 'Alignments/Subtypes_240710_Test.csv', 'rosalind_threshold': 3, 'testing': True}]


In [17]:
""" [1B] Coverage Software - Running the actual code to determine coverage """

from Bio.SeqUtils import MeltingTemp as mt
import json
from collections import defaultdict, OrderedDict

iupac_codes = {
    'A': {'A'}, 'C': {'C'}, 'G': {'G'}, 'T': {'T'},
    'R': {'A', 'G'}, 'Y': {'C', 'T'}, 'S': {'G', 'C'}, 'W': {'A', 'T'},
    'K': {'G', 'T'}, 'M': {'A', 'C'}, 'B': {'C', 'G', 'T'}, 'D': {'A', 'G', 'T'},
    'H': {'A', 'C', 'T'}, 'V': {'A', 'C', 'G'}, 'N': {'A', 'C', 'G', 'T'}
}

def read_fasta(file):
    return list(SeqIO.parse(file, "fasta"))

def reverse_complement(seq):
    return str(Seq(seq).reverse_complement())

def find_mismatches(seq, primer, start_pos, alignment_type):
    """
    Find mismatches between a sequence segment and a primer.
    Args:
        seq (str): The target sequence.
        primer (str): The primer sequence.
        start_pos (int): Starting position of the segment in the target sequence.  
    Returns:
        List[Tuple[int, int, str, str]]: List of mismatches as (target_location, primer_location, target_base, primer_base).
    """
    segment = seq[start_pos:start_pos + len(primer)]
    if alignment_type == "RC": ## Adjust the location of mismatch for RC
        mismatches = [(start_pos + j, len(primer) - 1 - j, a, b) for j, (a, b) in enumerate(zip(segment, primer))
                  if a not in iupac_codes.get(b, {b})]
    else:
        mismatches = [(start_pos + j, j, a, b) for j, (a, b) in enumerate(zip(segment, primer))
                    if a not in iupac_codes.get(b, {b})]
    return mismatches # (target_location, primer_location, target_base, primer_base)

def find_best_match(seq, primer):
    """
    Find the best match for a primer in a sequence, considering mismatches.
    Args:
        seq (str): The target sequence.
        primer (str): The primer sequence.
    Returns:
        Tuple[int, List[Tuple[int, int, str, str]], str, int]: Best match position, mismatches, alignment type, and number of mismatches.
    """
    best_mismatches = len(primer) + 1
    best_position = -1
    best_mismatches_details = []
    best_alignment_type = "primer"  # "primer" or "RC"

    for alignment_type, current_primer in [("primer", primer), ("RC", reverse_complement(primer))]:
        for i in range(len(seq) - len(current_primer) + 1):
            segment = seq[i:i + len(current_primer)]
            mismatch_count = sum(1 for a, b in zip(segment, current_primer) if a not in iupac_codes.get(b, {b}))
            if mismatch_count < best_mismatches:
                best_mismatches = mismatch_count
                best_position = i
                best_mismatches_details = find_mismatches(seq, current_primer, i, alignment_type)
                best_alignment_type = alignment_type

    return best_position, best_mismatches_details, best_alignment_type, best_mismatches

def organize_hierarchically(data):
    """
    Organize data hierarchically by file name, subtype, sequence ID, and assay name.
    Args:
        data (List[Tuple]): List of data tuples to organize.
    Returns:
        Dict: Hierarchically organized data.
    """
    organized_data = {}

    for entry in data:
        file_name, subtype, assay_name, sequence_id, primer_id, primer_length, best_position, mismatches, alignment_type, num_mismatches, primer_seq = entry
        if file_name not in organized_data:
            organized_data[file_name] = {}
        if subtype not in organized_data[file_name]:
            organized_data[file_name][subtype] = {}
        if sequence_id not in organized_data[file_name][subtype]:
            organized_data[file_name][subtype][sequence_id] = {}
        if assay_name not in organized_data[file_name][subtype][sequence_id]:
            organized_data[file_name][subtype][sequence_id][assay_name] = []
        organized_data[file_name][subtype][sequence_id][assay_name].append({
            "Primer_ID": primer_id,
            "Primer_length": primer_length,
            "Best_Position": best_position,
            "Mismatches": mismatches,
            "Alignment_Type": alignment_type,
            "Num_Mismatches": num_mismatches,
            "Primer_Sequence": primer_seq
        })

    return organized_data

def perfect_coverage(organized_data):
    """
    Calculate perfect coverage and coverage with one mismatch for each entry in the data.
    Args:
        organized_data (Dict): Hierarchically organized data. 
    Returns:
        Dict: Data with added coverage information.
    """
    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    for entry in entries:
                        # Add Coverage field based on Num_Mismatches
                        entry['Perfect_coverage'] = 1 if entry['Num_Mismatches'] == 0 else 0
                        entry['Coverage_with_one_mismatch'] = 1 if entry['Num_Mismatches'] <= 1 else 0
    return organized_data

def coverage_with_one_mismatch_n_bases_from_end(organized_data, n):
    """
    Calculate coverage with one mismatch within n bases from the end for each entry in the data.
    Args:
        organized_data (Dict): Hierarchically organized data.
        n (int): Number of bases from the end to consider. 
    Returns:
        Dict: Data with added coverage information.
    """
    coverage_field_name = f"Coverage_{n}"
    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    for entry in entries:
                        mismatches = entry['Mismatches']
                        primer_length = entry['Primer_length']
                        mismatch_within_n_bases = False

                        if len(mismatches) == 1:
                            mismatch_position = mismatches[0][1]
                            if primer_length - mismatch_position <= n:
                                mismatch_within_n_bases = True

                        if len(mismatches) == 0 or (len(mismatches) == 1 and not mismatch_within_n_bases):
                            entry[coverage_field_name] = 1
                        else:
                            entry[coverage_field_name] = 0
    return organized_data

def coverage_with_one_mismatch_per_assay(organized_data):
    """
    Calculate coverage with one mismatch per assay for each entry in the data.
    Args:
        organized_data (Dict): Hierarchically organized data.
    Returns:
        Dict: Data with added coverage information.
    """
    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    num_entries_with_no_mismatches = sum(1 for entry in entries if entry['Num_Mismatches'] == 0)
                    num_entries_with_one_mismatch = sum(1 for entry in entries if entry['Num_Mismatches'] == 1)
                    # Up to one entry with a mismatch while all others have none
                    if num_entries_with_one_mismatch <= 1 and num_entries_with_no_mismatches >= len(entries) - 1:
                        for entry in entries:
                            entry['Coverage_with_one_mismatch_per_assay'] = 1
                    else:
                        for entry in entries:
                            entry['Coverage_with_one_mismatch_per_assay'] = 0
    return organized_data

def calculate_assay_score(organized_data):
    """
    Calculate ROSALIND score for each primer based on mismatch positions and add it to the data.
    ROSALIND: https://www.rosalind.bio/en/knowledge/calculating-the-severity-score-for-lamp-assays
    Args:
        organized_data (Dict): Hierarchically organized data.
    Returns:
        Dict: Data with added primer scores.
    """

    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    #print(assay_name)
                    for entry in entries:
                        #print(entry['Primer_ID'])
                        score = 0
                        mismatches = entry['Mismatches']
                        primer_length = entry['Primer_length']
                        total_mismatches = entry['Num_Mismatches']
                        sequence = entry['Primer_Sequence']
                        old_sequence = correct_mismatches(entry['Primer_Sequence'], entry['Mismatches'], entry['Alignment_Type'])

                        # Score +2 for mismatches that are next to each other
                        for i in range(len(mismatches) - 1):
                            if mismatches[i][1] + 1 == mismatches[i + 1][1]:
                                score += 2
                                #print("+2 Next")

                        # Score +3 per mismatch in the last 2 bases
                        for mismatch in mismatches:
                            mismatch_position = mismatch[1]
                            if primer_length - mismatch_position <= 2:
                                score += 3
                                #print("+3 Last2")

                        # Score +2 per mismatch in positions 3-5 from the end
                        for mismatch in mismatches:
                            mismatch_position = mismatch[1]
                            if 3 <= primer_length - mismatch_position <= 5:
                                score += 2
                                #print("+2 Last 5")
                        
                        # Calculate Tm difference and adjust score
                        tm_diff = calculate_tm_diff(old_sequence, sequence)
                        if abs(tm_diff) > 5:
                            score += 4
                            #print("+4 TM")
                        elif abs(tm_diff) > 2.5:
                            score += 2
                            #print("+2 TM")

                        # Calculate score based on total number of mismatches and other criteria
                        entry['ROSALIND_Score'] = calculate_mismatch_score(total_mismatches) + score
                        #print("MS" + str(total_mismatches))
                        #print(calculate_mismatch_score(total_mismatches))

    return organized_data

def calculate_mismatch_score(mismatches):
    """
    Calculate the mismatch score (within ROSALIND) based on the total number of mismatches.
    ROSALIND: https://www.rosalind.bio/en/knowledge/calculating-the-severity-score-for-lamp-assays
    Args:
        mismatches (int): Total number of mismatches.   
    Returns:
        int: Mismatch score.
    """
    #print("INPUT" + str(mismatches))
    if mismatches == 0 or mismatches == 1:
        return 0
    elif mismatches == 2:
        return 1
    elif mismatches == 3:
        return 2
    else:
        return 4

def aggregate_coverage(organized_data):
    """
    Aggregate coverage information from the organized data.
    Args:
        organized_data (Dict): Hierarchically organized data.  
    Returns:
        Dict: Aggregated coverage information.
    """
    aggregated_data = {}

    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    max_score = max(entry['ROSALIND_Score'] for entry in entries)
                    representative_score = min(max_score, 5)
                    ##representative_score = entries[0]['ROSALIND_Score'] if entries else 0 ## Old version with score for whole assay

                    perfect_coverage = all(entry['Perfect_coverage'] == 1 for entry in entries)
                    coverage_with_one_mismatch = all(entry['Coverage_with_one_mismatch'] == 1 for entry in entries)
                    coverage_3 = all(entry['Coverage_3'] == 1 for entry in entries)
                    coverage_5 = all(entry['Coverage_5'] == 1 for entry in entries)
                    coverage_with_one_mismatch_per_assay = all(entry['Coverage_with_one_mismatch_per_assay'] == 1 for entry in entries)

                    primers_close = all(entry['Primers_Close_Enough'] for entry in entries)
                    primers_ordered = all(entry['Primers_Ordered_Correctly'] for entry in entries)

                    
                    if file_name not in aggregated_data:
                        aggregated_data[file_name] = {}
                    if subtype not in aggregated_data[file_name]:
                        aggregated_data[file_name][subtype] = {}
                    if sequence_id not in aggregated_data[file_name][subtype]:
                        aggregated_data[file_name][subtype][sequence_id] = {}
                    
                    aggregated_data[file_name][subtype][sequence_id][assay_name] = {
                        "ROSALIND_Score": representative_score,
                        "Perfect_coverage": int(perfect_coverage),
                        "Coverage_with_one_mismatch": int(coverage_with_one_mismatch),
                        "Coverage_3": int(coverage_3),
                        "Coverage_5": int(coverage_5),
                        "Coverage_with_one_mismatch_per_assay": int(coverage_with_one_mismatch_per_assay),
                        "Primers_Close_Enough": primers_close,
                        "Primers_Ordered_Correctly": primers_ordered
                    }
    
    return aggregated_data

def update_json_file(file_path, new_data):
    """
    Update an existing JSON file with new data.
    Args:
        file_path (str): Path to the JSON file.
        new_data (Dict): New data to update in the JSON file.
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True) # Ensure the directory exists

    # Read the existing file if it exists
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            existing_data = json.load(file)
    else:
        existing_data = {}

    # Update the existing data with new aggregated data
    for key, value in new_data.items():
        if key in existing_data:
            existing_data[key].update(value)
        else:
            existing_data[key] = value

    # Write the updated data back to the file
    with open(file_path, 'w') as file:
        json.dump(existing_data, file, indent=4)

def assay_coverage_analysis(file_path, rosalind_threshold):
    """
    Create the final nested dictionary with assay coverage analysis.
    Args:
        file_path (str): Path to the JSON file.
        rosalind_threshold (int): Threshold for the ROSALIND score.  
    Returns:
        Dict: Final nested dictionary with assay coverage analysis.
    """
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    final_dict = {}
    
    for file_name, subtypes in data.items():
        if file_name not in final_dict:
            final_dict[file_name] = {}
        for subtype, sequences in subtypes.items():
            if subtype not in final_dict[file_name]:
                final_dict[file_name][subtype] = {
                    "ROSALIND_Score_count": 0,
                    "Perfect_coverage_count": 0,
                    "Coverage_with_one_mismatch_count": 0,
                    "Coverage_3_count": 0,
                    "Coverage_5_count": 0,
                    "Coverage_with_one_mismatch_per_assay_count": 0,
                    "Total_sequences": len(sequences)
                }
            for sequence_id, assays in sequences.items():
                perfect_coverage = False
                coverage_with_one_mismatch = False
                coverage_3 = False
                coverage_5 = False
                coverage_ROSALIND = False
                coverage_with_one_mismatch_per_assay = False

                for entry in assays.values():

                    ## Leave as false if primer order check or proximity is turned on and doesn't meet criteria
                    if primer_order_check and not entry['Primers_Ordered_Correctly'] == True:
                            continue
                    if primer_proximity_check and not entry['Primers_Close_Enough'] == True:
                            continue
                    
                    ## Set each of these to true if true for one assay in set
                    perfect_coverage = perfect_coverage or entry["Perfect_coverage"] == 1
                    coverage_with_one_mismatch = coverage_with_one_mismatch or entry["Coverage_with_one_mismatch"] == 1
                    coverage_3 = coverage_3 or entry["Coverage_3"] == 1
                    coverage_5 = coverage_5 or entry["Coverage_5"] == 1
                    coverage_ROSALIND = coverage_ROSALIND or entry["ROSALIND_Score"] <= rosalind_threshold
                    coverage_with_one_mismatch_per_assay = coverage_with_one_mismatch_per_assay or entry["Coverage_with_one_mismatch_per_assay"] == 1

                if perfect_coverage:
                    final_dict[file_name][subtype]["Perfect_coverage_count"] += 1
                if coverage_with_one_mismatch:
                    final_dict[file_name][subtype]["Coverage_with_one_mismatch_count"] += 1
                if coverage_3:
                    final_dict[file_name][subtype]["Coverage_3_count"] += 1
                if coverage_5:
                    final_dict[file_name][subtype]["Coverage_5_count"] += 1
                if coverage_with_one_mismatch_per_assay:
                    final_dict[file_name][subtype]["Coverage_with_one_mismatch_per_assay_count"] += 1
                if coverage_ROSALIND:
                    final_dict[file_name][subtype]["ROSALIND_Score_count"] += 1    
    return final_dict

def create_csv_from_final_dict(final_dict, output_csv):
    categories = ["ROSALIND_Score_count", "Perfect_coverage_count", "Coverage_with_one_mismatch_count", "Coverage_3_count", "Coverage_5_count", "Coverage_with_one_mismatch_per_assay_count"]

    # Prepare the data for CSV
    rows = []
    aggregate_totals = {category: 0 for category in categories}
    aggregate_total_sequences = 0

    for file_name, subtypes in final_dict.items():
        for subtype, metrics in subtypes.items():
            row = [subtype]
            total_sequences = metrics["Total_sequences"]
            aggregate_total_sequences += total_sequences

            for category in categories:
                count = metrics[category]
                aggregate_totals[category] += count
                percent = (count / total_sequences) * 100 if total_sequences > 0 else 0
                row.append(count)
                row.append(f"{percent:.1f}%")
            row.append(total_sequences)  # Add the total sequences to the row
            rows.append(row)
    
    # Calculate aggregated row
    aggregate_row = ["Aggregated"]
    for category in categories:
        count = aggregate_totals[category]
        percent = (count / aggregate_total_sequences) * 100 if aggregate_total_sequences > 0 else 0
        aggregate_row.append(count)
        aggregate_row.append(f"{percent:.1f}%")
    aggregate_row.append(aggregate_total_sequences)  # Add the total sequences to the aggregated row

    # Write the data to CSV
    with open(output_csv, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        header = ["Subtype"]
        for category in categories:
            header.append(f"{category.replace('_count', '')} (Count)")
            header.append(f"{category.replace('_count', '')} (%)")
        header.append("Total_sequences")  # Add the header for total sequences
        csvwriter.writerow(header)
        csvwriter.writerows(rows)
        csvwriter.writerow(aggregate_row)  # Write the aggregated row

    print(f"CSV file {output_csv} created successfully.")

def calculate_weighted_aggregate(global_diversity_csv, output_csv, output_aggregate_csv):
    # Load the CSV files
    global_diversity_df = pd.read_csv(global_diversity_csv)
    output_df = pd.read_csv(output_csv)
    input_df = output_df
    # Remove the "Aggregated" row from the output dataframe
    output_df = output_df[output_df['Subtype'] != 'Aggregated']

    # Ensure both dataframes have the same number of rows
    assert len(global_diversity_df) == len(output_df), "The number of rows in the two files must match."

    # Define the categories to be aggregated
    categories = [
        "ROSALIND_Score (%)", "Perfect_coverage (%)",
        "Coverage_with_one_mismatch (%)", "Coverage_3 (%)",
        "Coverage_5 (%)", "Coverage_with_one_mismatch_per_assay (%)"
    ]
    for category in categories:
        #output_df[category] = output_df[category].str.rstrip('%').astype('float') / 100.0
        output_df.loc[:, category] = output_df[category].str.rstrip('%').astype('float') / 100.0
        
    # Initialize dictionary to store weighted sums and total prevalence
    weighted_sums = {category: 0 for category in categories}
    total_prevalence = global_diversity_df['Prevalence'].sum()

    # Calculate weighted sums
    for index in range(len(global_diversity_df)):
        prevalence = global_diversity_df.loc[index, 'Prevalence']
        for category in categories:
            weighted_sums[category] += output_df.loc[index, category] * prevalence

    # Calculate weighted percentages
    weighted_aggregate_row = ["Weighted Aggregated"]
    for category in categories:  # Skip Total_sequences for percentage calculation
        weighted_percent = weighted_sums[category] 
        weighted_aggregate_row.append(f"{weighted_percent:.1f}%")
    
    # Prepare the data for writing to CSV
    rows = [row for row in input_df.values]  # Existing rows
    print(output_df)
    weighted_aggregate_row_format = []
    for cell in weighted_aggregate_row:
        weighted_aggregate_row_format.append(cell)
        weighted_aggregate_row_format.append("")
    rows.append(weighted_aggregate_row_format)  # Append the weighted aggregate row

    # Write the data to the new CSV
    with open(output_aggregate_csv, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        header = output_df.columns.tolist()
        csvwriter.writerow(header)
        csvwriter.writerows(rows)
    
    print(f"CSV file {output_aggregate_csv} created successfully with weighted aggregates.")

def calculate_coverage_by_subtype_primer(organized_data_file, output_file, rosalind_threshold):
    """
    Calculate coverage by subtype for each primer region, including ROSALIND scoring.
    
    Args:
        organized_data_file (str): Path to the JSON file containing organized data.
        output_file (str): Path to the output JSON file for saving the coverage analysis results.
        rosalind_threshold (int): Threshold for the ROSALIND score. Default is 5.
    
    Returns: Create organized_data_primer.json file
    """

    with open(organized_data_file, 'r') as file:
        organized_data = json.load(file)

    coverage_by_subtype = {}

    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            if subtype not in coverage_by_subtype:
                coverage_by_subtype[subtype] = {}
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    if assay_name not in coverage_by_subtype[subtype]:
                        coverage_by_subtype[subtype][assay_name] = {}
                    for entry in entries:
                        primer_id = entry['Primer_ID']
                        if primer_id not in coverage_by_subtype[subtype][assay_name]:
                            coverage_by_subtype[subtype][assay_name][primer_id] = {
                                'Perfect_coverage': 0,
                                'Coverage_with_one_mismatch': 0,
                                'Coverage_3': 0,
                                'Coverage_5': 0,
                                'Coverage_with_one_mismatch_per_assay': 0,
                                'ROSALIND_Score_within_threshold': 0,
                                'Total': 0
                            }

                        coverage_by_subtype[subtype][assay_name][primer_id]['Perfect_coverage'] += entry['Perfect_coverage']
                        coverage_by_subtype[subtype][assay_name][primer_id]['Coverage_with_one_mismatch'] += entry['Coverage_with_one_mismatch']
                        coverage_by_subtype[subtype][assay_name][primer_id]['Coverage_3'] += entry['Coverage_3']
                        coverage_by_subtype[subtype][assay_name][primer_id]['Coverage_5'] += entry['Coverage_5']
                        coverage_by_subtype[subtype][assay_name][primer_id]['Coverage_with_one_mismatch_per_assay'] += entry['Coverage_with_one_mismatch_per_assay']
                        if entry['ROSALIND_Score'] <= rosalind_threshold:
                            coverage_by_subtype[subtype][assay_name][primer_id]['ROSALIND_Score_within_threshold'] += 1
                        coverage_by_subtype[subtype][assay_name][primer_id]['Total'] += 1

    with open(output_file, 'w') as outfile:
        json.dump(coverage_by_subtype, outfile, indent=4)

    print(f"Coverage by subtype with ROSALIND scoring has been calculated and saved to {output_file}")

def check_primer_proximity_and_order(organized_data):
    """
    Check if all primers in each assay are close enough to each other and in correct order based on 'Best_Position'.
    Args:
        organized_data (Dict): Hierarchically organized data.
        max_distance (int): Maximum allowed distance between primers.
    Returns:
        Dict: Data with flags indicating if primers are within the maximum distance and in correct order.
    """
    def are_primers_close_enough(primer_positions, max_distance):
        for i in range(len(primer_positions)):
            for j in range(i + 1, len(primer_positions)):
                if abs(primer_positions[i] - primer_positions[j]) > max_distance:
                    return False
        return True

    def are_primers_ordered(sorted_primers):
        sorted_ids = [primer['Primer_ID'] for primer in sorted_primers]

        for expected_order in expected_orders:
            if sorted_ids == expected_order or sorted_ids == expected_order[::-1]:
                return True
        return False

    for file_name, subtypes in organized_data.items():
        for subtype, sequences in subtypes.items():
            for sequence_id, assays in sequences.items():
                for assay_name, entries in assays.items():
                    # Sort primers based on Best_Position
                    sorted_primers = sorted(entries, key=lambda x: x['Best_Position'])
                    
                    # Extract the Best_Position values
                    primer_positions = [entry['Best_Position'] for entry in sorted_primers]

                    # Check if all primers are within the max_distance
                    primers_close = are_primers_close_enough(primer_positions, proximity_max_footprint)
                    
                    # Check if primers are ordered correctly based on 'Primer_ID' and 'Best_Position'
                    primers_ordered = are_primers_ordered(sorted_primers)
                    
                    # Add the flags to indicate primer proximity and order
                    for entry in entries:
                        entry['Primers_Close_Enough'] = primers_close
                        entry['Primers_Ordered_Correctly'] = primers_ordered

    return organized_data

def correct_mismatches(primer_sequence, mismatches, alignment_type):
    """
    Correct mismatches in the primer sequence to generate the original sequence.
    Args:
        primer_sequence (str): The primer sequence with mismatches.
        mismatches (list): List of mismatches, each defined as [position, index, original_base, mismatched_base].
        alignment_type (str): Alignment type, "RC" for reverse complement.
    Returns:
        str: The corrected (original) sequence.
    """
    # If the alignment type is RC, reverse complement the sequence
    if alignment_type == "RC":
        primer_sequence = str(Seq(primer_sequence).reverse_complement())

    # Convert to list for mutable string
    original_sequence = list(primer_sequence)

    # Correct each mismatch
    for mismatch in mismatches:
        position, index, original_base, mismatched_base = mismatch
        
        # Adjust index if alignment is RC since we flipped the sequence
        if alignment_type == "RC":
            index = len(primer_sequence) - 1 - index

        # Replace the mismatched base with the original base
        original_sequence[index] = original_base

    # Join the list back into a string
    return ''.join(original_sequence)

def calculate_tm_diff(primer_sequence, old_primer):
    """
    Calculate the melting temperature (Tm) of a DNA primer using the Biopython library.

    Parameters:
    primer_sequence (str): The DNA sequence of the primer.
    na_concentration (float): The concentration of monovalent cations (e.g., Na+) in M. Default is 50 mM.
    primer_concentration (float): The concentration of primers in M. Default is 250 nM.

    Returns:
    float: The melting temperature (Tm) in degrees Celsius.
    """
    # Calculate the Tm using the nearest-neighbor method provided by Biopython
    tm = mt.Tm_NN(primer_sequence, Na=na_concentration, Mg=corrected_mg_concentration, dnac1=primer_concentration)
    tm_old = mt.Tm_NN(old_primer, Na=na_concentration, Mg=corrected_mg_concentration, dnac1=primer_concentration)

    return tm-tm_old

def write_run_info(output_file, output_name_base, assays, alignments, rosalind_threshold,
                   testing, globalAnalysis, primer_order_check, primer_proximity_check,
                   proximity_max_footprint, na_concentration, primer_concentration,
                   mg_concentration, dntp_concentration, corrected_mg_concentration,
                   expected_orders):
    """
    Write the run information for the test case to a JSON file.

    Args:
        output_file (str): The path to the output JSON file.
        output_name_base (str): The base name for the output.
        assays (str): Path to the assays file.
        alignments (str): Path to the alignments file.
        rosalind_threshold (int): Threshold for the ROSALIND score.
        testing (bool): Indicates if the run is for testing purposes.
        globalAnalysis (bool): Indicates if global analysis is enabled.
        primer_order_check (bool): Flag to check primer order for coverage.
        primer_proximity_check (bool): Flag to check primer proximity for coverage.
        proximity_max_footprint (int): Maximum allowed distance for primer proximity.
        na_concentration (float): Sodium ion concentration in mM.
        primer_concentration (float): Primer concentration in nM.
        mg_concentration (float): Magnesium ion concentration in mM.
        dntp_concentration (float): dNTP concentration in mM.
        corrected_mg_concentration (float): Corrected magnesium ion concentration after accounting for dNTP binding.
        expected_orders (list): List of expected primer orders.

    Returns:
        None
    """
    run_info = {
        "output_name_base": output_name_base,
        "assays": assays,
        "alignments": alignments,
        "rosalind_threshold": rosalind_threshold,
        "testing": testing,
        "globalAnalysis": globalAnalysis,
        "primer_order_check": primer_order_check,
        "primer_proximity_check": primer_proximity_check,
        "proximity_max_footprint": proximity_max_footprint,
        "na_concentration": na_concentration,
        "primer_concentration": primer_concentration,
        "mg_concentration": mg_concentration,
        "dntp_concentration": dntp_concentration,
        "corrected_mg_concentration": corrected_mg_concentration,
        "expected_orders": expected_orders
    }

    with open(output_file, 'w') as f:
        json.dump(run_info, f, indent=4)

# Dictionary to handle IUPAC nucleotide complements
IUPAC_COMPLEMENT = {
    'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G',
    'R': 'Y', 'Y': 'R', 'S': 'S', 'W': 'W',
    'K': 'M', 'M': 'K', 'B': 'V', 'V': 'B',
    'D': 'H', 'H': 'D', 'N': 'N'
}

def process_mismatches(json_file):
    with open(json_file, 'r') as file:
        data = json.load(file)

    mismatch_frequencies = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int))))

    for experiment, subtypes in data.items():
        for subtype, sequences in subtypes.items():
            for sequence, assays in sequences.items():
                for assay_name, primers in assays.items():
                    for primer in primers:
                        primer_id = primer.get("Primer_ID")
                        primer_sequence = primer.get("Primer_Sequence")
                        mismatches = primer.get("Mismatches", [])

                        # Track mismatched positions
                        mismatched_positions = set()

                        # Process mismatches
                        for mismatch in mismatches:
                            position = mismatch[1]
                            reference_base = mismatch[3]  # Switch to observed base
                            observed_base = mismatch[2]   # Switch to reference base

                            if primer["Alignment_Type"] == "RC":
                                # Apply reverse complement if alignment type is RC
                                position = len(primer["Primer_Sequence"]) - 1 - position
                                reference_base = reverse_complement(mismatch[3])
                                observed_base = reverse_complement(mismatch[2])

                            # Mark position as mismatched
                            mismatched_positions.add(position)

                            # Update the frequency count
                            mismatch_frequencies[subtype][assay_name][primer_id][(position, observed_base)] += 1

                        # Initialize positions that have no mismatches with the original base
                        for i, base in enumerate(primer_sequence):
                            if i not in mismatched_positions:
                                mismatch_frequencies[subtype][assay_name][primer_id][(i, base+"*")] += 1

    return mismatch_frequencies

def calculate_percentages(mismatch_frequencies):
    percentages = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))))

    for subtype, assays in mismatch_frequencies.items():
        for assay, primers in assays.items():
            for primer_id, mismatches in primers.items():
                position_totals = defaultdict(int)

                # Calculate total occurrences per position
                for (position, base), frequency in mismatches.items():
                    position_totals[position] += frequency

                # Calculate percentages
                for (position, base), frequency in mismatches.items():
                    total = position_totals[position]
                    if total > 0:
                        percentage = (frequency / total) * 100
                        percentages[subtype][assay][primer_id][position][base] = round(percentage, 2)

    return percentages

def save_percentages_to_json(percentages, output_file):
    with open(output_file, 'w') as file:
        json.dump(percentages, file, indent=4)

def print_percentages(percentages):
    for subtype, assays in percentages.items():
        print(f"Subtype: {subtype}")
        for assay, primers in assays.items():
            print(f"  Assay: {assay}")
            for primer_id, positions in primers.items():
                print(f"    Primer ID: {primer_id}")
                for position, bases in positions.items():
                    print(f"      Position {position}:")
                    for base, percentage in bases.items():
                        print(f"        Base {base}: {percentage:.2f}%")

def sort_mismatch_frequencies_by_position(mismatch_frequencies):
    """
    Sort the positions within each primer by position.
    """
    sorted_frequencies = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: OrderedDict())))

    for subtype, assays in mismatch_frequencies.items():
        for assay, primers in assays.items():
            for primer_id, positions in primers.items():
                # Sort the positions by their integer key (position)
                sorted_positions = OrderedDict(sorted(positions.items()))
                sorted_frequencies[subtype][assay][primer_id] = sorted_positions

    return sorted_frequencies

def process_and_save_mismatch_percentages(json_file, output_file):
    save_percentages_to_json(
        calculate_percentages(
            sort_mismatch_frequencies_by_position(
                process_mismatches(json_file)
            )
        ), 
        output_file
    )

    # Optionally, you can include the print statements here:
    sorted_mismatch_frequencies = sort_mismatch_frequencies_by_position(process_mismatches(json_file))
    # print(sorted_mismatch_frequencies)
    
    percentages = calculate_percentages(sorted_mismatch_frequencies)
    # print_percentages(percentages)
    
    print(f"Percentages saved to {output_file}")

####### WRITE CODE
for case in test_cases:
    output_name_base = case["output_name_base"]
    assays = case["assays"]
    alignments = case["alignments"]
    rosalind_threshold = case["rosalind_threshold"]
    
    # Print the variables to verify
    print(f"Output Name Base: {output_name_base}")
    print(f"Assays: {assays}")
    print(f"Alignments: {alignments}")
    print(f"Rosalind Threshold: {rosalind_threshold}")
    print("---")
    
    # Read in the Subtypes.csv file
    subtype_data = pd.read_csv(alignments)

    # File Naming
    output_dir_base = "Outputs/" + output_name_base
    output_aggregate = output_dir_base + "_aggregate.json"
    output_organized = output_dir_base + "_organized_data.json"
    output_organized_primer = output_dir_base + "_organized_data_primer.json"
    output_coverage = output_dir_base + "_coverage.json"
    output_csv = output_dir_base + "_coverage.csv"
    output_agg_csv = output_dir_base + "_coverage_agg.csv"
    output_case_info = output_dir_base + "_case_info.json"
    output_by_base = output_dir_base + "_by_base.json"

    write_run_info(output_case_info, output_name_base, assays, alignments, rosalind_threshold,
               testing, globalAnalysis, primer_order_check, primer_proximity_check,
               proximity_max_footprint, na_concentration, primer_concentration,
               mg_concentration, dntp_concentration, corrected_mg_concentration,
               expected_orders)

    # This is for allowing a probe to not count in the 3' end but I haven't added it (For now - add NNNNN to the end of a probe)
    PCR_primers = True

    print("Analyzing assay: " + assays)
    print("Analyzing subtypes: " + alignments)
    # RUNNING THE CODE BELOW FOR ANALYSIS
    all_mismatch_details = []
    start_time = time.time()

    for subtype in subtype_data["subtypes"]: # Iterate over subtypes. This allows for analysis of multiple subtypes
        print(f"Processing subtype {subtype}")
        # Clean the subtype name to remove file extensions and paths
        subtype_name = re.sub(r"\.fasta$", "", subtype)
        subtype_name = re.sub(r".*/", "", re.sub(r"\.fasta$", "", subtype))
        output_json = output_dir_base + "_" + subtype_name
        # Read the sequences from the alignment file
        seqfile = f"Alignments/{subtype}"
        seqstring = read_fasta(seqfile)
        #print(seqstring)
        
        numseqs = len(seqstring)
        targets = pd.read_csv(assays) # Read the target assays
        # Create a dictionary of primers
        myprimers = {row[0]: read_fasta(f"Assays/{row[0]}.fasta") for _, row in targets.iterrows()}
        #print(myprimers)
        mismatch_details = []

        total_sequences = len(seqstring) # Number of sequences in this Subtype
        with tqdm.tqdm(total=total_sequences, desc="Processing sequences", unit="iter") as pbar:
            for seq_record in seqstring: # Loop through each sequence in the subtype file
                pbar.update(1)  # Update the progress bar for sequence iteration
                for k, (assay_name, primers) in enumerate(myprimers.items()): # Loop through each assay and associated primers
                    for primer in primers: # Loop through each primers
                        seq = str(seq_record.seq)
                        primer_seq = str(primer.seq)
                        primer_length = len(primer_seq)
                        primer_id = primer.id  # Assuming primer has an id attribute


                        # Find best match and mismatches
                        best_position, mismatches, alignment_type, num_mismatches = find_best_match(seq, primer_seq)
                        mismatch_details.append((output_name_base, subtype, assay_name, seq_record.id, primer.id, primer_length, best_position, mismatches, alignment_type, num_mismatches, primer_seq))
                #print(mismatch_details)
                # Add mismatch details to the global list
                all_mismatch_details.extend(mismatch_details)
        
            # Organize data for analysis
            organized_data = organize_hierarchically(mismatch_details)
            # Assess perfect coverage and coverage with single mismatch
            organized_data = perfect_coverage(organized_data)
        
            # Assess coverage allowing mismatches except with n from end
            organized_data = coverage_with_one_mismatch_n_bases_from_end(organized_data, 3)
            organized_data = coverage_with_one_mismatch_n_bases_from_end(organized_data, 5)
            #print(json.dumps(organized_data, indent=4))

            # Assess coverage with one mismatch per assay
            organized_data = coverage_with_one_mismatch_per_assay(organized_data)
            #print(json.dumps(organized_data, indent=4))
            
            # Calculate ROSALIND assay score based on total number of mismatches per assay
            organized_data = calculate_assay_score(organized_data)

            # Add a check for primer proximity and order
            organized_data = check_primer_proximity_and_order(organized_data)
            #print(json.dumps(organized_data, indent=4))
            update_json_file(output_organized, organized_data)
        
            # Aggregate ROSALIND scores
            aggregated_data = aggregate_coverage(organized_data)
            #print(json.dumps(aggregated_data, indent=4))
            update_json_file(output_aggregate, aggregated_data)


    # Create the final nested dictionary with a ROSALIND threshold
    print(output_aggregate)
    final_dict = assay_coverage_analysis(output_aggregate, rosalind_threshold)

    # Print the final dictionary
    print(json.dumps(final_dict, indent=4))
    # Write the updated data back to the file
    with open(output_coverage, 'w') as file:
        json.dump(final_dict, file, indent=4)

    create_csv_from_final_dict(final_dict, output_csv)
    print("DONE")
    print("Total time: " + str(time.time() - start_time))
    
    testing = case["testing"]

    global_diversity_csv = case.get("global_diversity", "Global_Diversity.csv")

    # if globalAnalysis: 
    #     global_diversity_csv = case["global_diversity"]
    # else:
    #     global_diversity_csv = "Global_Diversity.csv"
    if not testing:
        calculate_weighted_aggregate(global_diversity_csv, output_csv, output_agg_csv)
    
    calculate_coverage_by_subtype_primer(output_organized, output_organized_primer, rosalind_threshold)

    if singleBaseAnalysis:
        process_and_save_mismatch_percentages(output_organized, output_by_base)

Output Name Base: Validation
Assays: Targets/240710_Test.csv
Alignments: Alignments/Subtypes_240710_Test.csv
Rosalind Threshold: 3
---
Analyzing assay: Targets/240710_Test.csv
Analyzing subtypes: Alignments/Subtypes_240710_Test.csv
Processing subtype 240710_TestA.fasta


Processing sequences: 100%|██████████| 5/5 [00:00<00:00, 861.15iter/s]


Processing subtype 240710_TestB.fasta


Processing sequences: 100%|██████████| 7/7 [00:00<00:00, 2401.45iter/s]

Outputs/Validation_aggregate.json
{
    "Validation": {
        "240710_TestA.fasta": {
            "ROSALIND_Score_count": 5,
            "Perfect_coverage_count": 1,
            "Coverage_with_one_mismatch_count": 4,
            "Coverage_3_count": 3,
            "Coverage_5_count": 2,
            "Coverage_with_one_mismatch_per_assay_count": 5,
            "Total_sequences": 5
        },
        "240710_TestB.fasta": {
            "ROSALIND_Score_count": 3,
            "Perfect_coverage_count": 2,
            "Coverage_with_one_mismatch_count": 3,
            "Coverage_3_count": 3,
            "Coverage_5_count": 3,
            "Coverage_with_one_mismatch_per_assay_count": 7,
            "Total_sequences": 7
        }
    }
}
CSV file Outputs/Validation_coverage.csv created successfully.
DONE
Total time: 0.013569116592407227
Coverage by subtype with ROSALIND scoring has been calculated and saved to Outputs/Validation_organized_data_primer.json
defaultdict(<function sort_mismatch_fre


