In [1]:
# Cell 1: Essential Imports, Global Constants, Initialisation of Data Stores, and Matplotlib Backend Setup

# I'm placing all my required imports here at the very beginning of the script.
# This ensures all modules are available before I start defining classes and functions.
import numpy as np             # My general-purpose numerical computing library, essential for array operations and calculations.
import random                  # A useful module for generating random numbers, crucial for my simulations of genetic processes.
import itertools               # For creating efficient iterators, such as the colour cycling for plotting different generations.
import matplotlib.pyplot as plt # This is the core plotting library I'm using to create all my visualisations of hybrid index and heterozygosity.
from matplotlib.lines import Line2D # Specifically importing Line2D from matplotlib.lines, as I need it to create custom legend entries for my plots.
from typing import List, Tuple, Dict, Any # I use type hints like List and Tuple to make my code more readable and robust.
import os                      # Essential for interacting with the operating system, particularly for managing file paths.
import matplotlib              # The main Matplotlib module itself, needed here to explicitly set the backend for interactive plots.
import re                      # The 're' module handles regular expressions, which I'll use for parsing generation names.
import mplcursors              # This is a fantastic third-party library that adds interactive data cursors to my Matplotlib plots.
import csv
import time

# --- IMPORTANT: I'm setting the Matplotlib backend for interactivity here ---
# This block is crucial for ensuring my plots are interactive within environments like Jupyter.
# I'm attempting to use 'QtAgg' first as it's often preferred, then falling back to 'TkAgg'.
try:
    matplotlib.use('QtAgg')
    print("I'm using Matplotlib backend: QtAgg")
except ImportError:
    matplotlib.use('TkAgg')
    print("I'm falling back to Matplotlib backend: TkAgg")
except Exception as e:
    print(f"I couldn't set an interactive backend, an error occurred: {e}. Falling back to Matplotlib's default.")
    pass # I'll let Matplotlib choose its default if both preferred backends fail.


# I'm defining constants for the alleles used in my simulation. This makes the code
# more readable and easier to modify if I ever change the allele representation.
MAGENTA = 'M'  # allele label
YELLOW = 'Y'   # allele label

# Mapping allele labels to numeric values:
allele_to_num = {MAGENTA: 1, YELLOW: 0}

def genotype_to_numeric(genotype: list[str]) -> list[int]:
    """Convert list of allele letters to numeric values."""
    return [allele_to_num.get(allele, -1) for allele in genotype]

# I'm setting up global lists to store data generated throughout the simulation.
# This allows me to collect information from various parts of the process.
genetic_data_records = []         # This list will store detailed genotype information for all individuals.
chromatid_recombination_records = [] # This list will log details of recombination events for each chromatid.

# I need a global counter to assign a unique ID to each individual created in the simulation.
# It ensures every simulated creature has a distinct identifier.
individual_id_counter = 1

I'm using Matplotlib backend: QtAgg


In [1]:
# Check the individual class and use of numeric genotypes - not sure the use of 0,1 is correct? think this is messing with het calculations. 

In [2]:
# Cell 2: Genetic Data Structures - Chromosomes and Individuals

# This cell defines the fundamental data structures I'm using to represent genetic material
# and individuals within my simulation.

class Chromosome:
    def __init__(self, alleles: List[int]): # <<<--- CHANGED: alleles are now List[int]
        """
        I use this class to represent a single chromosome strand. It's essentially a linear sequence of alleles.

        Args:
            alleles (List[int]): A list of integers, where each integer represents an allele at a locus.
                                 (e.g., [0, 0, 1, ...]). Typically 0 for P_A origin, 1 for P_B origin.
        """
        self.alleles = alleles # Stores the sequence of alleles for this chromosome.

    def __repr__(self) -> str:
        """
        A string representation for the Chromosome object, useful for debugging.
        It shows a snippet of the alleles for brevity.
        """
        # I take the first 10 alleles to provide a quick look without printing the whole chromosome.
        # Use map(str, ...) to convert integers to strings before joining
        snippet = ''.join(map(str, self.alleles[:10])) if self.alleles else ''
        return f"Chr({snippet}...)" # Format: Chr(001...)


class DiploidChromosomePair:
    def __init__(self, chromatid1: Chromosome, chromatid2: Chromosome):
        """
        This class represents a pair of homologous chromosomes, as found in a diploid organism.
        Each chromosome in the pair is an instance of my 'Chromosome' class.

        Args:
            chromatid1 (Chromosome): The first chromatid (homolog) in the pair.
            chromatid2 (Chromosome): The second chromatid (homolog) in the pair.
        """
        self.chromatid1 = chromatid1 # Stores the first chromosome of the pair.
        self.chromatid2 = chromatid2 # Stores the second chromosome of the pair.

    def __repr__(self) -> str:
        """
        A string representation for the DiploidChromosomePair, showing both chromatids.
        """
        # I format it to show each chromatid on a new line for better readability.
        return f"Pair(\n  {self.chromatid1}\n  {self.chromatid2}\n)"


class Individual:
    def __init__(self, num_chromosomes: int, num_loci_per_chromosome: int):
        global individual_id_counter
        self.id = individual_id_counter
        individual_id_counter += 1

        self.num_chromosomes = num_chromosomes
        self.num_loci_per_chromosome = num_loci_per_chromosome
        self.diploid_chromosome_pairs: List['DiploidChromosomePair'] = [] # Type hint as string for forward reference

    def get_all_numeric_genotypes(self) -> List[int]:
        all_numeric = []
        for pair in self.diploid_chromosome_pairs:
            alleles_chromatid1 = pair.chromatid1.alleles
            alleles_chromatid2 = pair.chromatid2.alleles
            for i in range(self.num_loci_per_chromosome):
                allele1 = alleles_chromatid1[i]
                allele2 = alleles_chromatid2[i]

                # Now that Chromosome.alleles stores integers, this logic is correct.
                # Compare against integer allele values (0 for P_A, 1 for P_B)
                if allele1 == allele2: # Homozygous locus
                    if allele1 == 0: # Homozygous for the '0' allele (e.g., P_A type)
                        all_numeric.append(0) # Represents 0 '1' alleles (P_B type alleles)
                    elif allele1 == 1: # Homozygous for the '1' allele (e.g., P_B type)
                        all_numeric.append(2) # Represents 2 '1' alleles (P_B type alleles)
                else: # Heterozygous locus (one 0 and one 1)
                    all_numeric.append(1) # Represents 1 '1' allele (P_B type allele)
        return all_numeric

    def calculate_hybrid_index(self) -> float:
        all_numeric_genotypes = self.get_all_numeric_genotypes()
        total_loci = len(all_numeric_genotypes)
        if total_loci == 0:
            return 0.0
        # Sum of '1' alleles: 0 for (0,0), 1 for (0,1), 2 for (1,1)
        # The sum of all_numeric_genotypes already represents the total count of '1' alleles across all loci.
        sum_of_one_alleles = sum(all_numeric_genotypes)
        total_possible_one_alleles = 2 * total_loci # Each locus has two alleles, both *could* be '1'

        return sum_of_one_alleles / total_possible_one_alleles


    def calculate_heterozygosity(self) -> float:
        all_numeric_genotypes = self.get_all_numeric_genotypes()
        total_loci = len(all_numeric_genotypes)
        if total_loci == 0:
            return 0.0
        # Heterozygous loci are those represented by '1' in all_numeric_genotypes
        heterozygous_count = all_numeric_genotypes.count(1)
        return heterozygous_count / total_loci

    # Keep your existing get_chromatid_block_data and _analyse_single_chromatid functions as they are
    def get_chromatid_block_data(self):
        all_chromatid_data = []
        chromatid_labels = ['A', 'B']

        for chr_idx, diploid_pair in enumerate(self.diploid_chromosome_pairs):
            chromatids_in_pair = [diploid_pair.chromatid1, diploid_pair.chromatid2]

            for i, chromatid in enumerate(chromatids_in_pair):
                chromatid_alleles = chromatid.alleles
                junctions, lengths, alleles = self._analyse_single_chromatid(chromatid_alleles)

                all_chromatid_data.append({
                    'individual_id': self.id,
                    'diploid_chr_id': chr_idx + 1,
                    'chromatid_in_pair': chromatid_labels[i],
                    'total_junctions': junctions,
                    'block_lengths': lengths,
                    'block_alleles': alleles
                })
        return all_chromatid_data

    def _analyse_single_chromatid(self, alleles: List[int]) -> Tuple[int, List[int], List[int]]:
        if not alleles:
            return 0, [], []

        block_lengths = []
        block_alleles = []

        # Assuming alleles are now integers (0 or 1), itertools.groupby works correctly
        for allele, group in itertools.groupby(alleles):
            group_list = list(group)
            block_lengths.append(len(group_list))
            block_alleles.append(allele) # Alleles here will be 0 or 1

        junctions = len(block_lengths) - 1 if block_lengths else 0
        return junctions, block_lengths, block_alleles

In [3]:
# Cell 3: Meiosis Function

import random # Make sure random is imported at the top of this cell or earlier

def meiosis_with_recombination(
    diploid_pair: 'DiploidChromosomePair', # Assuming DiploidChromosomePair is defined
    recomb_event_probabilities: dict,
    recomb_probabilities: list # Length = num_loci_per_chromosome (probabilities BETWEEN loci)
) -> 'Chromosome': # Assuming Chromosome is defined
    """
    Simulates meiosis with a variable number of recombination events for one chromosome pair.

    Args:
        diploid_pair (DiploidChromosomePair): The pair of homologous chromatids.
        recomb_event_probabilities (dict): Probability for 0, 1, or 2 recombination events, e.g., {0: 0.1, 1: 0.85, 2: 0.05}.
        recomb_probabilities (list): Position-dependent probabilities for recombination along loci (length = loci per chromosome - 1).

    Returns:
        Chromosome: A recombinant chromosome after meiosis.
    """
    loci_len = len(diploid_pair.chromatid1.alleles)
    
    # Adjust recomb_probabilities to be for N-1 intervals if it's currently N loci + 1
    # Assuming recomb_probabilities is meant for intervals *between* loci.
    # If it's passed as [p0, p1, ..., pN-1] for N loci, where p0 is for locus 0, this logic is correct.
    # If it's passed for intervals, it should be len = loci_len - 1.
    # From problem: recomb_probabilities = [0.01] + [0.01]*(num_loci_per_chr - 1) means it has num_loci_per_chr elements.
    # The list(range(1, loci_len)) means breakpoints from index 1 to loci_len-1. So, loci_len - 1 possible intervals.
    # The weights slicing [1:loci_len] correctly aligns with this.

    # Decide how many recombination events happen (0, 1, or 2)
    n_events = random.choices(
        population=[0, 1, 2],
        weights=[recomb_event_probabilities.get(i, 0) for i in [0, 1, 2]],
        k=1
    )[0]

    # Possible breakpoints are between loci. For 'loci_len' loci, there are 'loci_len - 1' possible breakpoints.
    # These are at indices 1 to loci_len-1.
    possible_positions = list(range(1, loci_len))
    chosen_positions = []

    if n_events > 0:
        # Use recomb_probabilities to weight positions for crossover.
        # Ensure weights array matches possible_positions length (loci_len - 1)
        # Assuming recomb_probabilities corresponds to intervals [locus_0-locus_1, locus_1-locus_2, ..., locus_n-2-locus_n-1]
        weights = recomb_probabilities[:loci_len-1] # Take only the relevant weights for intervals
        weights_sum = sum(weights)

        if weights_sum == 0:
            # If all weights zero, choose breakpoints randomly without weights
            if len(possible_positions) < n_events: # Handle case where not enough positions for desired events
                chosen_positions = possible_positions[:] # Take all positions
            else:
                chosen_positions = sorted(random.sample(possible_positions, n_events))
        else:
            # Weighted random sampling without replacement
            chosen_positions = []
            # Make sure we don't try to pick more unique positions than available
            num_events_to_pick = min(n_events, len(possible_positions))
            while len(chosen_positions) < num_events_to_pick:
                # 'weights' needs to be aligned with 'possible_positions' or chosen from a different source
                # The issue here is that random.choices expects weights to correspond to the 'population' list.
                # If weights = recomb_probabilities[:loci_len-1] and possible_positions = list(range(1, loci_len)),
                # these lists have the same length and align correctly.
                pos = random.choices(possible_positions, weights=weights, k=1)[0]
                if pos not in chosen_positions:
                    chosen_positions.append(pos)
            chosen_positions.sort()

        # Start with a random chromatid to begin the segment copying
        current_chromatid_source = random.choice([diploid_pair.chromatid1.alleles, diploid_pair.chromatid2.alleles])
        other_chromatid_source = diploid_pair.chromatid1.alleles if current_chromatid_source is diploid_pair.chromatid2.alleles else diploid_pair.chromatid2.alleles

        recombinant_alleles = []
        last_pos = 0
        breakpoints = chosen_positions + [loci_len] # Include the end of the chromosome as a breakpoint

        for pos in breakpoints:
            recombinant_alleles.extend(current_chromatid_source[last_pos:pos])
            # Switch source for the next segment
            current_chromatid_source, other_chromatid_source = other_chromatid_source, current_chromatid_source
            last_pos = pos

        return Chromosome(recombinant_alleles)

    else: # n_events == 0 (no recombination)
        # With no recombination, randomly choose to pass on either chromatid1 or chromatid2
        if random.random() < 0.5:
            return Chromosome(diploid_pair.chromatid1.alleles)
        else:
            return Chromosome(diploid_pair.chromatid2.alleles)


In [4]:
# Cell 4: Data Recording Functions
# This cell contains helper functions I use to record the detailed genetic and recombination
# data of individuals into my global lists, as well as utilities for extracting key metrics.

def record_individual_genome(individual: Individual, generation_label: str):
    """
    I use this function to record the full genotype of each locus for every chromosome pair
    within a given individual. This data is then appended to the global `genetic_data_records` list.

    Each entry in `genetic_data_records` is a dictionary providing:
      - 'generation': The specific generation label (e.g., 'F2', 'BC1A').
      - 'individual_id': The unique identifier for the individual.
      - 'diploid_chr_id': The chromosome pair number (1-based for clarity).
      - 'locus_position': The position index of the locus along the chromosome (0-based).
      - 'genotype': A string representing the alleles at this locus, e.g., 'M|Y'.

    Args:
        individual (Individual): The 'Individual' object whose genome I want to record.
        generation_label (str): A string label to associate with the current generation.
    """
    # Iterate through each diploid chromosome pair of the individual.
    for chr_idx, pair in enumerate(individual.diploid_chromosome_pairs):
        # Then, iterate through each locus on the chromosome.
        for locus_idx in range(individual.num_loci_per_chromosome):
            # Extract the alleles from both chromatids at the current locus.
            allele_a = pair.chromatid1.alleles[locus_idx]
            allele_b = pair.chromatid2.alleles[locus_idx]
            # Form a standard genotype string (e.g., "M|Y").
            genotype_str = f"{allele_a}|{allele_b}"
            
            # Append a dictionary containing all the relevant details for this locus to my global records.
            genetic_data_records.append({
                'generation': generation_label,
                'individual_id': individual.id,
                'diploid_chr_id': chr_idx + 1, # Use 1-based indexing for chromosome ID.
                'locus_position': locus_idx,
                'genotype': genotype_str
            })


def record_chromatid_recombination(individual: Individual, generation_label: str):
    """
    This function records the detailed recombination block data for an individual's chromatids.
    It calls the individual's own method to get the block data, then enriches it with the
    generation label before appending it to the global `chromatid_recombination_records` list.

    Args:
        individual (Individual): The 'Individual' object whose recombination data I want to record.
        generation_label (str): A string label to associate with the current generation.
    """
    # I get the recombination block data for all chromatids of the individual.
    chromatid_data = individual.get_chromatid_block_data()
    # For each record (which represents one chromatid's data), I add the generation label.
    for record in chromatid_data:
        record['generation'] = generation_label
        # Then, I append the enriched record to my global recombination records list.
        chromatid_recombination_records.append(record)

In [5]:
# Cell 5: Population Creation Functions and Statistics Utility

# This cell provides helper functions for setting up initial populations
# and for calculating summary statistics of any given population.

# Assuming Chromosome and Individual classes, and DiploidChromosomePair are defined.
import random # Make sure random is imported, if not already at the top of your script/notebook

def create_pure_individual(num_chromosomes: int, num_loci_per_chr: int, allele_type: str) -> 'Individual':
    """
    I use this function to create a single individual that is 'pure' for a specific allele type.
    This means all its loci across all its chromosomes will be homozygous for the given allele.
    This is typically used for my initial parental populations (P_A or P_B).

    Args:
        num_chromosomes (int): The total number of diploid chromosome pairs for this individual.
        num_loci_per_chr (int): The number of genetic loci on each chromosome.
        allele_type (str): The allele ('0' or '1' string, or 'M'/'Y') that will fill all loci.

    Returns:
        Individual: A newly created 'Individual' object, homozygous for the specified allele at every locus.
    """
    # Convert the string allele_type to an integer (0 or 1) for internal consistency.
    if allele_type == '0' or allele_type.upper() == 'M':
        initial_allele_int = 0
    elif allele_type == '1' or allele_type.upper() == 'Y':
        initial_allele_int = 1
    else:
        raise ValueError(f"Unsupported allele_type: {allele_type}. Expected '0', '1', 'M', or 'Y'.")

    # I create a new Individual instance with the specified chromosome and locus counts.
    # REMOVED THE 'id' ARGUMENT from the Individual constructor call, as per your error.
    individual = Individual(
        num_chromosomes=num_chromosomes,
        num_loci_per_chromosome=num_loci_per_chr
        # If your Individual constructor expects 'initial_allele_value' here,
        # and you want to pass it directly: initial_allele_value=initial_allele_int
        # Otherwise, the manual building of diploid_chromosome_pairs below is correct.
    )
    
    # For each chromosome pair the individual has:
    for _ in range(num_chromosomes):
        # I create a list of alleles where every locus is filled with the specified INTEGER allele value.
        chromosome_alleles = [initial_allele_int] * num_loci_per_chr

        # Both chromatids in the pair will be identical, ensuring homozygosity.
        chromatid1 = Chromosome(chromosome_alleles[:])
        chromatid2 = Chromosome(chromosome_alleles[:])

        # I add the newly formed homozygous diploid chromosome pair to the individual.
        individual.diploid_chromosome_pairs.append(DiploidChromosomePair(chromatid1, chromatid2))

    return individual

def create_pure_populations(
    num_individuals: int,
    num_chromosomes: int,
    num_loci_per_chr: int,
    allele_type: str
) -> List[Individual]:
    """
    This function allows me to easily create an entire population of 'pure' individuals.
    All individuals in this population will be homozygous for the same specified allele type.

    Args:
        num_individuals (int): The number of individuals I want in this pure population.
        num_chromosomes (int): The number of chromosomes each individual in the population will have.
        num_loci_per_chr (int): The number of loci on each chromosome for individuals in this population.
        allele_type (str): The allele ('M' or 'Y') that all individuals in this population will be homozygous for.

    Returns:
        List[Individual]: A list containing all the newly created pure individuals.
    """
    # I use a list comprehension to efficiently create multiple pure individuals
    # by repeatedly calling my 'create_pure_individual' function.
    return [create_pure_individual(num_chromosomes, num_loci_per_chr, allele_type) for _ in range(num_individuals)]


def create_F1_population(
    pure_pop_A: List[Individual],
    pure_pop_B: List[Individual],
    recomb_event_probabilities: dict,
    recomb_probabilities: List[float]
) -> List[Individual]:
    """
    I use this function to generate the first filial (F1) hybrid population.
    This is achieved by crossing paired individuals from two pure parental populations.
    Each F1 individual will receive one recombinant chromatid from a parent from 'pure_pop_A'
    and one from a parent from 'pure_pop_B'.

    Args:
        pure_pop_A (List[Individual]): The pure parental population (e.g., all 'M' alleles).
        pure_pop_B (List[Individual]): The other pure parental population (e.g., all 'Y' alleles).
        recomb_event_probabilities (dict): The probability distribution for the number of recombination events per chromosome.
        recomb_probabilities (List[float]): The position-dependent probabilities for recombination along chromosomes.

    Raises:
        ValueError: If the input parental populations are not of the same size, as pairing
                    for crosses would be ambiguous.

    Returns:
        List[Individual]: A list containing all the newly created F1 hybrid individuals.
    """
    # I first check that the parental populations are of equal size, which is necessary for paired crosses.
    if len(pure_pop_A) != len(pure_pop_B):
        raise ValueError("Error: Pure populations must be the same size to create F1 population via paired crosses.")

    f1_population = [] # Initialise an empty list to store the F1 individuals.

    # I iterate through the parental populations, pairing individuals by their index.
    for i in range(len(pure_pop_A)):
        parent_A = pure_pop_A[i] # Get one parent from population A.
        parent_B = pure_pop_B[i] # Get the corresponding parent from population B.

        # I create a new 'Individual' instance for the F1 offspring.
        # It will have the same chromosome and locus structure as its parents.
        child = Individual(parent_A.num_chromosomes, parent_A.num_loci_per_chromosome)
        child.diploid_chromosome_pairs = [] # I explicitly clear this list, though it should be empty on new creation.

        # For each chromosome pair, I simulate gamete formation and combine them for the offspring.
        for chr_idx in range(parent_A.num_chromosomes):
            # I get the specific diploid chromosome pair from each parent.
            chr_A_pair = parent_A.diploid_chromosome_pairs[chr_idx]
            chr_B_pair = parent_B.diploid_chromosome_pairs[chr_idx]

            # I simulate meiosis to get one recombinant haploid chromatid from each parent.
            haploid_A = meiosis_with_recombination(chr_A_pair, recomb_event_probabilities, recomb_probabilities)
            haploid_B = meiosis_with_recombination(chr_B_pair, recomb_event_probabilities, recomb_probabilities)

            # I then combine these two haploid chromatids to form a new diploid pair for the F1 child.
            child.diploid_chromosome_pairs.append(DiploidChromosomePair(haploid_A, haploid_B))

        # After creating the F1 child, I immediately record its genetic and recombination data.
        record_individual_genome(child, 'F1')
        record_chromatid_recombination(child, 'F1')

        f1_population.append(child) # Add the new F1 individual to the list.
    return f1_population # Return the complete F1 population.


def population_stats(pop: List[Individual]) -> dict:
    """
    I use this helper function to calculate key summary statistics for a given population of 'Individual' objects.
    This helps me quickly understand the genetic composition of each generation.

    Args:
        pop (List[Individual]): A list of individuals in the population.

    Returns:
        dict: Summary stats including mean and std deviation of hybrid index (HI),
              mean and std deviation of heterozygosity (HET), and population size.
    """
    his = [ind.calculate_hybrid_index() for ind in pop]   # I calculate the Hybrid Index for each individual.
    hets = [ind.calculate_heterozygosity() for ind in pop] # I calculate the Heterozygosity for each individual.

    # I return a dictionary with the calculated statistics. I use conditional checks (if his/hets else 0)
    # to prevent errors if a population happens to be empty.
    return {
        'mean_HI': np.mean(his) if his else 0,
        'std_HI': np.std(his) if his else 0,
        'mean_HET': np.mean(hets) if hets else 0,
        'std_HET': np.std(hets) if hets else 0,
        'count': len(pop)
    }

In [6]:
# Cell 6: Breeding Plan Functions

# This cell contains functions to systematically build my breeding plans,
# which define how different generations will be crossed.

def build_forward_generations(base_name: str, start_gen: int, end_gen: int) -> List[Tuple[str, str, str]]:
    """
    I use this function to create a breeding plan for sequential forward generations (e.g., F1, F2, F3...).
    The process starts from 'start_gen' and goes up to 'end_gen' (inclusive).
    The very first generation (specified by 'start_gen') is always a cross between two pure parental populations ('P_A' and 'P_B').
    Subsequent generations in this forward sequence are then bred by crossing individuals from the *previous* generation amongst themselves.

    Args:
        base_name (str): The prefix for the generation names (e.g., "F" for Filial generations).
        start_gen (int): The starting generation number (e.g., 1 for F1).
        end_gen (int): The final generation number to include (e.g., 5 for F5).

    Returns:
        List[Tuple[str, str, str]]: A list of tuples, where each tuple represents a planned cross:
                                    (new_generation_label, parent1_label, parent2_label).
    """
    plan = [] # Initialise an empty list to store my breeding plan.
    for i in range(start_gen, end_gen + 1):
        current_gen_label = f"{base_name}{i}" # Construct the label for the current generation, e.g., "F1", "F2".
        if i == start_gen:
            # For the first generation in the sequence, I'm crossing the pure parental populations.
            plan.append((current_gen_label, 'P_A', 'P_B'))
        else:
            # For subsequent generations, I cross individuals from the previous generation with themselves.
            previous_gen_label = f"{base_name}{i-1}"
            plan.append((current_gen_label, previous_gen_label, previous_gen_label))
    return plan # Return the complete breeding plan.


def build_backcross_generations(
    base_name: str,
    initial_hybrid_gen_label: str, # This will be 'F1' (or whatever starts the BC series)
    pure_pop_label: str,
    num_backcross_generations: int # How many BC generations do you want (e.g., 5 for BC1, BC2, BC3, BC4, BC5)
) -> List[Tuple[str, str, str]]:
    """
    This function builds a sequential backcross generation plan.
    BC1 = initial_hybrid_gen_label x pure_pop_label
    BC2 = BC1 x pure_pop_label
    ...
    BCn = BC(n-1) x pure_pop_label

    Args:
        base_name (str): The prefix for backcross generation names (e.g., "BC").
        initial_hybrid_gen_label (str): The label of the first hybrid generation to be backcrossed
                                         (e.g., "F1").
        pure_pop_label (str): The label of the pure parental population (e.g., "P_A" or "P_B")
                              that the hybrid generations will be repeatedly crossed with.
        num_backcross_generations (int): The total number of backcross generations to create (e.g., 5 for BC1 to BC5).

    Returns:
        List[Tuple[str, str, str]]: A list of backcross generation crosses.
                                     Example: [('BC1A', 'F1', 'P_A'), ('BC2A', 'BC1A', 'P_A'), ...]
    """
    plan = []
    # The recurrent parent is always the pure population
    recurrent_parent = pure_pop_label

    # The first parent for BC1 is the initial hybrid generation (e.g., F1)
    current_hybrid_parent = initial_hybrid_gen_label

    # Iterate to create the desired number of backcross generations
    for i in range(1, num_backcross_generations + 1):
        # Construct the label for the current backcross generation, e.g., "BC1A", "BC2A"
        backcross_label = f"{base_name}{i}{pure_pop_label[-1]}"

        # Append the planned cross: (new BC generation, hybrid parent, recurrent parent)
        plan.append((backcross_label, current_hybrid_parent, recurrent_parent))

        # For the next iteration, the newly created backcross generation becomes the hybrid parent
        current_hybrid_parent = backcross_label

    return plan

In [7]:
# Cell 7: Simulating Genetic Crosses

# Assuming Individual, DiploidChromosomePair, meiosis_with_recombination,
# record_individual_genome, and record_chromatid_recombination
# are defined in your previous cells or imported.


def run_genetic_cross(
    parents_pop_A: List['Individual'],
    parents_pop_B: List['Individual'],
    offspring_count_per_mating_pair: int,
    generation_label: str,
    num_chromosomes_for_offspring: int,
    recomb_event_probabilities: Dict[int, float],
    recomb_probabilities: List[float]
) -> List['Individual']:
    """
    I use this function to simulate a genetic cross, where individuals from two distinct parental
    populations (pop_A and pop_B) mate to produce offspring.
    Each unique mating pair will produce 'offspring_count_per_mating_pair' offspring.

    Args:
        parents_pop_A (List[Individual]): The first group of parental individuals available for mating.
        parents_pop_B (List[Individual]): The second group of parental individuals available for mating.
        offspring_count_per_mating_pair (int): The number of new offspring individuals I want to generate
                                               *for each unique mating pair*.
        generation_label (str): A descriptive label for the new generation being created (e.g., "F2", "BC1A").
        num_chromosomes_for_offspring (int): The number of diploid chromosome pairs each new offspring will have.
        recomb_event_probabilities (dict): A probability distribution that dictates how many
                                           recombination events (crossovers) occur on a chromosome during meiosis.
        recomb_probabilities (List[float]): A list or array of probabilities for recombination occurring
                                           at each specific locus position along a chromosome.

    Returns:
        List[Individual]: A list containing all the newly created offspring individuals from this cross.
    """
    # Debug prints (keep these for now, they are very helpful!)
    print(f"\n--- DEBUG_CROSS for {generation_label} ---")
    print(f"DEBUG_CROSS: Parent A size entering cross: {len(parents_pop_A)}")
    print(f"DEBUG_CROSS: Parent B size entering cross: {len(parents_pop_B)}")
    print(f"DEBUG_CROSS: Offspring *per mating pair* expected: {offspring_count_per_mating_pair}")

    offspring = [] # Initialise an empty list to store the new individuals.

    # Shuffle parents to ensure random, unique pairing without replacement.
    shuffled_parent_A = random.sample(parents_pop_A, len(parents_pop_A))
    shuffled_parent_B = random.sample(parents_pop_B, len(parents_pop_B))

    # Determine the number of unique mating pairs possible
    num_mating_pairs = min(len(shuffled_parent_A), len(shuffled_parent_B))
    print(f"DEBUG_CROSS: Number of unique mating pairs formed: {num_mating_pairs}")

    # Iterate through unique pairs of parents
    for parent_A, parent_B in zip(shuffled_parent_A, shuffled_parent_B):
        # For EACH unique mating pair, create the specified number of offspring
        for _ in range(offspring_count_per_mating_pair):
            # Create a new Individual instance for the child. It inherits the number of loci.
            # --- MODIFICATION STARTS HERE ---
            # Removed 'id' argument from Individual constructor call, as it's not accepted by your Individual.__init__
            child = Individual(
                num_chromosomes=num_chromosomes_for_offspring,
                num_loci_per_chromosome=parent_A.num_loci_per_chromosome # Assuming consistent loci count
            )
            # --- MODIFICATION ENDS HERE ---

            # Assign an ID to the child after creation. We will add an ID attribute if it doesn't exist.
            # This is a common way to set an ID when it's not part of the __init__ but is desired later.
            child.id = f"{generation_label}_Ind{len(offspring) + 1}"


            # Now, for each chromosome pair the child will have, simulate the genetic inheritance.
            for chr_idx in range(num_chromosomes_for_offspring):
                diploid_pair_A = parent_A.diploid_chromosome_pairs[chr_idx]
                diploid_pair_B = parent_B.diploid_chromosome_pairs[chr_idx]

                # Generate a recombinant haploid chromatid from each parent's chromosome pair.
                haploid_from_A = meiosis_with_recombination(diploid_pair_A, recomb_event_probabilities, recomb_probabilities)
                haploid_from_B = meiosis_with_recombination(diploid_pair_B, recomb_event_probabilities, recomb_probabilities)

                # Combine these two haploid chromatids to form a new diploid chromosome pair for the child.
                child.diploid_chromosome_pairs.append(DiploidChromosomePair(haploid_from_A, haploid_from_B))

            # Record the child's genetic data using global recording functions.
            record_individual_genome(child, generation_label)
            record_chromatid_recombination(child, generation_label)

            # Add the newly created child to my list of offspring for this cross.
            offspring.append(child)

    print(f"DEBUG_CROSS: Final new_generation size created: {len(offspring)}")
    print(f"--- END DEBUG_CROSS for {generation_label} ---\n")

    return offspring

In [8]:
# Cell 8: simulate_generations function

# Assuming calculate_hi_het_for_population and population_stats are defined
# as well as record_individual_genome and record_chromatid_recombination
# from previous cells or imports.
# For example: from your_module import calculate_hi_het_for_population, population_stats, Individual


def calculate_hi_het_for_population(population: List['Individual']) -> List[Dict[str, float]]:
    data = []
    for indiv in population:
        hi = indiv.calculate_hybrid_index()
        het = indiv.calculate_heterozygosity()
        # Make sure 'id' attribute exists on your Individual objects
        data.append({'id': getattr(indiv, 'id', 'NoID'), 'HI': hi, 'HET': het})
    return data

def simulate_generations(
    initial_pop_A: list = None,
    initial_pop_B: list = None,
    generation_plan: list = None,
    num_offspring_per_cross: int = 2, # This param remains the same name here
    num_chromosomes: int = 2,
    recomb_event_probabilities: dict = None,
    recomb_probabilities: list = None,
    existing_populations: dict = None,
    verbose: bool = False,
):
    # Initialise populations dict (existing or new)
    populations = existing_populations if existing_populations is not None else {}

    # Initialise dict to store HI and HET data for each generation
    all_generations_data = {}

    # Add initial pure populations if provided, and record HI/HET for them
    # Ensure 'P_A' and 'P_B' labels are consistent with how you pass them in Cell 9
    if initial_pop_A is not None and 'P_A' not in populations:
        populations['P_A'] = initial_pop_A
        for ind in initial_pop_A:
            # Ensure these recording functions are defined globally or imported
            record_individual_genome(ind, 'P_A')
            record_chromatid_recombination(ind, 'P_A')
        all_generations_data['P_A'] = calculate_hi_het_for_population(initial_pop_A)

    if initial_pop_B is not None and 'P_B' not in populations:
        populations['P_B'] = initial_pop_B
        for ind in initial_pop_B:
            record_individual_genome(ind, 'P_B')
            record_chromatid_recombination(ind, 'P_B')
        all_generations_data['P_B'] = calculate_hi_het_for_population(initial_pop_B)

    # Check for generation plan
    if generation_plan is None:
        print("Warning: No generation plan provided. Returning existing populations.")
        return populations, all_generations_data

    # Loop over planned generations to simulate crosses
    for gen_info in generation_plan:
        if len(gen_info) == 1:
            continue  # Skip if only generation label is given (no cross info)

        gen_name = gen_info[0]
        parents_names = gen_info[1:]

        # Check parents exist
        for p_name in parents_names:
            if p_name not in populations:
                raise ValueError(f"Parent population '{p_name}' not found for generation '{gen_name}'.")

        parents_pop_A_for_cross = populations[parents_names[0]]
        parents_pop_B_for_cross = populations[parents_names[1]]

        # Run the cross to get new generation
        new_pop = run_genetic_cross(
            parents_pop_A_for_cross,
            parents_pop_B_for_cross,
            offspring_count_per_mating_pair=num_offspring_per_cross, # <--- PARAMETER NAME MATCHES CELL 7
            generation_label=gen_name, # <--- PARAMETER NAME MATCHES CELL 7
            num_chromosomes_for_offspring=num_chromosomes,
            recomb_event_probabilities=recomb_event_probabilities,
            recomb_probabilities=recomb_probabilities
        )

        # Store new population
        populations[gen_name] = new_pop

        # Calculate and store HI/HET for this generation
        all_generations_data[gen_name] = calculate_hi_het_for_population(new_pop)

        # DEBUG print statement (corrected variable names)
        print(f"DEBUG: Generated population {gen_name} with {len(new_pop)} individuals.")

        # Verbose output
        if verbose:
            stats = population_stats(new_pop)
            print(f"{gen_name} created from parents {parents_names[0]} and {parents_names[1]} | "
                  f"Count: {len(new_pop)} | Mean HI: {stats['mean_HI']:.3f} (±{stats['std_HI']:.3f}), "
                  f"Mean HET: {stats['mean_HET']:.3f} (±{stats['std_HET']:.3f})")
            print(f"Added '{gen_name}' to populations. Current population keys: {list(populations.keys())}")

    return populations, all_generations_data

### Cell 9: Main Simulation Execution

In [None]:
# Cell 9: Main Simulation Execution

# Assuming create_pure_population, build_forward_generations, build_backcross_generations
# are defined in your previous cells or imported.

# 1. Define Simulation Parameters
num_individuals_per_pure_pop = 50 # Recommended for decent population sizes
num_offspring_per_cross = 2       # Recommended for maintaining population sizes

num_chromosomes = 10
num_loci_per_chr = 20

# Recombination parameters (from previous discussions)
# Example: 1 crossover per chromosome on average
recomb_event_probabilities = {0: 0, 1: 1, 2: 0} # Example distribution
recomb_probabilities = [0.01] + [0.01]*(num_loci_per_chr - 1) # Low, uniform recombination probability

# 2. Create Initial Pure Populations (P_A and P_B)
print("Creating initial pure populations (P_A and P_B)...")

# Call YOUR create_pure_populations function for P_A (all '0' alleles)
# IMPORTANT: Ensure your create_pure_individual can handle '0' (string) correctly,
# or map it to an integer (0) if needed before passing.
pop_A = create_pure_populations(
    num_individuals_per_pure_pop,
    num_chromosomes,
    num_loci_per_chr,
    allele_type='0' # Pass '0' for P_A's allele type
)
print(f"P_A created with {len(pop_A)} individuals.")

# Call YOUR create_pure_populations function for P_B (all '1' alleles)
# IMPORTANT: Ensure your create_pure_individual can handle '1' (string) correctly,
# or map it to an integer (1) if needed before passing.
pop_B = create_pure_populations(
    num_individuals_per_pure_pop,
    num_chromosomes,
    num_loci_per_chr,
    allele_type='1' # Pass '1' for P_B's allele type
)
print(f"P_B created with {len(pop_B)} individuals.")

# Store both populations in a dictionary as expected by simulate_generations
initial_populations = {'P_A': pop_A, 'P_B': pop_B}

# 3. Define Breeding Plans
print("\nDefining breeding plans for forward and backcross generations...")

# Forward generations
# Adjusting the call to match YOUR build_forward_generations function's signature
forward_plan = build_forward_generations(
    base_name='F',
    start_gen=1, # Your function starts with 'start_gen' (e.g., F1)
    end_gen=10   # Your function ends with 'end_gen' (e.g., F20)
    # Removed parent1_label_f1 and parent2_label_f1 as your function hardcodes P_A x P_B for the first gen
)

# Backcross generations (BC1A to BC5A, and BC1B to BC5B)
# This uses the specific information you saved about your build_backcross_generations function
num_sequential_backcrosses = 5

backcross_plan_A = build_backcross_generations(
    base_name='BC',
    initial_hybrid_gen_label='F1',
    pure_pop_label='P_A',
    num_backcross_generations=num_sequential_backcrosses
)

backcross_plan_B = build_backcross_generations(
    base_name='BC',
    initial_hybrid_gen_label='F1',
    pure_pop_label='P_B',
    num_backcross_generations=num_sequential_backcrosses
)

# Combine all plans into a single comprehensive breeding plan
# Ensure the order makes sense for parent availability (e.g., F1 created before BC1A/BC1B)
full_breeding_plan = forward_plan + backcross_plan_A + backcross_plan_B
print(f"Total generations in breeding plan: {len(full_breeding_plan)}")

# 4. Simulate Generations
print("\nStarting genetic simulation across generations...")
populations, all_generations_data = simulate_generations(
    initial_pop_A=initial_populations['P_A'], # Pass P_A separately
    initial_pop_B=initial_populations['P_B'], # Pass P_B separately
    generation_plan=full_breeding_plan,
    num_offspring_per_cross=num_offspring_per_cross, # Passed directly now
    num_chromosomes=num_chromosomes,
    recomb_event_probabilities=recomb_event_probabilities,
    recomb_probabilities=recomb_probabilities,
    verbose=True, # Set to True to see detailed stats per generation
)

print("\nSimulation complete!")
print(f"Final number of populations tracked: {len(populations)}")
# You can inspect specific population sizes, e.g., print(len(populations['F1']))

Creating initial pure populations (P_A and P_B)...
P_A created with 50 individuals.
P_B created with 50 individuals.

Defining breeding plans for forward and backcross generations...
Total generations in breeding plan: 20

Starting genetic simulation across generations...

--- DEBUG_CROSS for F1 ---
DEBUG_CROSS: Parent A size entering cross: 50
DEBUG_CROSS: Parent B size entering cross: 50
DEBUG_CROSS: Offspring *per mating pair* expected: 2
DEBUG_CROSS: Number of unique mating pairs formed: 50
DEBUG_CROSS: Final new_generation size created: 100
--- END DEBUG_CROSS for F1 ---

DEBUG: Generated population F1 with 100 individuals.
F1 created from parents P_A and P_B | Count: 100 | Mean HI: 0.500 (±0.000), Mean HET: 1.000 (±0.000)
Added 'F1' to populations. Current population keys: ['P_A', 'P_B', 'F1']

--- DEBUG_CROSS for F2 ---
DEBUG_CROSS: Parent A size entering cross: 100
DEBUG_CROSS: Parent B size entering cross: 100
DEBUG_CROSS: Offspring *per mating pair* expected: 2
DEBUG_CROSS: N

In [None]:
# Check sim generation 4. - parents A and B passed in every time? or is it ocrrect f gens? 

### All gens visualisation

In [22]:
def plot_hi_het_triangle_all_generations(all_generations_data: Dict[str, List[Dict[str, Any]]],
                                          save_filename: str = None,
                                          highlight_gen: str = None):
    """
    Plots Hybrid Index (HI) versus Heterozygosity (HET) for all simulated generations.
    P_A, P_B, and F1 data are now expected to be within all_generations_data.
    The plot includes interactive hover functionality to inspect individual data points,
    and it can save the plot with a dynamically generated filename.

    This enhanced version allows highlighting a specific generation in color while
    rendering all other points in grayscale.

    Args:
        all_generations_data (dict): A dictionary where keys are generation names (e.g., 'P_A', 'F2', 'BC1A')
                                     and values are lists of dictionaries, each containing 'HI' and 'HET'
                                     values for individual organisms in that generation.
        save_filename (str, optional): Custom filename for saving the plot.
        highlight_gen (str, optional): The name of the generation to highlight (e.g., 'F4', 'BC1A').
                                       If provided, this generation's points will be in a distinct color,
                                       and all others will be in grayscale. If None, default coloring applies.
    """
    fig, ax = plt.subplots(figsize=(10, 8))

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.2)
    ax.spines['bottom'].set_linewidth(1.2)

    ax.set_xlabel("Hybrid Index (proportion M alleles)", fontsize=12)
    ax.set_ylabel("Heterozygosity (proportion heterozygous loci)", fontsize=12)

    # Define highlight and grayscale styles
    HIGHLIGHT_COLOR = 'dodgerblue' # A vibrant blue for the highlighted generation
    HIGHLIGHT_ALPHA = 1.0
    HIGHLIGHT_MARKER_SIZE = 35 # Slightly larger for emphasis

    GRAYSCALE_COLOR = 'silver' # A light gray for non-highlighted generations
    GRAYSCALE_ALPHA = 0.2
    GRAYSCALE_MARKER_SIZE = 20

    # Default colors for the specific anchor populations (if not highlighted)
    P_A_COLOR_DEFAULT = 'black'
    P_B_COLOR_DEFAULT = 'grey'
    F1_COLOR_DEFAULT = 'purple'

    generation_styles = {}
    legend_elements = []

    scatter_artists = []
    scatter_data_map = {}

    def add_gen_style_and_legend(name: str, color: str, alpha: float, marker: str, s: int):
        style = {'color': color, 'alpha': alpha, 'marker': marker, 's': s}
        generation_styles[name] = style
        legend_elements.append(Line2D([0], [0], marker=style['marker'], color='w',
                                     markerfacecolor=style['color'], markersize=8,
                                     alpha=style['alpha'], label=name))

    # --- NEW LOGIC: Extract P_A, P_B, F1 data directly from all_generations_data ---
    # These will be lists of {'id': x, 'HI': y, 'HET': z} if present
    p_a_data = all_generations_data.get('P_A')
    p_b_data = all_generations_data.get('P_B')
    f1_data = all_generations_data.get('F1')
    # --- END NEW LOGIC ---

    # Sorting key (copied from your original code, useful for consistent legend order)
    def temp_sort_key_for_assignment(label: str):
        if label == 'P_A': return (0, label)
        if label == 'P_B': return (1, label)
        if label == 'F1': return (2, label)
        match_f = re.match(r'F(\d+)', label)
        if match_f:
            return (3, int(match_f.group(1)))
        elif label.startswith('F'):
            return (3, float('inf'), label)
        match_bc = re.match(r'BC(\d+)([A-Z]?)', label)
        if match_bc:
            num_part = int(match_bc.group(1))
            suffix_part = match_bc.group(2) if match_bc.group(2) else ''
            return (4, num_part, suffix_part)
        elif label.startswith('BC'):
            return (4, float('inf'), label)
        return (5, label)

    sorted_present_gen_names = sorted(list(all_generations_data.keys()), key=temp_sort_key_for_assignment)

    # --- Apply highlighting logic to determine styles for each generation ---
    for gen_name in sorted_present_gen_names:
        if highlight_gen is not None and gen_name == highlight_gen:
            # This is the generation to be highlighted
            add_gen_style_and_legend(gen_name, HIGHLIGHT_COLOR, HIGHLIGHT_ALPHA, 'o', HIGHLIGHT_MARKER_SIZE)
        else:
            # All other generations (including P_A, P_B, F1 if not the highlighted one)
            if highlight_gen is not None: # If highlighting is active, all others are grayscale
                add_gen_style_and_legend(gen_name, GRAYSCALE_COLOR, GRAYSCALE_ALPHA, 'o', GRAYSCALE_MARKER_SIZE)
            else: # No highlighting requested, apply original color logic
                if gen_name == 'P_A':
                    add_gen_style_and_legend('P_A', P_A_COLOR_DEFAULT, 1.0, marker='o', s=50)
                elif gen_name == 'P_B':
                    add_gen_style_and_legend('P_B', P_B_COLOR_DEFAULT, 1.0, marker='o', s=50)
                elif gen_name == 'F1':
                    add_gen_style_and_legend('F1', F1_COLOR_DEFAULT, 1.0, marker='o', s=50)
                elif gen_name.startswith('F'):
                    # Original logic using a color cycler. Re-initialize or use a fixed set for F
                    colors_f = ['blue', 'red', 'green', 'purple', 'orange', 'brown', 'pink', 'teal']
                    idx = int(re.match(r'F(\d+)', gen_name).group(1)) % len(colors_f) if re.match(r'F(\d+)', gen_name) else 0
                    add_gen_style_and_legend(gen_name, colors_f[idx], 1.0, 'o', 20)
                elif gen_name.startswith('BC'):
                    # Original logic using a color cycler. Re-initialize or use a fixed set for BC
                    colors_bc = ['darkgreen', 'darkred', 'darkblue', 'darkgoldenrod', 'darkslategray', 'cornflowerblue']
                    idx = int(re.match(r'BC(\d+)', gen_name).group(1)) % len(colors_bc) if re.match(r'BC(\d+)', gen_name) else 0
                    add_gen_style_and_legend(gen_name, colors_bc[idx], 0.7, 'o', 20)
                else:
                    # Fallback for any other generations not caught by above rules
                    add_gen_style_and_legend(gen_name, 'darkgrey', 0.5, 'o', 20)


    # --- NEW PLOTTING LOOP: Plot all generations (including P_A, P_B, F1) ---
    for gen_name in sorted_present_gen_names:
        values = all_generations_data.get(gen_name, [])
        if not values: # Skip if no data for this generation
            continue

        style = generation_styles.get(gen_name, { # Fallback style if generation not explicitly styled
            'color': GRAYSCALE_COLOR if highlight_gen is not None else 'lightgray',
            'alpha': GRAYSCALE_ALPHA if highlight_gen is not None else 0.5,
            'marker': 'o',
            's': GRAYSCALE_MARKER_SIZE
        })

        hi_values = [d['HI'] for d in values if 'HI' in d and d['HI'] is not None]
        het_values = [d['HET'] for d in values if 'HET' in d and d['HET'] is not None]

        point_data_for_current_gen = []
        for i in range(len(hi_values)):
            point_data_for_current_gen.append({'gen_name': gen_name, 'hi': hi_values[i], 'het': het_values[i]})

        if hi_values:
            # Set zorder based on whether it's an anchor point or highlighted, to bring it to front
            zorder_val = 2
            if gen_name in ['P_A', 'P_B', 'F1'] or (highlight_gen is not None and gen_name == highlight_gen):
                zorder_val = 5 # Ensure anchor points and highlighted points are on top

            sc = ax.scatter(hi_values, het_values,
                            color=style['color'],
                            alpha=style['alpha'],
                            marker=style['marker'],
                            s=style['s'],
                            zorder=zorder_val) # Apply zorder
            scatter_artists.append(sc)
            scatter_data_map[sc] = point_data_for_current_gen
        else:
            print(f"Skipping plotting for {gen_name} as no valid HI/Het data was found.")

    # Draw triangle edges
    triangle_edges = [
        [(0.0, 0.0), (0.5, 1.0)],
        [(0.5, 1.0), (1.0, 0.0)],
        [(0.0, 0.0), (1.0, 0.0)]
    ]
    for (x0, y0), (x1, y1) in triangle_edges:
        ax.plot([x0, x1], [y0, y1], linestyle='-', color='gray', linewidth=1.5, alpha=0.7)

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_aspect('equal', adjustable='box')

    plt.subplots_adjust(left=0.1, right=0.75, top=0.9, bottom=0.1)

    # Add legend
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1), fontsize=10, frameon=False)

    # Hover annotation setup
    annot = ax.annotate("", xy=(0, 0), xytext=(15, 15), textcoords="offset points",
                             bbox=dict(boxstyle="round", fc="w"),
                             arrowprops=dict(arrowstyle="->"))
    annot.set_visible(False)

    def update_annot(scatter, ind):
        pos = scatter.get_offsets()[ind["ind"][0]]
        annot.xy = pos
        texts = []
        for idx in ind["ind"]:
            if idx < len(scatter_data_map[scatter]):
                d = scatter_data_map[scatter][idx]
                texts.append(f"{d['gen_name']}:\nHI = {d['hi']:.3f}\nHET = {d['het']:.3f}")
        annot.set_text("\n\n".join(texts))
        facecolor = scatter.get_facecolor()
        if len(facecolor) > 0:
            annot.get_bbox_patch().set_facecolor(facecolor[0])
        else:
            annot.get_bbox_patch().set_facecolor('lightgray')
        annot.get_bbox_patch().set_alpha(0.8)

    def hover(event):
        visible = annot.get_visible()
        if event.inaxes == ax:
            for scatter in scatter_artists:
                cont, ind = scatter.contains(event)
                if cont:
                    update_annot(scatter, ind)
                    annot.set_visible(True)
                    fig.canvas.draw_idle()
                    return
        if visible:
            annot.set_visible(False)
            fig.canvas.draw_idle()

    fig.canvas.mpl_connect("motion_notify_event", hover)

    if save_filename: # This block now comes first for saving
        plt.savefig(save_filename, bbox_inches='tight')
        print(f"Plot saved to {save_filename}")

    plt.show() # This line is now here, after saving

In [23]:
# Call the plotting function and save it
plot_hi_het_triangle_all_generations(
    all_generations_data=all_generations_data,
    save_filename='C:\\Users\\sophi\\Jupyter_projects\\Hybrid_Code\\output_data\\triangle_plot_images\\all_gens.png'
)

Plot saved to C:\Users\sophi\Jupyter_projects\Hybrid_Code\output_data\triangle_plot_images\all_gens.png


### Highlight Gen - triangle plot

In [26]:
# Example usage (assuming your data variables are already set up):

# To plot F4 in colour (dodgerblue) and all other generations in grayscale:
plot_hi_het_triangle_all_generations(
    all_generations_data=all_generations_data,
    highlight_gen='F4', # This will make F4 points blue and others gray
    save_filename='C:\\Users\\sophi\\Jupyter_projects\\Hybrid_Code\\output_data\\triangle_plot_images\\F4_highlighted.png' 
)

Plot saved to C:\Users\sophi\Jupyter_projects\Hybrid_Code\output_data\triangle_plot_images\F4_highlighted.png


### Plotting mean Het and Hi

In [10]:
def plot_mean_hi_het_triangle_all_generations(all_generations_data: Dict[str, List[Dict[str, Any]]],
                                               show_individual_points: bool = False,
                                               save_filename: str = None):
    """
    Plots the MEAN Hybrid Index (HI) versus MEAN Heterozygosity (HET) for all
    simulated generations, including P_A, P_B, and F1 if present in all_generations_data.
    Can optionally show individual points in a lighter shade.

    Args:
        all_generations_data (dict): A dictionary where keys are generation names (e.g., 'P_A', 'F2', 'BC1A')
                                     and values are lists of dictionaries, each containing 'HI' and 'HET'
                                     values for individual organisms in that generation.
        show_individual_points (bool): If True, plots all individual points in a light grayscale
                                       behind the mean points. Defaults to False.
        save_filename (str, optional): Custom filename for saving the plot.
    """
    fig, ax = plt.subplots(figsize=(10, 8))

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.2)
    ax.spines['bottom'].set_linewidth(1.2)

    ax.set_xlabel("Mean Hybrid Index (proportion M alleles)", fontsize=12)
    ax.set_ylabel("Mean Heterozygosity (proportion heterozygous loci)", fontsize=12)

    # Define colors for generations
    default_colors = itertools.cycle([
        'blue', 'red', 'green', 'purple', 'orange', 'brown', 'pink', 'teal',
        'darkviolet', 'magenta', 'cyan', 'lime', 'gold', 'navy', 'maroon',
        'darkgreen', 'darkred', 'darkblue', 'darkgoldenrod', 'darkslategray',
        'cornflowerblue', 'olivedrab', 'peru', 'rosybrown', 'salmon',
        'seagreen', 'sienna', 'darkkhaki', 'mediumorchid', 'lightcoral'
    ])
    color_map = {} # To assign a consistent color to each generation

    legend_elements = []

    # Helper function to get or assign a color
    def get_color(gen_name):
        if gen_name not in color_map:
            # Assign specific colors for anchor points if they haven't been assigned yet
            if gen_name == 'P_A': color_map[gen_name] = 'black'
            elif gen_name == 'P_B': color_map[gen_name] = 'grey'
            elif gen_name == 'F1': color_map[gen_name] = 'purple'
            else: color_map[gen_name] = next(default_colors)
        return color_map[gen_name]

    # Sort generations for consistent plotting and legend order
    def sort_key(label: str):
        if label == 'P_A': return (0, label)
        if label == 'P_B': return (1, label)
        if label == 'F1': return (2, label)
        match_f = re.match(r'F(\d+)', label)
        if match_f:
            return (3, int(match_f.group(1)))
        elif label.startswith('F'):
            return (3, float('inf'), label)
        match_bc = re.match(r'BC(\d+)([A-Z]?)', label)
        if match_bc:
            num_part = int(match_bc.group(1))
            suffix_part = match_bc.group(2) if match_bc.group(2) else ''
            return (4, num_part, suffix_part)
        elif label.startswith('BC'):
            return (4, float('inf'), label)
        return (5, label)

    sorted_gen_names = sorted(list(all_generations_data.keys()), key=sort_key)

    # Plot individual points in background if requested
    if show_individual_points:
        for gen_name in sorted_gen_names:
            values = all_generations_data.get(gen_name, [])
            if not values: continue

            hi_values = [d['HI'] for d in values if 'HI' in d and d['HI'] is not None]
            het_values = [d['HET'] for d in values if 'HET']
            if hi_values:
                # Use a very light alpha for individual points
                ax.scatter(hi_values, het_values, color='lightgray', alpha=0.1, s=10, zorder=1)
    
    # Plot mean points for each generation (including P_A, P_B, F1)
    for gen_name in sorted_gen_names:
        values = all_generations_data.get(gen_name, [])
        if not values: # Skip if no data for this generation
            print(f"Skipping mean plot for {gen_name} due to missing data.")
            continue

        hi_values = [d['HI'] for d in values if 'HI' in d and d['HI'] is not None]
        het_values = [d['HET'] for d in values if 'HET'] # Corrected: 'HET' should be checked for existence not just truthiness
        if not het_values: # Additional check for empty het_values list after filtering
            het_values = [d['HET'] for d in values if 'HET' in d and d['HET'] is not None] # Ensure consistent filtering
            
        if hi_values and het_values: # Recheck after filtering to avoid errors
            mean_hi = np.mean(hi_values)
            mean_het = np.mean(het_values)
            
            color = get_color(gen_name)
            
            # Use a slightly larger marker for the mean points
            marker_size = 80
            if gen_name in ['P_A', 'P_B', 'F1']: # Anchor points might be slightly larger or distinct
                marker_size = 100

            ax.scatter(mean_hi, mean_het, color=color, s=marker_size, marker='o', edgecolors='black', linewidth=1.5, zorder=3)
            
            # Add text label next to the mean point for all generations
            ax.text(mean_hi + 0.01, mean_het + 0.01, gen_name, fontsize=9, color=color, ha='left', va='bottom', zorder=4)
            
            # Add to legend only if not already covered by text labels or if it's an anchor point
            if gen_name not in [el.get_label() for el in legend_elements]:
                 legend_elements.append(Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8, label=gen_name))
        else:
            print(f"Skipping mean plot for {gen_name} as no valid HI/HET data was found to calculate mean.")

    # Draw triangle edges
    triangle_edges = [
        [(0.0, 0.0), (0.5, 1.0)],
        [(0.5, 1.0), (1.0, 0.0)],
        [(0.0, 0.0), (1.0, 0.0)]
    ]
    for (x0, y0), (x1, y1) in triangle_edges:
        ax.plot([x0, x1], [y0, y1], linestyle='-', color='gray', linewidth=1.5, alpha=0.7)

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_aspect('equal', adjustable='box')

    plt.subplots_adjust(left=0.1, right=0.75, top=0.9, bottom=0.1)

    # Add legend
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1), fontsize=10, frameon=False)

    if save_filename:
        plt.savefig(save_filename, bbox_inches='tight')
        print(f"Plot saved to {save_filename}")

    plt.show()

In [11]:
# Example call
plot_mean_hi_het_triangle_all_generations(
    all_generations_data=all_generations_data,
    show_individual_points=True, # Set to True to see background points, False for only means
    save_filename='C:\\Users\\sophi\\Jupyter_projects\\Hybrid_Code\\output_data\\triangle_plot_images\\mean_gens.png'
)

Plot saved to C:\Users\sophi\Jupyter_projects\Hybrid_Code\output_data\triangle_plot_images\mean_gens.png


### Debugging

In [None]:
# Assuming all_generations_data is already defined in your environment
# (i.e., you've run your simulation or data loading code before this)

generations_to_inspect = ['F2', 'F3', 'F4', 'F5'] # Add more if you have F6, F7 etc.

print("--- Inspecting Individual HI and HET Data for F Generations ---")

for gen_name in generations_to_inspect:
    print(f"\n--- Generation: {gen_name} ---")
    
    if gen_name not in all_generations_data:
        print(f"  {gen_name} not found in all_generations_data.")
        continue

    values = all_generations_data[gen_name]

    if not values:
        print(f"  No data found for {gen_name}.")
        continue

    # Extract HI and HET values
    hi_values = [d['HI'] for d in values if 'HI' in d and d['HI'] is not None]
    het_values = [d['HET'] for d in values if 'HET' in d and d['HET'] is not None]

    if not hi_values or not het_values:
        print(f"  No valid HI or HET values found for {gen_name}.")
        continue

    # Print first 5 individuals' data (or fewer if less than 5)
    print(f"  First {min(5, len(values))} individuals (ID, HI, HET):")
    for i in range(min(5, len(values))):
        individual_data = values[i]
        # Safely get ID if it exists
        individual_id = individual_data.get('id', 'N/A') 
        print(f"    ID: {individual_id}, HI: {individual_data.get('HI', 'N/A'):.4f}, HET: {individual_data.get('HET', 'N/A'):.4f}")

    # Calculate and print statistics
    mean_hi = np.mean(hi_values)
    mean_het = np.mean(het_values)
    
    std_hi = np.std(hi_values) if len(hi_values) > 1 else 0.0 # Std dev requires more than 1 point
    std_het = np.std(het_values) if len(het_values) > 1 else 0.0

    print(f"\n  Statistics for {gen_name}:")
    print(f"    Mean HI:  {mean_hi:.4f} (Std Dev: {std_hi:.4f})")
    print(f"    Mean HET: {mean_het:.4f} (Std Dev: {std_het:.4f})")
    print(f"    Number of individuals: {len(values)}")