In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import json
import random
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass, asdict

# Utility class for JSON serialization
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

# %%
@dataclass
class Marker:
    """
    Represents a single genetic marker on a chromosome.

    Attributes:
        id (str): Unique identifier for the marker (e.g., 'SNP1').
        physical_position (float): Physical position of the marker on the chromosome (e.g., in base pairs).
        genetic_position (float): Genetic position of the marker on the chromosome (e.g., in centiMorgans).
    """
    id: str
    physical_position: float
    genetic_position: float

# %%
@dataclass
class Chromosome:
    """
    Represents a single chromosome with its physical and genetic properties.

    Attributes:
        id (int): Unique integer identifier for the chromosome.
        physical_length_bp (float): Total physical length of the chromosome in base pairs.
        genetic_length_cM (float): Total genetic length of the chromosome in centiMorgans.
        markers (List[Marker]): A list of Marker objects located on this chromosome.
    """
    id: int
    physical_length_bp: float
    genetic_length_cM: float
    markers: List[Marker]

# %%
class Individual:
    """
    Represents a diploid individual with a pair of homologous chromosomes.

    Each individual carries two copies of each chromosome: one inherited from its
    biological mother and one from its biological father. These are stored in
    `maternal_chroms` and `paternal_chroms` respectively.

    Attributes:
        id (str): A unique identifier for the individual.
        chromosomes (List[Chromosome]): A list of Chromosome objects defining the structure
                                        of the genome for this individual.
        maternal_chroms (Dict[int, Dict]): Stores data for chromosomes inherited from the
                                            individual's biological mother (keyed by chromosome ID).
                                            Each entry contains 'alleles' and 'positions'.
        paternal_chroms (Dict[int, Dict]): Stores data for chromosomes inherited from the
                                            individual's biological father (keyed by chromosome ID).
                                            Each entry contains 'alleles' and 'positions'.
    """
    def __init__(self, id: str, chromosomes: List[Chromosome]):
        self.id = id
        self.chromosomes = chromosomes
        # Only 'alleles' and 'positions' will be stored. Ancestry is inferred from allele value.
        self.maternal_chroms = {}
        self.paternal_chroms = {}

    def initialise_ancestral_chromosomes(self, ancestry: int):
        """
        Initialises an individual's chromosomes with ancestral alleles based on a
        specified pure population origin.

        For admixture simulations, 'ancestry' (0 or 2) typically denotes the pure
        population from which this individual originated (e.g., 0 for Population A,
        2 for Population B). All markers on both homologous chromosomes of this
        founding individual are set to its `ancestry` value. There is no concept
        of continuous ancestry blocks; ancestry is only defined at marker points.

        Args:
            ancestry (int): The ancestral population ID (0 or 2) for this individual.
        """
        for chrom in self.chromosomes:
            maternal_alleles = []
            paternal_alleles = []

            for _ in chrom.markers:
                # For initial pure ancestors, all alleles come from the designated pure population.
                # The allele value itself (0 or 2) represents the ancestral origin at this marker.
                maternal_alleles.append(ancestry)
                paternal_alleles.append(ancestry)

            self.maternal_chroms[chrom.id] = {
                'alleles': maternal_alleles,
                'positions': [m.physical_position for m in chrom.markers]
            }
            self.paternal_chroms[chrom.id] = {
                'alleles': paternal_alleles,
                'positions': [m.physical_position for m in chrom.markers]
            }

    def produce_gamete(self,
                         chrom: Chromosome,
                         use_poisson: bool,
                         fixed_crossover_count_for_chrom: int,
                         custom_crossover_counts_dist: Optional[List[int]],
                         custom_crossover_probs_dist: Optional[List[float]],
                         sim_detected_crossovers: List[Dict],
                         sim_all_true_crossovers: List[Dict]) -> List[int]:
        """
        Simulates meiosis for a single chromosome pair from this individual to produce
        one haploid chromosome for a gamete.

        This process involves crossing over between the individual's own maternal
        and paternal homologous chromosomes. The resulting gamete chromosome will
        be a mosaic of segments from these two original chromosomes, as reflected
        by the allele values (0 or 2) at each marker.

        Args:
            chrom (Chromosome): The Chromosome object for which a gamete is being produced.
            use_poisson (bool): If True, the number of crossovers is determined by a Poisson
                                process based on the chromosome's genetic length.
            fixed_crossover_count_for_chrom (int): The exact number of crossovers to simulate
                                                    for this chromosome if `use_poisson` is False
                                                    AND no custom distribution is provided.
            custom_crossover_counts_dist (Optional[List[int]]): List of possible crossover counts
                                                                  for a custom discrete distribution.
            custom_crossover_probs_dist (Optional[List[float]]): List of probabilities for
                                                                   `custom_crossover_counts_dist`.
            sim_detected_crossovers (List[Dict]): A list to append details of crossovers
                                                    that result in an observable switch in allele
                                                    between markers.
            sim_all_true_crossovers (List[Dict]): A list to append details of *all* crossovers
                                                    that physically occurred, regardless of detectability.

        Returns:
            List[int]: The alleles for each marker on the recombined gamete chromosome.
                        (The allele value itself represents ancestral origin at that marker).
        """
        # Retrieve the alleles from this individual's maternal and paternal homologous chromosomes
        parental_mat_alleles = self.maternal_chroms[chrom.id]['alleles']
        parental_pat_alleles = self.paternal_chroms[chrom.id]['alleles']

        gamete_alleles = []

        crossover_genetic_positions = []
        if use_poisson:
            # Expected number of crossovers is genetic length in Morgans (cM / 100)
            expected_crossovers = chrom.genetic_length_cM / 100
            n_crossovers = np.random.poisson(expected_crossovers)
            if n_crossovers > 0:
                # Crossover positions are randomly drawn from a uniform distribution
                # across the chromosome's genetic length (0 to genetic_length_cM).
                crossover_genetic_positions = sorted(np.random.uniform(0, chrom.genetic_length_cM, n_crossovers))
        # Use custom distribution if provided and Poisson is off
        elif custom_crossover_counts_dist is not None and custom_crossover_probs_dist is not None:
            # Draw the number of crossovers from the custom discrete distribution
            n_crossovers = np.random.choice(custom_crossover_counts_dist, p=custom_crossover_probs_dist)
            if n_crossovers < 0:
                raise ValueError("Number of crossovers drawn from custom distribution cannot be negative.")
            if n_crossovers > 0:
                crossover_genetic_positions = sorted(np.random.uniform(0, chrom.genetic_length_cM, n_crossovers))
        else: # Fallback to fixed count (which could be 0 by default if not specified)
            n_crossovers = fixed_crossover_count_for_chrom
            if n_crossovers < 0:
                raise ValueError("Fixed crossover count cannot be negative.")
            if n_crossovers > 0:
                crossover_genetic_positions = sorted(np.random.uniform(0, chrom.genetic_length_cM, n_crossovers))

        # Log all true crossover events
        for pos in crossover_genetic_positions:
            sim_all_true_crossovers.append({
                'chromosome': chrom.id,
                'genetic_position': pos,
                'parent_id': self.id # Track which parent this crossover happened in
            })

        # Track which of the individual's homologous chromosomes (0: maternal, 1: paternal)
        # is currently contributing segments to the gamete. Always starts from the maternal.
        current_source_chromosome_idx = 0 # 0 for maternal, 1 for paternal

        # Store the allele from the *previous* marker in the gamete to detect switches
        previous_marker_allele_in_gamete: Optional[int] = None

        # Iterate through each marker to determine its origin in the gamete
        next_crossover_event_idx = 0
        for marker_idx, marker in enumerate(chrom.markers):
            # Advance through the sorted crossover positions. If a crossover occurred
            # before or at the current marker's genetic position, it affects this marker
            # and subsequent markers until the next crossover.

            # Keep track of the source chromosome *before* processing this marker,
            # in case a crossover happens right at its position.
            source_before_processing_marker = current_source_chromosome_idx

            while next_crossover_event_idx < len(crossover_genetic_positions) and \
                  (crossover_genetic_positions[next_crossover_event_idx] <= marker.genetic_position):

                # A crossover has occurred, so switch the source chromosome
                current_source_chromosome_idx = 1 - current_source_chromosome_idx
                next_crossover_event_idx += 1 # Move to the next crossover

            # Determine the allele for the current marker based on the final source
            if current_source_chromosome_idx == 0: # From this individual's maternal chromosome
                current_marker_allele = parental_mat_alleles[marker_idx]
            else: # From this individual's paternal chromosome
                current_marker_allele = parental_pat_alleles[marker_idx]

            gamete_alleles.append(current_marker_allele)

            # Detect observable crossovers: A switch in allele value between markers
            # This accounts for the scenario where a crossover occurs between the previous
            # marker and the current one, leading to a change in inherited allele (and thus ancestry).
            if marker_idx > 0 and previous_marker_allele_in_gamete is not None:
                if current_marker_allele != previous_marker_allele_in_gamete:
                    # An observable switch in allele value (ancestry) has occurred between the
                    # previous marker and the current marker. Log this as a "detected" crossover.
                    # The genetic position of this "detected" event is approximated as the current marker's pos.
                    sim_detected_crossovers.append({
                        'chromosome': chrom.id,
                        'genetic_position_approx': marker.genetic_position,
                        'physical_position_approx': marker.physical_position,
                        'marker_interval_detection': f"{chrom.markers[marker_idx-1].id}-{marker.id}",
                        'parent_id': self.id # Track which parent this detected crossover came from
                    })

            previous_marker_allele_in_gamete = current_marker_allele

        return gamete_alleles

# %%
class RecombinationSimulator:
    """
    The main class for setting up and running genetic recombination simulations.

    It manages chromosome definitions, marker placement, allele frequency handling,
    and the simulation of gamete formation and offspring creation through crosses.
    """

    def __init__(self,
                 n_chromosomes: int = 4,
                 chromosome_sizes: List[float] = None,
                 n_markers: int = 10,
                 marker_distribution: str = 'uniform',  # 'uniform' or 'random'
                 use_poisson: bool = True,
                 use_centimorgan: bool = True,
                 allele_freq_file: str = None,
                 random_seed: Optional[int] = None,
                 fixed_crossover_config: Optional[Union[int, Dict[int, int]]] = None,
                 custom_crossover_counts: Optional[List[int]] = None,
                 custom_crossover_probs: Optional[List[float]] = None):

        self.n_chromosomes = n_chromosomes
        self.chromosome_sizes = chromosome_sizes or [1.0] * n_chromosomes
        self.n_markers = n_markers
        self.marker_distribution = marker_distribution
        self.use_poisson = use_poisson
        self.use_centimorgan = use_centimorgan
        self.allele_freq_file = allele_freq_file
        
        # Store fixed crossover configuration and parse it
        self.fixed_crossover_config = fixed_crossover_config
        self._fixed_crossover_uniform_count: Optional[int] = None
        self._fixed_crossover_per_chrom_counts: Optional[Dict[int, int]] = None

        if isinstance(self.fixed_crossover_config, int):
            self._fixed_crossover_uniform_count = self.fixed_crossover_config
            if self._fixed_crossover_uniform_count < 0:
                raise ValueError("Fixed uniform crossover count cannot be negative.")
        elif isinstance(self.fixed_crossover_config, dict):
            self._fixed_crossover_per_chrom_counts = self.fixed_crossover_config
            if any(count < 0 for count in self._fixed_crossover_per_chrom_counts.values()):
                raise ValueError("Fixed per-chromosome crossover counts cannot be negative.")
        elif self.fixed_crossover_config is not None:
            print("Warning: fixed_crossover_config was provided in an unsupported format. It will be ignored.")
            self.fixed_crossover_config = None

        # Store custom discrete distribution for crossover counts
        self._custom_crossover_counts = custom_crossover_counts
        self._custom_crossover_probs = custom_crossover_probs
        if self._custom_crossover_counts is not None and self._custom_crossover_probs is not None:
            if len(self._custom_crossover_counts) != len(self._custom_crossover_probs):
                raise ValueError("Custom crossover counts and probabilities lists must have the same length.")
            # Note: Probability normalization is now handled during argument parsing in main()


        # Set random seeds for reproducibility if provided
        self.random_seed = random_seed # Store the seed passed in
        if self.random_seed is not None:
            random.seed(self.random_seed)
            np.random.seed(self.random_seed)
            print(f"Set random seed to {self.random_seed}")

        # Load allele frequencies from a file or prepare to generate random ones
        self.allele_frequencies = self.load_allele_frequencies()

        # Normalise chromosome sizes so they sum to 1.0, ensuring proportional allocation
        total_size = sum(self.chromosome_sizes)
        if total_size <= 0:
            raise ValueError("Total chromosome size must be greater than zero.")
        self.chromosome_sizes = [s/total_size for s in self.chromosome_sizes]

        self.chromosomes = []
        self.all_true_crossovers = [] # Stores all actual crossovers that occurred during gamete production (from Poisson or fixed)
        self.detected_crossovers = [] # Stores only those crossovers that were detectable by a change in marker alleles
        self.blind_spot_crossovers = [] # Stores crossovers that occurred but were not detected by markers
        self.current_simulation_crossovers_info = {} # To store info for the specific cross just run

    def load_allele_frequencies(self) -> Dict[str, Dict[int, float]]:
        """
        Loads allele frequencies from a CSV file if provided, otherwise generates random ones.

        CSV format: marker_id, allele_0_freq_pop1, allele_0_freq_pop2
        Returns:
            Dict[str, Dict[int, float]]: A dictionary where keys are marker IDs,
                                         and values are dictionaries containing allele 0 frequencies
                                         for population 0 and population 2.
                                         e.g., {'marker_id_1': {0: 0.9, 2: 0.1}}
        """
        if self.allele_freq_file:
            try:
                df = pd.read_csv(self.allele_freq_file)
                if not all(col in df.columns for col in ['marker_id', 'allele_0_freq_pop1', 'allele_0_freq_pop2']):
                    raise ValueError("Allele frequency CSV must contain 'marker_id', 'allele_0_freq_pop1', 'allele_0_freq_pop2' columns.")
                
                allele_freqs = {}
                for _, row in df.iterrows():
                    # Ensure frequencies are between 0 and 1
                    freq_pop1 = max(0.0, min(1.0, row['allele_0_freq_pop1']))
                    freq_pop2 = max(0.0, min(1.0, row['allele_0_freq_pop2']))
                    allele_freqs[row['marker_id']] = {0: freq_pop1, 2: freq_pop2}
                print(f"Loaded allele frequencies from {self.allele_freq_file}")
                return allele_freqs
            except FileNotFoundError:
                print(f"Warning: Allele frequency file '{self.allele_freq_file}' not found. Generating random frequencies.")
                return self._generate_random_allele_frequencies()
            except Exception as e:
                print(f"Error loading allele frequency file: {e}. Generating random frequencies.")
                return self._generate_random_allele_frequencies()
        else:
            return self._generate_random_allele_frequencies()

    def _generate_random_allele_frequencies(self) -> Dict[str, Dict[int, float]]:
        """Generates random allele frequencies for a default set of markers."""
        print("Generating random allele frequencies...")
        allele_freqs = {}
        for i in range(self.n_chromosomes * self.n_markers): # Max possible markers for initial setup
            marker_id = f"marker_{i+1}"
            # Random frequencies between 0.1 and 0.9 to ensure some polymorphism
            allele_freqs[marker_id] = {
                0: np.random.uniform(0.1, 0.9), # Allele 0 freq in Pop A
                2: np.random.uniform(0.1, 0.9)  # Allele 0 freq in Pop B
            }
        return allele_freqs


    def create_chromosomes(self, base_length: float = 100_000_000, base_genetic_length: float = 100.0):
        """
        Creates Chromosome objects based on configured sizes.

        Args:
            base_length (float): The physical length (in base pairs) for a chromosome
                                 with a relative size of 1.0.
            base_genetic_length (float): The genetic length (in centiMorgans) for a chromosome
                                         with a relative size of 1.0.
        """
        self.chromosomes = []
        for i in range(self.n_chromosomes):
            relative_size = self.chromosome_sizes[i]
            physical_len = base_length * relative_size
            genetic_len = base_genetic_length * relative_size

            # If not using cM model, genetic length is 0 for crossover purposes
            if not self.use_centimorgan:
                genetic_len = 0.0 # This ensures no recombination if cM model is off

            self.chromosomes.append(
                Chromosome(id=i+1, physical_length_bp=physical_len, genetic_length_cM=genetic_len, markers=[])
            )

    def assign_markers_to_chromosomes(self):
        """
        Assigns markers to the created chromosomes based on the specified distribution.
        Markers are evenly distributed or randomly placed.
        """
        total_markers_per_chromosome = self.n_markers // self.n_chromosomes
        remaining_markers = self.n_markers % self.n_chromosomes

        current_marker_id_counter = 1

        for chrom in self.chromosomes:
            num_markers_on_this_chrom = total_markers_per_chromosome
            if remaining_markers > 0:
                num_markers_on_this_chrom += 1
                remaining_markers -= 1

            chrom.markers = [] # Clear any existing markers

            if self.marker_distribution == 'uniform':
                # Distribute markers evenly along physical and genetic lengths
                for i in range(num_markers_on_this_chrom):
                    # +1 ensures markers are not at 0, and distributes over length
                    physical_pos = chrom.physical_length_bp * ((i + 1) / (num_markers_on_this_chrom + 1))
                    genetic_pos = chrom.genetic_length_cM * ((i + 1) / (num_markers_on_this_chrom + 1))
                    chrom.markers.append(
                        Marker(id=f"chr{chrom.id}_marker_{current_marker_id_counter}",
                               physical_position=physical_pos,
                               genetic_position=genetic_pos)
                    )
                    current_marker_id_counter += 1
            elif self.marker_distribution == 'random':
                # Place markers randomly along physical and genetic lengths
                for _ in range(num_markers_on_this_chrom):
                    physical_pos = np.random.uniform(0, chrom.physical_length_bp)
                    genetic_pos = np.random.uniform(0, chrom.genetic_length_cM)
                    chrom.markers.append(
                        Marker(id=f"chr{chrom.id}_marker_{current_marker_id_counter}",
                               physical_position=physical_pos,
                               genetic_position=genetic_pos)
                    )
                    current_marker_id_counter += 1
                # Sort markers by physical position after random placement
                chrom.markers.sort(key=lambda m: m.physical_position)
            

    def calculate_hybrid_index(self, individual: Individual) -> float:
        """
        Calculates the hybrid index (proportion of ancestry from Population 0) for an individual.

        Args:
            individual (Individual): The individual to calculate the hybrid index for.

        Returns:
            float: The proportion of alleles originating from Population 0 (ancestry=0).
        """
        total_markers = 0
        pop0_alleles = 0
        for chrom_id in individual.maternal_chroms:
            m_alleles = individual.maternal_chroms[chrom_id]['alleles']
            p_alleles = individual.paternal_chroms[chrom_id]['alleles']

            total_markers += len(m_alleles) + len(p_alleles)
            pop0_alleles += sum(1 for allele in m_alleles if allele == 0)
            pop0_alleles += sum(1 for allele in p_alleles if allele == 0)

        if total_markers == 0:
            return 0.0 # Avoid division by zero
        return pop0_alleles / total_markers

    def calculate_heterozygosity(self, individual: Individual) -> float:
        """
        Calculates the average heterozygosity (proportion of heterozygous markers) for an individual.

        A marker is considered heterozygous if the alleles on the maternal and paternal
        homologous chromosomes at that marker position are different (e.g., one is 0 and the other is 2).

        Args:
            individual (Individual): The individual to calculate heterozygosity for.

        Returns:
            float: The proportion of markers that are heterozygous.
        """
        total_markers = 0
        heterozygous_markers = 0
        for chrom_id in individual.maternal_chroms:
            m_alleles = individual.maternal_chroms[chrom_id]['alleles']
            p_alleles = individual.paternal_chroms[chrom_id]['alleles']

            for i in range(len(m_alleles)):
                total_markers += 1
                if m_alleles[i] != p_alleles[i]:
                    heterozygous_markers += 1

        if total_markers == 0:
            return 0.0 # Avoid division by zero
        return heterozygous_markers / total_markers

    def print_summary(self):
        """
        Prints a summary of the current simulation parameters and recombination events.
        This focuses on the most recently simulated cross.
        """
        print("\n--- Simulation Summary ---")
        print(f"Number of Chromosomes: {self.n_chromosomes}")
        print(f"Total Markers: {self.n_markers}")
        print(f"Marker Distribution: {self.marker_distribution}")
        print(f"Crossover Model: {'Poisson' if self.use_poisson else 'Fixed/Custom'}")
        if not self.use_poisson:
            if self._custom_crossover_counts is not None:
                dist_str = ', '.join([f"{c}:{p:.2f}" for c, p in zip(self._custom_crossover_counts, self._custom_crossover_probs)])
                print(f"  Custom Crossover Distribution (counts:probabilities): {dist_str}")
            elif self._fixed_crossover_uniform_count is not None:
                print(f"  Fixed Crossovers per Chromosome (Uniform): {self._fixed_crossover_uniform_count}")
            elif self._fixed_crossover_per_chrom_counts:
                print(f"  Fixed Crossovers per Chromosome (Per Chrom ID): {self._fixed_crossover_per_chrom_counts}")
            else:
                print(f"  No Crossovers (Fixed or Custom not set and Poisson is off)")
        print(f"Using cM distances for recombination: {self.use_centimorgan}")
        print(f"Random Seed: {'Not set' if self.random_seed is None else 'Set to ' + str(self.random_seed)}") # Corrected check
        
        print("\n--- Recombination Events (Last Cross) ---")
        print(f"Total true crossovers (occurred): {len(self.all_true_crossovers)}")
        print(f"Total detected crossovers (observable): {len(self.detected_crossovers)}")
        print(f"Total blind spot crossovers (undetected): {len(self.blind_spot_crossovers)}")

        if self.all_true_crossovers:
            df_true = pd.DataFrame(self.all_true_crossovers)
            print("\nTrue Crossovers by Chromosome:")
            print(df_true['chromosome'].value_counts().sort_index())

        if self.detected_crossovers:
            df_detected = pd.DataFrame(self.detected_crossovers)
            print("\nDetected Crossovers by Chromosome:")
            print(df_detected['chromosome'].value_counts().sort_index())


    def simulate_recombination(self, parent1: Individual, parent2: Individual) -> Individual:
        """
        Simulates a genetic cross between two parent individuals to produce one offspring.

        This involves:
        1. Each parent producing a haploid gamete through meiosis.
        2. Combining the two gametes (one from each parent) to form a new diploid offspring.
        All true crossover events are tracked, as are those detected by marker switches.

        Args:
            parent1 (Individual): The first parent individual.
            parent2 (Individual): The second parent individual.

        Returns:
            Individual: A new `Individual` object representing the offspring.
        """
        offspring = Individual(f"offspring_{random.randint(1000,9999)}", self.chromosomes)

        # Clear previous simulation's crossover data for a clean slate for the new cross
        self.all_true_crossovers = []
        self.detected_crossovers = []
        self.blind_spot_crossovers = []
        self.current_simulation_crossovers_info = {'parent1': {'all_true': [], 'detected': []},
                                                    'parent2': {'all_true': [], 'detected': []}}

        print(f"  Crossing {parent1.id} with {parent2.id} to produce offspring {offspring.id}...")

        for chrom in self.chromosomes:
            # Determine fixed crossover count for this chromosome, defaulting to 0 if not specified
            # This is only a fallback if use_poisson is False AND no custom distribution is provided
            fixed_count_for_this_chrom = 0 
            if self._fixed_crossover_uniform_count is not None:
                fixed_count_for_this_chrom = self._fixed_crossover_uniform_count
            elif self._fixed_crossover_per_chrom_counts is not None:
                fixed_count_for_this_chrom = self._fixed_crossover_per_chrom_counts.get(chrom.id, 0)


            # Parent 1 produces its gamete for this chromosome
            p1_gamete_detected_crossovers = []
            p1_gamete_all_true_crossovers = []

            p1_gamete_alleles = parent1.produce_gamete(
                chrom,
                self.use_poisson,
                fixed_count_for_this_chrom,
                self._custom_crossover_counts,
                self._custom_crossover_probs,
                p1_gamete_detected_crossovers,
                p1_gamete_all_true_crossovers
            )
            self.current_simulation_crossovers_info['parent1']['all_true'].extend(p1_gamete_all_true_crossovers)
            self.current_simulation_crossovers_info['parent1']['detected'].extend(p1_gamete_detected_crossovers)

            # Parent 2 produces its gamete for this chromosome
            p2_gamete_detected_crossovers = []
            p2_gamete_all_true_crossovers = []

            p2_gamete_alleles = parent2.produce_gamete(
                chrom,
                self.use_poisson,
                fixed_count_for_this_chrom,
                self._custom_crossover_counts,
                self._custom_crossover_probs,
                p2_gamete_detected_crossovers,
                p2_gamete_all_true_crossovers
            )
            self.current_simulation_crossovers_info['parent2']['all_true'].extend(p2_gamete_all_true_crossovers)
            self.current_simulation_crossovers_info['parent2']['detected'].extend(p2_gamete_detected_crossovers)

            # The offspring inherits one gamete chromosome from parent1 (becomes its maternal)
            # and one from parent2 (becomes its paternal). The assignment is random.
            if random.random() < 0.5:
                # Parent 1's gamete becomes the offspring's maternal chromosome
                offspring.maternal_chroms[chrom.id] = {
                    'alleles': p1_gamete_alleles,
                    'positions': [m.physical_position for m in chrom.markers]
                }
                # Parent 2's gamete becomes the offspring's paternal chromosome
                offspring.paternal_chroms[chrom.id] = {
                    'alleles': p2_gamete_alleles,
                    'positions': [m.physical_position for m in chrom.markers]
                }
            else:
                # Parent 2's gamete becomes the offspring's maternal chromosome
                offspring.maternal_chroms[chrom.id] = {
                    'alleles': p2_gamete_alleles,
                    'positions': [m.physical_position for m in chrom.markers]
                }
                # Parent 1's gamete becomes the offspring's paternal chromosome
                offspring.paternal_chroms[chrom.id] = {
                    'alleles': p1_gamete_alleles,
                    'positions': [m.physical_position for m in chrom.markers]
                }

        # Consolidate crossover data after both gametes have been produced
        self.all_true_crossovers = self.current_simulation_crossovers_info['parent1']['all_true'] + \
                                   self.current_simulation_crossovers_info['parent2']['all_true']

        self.detected_crossovers = self.current_simulation_crossovers_info['parent1']['detected'] + \
                                   self.current_simulation_crossovers_info['parent2']['detected']

        # Calculate blind spots: any true crossover that was not detected
        detected_genetic_positions_rounded = set()
        for d in self.detected_crossovers:
            detected_genetic_positions_rounded.add((d['chromosome'], round(d['genetic_position_approx'], 2)))

        self.blind_spot_crossovers = []
        for tc in self.all_true_crossovers:
            if (tc['chromosome'], round(tc['genetic_position'], 2)) not in detected_genetic_positions_rounded:
                self.blind_spot_crossovers.append(tc)


        return offspring

# %%
def main(cli_args=None):
    """
    Main function to run the chromosome recombination simulation.
    It handles argument parsing, simulation setup, running crosses,
    calculating metrics, and saving results.

    Args:
        cli_args (list, optional): A list of strings representing command-line arguments.
                                   Useful for testing in notebooks without using `!python`.
                                   If None, `argparse` will parse `sys.argv`.
    """
    parser = argparse.ArgumentParser(description='Advanced Chromosome Recombination Simulator')
    parser.add_argument('--n_chromosomes', type=int, default=4,
                        help='Number of chromosomes (default: 4)')
    parser.add_argument('--chromosome_sizes', type=str, default='1.0,0.8,0.6,0.4',
                        help='Comma-separated relative chromosome sizes (default: 1.0,0.8,0.6,0.4)')
    parser.add_argument('--n_markers', type=int, default=20,
                        help='Total number of markers (default: 20)')
    parser.add_argument('--marker_distribution', choices=['uniform', 'random'], default='uniform',
                        help='Marker distribution method (default: uniform)')
    parser.add_argument('--use_poisson', action='store_true',
                        help='Use Poisson process for crossovers (default: False, i.e., simplified recombination model)')
    parser.add_argument('--no_centimorgan', action='store_false', dest='use_centimorgan', default=True,
                        help='Do NOT use centiMorgan genetic distances for recombination calculation (default: False, i.e., will use cM)')
    parser.add_argument('--allele_freq_file', type=str, default=None,
                        help='CSV file with allele frequencies (columns: marker_id, allele_0_freq_pop1, allele_0_freq_pop2)')
    parser.add_argument('--random_seed', type=int, default=None,
                        help='Seed for random number generators for reproducibility (default: None)')
    parser.add_argument('--output', type=str, default='simulation_results.json',
                        help='Output file name (default: simulation_results.json)')
    parser.add_argument('--fixed_crossovers', type=str, default=None,
                        help='Integer number of crossovers for ALL chromosomes, or comma-separated fixed crossover counts per chromosome (e.g., "1:1,2:2"). Used when --use_poisson is FALSE AND --custom_crossover_dist is NOT set.')
    parser.add_argument('--custom_crossover_dist', type=str, default=None,
                        help='Custom discrete probability distribution for crossover counts per gamete per chromosome (e.g., "0:0.2,1:0.8"). Used when --use_poisson is FALSE.')


    args = parser.parse_args(cli_args) # Pass cli_args if running from notebook, otherwise parses sys.argv

    # Parse chromosome sizes from string argument
    chromosome_sizes = [float(x) for x in args.chromosome_sizes.split(',')]

    # Parse custom_crossover_dist argument
    custom_crossover_counts = None
    custom_crossover_probs = None
    if args.custom_crossover_dist:
        parsed_counts = []
        parsed_probs = []
        try:
            for item in args.custom_crossover_dist.split(','):
                num_str, prob_str = item.split(':')
                count = int(num_str)
                prob = float(prob_str)
                if prob < 0:
                    raise ValueError("Probabilities cannot be negative.")
                parsed_counts.append(count)
                parsed_probs.append(prob)
            
            # Normalize probabilities in case sum is not exactly 1.0 due to float precision
            total_prob = sum(parsed_probs)
            if total_prob > 0:
                parsed_probs = [p / total_prob for p in parsed_probs]
            else:
                raise ValueError("Sum of probabilities must be greater than zero.")

            custom_crossover_counts = parsed_counts
            custom_crossover_probs = parsed_probs

        except Exception as e:
            print(f"Warning: Could not parse --custom_crossover_dist '{args.custom_crossover_dist}'. Error: {e}. It will be ignored.")
            custom_crossover_counts = None
            custom_crossover_probs = None

    # Parse fixed_crossovers argument (remains for fallback)
    fixed_crossover_config = None
    if args.fixed_crossovers:
        try:
            fixed_crossover_config = int(args.fixed_crossovers)
        except ValueError:
            fixed_crossover_config_dict = {}
            for item in args.fixed_crossovers.split(','):
                if ':' in item:
                    try:
                        chrom_id_str, count_str = item.split(':')
                        chrom_id = int(chrom_id_str)
                        count = int(count_str)
                        fixed_crossover_config_dict[chrom_id] = count
                    except ValueError:
                        print(f"Warning: Invalid fixed_crossovers format for '{item}'. Expected 'chromosome_id:count' (e.g., '1:1'). Skipping this entry.")
                else:
                    print(f"Warning: Invalid fixed_crossovers format for '{item}'. Expected 'chromosome_id:count' or a single integer (e.g., '1'). Skipping this entry.")
            
            if fixed_crossover_config_dict:
                fixed_crossover_config = fixed_crossover_config_dict
            else:
                print(f"Warning: Could not parse --fixed_crossovers '{args.fixed_crossovers}'. It will be ignored.")


    # Initialise the recombination simulator
    sim = RecombinationSimulator(
        n_chromosomes=args.n_chromosomes,
        chromosome_sizes=chromosome_sizes,
        n_markers=args.n_markers,
        marker_distribution=args.marker_distribution,
        use_poisson=args.use_poisson,
        use_centimorgan=args.use_centimorgan,
        allele_freq_file=args.allele_freq_file,
        random_seed=args.random_seed,
        fixed_crossover_config=fixed_crossover_config,
        custom_crossover_counts=custom_crossover_counts,
        custom_crossover_probs=custom_crossover_probs
    )

    # Step 1: Create the chromosomes with their physical and genetic properties
    print("Creating chromosomes...")
    sim.create_chromosomes()

    # Step 2: Assign markers to the chromosomes
    print("Assigning markers to chromosomes...")
    sim.assign_markers_to_chromosomes()

    # Step 3: Create the founding populations (pure lines)
    print("Creating founding populations (Parent A and Parent B)...")
    parent_A = Individual("P_A", sim.chromosomes)
    parent_A.initialise_ancestral_chromosomes(ancestry=0) # Pure Population 0 ancestry

    parent_B = Individual("P_B", sim.chromosomes)
    parent_B.initialise_ancestral_chromosomes(ancestry=2) # Pure Population 2 ancestry

    # Step 4: Simulate the F1 hybrid generation
    print("Simulating the F1 (First Filial) generation by crossing P_A x P_B...")
    offspring_F1 = sim.simulate_recombination(parent_A, parent_B)

    # Calculate and print metrics for the F1 offspring
    hybrid_index_F1 = sim.calculate_hybrid_index(offspring_F1)
    heterozygosity_F1 = sim.calculate_heterozygosity(offspring_F1)

    print(f"\n--- F1 Generation Results (Offspring {offspring_F1.id}) ---")
    print(f"F1 Hybrid Index (Proportion from Pop 0): {hybrid_index_F1:.3f}")
    print(f"F1 Heterozygosity: {heterozygosity_F1:.3f}")

    # Print summary of the simulation parameters and the recombination events from the F1 cross
    sim.print_summary()

    # Step 5: Simulate a Backcross generation (e.g., F1 x P_A)
    print("\nSimulating a Backcross generation (e.g., F1 x P_A)...")
    # For a backcross, one parent is an F1 individual, the other is a pure parent
    bc_offspring = sim.simulate_recombination(offspring_F1, parent_A)

    # Calculate and print metrics for the Backcross offspring
    bc_hybrid_index = sim.calculate_hybrid_index(bc_offspring)
    bc_heterozygosity = sim.calculate_heterozygosity(bc_offspring)

    print(f"\n--- Backcross (F1 x P_A) Results (Offspring {bc_offspring.id}) ---")
    print(f"BC Hybrid Index (Proportion from Pop 0): {bc_hybrid_index:.3f}")
    print(f"BC Heterozygosity: {bc_heterozygosity:.3f}")

    # Print summary of the simulation parameters and recombination events from the BC cross
    sim.print_summary()

    # Step 6: Save results to a JSON file
    # Compile all relevant results into a dictionary for saving
    results_to_save = {
        'F1_generation': {
            'offspring_id': offspring_F1.id,
            'hybrid_index': hybrid_index_F1,
            'heterozygosity': heterozygosity_F1,
            'crossovers_info_F1': {
                # Ensure we capture parent-specific crossover info accurately for F1
                'parent1_all_true': sim.current_simulation_crossovers_info['parent1']['all_true'],
                'parent1_detected': sim.current_simulation_crossovers_info['parent1']['detected'],
                'parent2_all_true': sim.current_simulation_crossovers_info['parent2']['all_true'],
                'parent2_detected': sim.current_simulation_crossovers_info['parent2']['detected']
            },
            'offspring_chromosomes': {
                'maternal_chroms': {k: v for k, v in offspring_F1.maternal_chroms.items()},
                'paternal_chroms': {k: v for k, v in offspring_F1.paternal_chroms.items()}
            }
        },
        'BC_generation': {
            'offspring_id': bc_offspring.id,
            'hybrid_index': bc_hybrid_index,
            'heterozygosity': bc_heterozygosity,
            'crossovers_info_BC': {
                'all_true_crossovers': sim.all_true_crossovers, # This holds for the *last* cross
                'detected_crossovers': sim.detected_crossovers, # This holds for the *last* cross
                'blind_spot_crossovers': sim.blind_spot_crossovers # This holds for the *last* cross
            },
            'offspring_chromosomes': {
                'maternal_chroms': {k: v for k, v in bc_offspring.maternal_chroms.items()},
                'paternal_chroms': {k: v for k, v in bc_offspring.paternal_chroms.items()}
            }
        },
        'chromosome_info': [asdict(chrom) for chrom in sim.chromosomes]
    }

    with open(args.output, 'w') as f:
        json.dump(results_to_save, f, indent=2, cls=NpEncoder)

    print(f"\nAll simulation results saved to '{args.output}'")


# %%
def main_toy_run():
    """
    Runs a small, fixed simulation for testing and debugging purposes in a Jupyter environment.
    This bypasses argparse and directly sets parameters for a quick test.
    """
    print("--- Running Toy Simulation (main_toy_run) ---")
    
    # Define fixed parameters for the toy run
    n_chromosomes = 1
    n_markers = 5
    chromosome_sizes = [1.0] # Only one chromosome
    marker_distribution = 'uniform'
    use_poisson = False # SET THIS TO FALSE to use the custom distribution
    random_seed = 42 # A fixed seed for reproducible toy runs
    output_filename = 'toy_simulation_results.json'

    # Define a custom discrete distribution for crossover counts for the toy run
    # Example: 20% chance of 0 crossovers, 80% chance of 1 crossover
    custom_counts_toy = [0, 1]
    custom_probs_toy = [0.2, 0.8] # Sum should be 1.0 (or close)

    # Initialise the recombination simulator
    sim = RecombinationSimulator(
        n_chromosomes=n_chromosomes,
        chromosome_sizes=chromosome_sizes,
        n_markers=n_markers,
        marker_distribution=marker_distribution,
        use_poisson=use_poisson,
        random_seed=random_seed,
        custom_crossover_counts=custom_counts_toy,
        custom_crossover_probs=custom_probs_toy
        # fixed_crossover_config is not used if custom_crossover_counts is provided
    )

    # Step 1: Create the chromosomes
    print("\nCreating chromosomes for toy run...")
    sim.create_chromosomes(base_length=10_000, base_genetic_length=10.0) # Smaller lengths for toy

    # Step 2: Assign markers
    print("Assigning markers to chromosomes for toy run...")
    sim.assign_markers_to_chromosomes()

    # Step 3: Create founding populations
    print("Creating founding populations (Parent A and Parent B) for toy run...")
    parent_A = Individual("P_A_Toy", sim.chromosomes)
    parent_A.initialise_ancestral_chromosomes(ancestry=0)

    parent_B = Individual("P_B_Toy", sim.chromosomes)
    parent_B.initialise_ancestral_chromosomes(ancestry=2)

    # Step 4: Simulate the F1 hybrid generation
    print("Simulating F1 generation (P_A x P_B) for toy run...")
    offspring_F1 = sim.simulate_recombination(parent_A, parent_B)

    # Calculate and print metrics for the F1 offspring
    hybrid_index_F1 = sim.calculate_hybrid_index(offspring_F1)
    heterozygosity_F1 = sim.calculate_heterozygosity(offspring_F1)

    print(f"\n--- Toy F1 Generation Results (Offspring {offspring_F1.id}) ---")
    print(f"F1 Hybrid Index (Proportion from Pop 0): {hybrid_index_F1:.3f}")
    print(f"F1 Heterozygosity: {heterozygosity_F1:.3f}")

    # Print summary of the simulation parameters and the recombination events from the F1 cross
    sim.print_summary()

    # Save results to a JSON file (using the NpEncoder defined at the top)
    try:
        results_to_save = {
            'Toy_F1_generation': {
                'offspring_id': offspring_F1.id,
                'hybrid_index': hybrid_index_F1,
                'heterozygosity': heterozygosity_F1,
                'crossovers_info_F1': {
                    # Capture both parents' crossover info from the F1 cross
                    'parent1_all_true': sim.current_simulation_crossovers_info['parent1']['all_true'],
                    'parent1_detected': sim.current_simulation_crossovers_info['parent1']['detected'],
                    'parent2_all_true': sim.current_simulation_crossovers_info['parent2']['all_true'],
                    'parent2_detected': sim.current_simulation_crossovers_info['parent2']['detected']
                },
                'offspring_chromosomes': {
                    'maternal_chroms': {k: v for k, v in offspring_F1.maternal_chroms.items()},
                    'paternal_chroms': {k: v for k, v in offspring_F1.paternal_chroms.items()}
                }
            },
            'chromosome_info': [asdict(chrom) for chrom in sim.chromosomes]
        }

        with open(output_filename, 'w') as f:
            json.dump(results_to_save, f, indent=2, cls=NpEncoder)
        print(f"\nToy simulation results saved to '{output_filename}'")
    except Exception as e:
        print(f"\nError saving toy simulation results to JSON: {e}")

    print("--- Toy Simulation Finished ---")


# %%
# --- Final Execution Block ---
# This block allows you to run the simulation by calling the main functions.
# In Jupyter, ensure all cells above this have been run before executing this cell.

if __name__ == "__main__":
    # --- Uncomment one of the following lines to run ---

    # 1. Run the small toy example for debugging and testing in Jupyter:
    main_toy_run() # This will run a toy simulation with the 0:0.2, 1:0.8 custom distribution

    # 2. Run the full simulation with default command-line parameters (as defined in argparse defaults):
    #    This would effectively result in 0 crossovers if --use_poisson is not present
    # main() 

    # 3. Run the full simulation with custom command-line parameters:
    #    Example: Use the custom distribution (20% for 0 CO, 80% for 1 CO) for 4 chromosomes.
    # main(cli_args=[
    #     '--n_chromosomes', '4',
    #     '--n_markers', '50',
    #     # IMPORTANT: Do NOT include '--use_poisson' here to activate the custom distribution
    #     '--random_seed', '123',
    #     '--output', 'custom_crossover_dist_results.json',
    #     '--chromosome_sizes', '1.0,0.8,0.6,0.4',
        
    #     # Specify the custom discrete crossover distribution
    #     # Format: "count1:prob1,count2:prob2,..."
    #     '--custom_crossover_dist', '0:0.2,1:0.8' 
    # ])

    # Example: To force 1 crossover on ALL chromosomes (using the fixed_crossovers argument)
    # main(cli_args=[
    #     '--n_chromosomes', '4',
    #     '--n_markers', '50',
    #     '--random_seed', '123',
    #     '--output', 'fixed_1_crossover_run_results.json',
    #     '--chromosome_sizes', '1.0,0.8,0.6,0.4',
    #     '--fixed_crossovers', '1' # This will set 1 crossover for EACH of your 4 chromosomes
    # ])

    # Example: To force 2 crossovers on Chr1 and 0 on Chr2 (using fixed_crossovers with per-chrom control)
    # main(cli_args=[
    #     '--n_chromosomes', '2',
    #     '--n_markers', '20',
    #     '--random_seed', '123',
    #     '--output', 'fixed_per_chrom_results.json',
    #     '--chromosome_sizes', '1.0,1.0',
    #     '--fixed_crossovers', '1:2,2:0' # Chr1: 2 COs, Chr2: 0 COs
    # ])

--- Running Toy Simulation (main_toy_run) ---
Set random seed to 42
Generating random allele frequencies...

Creating chromosomes for toy run...
Assigning markers to chromosomes for toy run...
Creating founding populations (Parent A and Parent B) for toy run...
Simulating F1 generation (P_A x P_B) for toy run...
  Crossing P_A_Toy with P_B_Toy to produce offspring offspring_2824...

--- Toy F1 Generation Results (Offspring offspring_2824) ---
F1 Hybrid Index (Proportion from Pop 0): 0.500
F1 Heterozygosity: 1.000

--- Simulation Summary ---
Number of Chromosomes: 1
Total Markers: 5
Marker Distribution: uniform
Crossover Model: Fixed/Custom
  Custom Crossover Distribution (counts:probabilities): 0:0.20, 1:0.80
Using cM distances for recombination: True
Random Seed: Set to 42

--- Recombination Events (Last Cross) ---
Total true crossovers (occurred): 1
Total detected crossovers (observable): 0
Total blind spot crossovers (undetected): 1

True Crossovers by Chromosome:
chromosome
1    1
