## Automated_hybrid_sim

In [10]:
# Import necessary libraries for numerical operations, plotting, random choices, and data manipulation.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import random
import matplotlib.patches as mpatches
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output

# --- Global Constants for Allele Types and Colors ---
# Define allele types. These are arbitrary labels representing genetic variants.
MAGENTA = 'M'
YELLOW = 'Y'

# Define a color map for visualizing alleles on chromosomes.
# This maps each allele type to a specific color for graphical representation.
col_map = {MAGENTA: 'purple', YELLOW: 'gold'}

# --- Class Definitions ---
# These classes model the fundamental biological components of the simulation:
# Chromosomes, diploid chromosome pairs, and individual organisms.

class Chromosome:
    """
    Represents a single haploid chromosome, a linear sequence of alleles.
    Each chromosome is essentially one of the two chromatids in a diploid pair.
    """
    def __init__(self, alleles):
        """
        Initializes a Chromosome object.

        Args:
            alleles (list): A list of strings, where each string represents an allele at a locus.
        """
        self.alleles = alleles # The genetic content of the chromosome (e.g., ['M', 'Y', 'M', ...])

class DiploidChromosomePair:
    """
    Represents a homologous pair of chromosomes (a diploid pair).
    In a diploid organism, genetic information for a given chromosome is present
    on two homologous chromatids, one inherited from each parent.
    """
    def __init__(self, chromatid1, chromatid2):
        """
        Initializes a DiploidChromosomePair object.

        Args:
            chromatid1 (Chromosome): The first chromatid of the pair.
            chromatid2 (Chromosome): The second chromatid of the pair.
        """
        self.chromatid1 = chromatid1 # First haploid chromosome (e.g., from paternal origin)
        self.chromatid2 = chromatid2 # Second haploid chromosome (e.g., from maternal origin)

class Individual:
    """
    Represents a single diploid organism in the simulation.
    An individual's genome is composed of one or more diploid chromosome pairs.
    It includes methods to calculate key genetic metrics like Hybrid Index and Heterozygosity.
    """
    def __init__(self, num_chromosomes, num_loci_per_chromosome):
        """
        Initializes an Individual object.

        Args:
            num_chromosomes (int): The number of diploid chromosome pairs this individual has.
            num_loci_per_chromosome (int): The number of loci on each chromatid.
        """
        global individual_id_counter # Access the global counter for unique IDs
        self.id = individual_id_counter # Assign a unique ID to this individual
        individual_id_counter += 1 # Increment the counter for the next individual

        self.num_chromosomes = num_chromosomes # Total number of diploid chromosome pairs
        self.num_loci_per_chromosome = num_loci_per_chromosome # Length of each chromosome (number of loci)
        self.diploid_chromosome_pairs = [] # List to store the DiploidChromosomePair objects

    def calculate_hybrid_index(self):
        """
        Calculates the Hybrid Index (HI) for the individual.
        HI is the proportion of 'M' (Magenta) alleles across all loci in the genome.
        A pure 'Y' parent would have HI = 0.0, a pure 'M' parent HI = 1.0,
        and an F1 hybrid (MY at all loci) HI = 0.5.

        Returns:
            float: The calculated Hybrid Index.
        """
        total_m_alleles = 0
        total_alleles = 0
        for pair in self.diploid_chromosome_pairs:
            # Count alleles on chromatid 1
            total_m_alleles += pair.chromatid1.alleles.count(MAGENTA)
            total_alleles += len(pair.chromatid1.alleles)
            # Count alleles on chromatid 2
            total_m_alleles += pair.chromatid2.alleles.count(MAGENTA)
            total_alleles += len(pair.chromatid2.alleles)

        return total_m_alleles / total_alleles if total_alleles > 0 else 0.0

    def calculate_heterozygosity(self):
        """
        Calculates the genome-wide Heterozygosity for the individual.
        Heterozygosity is the proportion of loci where the two alleles on a
        diploid chromosome pair are different (e.g., one 'M' and one 'Y').

        Returns:
            float: The calculated Heterozygosity.
        """
        heterozygous_loci = 0
        total_loci = 0
        for pair in self.diploid_chromosome_pairs:
            for i in range(self.num_loci_per_chromosome):
                allele1 = pair.chromatid1.alleles[i]
                allele2 = pair.chromatid2.alleles[i]
                if allele1 != allele2:
                    heterozygous_loci += 1
                total_loci += 1

        return heterozygous_loci / total_loci if total_loci > 0 else 0.0

# --- Helper Functions for Genetic Processes ---

# Helper function to record an individual's detailed genomic data into the global list.
def _record_individual_genome_for_detailed_dataframe(individual, generation_stage_label):
    """
    Records the detailed genotype information for an individual at each locus
    into the global 'genetic_data_records' list. This list will later be converted
    into a pandas DataFrame for comprehensive data analysis.

    Args:
        individual (Individual): The individual whose genome data is to be recorded.
        generation_stage_label (str): A string label indicating the generation stage of the individual
                                      (e.g., "P0_A", "F1", "F2", "BC1_A").
    """
    # Iterate through each diploid chromosome pair that the individual possesses.
    # 'chr_idx' serves as a 0-indexed counter for the chromosome pair.
    for chr_idx, diploid_pair in enumerate(individual.diploid_chromosome_pairs):
        # Iterate through each locus position along the length of the chromosomes.
        for locus_idx in range(individual.num_loci_per_chromosome):
            # Extract the alleles from both chromatids at the current locus.
            allele_a = diploid_pair.chromatid1.alleles[locus_idx]
            allele_b = diploid_pair.chromatid2.alleles[locus_idx]

            # Sort the alleles (e.g., ensures 'M|Y' instead of 'Y|M') for consistent genotype representation.
            sorted_alleles = sorted([allele_a, allele_b])
            # Form a genotype string representing the alleles at this locus.
            genotype_str = f"{sorted_alleles[0]}|{sorted_alleles[1]}"

            # Append a dictionary representing this locus's data to the global records list.
            # 'diploid_chr_id' is 1-indexed for better readability (e.g., Chr 1, Chr 2).
            genetic_data_records.append({
                'generation': generation_stage_label,
                'individual_id': individual.id, # This is the unique ID for the whole individual
                'diploid_chr_id': chr_idx + 1, # This is the ID for the specific diploid chromosome pair within the individual
                'locus_position': locus_idx,
                'genotype': genotype_str # The combined genotype at this specific locus (e.g., 'M|Y').
            })

def form_gamete_from_diploid_pair(diploid_chromosome_pair, chromosome_length, recomb_event_probabilities, recomb_probabilities):
    """
    Simulates the formation of a single haploid chromatid (gamete) from a diploid chromosome pair,
    including recombination (crossing over).

    Args:
        diploid_chromosome_pair (DiploidChromosomePair): The diploid pair from which the gamete is formed.
        chromosome_length (int): The number of loci on the chromosome.
        recomb_event_probabilities (list): Probabilities for 0, 1, 2, ... recombination events.
                                           e.g., [0.5, 0.5, 0] means 50% chance of 0 crossovers, 50% chance of 1.
        recomb_probabilities (list): A list of probabilities for recombination occurring between
                                     each adjacent pair of loci (length `chromosome_length - 1`).

    Returns:
        Chromosome: A new Chromosome object representing the haploid gamete.
    """
    if chromosome_length <= 1:
        # If chromosome has 0 or 1 locus, no recombination can occur.
        # Randomly choose one of the two chromatids to pass on entirely.
        return random.choice([diploid_chromosome_pair.chromatid1, diploid_chromosome_pair.chromatid2])

    # Determine the number of recombination events based on provided probabilities.
    # If the sum of probabilities is 0, no recombination events occur.
    if sum(recomb_event_probabilities) == 0:
        num_recombs = 0
    else:
        num_recombs = random.choices(
            population=list(range(len(recomb_event_probabilities))), # Number of crossovers (0, 1, 2, ...)
            weights=recomb_event_probabilities, # Their respective probabilities
            k=1 # Choose one
        )[0]

    # Identify crossover points.
    crossover_points = []
    # Only try to find crossover points if at least one recombination event is determined.
    if num_recombs > 0:
        # Create a list of possible intervals where recombination can occur (between loci).
        recombination_intervals = list(range(chromosome_length - 1))

        # Filter intervals by their recombination probability.
        # This creates a list of intervals where recombination is actually possible (prob > 0).
        possible_crossovers = [i for i, p in enumerate(recomb_probabilities) if p > 0]

        if not possible_crossovers: # If no intervals have a recombination probability > 0
            num_recombs = 0 # No crossovers can actually happen
        else:
            # Randomly select unique crossover points from possible intervals.
            # `min` ensures we don't try to select more unique points than available.
            crossover_points = sorted(random.sample(
                possible_crossovers, min(num_recombs, len(possible_crossovers))
            ))

    # Initialize the gamete with alleles from a randomly chosen starting chromatid.
    current_strand = random.choice([diploid_chromosome_pair.chromatid1, diploid_chromosome_pair.chromatid2])
    gamete_alleles = []

    last_crossover_point = 0
    # Build the gamete by stitching segments from the two chromatids.
    for i, crossover_point in enumerate(crossover_points):
        # Add alleles from the current strand up to the crossover point.
        gamete_alleles.extend(current_strand.alleles[last_crossover_point : crossover_point + 1])
        # Switch to the other homologous strand for the next segment.
        current_strand = (diploid_chromosome_pair.chromatid1 if current_strand == diploid_chromosome_pair.chromatid2
                          else diploid_chromosome_pair.chromatid2)
        last_crossover_point = crossover_point + 1

    # Add any remaining alleles from the current strand after the last crossover.
    gamete_alleles.extend(current_strand.alleles[last_crossover_point:])

    # If no crossovers occurred, or if the gamete is not yet full length (shouldn't happen with correct logic),
    # ensure it's filled from the selected starting strand.
    if not crossover_points and num_recombs == 0:
        gamete_alleles = current_strand.alleles[:] # Take all alleles from the chosen initial strand

    # Ensure the gamete has the correct length. This is a safeguard.
    if len(gamete_alleles) != chromosome_length:
        # If for some reason the gamete isn't the correct length, fall back to a simple random choice.
        # This shouldn't be hit with the current logic if inputs are valid.
        print(f"Warning: Gamete length mismatch ({len(gamete_alleles)} vs {chromosome_length}). Falling back to simple inheritance.")
        return random.choice([diploid_chromosome_pair.chromatid1, diploid_chromosome_pair.chromatid2])

    return Chromosome(gamete_alleles)

def produce_haploid_set_of_chromosomes_for_gamete(parent_individual, target_num_chromosomes_in_gamete_set, recomb_event_probabilities, recomb_probabilities):
    """
    Generates a full haploid set of chromosomes (a gamete) from a parent individual.
    This involves forming a gamete from each diploid chromosome pair of the parent.

    Args:
        parent_individual (Individual): The parent from which the gamete is produced.
        target_num_chromosomes_in_gamete_set (int): The number of chromosomes the resulting gamete should have.
                                                     This matches the number of diploid pairs in the offspring.
        recomb_event_probabilities (list): Probabilities for 0, 1, 2, ... recombination events.
        recomb_probabilities (list): A list of probabilities for recombination occurring between
                                     each adjacent pair of loci.

    Returns:
        list: A list of Chromosome objects, representing the haploid set (gamete).
    """
    haploid_set = []
    # Iterate through each diploid chromosome pair of the parent.
    for diploid_pair in parent_individual.diploid_chromosome_pairs:
        # Form a single haploid chromatid (gamete) from this pair, including recombination.
        gamete_chromatid = form_gamete_from_diploid_pair(
            diploid_pair, parent_individual.num_loci_per_chromosome,
            recomb_event_probabilities, recomb_probabilities
        )
        haploid_set.append(gamete_chromatid)

    # If the parent has more chromosomes than the offspring is supposed to inherit,
    # randomly select the required number. This supports cases like P0 to F1 (1 chr pair).
    if len(haploid_set) > target_num_chromosomes_in_gamete_set:
        return random.sample(haploid_set, target_num_chromosomes_in_gamete_set)
    # If the parent has fewer or equal chromosomes, return all.
    return haploid_set

def generate_offspring_individual(parent_a, parent_b, offspring_num_chromosomes, offspring_num_loci_per_chromosome, generation_label, recomb_event_probabilities, recomb_probabilities):
    """
    Generates a single offspring individual by combining gametes from two parents.

    Args:
        parent_a (Individual): The first parent.
        parent_b (Individual): The second parent.
        offspring_num_chromosomes (int): The number of diploid chromosome pairs the offspring will have.
        offspring_num_loci_per_chromosome (int): The number of loci on each chromosome for the offspring.
        generation_label (str): The label for the offspring's generation (e.g., "F1", "F2").
        recomb_event_probabilities (list): Probabilities for 0, 1, 2, ... recombination events.
        recomb_probabilities (list): A list of probabilities for recombination occurring between
                                     each adjacent pair of loci.

    Returns:
        Individual: A new Individual object representing the offspring.
    """
    # Create haploid gamete sets from each parent.
    # The number of chromosomes in the gamete set matches the number of diploid pairs the offspring will have.
    gamete_a_set = produce_haploid_set_of_chromosomes_for_gamete(
        parent_a, offspring_num_chromosomes, recomb_event_probabilities, recomb_probabilities
    )
    gamete_b_set = produce_haploid_set_of_chromosomes_for_gamete(
        parent_b, offspring_num_chromosomes, recomb_event_probabilities, recomb_probabilities
    )

    # Create a new Individual object for the offspring.
    offspring = Individual(offspring_num_chromosomes, offspring_num_loci_per_chromosome)

    # Combine chromosomes from the two gamete sets to form diploid pairs in the offspring.
    for i in range(offspring_num_chromosomes):
        # Create a new diploid pair from one chromatid from each parent's gamete set.
        offspring.diploid_chromosome_pairs.append(
            DiploidChromosomePair(gamete_a_set[i], gamete_b_set[i])
        )

    # Record the newly created offspring's genome for detailed data analysis.
    _record_individual_genome_for_detailed_dataframe(offspring, generation_label)
    return offspring

def run_genetic_cross(parent_population_a, parent_population_b, num_offspring_to_create,
                      generation_label, num_chromosomes_for_offspring,
                      recomb_event_probabilities, recomb_probabilities):
    """
    Manages the genetic cross between two parent populations to produce a new generation.

    Args:
        parent_population_a (list): A list of Individual objects representing the first parent population.
        parent_population_b (list): A list of Individual objects representing the second parent population.
        num_offspring_to_create (int): The desired number of offspring in the new generation.
        generation_label (str): The label for the new generation (e.g., "F1", "F2", "BC1").
        num_chromosomes_for_offspring (int): The number of diploid chromosome pairs offspring will have.
        recomb_event_probabilities (list): Probabilities for 0, 1, 2, ... recombination events.
        recomb_probabilities (list): A list of probabilities for recombination occurring between
                                     each adjacent pair of loci.

    Returns:
        list: A list of Individual objects representing the newly created offspring population.
    """
    new_generation_population = []
    # Determine the number of loci per chromosome from the first parent (assuming consistency).
    num_loci_per_chromosome = parent_population_a[0].num_loci_per_chromosome

    for _ in range(num_offspring_to_create):
        # Randomly select one parent from each population.
        parent_a = random.choice(parent_population_a)
        parent_b = random.choice(parent_population_b)

        # Generate a single offspring from the selected parents.
        offspring = generate_offspring_individual(
            parent_a, parent_b, num_chromosomes_for_offspring, num_loci_per_chromosome,
            generation_label, recomb_event_probabilities, recomb_probabilities
        )
        new_generation_population.append(offspring)

    return new_generation_population

# --- Chromosome Visualisation Function ---

def plot_individual_chromosomes(individuals_to_plot, titles, loci_per_chromosome, max_chromosomes_to_plot=1):
    """
    Visualises the alleles along a specified number of chromosomes for one or more individuals.
    Each diploid chromosome pair is shown with its two constituent chromatids.

    Args:
        individuals_to_plot (list): A list of Individual objects to visualise.
        titles (list): A list of titles for each individual, matching the order of individuals_to_plot.
        loci_per_chromosome (int): The length of each chromosome (number of loci).
        max_chromosomes_to_plot (int, optional): The maximum number of diploid chromosome pairs to plot per individual.
                                                  Defaults to 1 for a concise view across generations.
                                                  If set to None, plots all pairs (use with caution for F2+!).
    """
    if not individuals_to_plot:
        print("No individuals to plot.")
        return

    num_individuals = len(individuals_to_plot)

    # Determine the actual number of chromosome pairs we'll display for each individual.
    # We take the minimum of the individual's actual chromosome count and the 'max_chromosomes_to_plot' limit.
    actual_chroms_to_display_per_ind = max_chromosomes_to_plot if max_chromosomes_to_plot is not None else individuals_to_plot[0].num_chromosomes

    # Calculate the total number of subplots needed for the visualisations.
    # Each diploid pair consists of 2 chromatids, hence multiplying by 2.
    total_subplots = num_individuals * actual_chroms_to_display_per_ind * 2

    # Dynamically adjust figure size for better readability, especially with many loci or individuals.
    # The width scales with the number of subplots, and height scales with chromosome length.
    fig_width = total_subplots * 1.0
    fig_height = (loci_per_chromosome * 0.05) + 3

    # Ensure a minimum figure size for basic readability, regardless of input parameters.
    fig_width = max(10, fig_width)
    fig_height = max(6, fig_height)

    # Create the figure and a grid of subplots. `sharey=True` ensures all subplots share the same Y-axis limits.
    fig, axs = plt.subplots(1, total_subplots, figsize=(fig_width, fig_height), sharey=True)

    # If only one subplot is created (e.g., only one chromatid displayed), `axs` might not be an array.
    # Convert it to a list for consistent iteration.
    if total_subplots == 1:
        axs = [axs]

    # Drawing settings for individual allele blocks within the chromosome visualisation.
    block_width = 0.8    # Width of the rectangular block representing an allele.
    block_height = 0.9   # Height of the rectangular block.
    block_spacing = 1.1  # Vertical spacing between allele blocks.

    current_subplot_idx = 0 # Counter to keep track of the current subplot being drawn on.

    # Iterate through each individual selected for plotting.
    for ind_idx, individual in enumerate(individuals_to_plot):
        # Iterate through a limited number of diploid chromosome pairs for each individual.
        for pair_idx in range(min(len(individual.diploid_chromosome_pairs), actual_chroms_to_display_per_ind)):
            diploid_pair = individual.diploid_chromosome_pairs[pair_idx]

            # --- Plot Chromatid 1 (e.g., paternal origin) ---
            ax_chr1 = axs[current_subplot_idx]
            # Draw each allele as a coloured rectangle.
            for k, allele in enumerate(diploid_pair.chromatid1.alleles):
                y_pos = k * block_spacing
                ax_chr1.add_patch(
                    mpatches.Rectangle((0.1, y_pos), block_width, block_height, color=col_map[allele])
                )
            ax_chr1.set_xlim(0, 1) # Set X-axis limits for the current subplot.
            ax_chr1.set_ylim(-block_spacing, loci_per_chromosome * block_spacing) # Set Y-axis limits.
            ax_chr1.axis('off') # Hide axis ticks and labels for a cleaner chromosome visualisation.
            # Add a label at the bottom of the chromatid indicating which chromosome pair and strand it is.
            ax_chr1.text(0.5, -0.05, f"Chr {pair_idx+1}a", ha='center', va='top', fontsize=9, transform=ax_chr1.transAxes)
            current_subplot_idx += 1 # Move to the next subplot.

            # --- Plot Chromatid 2 (e.g., maternal origin) ---
            ax_chr2 = axs[current_subplot_idx]
            # Draw each allele as a coloured rectangle.
            for k, allele in enumerate(diploid_pair.chromatid2.alleles):
                y_pos = k * block_spacing
                ax_chr2.add_patch(
                    mpatches.Rectangle((0.1, y_pos), block_width, block_height, color=col_map[allele])
                )
            ax_chr2.set_xlim(0, 1)
            ax_chr2.set_ylim(-block_spacing, loci_per_chromosome * block_spacing)
            ax_chr2.axis('off')
            ax_chr2.text(0.5, -0.05, f"Chr {pair_idx+1}b", ha='center', va='top', fontsize=9, transform=ax_chr2.transAxes)
            current_subplot_idx += 1

        # Add a common title for the entire individual (e.g., "P0_A Parent") above its set of chromosomes.
        # This calculates the mid-point x-coordinate across the subplots allocated for this individual.
        start_ax_idx_for_ind = ind_idx * (actual_chroms_to_display_per_ind * 2)
        end_ax_idx_for_ind = start_ax_idx_for_ind + (actual_chroms_to_display_per_ind * 2 - 1)

        mid_x = axs[start_ax_idx_for_ind].get_position().x0 + \
                (axs[end_ax_idx_for_ind].get_position().x1 - axs[start_ax_idx_for_ind].get_position().x0) / 2

        # Add the title text for the individual. A small offset (-0.02) is applied to subtly shift the title left.
        fig.text(mid_x - 0.02, 0.95, titles[ind_idx], ha='center', va='bottom', fontsize=12, fontweight='bold')

    # Create a global legend for the allele colours (Magenta and Yellow).
    legend_patches = [
        mpatches.Patch(color=col_map[MAGENTA], label=f'{MAGENTA} (Magenta)'),
        mpatches.Patch(color=col_map[YELLOW], label=f'{YELLOW} (Yellow)')
    ]
    # Place the legend to the right of the plot area, vertically centred.
    fig.legend(handles=legend_patches, loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False, fontsize=10)

    # Adjust subplot parameters for a tighter layout and proper spacing.
    plt.subplots_adjust(wspace=0.3)
    # Display the plot.
    plt.show()


# --- Global Data Storage and Counter ---
# Global list to store all detailed genetic data records (locus-level genotypes).
# Each dictionary in this list will represent the genotype at a single locus for a given individual.
# This will be cleared before each simulation run.
genetic_data_records = []

# Global counter for assigning unique identification numbers to each individual created in the simulation.
# Starts at 1 to ensure individual IDs are 1-indexed for clarity and common biological practice.
# This will be reset before each simulation run.
individual_id_counter = 1


# --- Main Simulation Function ---
def run_genetic_simulation(
    chromosome_length_input,
    num_f1_individuals,
    num_target_gen_individuals, # This will be the size of the final generation (F2, F3, F_N, BC)
    num_chromosomes_f2_onwards,
    target_generation_number, # e.g., 2 for F2, 3 for F3, etc.
    include_backcrosses, # True/False
    use_uniform_recomb, # True/False
    output_all_generations # True/False
):
    """
    Runs the full genetic simulation based on the provided parameters.
    It generates P0, F1, and then proceeds to the target filial generation (F2, F3, F_N)
    and optional backcross generations.
    """
    global genetic_data_records
    global individual_id_counter

    # Reset globals for a clean run each time
    genetic_data_records = []
    individual_id_counter = 1

    print(f"--- Starting Genetic Simulation for F{target_generation_number} ---")
    print(f"Chromosome Length: {chromosome_length_input}")
    print(f"Number of F1 Individuals: {num_f1_individuals}")
    print(f"Number of Target Gen Individuals: {num_target_gen_individuals}")
    print(f"Chromosomes F2 Onwards: {num_chromosomes_f2_onwards}")
    print(f"Target Filial Generation: F{target_generation_number}")
    print(f"Include Backcrosses: {include_backcrosses}")
    print(f"Use Uniform Recombination: {use_uniform_recomb}")
    print(f"Output All Generations Data: {output_all_generations}\n")


    # --- Recombination Setup (adapted from your original code) ---
    # These global parameters will now be local to the function or derived from inputs
    current_chromosome_length = chromosome_length_input

    if use_uniform_recomb:
        recomb_probs_for_crosses = [0.5] * (current_chromosome_length - 1)
    else:
        # Example hotspot defined here. Adjust as needed for specific scenarios.
        recomb_probs_for_crosses = [0.01] * (current_chromosome_length - 1) # default low rate
        for i in range(current_chromosome_length - 1):
            if 45 <= i < 55: # Example hotspot region
                recomb_probs_for_crosses[i] = 0.1 # Higher recombination rate in hotspot

    recomb_event_probs_for_crosses = [0, 1, 0] # Example: always one crossover event if possible

    # --- Store populations and metrics for potential output ---
    populations_data = {}
    hybrid_indices_data = {}
    heterozygosities_data = {}

    # --- P0 Generation Creation and Calculations ---
    print("--- Creating P0 Generation ---")
    p0_a_individual = create_pure_parent(MAGENTA, 1, current_chromosome_length, "P0_A")
    p0_a_population = [p0_a_individual]
    p0_b_individual = create_pure_parent(YELLOW, 1, current_chromosome_length, "P0_B")
    p0_b_population = [p0_b_individual]

    p0_a_hi = p0_a_individual.calculate_hybrid_index()
    p0_a_het = p0_a_individual.calculate_heterozygosity()
    p0_b_hi = p0_b_individual.calculate_hybrid_index()
    p0_b_het = p0_b_individual.calculate_heterozygosity()

    populations_data['P0_A'] = p0_a_population
    populations_data['P0_B'] = p0_b_population
    hybrid_indices_data['P0_A'] = [p0_a_hi]
    hybrid_indices_data['P0_B'] = [p0_b_hi]
    heterozygosities_data['P0_A'] = [p0_a_het]
    heterozygosities_data['P0_B'] = [p0_b_het]


    # --- F1 Generation Creation and Calculations ---
    print("--- Creating F1 Generation ---")
    f1_population = run_genetic_cross(
        parent_population_a=p0_a_population,
        parent_population_b=p0_b_population,
        num_offspring_to_create=num_f1_individuals,
        generation_label="F1",
        num_chromosomes_for_offspring=1, # F1 always has 1 chromosome pair from P0
        recomb_event_probabilities=recomb_event_probs_for_crosses,
        recomb_probabilities=recomb_probs_for_crosses
    )
    f1_hybrid_indices = [ind.calculate_hybrid_index() for ind in f1_population]
    f1_heterozygosities = [ind.calculate_heterozygosity() for ind in f1_population]
    populations_data['F1'] = f1_population
    hybrid_indices_data['F1'] = f1_hybrid_indices
    heterozygosities_data['F1'] = f1_heterozygosities

    current_parent_population = f1_population

    # --- Loop for F2, F3, F_N Generations ---
    for gen_num in range(2, target_generation_number + 1):
        gen_label = f"F{gen_num}"
        print(f"--- Creating {gen_label} Generation ---")
        current_generation_population = run_genetic_cross(
            parent_population_a=current_parent_population,
            parent_population_b=current_parent_population,
            num_offspring_to_create=num_target_gen_individuals, # All subsequent F-gens use this size
            generation_label=gen_label,
            num_chromosomes_for_offspring=num_chromosomes_f2_onwards,
            recomb_event_probabilities=recomb_event_probs_for_crosses,
            recomb_probabilities=recomb_probs_for_crosses
        )
        current_gen_hi = [ind.calculate_hybrid_index() for ind in current_generation_population]
        current_gen_het = [ind.calculate_heterozygosity() for ind in current_generation_population]

        populations_data[gen_label] = current_generation_population
        hybrid_indices_data[gen_label] = current_gen_hi
        heterozygosities_data[gen_label] = current_gen_het

        current_parent_population = current_generation_population # Set for next iteration

    # --- Backcross Generations (Conditional) ---
    if include_backcrosses:
        # BC1_A: F1 x P0_A
        print("--- Creating BC1_A Generation ---")
        bc1_a_population = run_genetic_cross(
            parent_population_a=f1_population,
            parent_population_b=p0_a_population,
            num_offspring_to_create=num_target_gen_individuals, # Using target gen size for BCs
            generation_label="BC1_A",
            num_chromosomes_for_offspring=num_chromosomes_f2_onwards,
            recomb_event_probabilities=recomb_event_probs_for_crosses,
            recomb_probabilities=recomb_probs_for_crosses
        )
        bc1_a_hi = [ind.calculate_hybrid_index() for ind in bc1_a_population]
        bc1_a_het = [ind.calculate_heterozygosity() for ind in bc1_a_population]
        populations_data['BC1_A'] = bc1_a_population
        hybrid_indices_data['BC1_A'] = bc1_a_hi
        heterozygosities_data['BC1_A'] = bc1_a_het

        # BC1_B: F1 x P0_B
        print("--- Creating BC1_B Generation ---")
        bc1_b_population = run_genetic_cross(
            parent_population_a=f1_population,
            parent_population_b=p0_b_population,
            num_offspring_to_create=num_target_gen_individuals, # Using target gen size for BCs
            generation_label="BC1_B",
            num_chromosomes_for_offspring=num_chromosomes_f2_onwards,
            recomb_event_probabilities=recomb_event_probs_for_crosses,
            recomb_probabilities=recomb_probs_for_crosses
        )
        bc1_b_hi = [ind.calculate_hybrid_index() for ind in bc1_b_population]
        bc1_b_het = [ind.calculate_heterozygosity() for ind in bc1_b_population]
        populations_data['BC1_B'] = bc1_b_population
        hybrid_indices_data['BC1_B'] = bc1_b_hi
        heterozygosities_data['BC1_B'] = bc1_b_het


    # --- Create and Display DataFrame ---
    print("\n--- Compiling Genetic Data DataFrame ---")
    final_genetic_df = pd.DataFrame(genetic_data_records)

    print(f"\nSimulation complete for F{target_generation_number}.")

    # --- Return Results Based on 'output_all_generations' ---
    if output_all_generations:
        return {
            'genetic_df': final_genetic_df,
            'populations': populations_data,
            'hybrid_indices': hybrid_indices_data,
            'heterozygosities': heterozygosities_data
        }
    else:
        # Return only the target generation's data and relevant DataFrame slice
        target_gen_label = f"F{target_generation_number}"
        
        # Determine which generations to include in the "relevant_df" if output_all_generations is False
        generations_for_relevant_df = [target_gen_label]
        if include_backcrosses:
             generations_for_relevant_df.extend(['BC1_A', 'BC1_B'])
        
        relevant_df = final_genetic_df[final_genetic_df['generation'].isin(generations_for_relevant_df)]
        
        # Filter populations, hybrid_indices, and heterozygosities to include only what's needed for plotting
        filtered_populations = {}
        filtered_hybrid_indices = {}
        filtered_heterozygosities = {}

        # Always include P0 and F1 for plotting reference in the filtered output
        filtered_populations['P0_A'] = populations_data.get('P0_A')
        filtered_populations['P0_B'] = populations_data.get('P0_B')
        filtered_populations['F1'] = populations_data.get('F1')
        
        filtered_hybrid_indices['P0_A'] = hybrid_indices_data.get('P0_A')
        filtered_hybrid_indices['P0_B'] = hybrid_indices_data.get('P0_B')
        filtered_hybrid_indices['F1'] = hybrid_indices_data.get('F1')
        
        filtered_heterozygosities['P0_A'] = heterozygosities_data.get('P0_A')
        filtered_heterozygosities['P0_B'] = heterozygosities_data.get('P0_B')
        filtered_heterozygosities['F1'] = heterozygosities_data.get('F1')

        # Add target F-gen and BCs to the filtered data
        if target_gen_label in populations_data:
            filtered_populations[target_gen_label] = populations_data[target_gen_label]
            filtered_hybrid_indices[target_gen_label] = hybrid_indices_data[target_gen_label]
            filtered_heterozygosities[target_gen_label] = heterozygosities_data[target_gen_label]

        if include_backcrosses:
            if 'BC1_A' in populations_data:
                filtered_populations['BC1_A'] = populations_data['BC1_A']
                filtered_hybrid_indices['BC1_A'] = hybrid_indices_data['BC1_A']
                filtered_heterozygosities['BC1_A'] = heterozygosities_data['BC1_A']
            if 'BC1_B' in populations_data:
                filtered_populations['BC1_B'] = populations_data['BC1_B']
                filtered_hybrid_indices['BC1_B'] = hybrid_indices_data['BC1_B']
                filtered_heterozygosities['BC1_B'] = heterozygosities_data['BC1_B']


        return {
            'genetic_df': relevant_df,
            'populations': filtered_populations,
            'hybrid_indices': filtered_hybrid_indices,
            'heterozygosities': filtered_heterozygosities
        }

# --- Interactive Controls for the Simulation ---

# Define widgets for each parameter
chromosome_length_slider = widgets.IntSlider(
    value=50,
    min=10,
    max=200,
    step=5,
    description='Locus Length:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

num_f1_individuals_slider = widgets.IntSlider(
    value=100, # A reasonable default for F1, as they are usually identical
    min=1,
    max=500,
    step=10,
    description='Num F1 Indivs:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

num_target_gen_individuals_slider = widgets.IntSlider(
    value=500, # Default for F2, F3, BCs etc.
    min=10,
    max=2000,
    step=50,
    description='Num Individuals:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

num_chromosomes_f2_onwards_slider = widgets.IntSlider(
    value=10,
    min=1,
    max=20,
    step=1,
    description='Num Chromosomes:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

target_generation_number_slider = widgets.IntSlider(
    value=2, # Default to F2
    min=1,
    max=10, # Limiting for now, can be increased if needed
    step=1,
    description='Target F-Gen:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

include_backcrosses_checkbox = widgets.Checkbox(
    value=False,
    description='Include BC1s?',
    disabled=False,
    indent=False
)

use_uniform_recomb_checkbox = widgets.Checkbox(
    value=True,
    description='Uniform Recomb?',
    disabled=False,
    indent=False
)

output_all_generations_checkbox = widgets.Checkbox(
    value=True, # Default to showing all generation data
    description='Output All Generations Data?',
    disabled=False,
    indent=False
)

# --- Output Area for Plots and Text ---
output_area = widgets.Output()

# --- Function to run simulation and display results ---
def on_simulate_button_clicked(b):
    with output_area:
        clear_output(wait=True) # Clear previous output before new run
        print("Running simulation...")

        # Call the main simulation function with current widget values
        simulation_results = run_genetic_simulation(
            chromosome_length_input=chromosome_length_slider.value,
            num_f1_individuals=num_f1_individuals_slider.value,
            num_target_gen_individuals=num_target_gen_individuals_slider.value,
            num_chromosomes_f2_onwards=num_chromosomes_f2_onwards_slider.value,
            target_generation_number=target_generation_number_slider.value,
            include_backcrosses=include_backcrosses_checkbox.value,
            use_uniform_recomb=use_uniform_recomb_checkbox.value,
            output_all_generations=output_all_generations_checkbox.value
        )

        # --- Display Results ---
        genetic_df = simulation_results['genetic_df']
        populations = simulation_results['populations']
        hybrid_indices = simulation_results['hybrid_indices']
        heterozygosities = simulation_results['heterozygosities']


        print("\n--- Genetic Data DataFrame (First 5 rows of selected data) ---")
        display(genetic_df.head())
        print(f"\nTotal records in DataFrame: {len(genetic_df)}")

        # Prepare P0 and F1 data for plotting reference
        # These should always be available in the 'populations' dictionary
        p0_a_hi_plot = hybrid_indices['P0_A'][0]
        p0_a_het_plot = heterozygosities['P0_A'][0]
        p0_b_hi_plot = hybrid_indices['P0_B'][0]
        p0_b_het_plot = heterozygosities['P0_B'][0]
        f1_hi_plot = hybrid_indices['F1'][0] # Assuming F1s are identical for this point
        f1_het_plot = heterozygosities['F1'][0]

        # --- Triangle Plot ---
        print("\n--- Plotting Triangle Plot ---")
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(1.2)
        ax.spines['bottom'].set_linewidth(1.2)
        ax.set_xlabel("Hybrid Index (proportion M alleles)", fontsize=12)
        ax.set_ylabel("Heterozygosity (proportion heterozygous loci)", fontsize=12)

        # Plot P0 and F1 reference points (always included)
        ax.scatter(p0_a_hi_plot, p0_a_het_plot, color='magenta', s=80, edgecolor='black', zorder=5, label='P0_A (MM)')
        ax.scatter(p0_b_hi_plot, p0_b_het_plot, color='yellow', s=80, edgecolor='black', zorder=5, label='P0_B (YY)')
        ax.scatter(f1_hi_plot, f1_het_plot, color='orange', s=80, edgecolor='black', zorder=5, label='F1')
        ax.annotate('P0_A', (p0_a_hi_plot, p0_a_het_plot), xytext=(5, 5), textcoords='offset points', fontsize=10, fontweight='bold')
        ax.annotate('P0_B', (p0_b_hi_plot, p0_b_het_plot), xytext=(5, 5), textcoords='offset points', fontsize=10, fontweight='bold')
        ax.annotate('F1', (f1_hi_plot, f1_het_plot), xytext=(5, 5), textcoords='offset points', fontsize=10, fontweight='bold')

        plot_handles = [
            Line2D([0], [0], marker='o', color='w', markerfacecolor='magenta', markersize=10, markeredgecolor='black', label='P0_A (MM)'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='yellow', markersize=10, markeredgecolor='black', label='P0_B (YY)'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='orange', markersize=10, markeredgecolor='black', label='F1')
        ]

        # Dynamically plot generations based on what's available in 'hybrid_indices'
        # Define a consistent set of colors for generations
        plot_colors = {
            'F2': 'black', 'F3': 'green', 'F4': 'darkblue', 'F5': 'purple',
            'F6': 'brown', 'F7': 'teal', 'F8': 'olive', 'F9': 'pink', 'F10': 'cyan',
            'BC1_A': 'blue', 'BC1_B': 'red'
        }
        plot_alphas = {
            'F2': 0.2, 'F3': 0.2, 'F4': 0.2, 'F5': 0.2, 'F6': 0.2, 'F7': 0.2, 'F8': 0.2, 'F9': 0.2, 'F10': 0.2,
            'BC1_A': 0.7, 'BC1_B': 0.7
        }

        # Sort generations for consistent plotting order (e.g., F1, F2, F3, then BCs)
        # Exclude P0 and F1 from this loop as they are plotted separately as reference points
        generations_to_plot_scatter = [
            gen_label for gen_label in populations.keys()
            if gen_label not in ['P0_A', 'P0_B', 'F1']
        ]
        # Sort F-generations numerically, then backcrosses
        generations_to_plot_scatter.sort(key=lambda x: (int(x[1:]) if x.startswith('F') else (100 if x.startswith('BC1_A') else 200))) # Assign large numbers for BCs to sort them last

        for gen_label in generations_to_plot_scatter:
            if gen_label in hybrid_indices and hybrid_indices[gen_label]: # Check if data exists and is not empty
                ax.scatter(
                    hybrid_indices[gen_label],
                    heterozygosities[gen_label],
                    color=plot_colors.get(gen_label, 'gray'), # Default to gray if color not defined
                    alpha=plot_alphas.get(gen_label, 0.2),
                    s=15,
                    label=f'{gen_label} Population'
                )
                plot_handles.append(Line2D([0], [0], marker='o', color='w',
                                           markerfacecolor=plot_colors.get(gen_label, 'gray'),
                                           alpha=plot_alphas.get(gen_label, 0.3), markersize=8,
                                           label=f'{gen_label} Population'))

        # Draw triangle edges
        triangle_edges = [
            [(0.0, 0.0), (0.5, 1.0)],
            [(0.5, 1.0), (1.0, 0.0)],
            [(0.0, 0.0), (1.0, 0.0)]
        ]
        for (x0, y0), (x1, y1) in triangle_edges:
            ax.plot([x0, x1], [y0, y1], linestyle='-', color='gray', linewidth=1.5, alpha=0.7)

        ax.set_xlim(-0.05, 1.05)
        ax.set_ylim(-0.05, 1.05)
        ax.grid(False)
        ax.set_aspect('equal', adjustable='box')
        ax.legend(handles=plot_handles, loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False, fontsize=9)
        plt.tight_layout()
        plt.subplots_adjust(right=0.8)
        plt.show()

        # --- Chromosome Visualisation ---
        print("\n--- Visualising Representative Chromosomes ---")

        # Select representative individuals based on output mode
        individuals_to_plot_viz = []
        titles_viz = []

        # Always add P0_A, P0_B, F1 for visual clarity
        individuals_to_plot_viz.append(populations['P0_A'][0])
        titles_viz.append("P0_A")
        individuals_to_plot_viz.append(populations['P0_B'][0])
        titles_viz.append("P0_B")
        individuals_to_plot_viz.append(populations['F1'][0])
        titles_viz.append("F1")

        # Add target F-gen and BCs if they exist and are requested/generated
        target_f_gen_label = f"F{target_generation_number_slider.value}"
        if target_f_gen_label in populations and populations[target_f_gen_label]:
            individuals_to_plot_viz.append(populations[target_f_gen_label][0])
            titles_viz.append(target_f_gen_label)

        if include_backcrosses_checkbox.value:
            if 'BC1_A' in populations and populations['BC1_A']:
                individuals_to_plot_viz.append(populations['BC1_A'][0])
                titles_viz.append("BC1_A")
            if 'BC1_B' in populations and populations['BC1_B']:
                individuals_to_plot_viz.append(populations['BC1_B'][0])
                titles_viz.append("BC1_B")

        plot_individual_chromosomes(individuals_to_plot_viz, titles_viz, chromosome_length_slider.value, max_chromosomes_to_plot=1)

# --- Create the "Run Simulation" button ---
simulate_button = widgets.Button(description="Run Simulation")
simulate_button.on_click(on_simulate_button_clicked)

# --- Arrange widgets for display ---
controls = widgets.VBox([
    chromosome_length_slider,
    num_f1_individuals_slider,
    num_target_gen_individuals_slider,
    num_chromosomes_f2_onwards_slider,
    target_generation_number_slider,
    include_backcrosses_checkbox,
    use_uniform_recomb_checkbox,
    output_all_generations_checkbox,
    simulate_button
])

# Display controls and the output area
display(controls, output_area)

VBox(children=(IntSlider(value=50, continuous_update=False, description='Locus Length:', max=200, min=10, step…

Output()