## Parental Care Model (with Teleportation case to limit mate guarding)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
from dataclasses import dataclass, field
from typing import List, Dict
from scipy.stats import ttest_rel, linregress
import seaborn as sns
from collections import defaultdict


# 1. Constant Defining
# Grid Dimensions
GRID_WIDTH = 50
GRID_HEIGHT = 50  
# Population and Sexes
INITIAL_POPULATION = 100
# Hence, big grid, low density.
SEX_MALE = 0
SEX_FEMALE = 1
# Simulation Parameters
BENEFIT_PARENTAL_CARE = 1.3 #30% benefit
GRAVID_PERIOD = 14 #for female unavailability
NEIGHBORHOOD_RADIUS = 8 #How much a male can see around itself
HIGH_QUALITY_PERCENTILE = 75 #Top 25 percentile of population
FECUNDITY_MULTIPLIER = 5.0
STUCK_THRESHOLD = 5 # Days a male can be stationary before exploring
# A parameter to control the "fuzziness" of female choice.
CHOICE_FUZZINESS = 10.0




@dataclass
class Individual:
    """Individual in grid"""
    id: int
    sex: int
    quality: float
    position: List[int]
    fitness: float = 0.0
    is_gravid_days: int = 0
    next_position: List[int] = field(default_factory=list)
    days_stationary: int = 0

    def __post_init__(self):
        self.next_position = list(self.position)

#3. Definition of Helping functions

def softmax(x: np.ndarray, temp: float) -> np.ndarray:
    """Computes softmax probabilities for a given array of scores."""
    if not len(x):
        return np.array([])
    e_x = np.exp((x - np.max(x)) * temp)
    return e_x / e_x.sum()

def initialize_population(n_total: int) -> Dict[int, Individual]:
    """Creates the initial population with one female per territory."""
    population = {}
    occupied_female_territories = set()

    num_females = n_total // 2
    for i in range(num_females):
        while True:
            pos_tuple = (random.randint(0, GRID_WIDTH - 1), random.randint(0, GRID_HEIGHT - 1))
            if pos_tuple not in occupied_female_territories:
                occupied_female_territories.add(pos_tuple)
                break
        quality = np.random.beta(3, 7) #because low quality individuals are more common.
        population[i] = Individual(id=i, sex=SEX_FEMALE, quality=quality, position=list(pos_tuple))

    num_males = n_total - num_females
    for i in range(num_females, n_total):
        pos = [random.randint(0, GRID_WIDTH - 1), random.randint(0, GRID_HEIGHT - 1)]
        quality = np.random.beta(5, 5)
        population[i] = Individual(id=i, sex=SEX_MALE, quality=quality, position=pos)

    return population

def find_neighbors(pos: List[int], radius: int) -> List[List[int]]:
    """Finds all valid grid coordinates within a given radius of a position."""
    neighbors = []
    x, y = pos
    for dx in range(-radius, radius + 1):
        for dy in range(-radius, radius + 1):
            if dx == 0 and dy == 0: continue
            nx, ny = x + dx, y + dy
            if 0 <= nx < GRID_WIDTH and 0 <= ny < GRID_HEIGHT:
                neighbors.append([nx, ny])
    return neighbors

def create_next_generation(zygotebook: list, current_pop_size: int, dispersal_rate: float) -> Dict[int, Individual]:
    """Creates the next generation, ensuring females get unique territories."""
    next_generation = {}
    if not zygotebook: return next_generation

    total_zygotes = sum(c['num_zygotes'] for c in zygotebook)
    if total_zygotes <= 0: return next_generation

    scaling_factor = current_pop_size / total_zygotes

    male_offspring, female_offspring = [], []
    next_id = 0
    for clutch in zygotebook:
        num_to_create = int(np.round(clutch['num_zygotes'] * scaling_factor))
        for _ in range(num_to_create):
            if next_id >= current_pop_size: break

            #px, py = clutch['hatch_pos']
            #new_pos_x = int(np.clip(random.uniform(px - dispersal_rate, px + dispersal_rate), 0, GRID_WIDTH - 1))
            #new_pos_y = int(np.clip(random.uniform(py - dispersal_rate, py + dispersal_rate), 0, GRID_HEIGHT - 1))
            # This is the new, global dispersal
            new_pos_x = random.randint(0, GRID_WIDTH - 1)
            new_pos_y = random.randint(0, GRID_HEIGHT - 1)

            sex = random.choice([SEX_MALE, SEX_FEMALE])
            offspring = Individual(id=next_id, sex=sex, quality=clutch['avg_quality'], position=[new_pos_x, new_pos_y])

            if sex == SEX_FEMALE: female_offspring.append(offspring)
            else: male_offspring.append(offspring)
            next_id += 1

    occupied_territories = set()
    for female in female_offspring:
        if tuple(female.position) in occupied_territories:
            available_spots = [(x, y) for x in range(GRID_WIDTH) for y in range(GRID_HEIGHT) if (x, y) not in occupied_territories]
            if not available_spots: continue
            female.position = list(random.choice(available_spots))

        occupied_territories.add(tuple(female.position))
        next_generation[female.id] = female

    for male in male_offspring:
        next_generation[male.id] = male

    return next_generation

def run_breeding_season(population: Dict[int, Individual], days: int, dispersal_rate: float):
    """
    Simulates one full breeding season.
    MODIFIED: Includes "female teleportation" to separate mate guarding from parental care.
    """
    interaction_log = {}
    zygotebook = []
    all_males = [ind for ind in population.values() if ind.sex == SEX_MALE]
    all_females = [ind for ind in population.values() if ind.sex == SEX_FEMALE]

    for day in range(days):
        # Male Decision Phase
        for male in all_males:
            potential_mates = [fem for fem in all_females if fem.is_gravid_days == 0 and fem.position in find_neighbors(male.position, NEIGHBORHOOD_RADIUS)]
            desirabilities = {}
            for female in potential_mates:
                competitors = [m for m in all_males if m.position == female.position]
                if not competitors: continue

                competitor_qualities = np.array([c.quality for c in competitors])
                probabilities = softmax(competitor_qualities, CHOICE_FUZZINESS)
                try:
                    male_prob_idx = competitors.index(male)
                    probability_of_winning = probabilities[male_prob_idx]
                    desirability = female.quality * probability_of_winning
                    desirabilities[female.id] = desirability
                except ValueError:
                    pass

            best_new_prospect_desirability = max(desirabilities.values()) if desirabilities else 0

            if best_new_prospect_desirability > 0:
                best_female_id = max(desirabilities, key=desirabilities.get)
                male.next_position = population[best_female_id].position
            else:
                male.next_position = list(male.position)

            if male.next_position == male.position: male.days_stationary += 1
            else: male.days_stationary = 0

            if male.days_stationary > STUCK_THRESHOLD:
                is_providing_care = False
                female_in_square = next((f for f in all_females if f.position == male.position), None)
                if female_in_square and female_in_square.is_gravid_days > 0:
                    if (male.id, female_in_square.id) in interaction_log:
                        is_providing_care = True

                if not is_providing_care:
                    neighbors = find_neighbors(male.position, NEIGHBORHOOD_RADIUS)
                    if neighbors:
                        male.next_position = random.choice(neighbors)
                    male.days_stationary = 0

        # Action Phase
        for male in all_males:
            male.position = list(male.next_position)

        for female in all_females:
            if female.is_gravid_days > 0: continue
            males_in_territory = [m for m in all_males if m.position == female.position]
            if not males_in_territory: continue

            male_qualities = np.array([m.quality for m in males_in_territory])
            mating_probs = softmax(male_qualities, CHOICE_FUZZINESS)
            chosen_male = random.choices(males_in_territory, weights=mating_probs, k=1)[0]

            chosen_male.days_stationary = 0

            female.is_gravid_days = GRAVID_PERIOD
            num_zygotes = female.quality * FECUNDITY_MULTIPLIER
            zygote_quality = np.random.normal((female.quality + chosen_male.quality) / 2, 0.1)
            clutch = {"mother_id": female.id, "father_id": chosen_male.id, "num_zygotes": num_zygotes,
                      "avg_quality": np.clip(zygote_quality, 0, 1), "hatch_pos": list(female.position),
                      "care_successful": False}
            zygotebook.append(clutch)
            pair_key = (chosen_male.id, female.id)
            if pair_key not in interaction_log:
                interaction_log[pair_key] = {'male_q': chosen_male.quality, 'female_q': female.quality, 'rounds_since_mating': 1, 'mating_count': 1}
            else:
                interaction_log[pair_key]['mating_count'] += 1
                interaction_log[pair_key]['rounds_since_mating'] = 1

       
        
        # Daily Update with Female Teleportation
        current_female_positions = {tuple(f.position) for f in all_females}

        for female in all_females:
            # Check if she was gravid on the previous day and is now becoming available
            if female.is_gravid_days == 1:
                female.is_gravid_days -= 1 # She is now at 0 and available

                #TELEPORTATION LOGIC
                # Remove her old position from the set of occupied spots
                if tuple(female.position) in current_female_positions:
                    current_female_positions.remove(tuple(female.position))

                # Find a new, unoccupied position for her
                while True:
                    new_pos_tuple = (random.randint(0, GRID_WIDTH - 1), random.randint(0, GRID_HEIGHT - 1))
                    if new_pos_tuple not in current_female_positions:
                        female.position = list(new_pos_tuple)
                        current_female_positions.add(new_pos_tuple)
                        break

            # If she is still gravid for more than one day, just count down
            elif female.is_gravid_days > 1:
                female.is_gravid_days -= 1

        

        for pair_key, log_entry in interaction_log.items():
            male_id, female_id = pair_key
            if male_id in population and female_id in population:
                male, female = population[male_id], population[female_id]
                if male.position == female.position and any(c['father_id'] == male_id and c['mother_id'] == female_id for c in zygotebook):
                     log_entry['rounds_since_mating'] += 1

    # End of Season Calculations
    raw_offspring_count = sum(c['num_zygotes'] for c in zygotebook)

    for clutch in zygotebook:
        pair_key = (clutch['father_id'], clutch['mother_id'])
        if pair_key in interaction_log and interaction_log[pair_key]['rounds_since_mating'] >= GRAVID_PERIOD:
            clutch['care_successful'] = True
            bonus_offspring = clutch['num_zygotes'] * (BENEFIT_PARENTAL_CARE - 1)
            clutch['num_zygotes'] += bonus_offspring

    for ind in population.values(): ind.fitness = 0
    for clutch in zygotebook:
        if clutch['mother_id'] in population and clutch['father_id'] in population:
            population[clutch['mother_id']].fitness += clutch['num_zygotes']
            population[clutch['father_id']].fitness += clutch['num_zygotes']

    next_gen_pop = create_next_generation(zygotebook, INITIAL_POPULATION, dispersal_rate)

    return next_gen_pop, interaction_log, population, raw_offspring_count


# 4. ANALYSIS AND PLOTTING FUNCTIONS
def analyze_and_print_stats(population, gen_num):
    males = [ind for ind in population.values() if ind.sex == SEX_MALE]
    if not males: return
    male_qualities = [m.quality for m in males]
    high_q_threshold = np.percentile(male_qualities, HIGH_QUALITY_PERCENTILE)
    low_q_males = [m for m in males if m.quality < high_q_threshold]
    high_q_males = [m for m in males if m.quality >= high_q_threshold]
    avg_low_q_fitness = np.mean([m.fitness for m in low_q_males]) if low_q_males else 0
    avg_high_q_fitness = np.mean([m.fitness for m in high_q_males]) if high_q_males else 0
    print(f"--- Generation {gen_num + 1} Stats ---")
    print(f"Average Fitness (Low Quality Males): {avg_low_q_fitness:.2f}")
    print(f"Average Fitness (High Quality Males): {avg_high_q_fitness:.2f}")
    print("-" * 20)

def analyze_parental_care(all_logs, all_populations):
    low_q_times, high_q_times = [], []
    for gen_num, interaction_log in enumerate(all_logs):
        population = all_populations[gen_num]
        males = [ind for ind in population.values() if ind.sex == SEX_MALE]
        if not males:
            low_q_times.append(0); high_q_times.append(0)
            continue
        male_qualities = [m.quality for m in males]
        high_q_threshold = np.percentile(male_qualities, HIGH_QUALITY_PERCENTILE)
        gen_low_q, gen_high_q = [], []
        for log in interaction_log.values():
            if log['male_q'] < high_q_threshold: gen_low_q.append(log['rounds_since_mating'])
            else: gen_high_q.append(log['rounds_since_mating'])
        low_q_times.append(np.mean(gen_low_q) if gen_low_q else 0)
        high_q_times.append(np.mean(gen_high_q) if gen_high_q else 0)
    return low_q_times, high_q_times

def plot_parental_care(low_quality_times, high_quality_times):
    num_generations = len(low_quality_times)
    if num_generations < 2: return
    generations_axis = np.arange(1, num_generations + 1)
    plt.figure(figsize=(12, 7))
    plt.plot(generations_axis, low_quality_times, marker='o', linestyle='-', color='b', label='Low Quality Males')
    plt.plot(generations_axis, high_quality_times, marker='o', linestyle='-', color='r', label='High Quality Males')
    if num_generations > 1:
        slope_l, intercept_l, _, _, _ = linregress(generations_axis, low_quality_times)
        plt.plot(generations_axis, slope_l * generations_axis + intercept_l, '--', color='cornflowerblue')
        slope_h, intercept_h, _, _, _ = linregress(generations_axis, high_quality_times)
        plt.plot(generations_axis, slope_h * generations_axis + intercept_h, '--', color='lightcoral')
    plt.title('Average Parental Care Duration Over Generations'); plt.xlabel('Generation'); plt.ylabel('Average Time Spent with Partner (Days)')
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.6); plt.tight_layout(); plt.show()

    if len(low_quality_times) > 1 and len(high_quality_times) > 1:
        t_stat, p_value = ttest_rel(low_quality_times, high_quality_times)
        print("\n--- Paired T-test Results for Parental Care ---")
        print(f"T-statistic: {t_stat:.4f}"); print(f"P-value: {p_value:.4f}")
        if p_value < 0.05: print("The difference in parental care is statistically significant.")
        else: print("The difference in parental care is not statistically significant.")

def plot_grid_heatmap(population, title):
    if not population: return
    male_grid, female_grid = np.zeros((GRID_WIDTH, GRID_HEIGHT)), np.zeros((GRID_WIDTH, GRID_HEIGHT))
    for ind in population.values():
        x, y = ind.position
        if ind.sex == SEX_MALE: male_grid[y, x] += 1
        else: female_grid[y, x] += 1
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    fig.suptitle(title, fontsize=16)
    sns.heatmap(male_grid, ax=ax1, cmap="Blues", linewidths=.5, linecolor='lightgray', cbar_kws={'label': 'Number of Males'})
    ax1.set_title('Male Distribution'); ax1.set_aspect('equal')
    sns.heatmap(female_grid, ax=ax2, cmap="Reds", linewidths=.5, linecolor='lightgray', cbar_kws={'label': 'Number of Females'})
    ax2.set_title('Female Distribution'); ax2.set_aspect('equal')
    plt.tight_layout(rect=[0, 0, 1, 0.96]); plt.show()

def plot_male_quality_evolution(all_populations):
    mean_qualities, std_devs = [], []
    for pop in all_populations:
        male_qualities = [ind.quality for ind in pop.values() if ind.sex == SEX_MALE]
        if male_qualities:
            mean_qualities.append(np.mean(male_qualities))
            std_devs.append(np.std(male_qualities))
        else:
            mean_qualities.append(0); std_devs.append(0)
    generations_axis = np.arange(1, len(mean_qualities) + 1)
    mean_qualities, std_devs = np.array(mean_qualities), np.array(std_devs)
    plt.figure(figsize=(12, 7))
    plt.plot(generations_axis, mean_qualities, color='green', label='Mean Male Quality')
    plt.fill_between(generations_axis, mean_qualities - std_devs, mean_qualities + std_devs, color='green', alpha=0.2, label='1 Std. Dev.')
    plt.title('Evolution of Male Quality Over Generations'); plt.xlabel('Generation'); plt.ylabel('Male Quality')
    plt.ylim(0, 1); plt.legend(); plt.grid(True, linestyle='--', alpha=0.6); plt.show()

def plot_fitness_quality_evolution(all_populations):
    slopes = []
    for pop in all_populations:
        males = [ind for ind in pop.values() if ind.sex == SEX_MALE]
        if len(males) < 2:
            slopes.append(0)
            continue
        qualities = [m.quality for m in males]
        fitnesses = [m.fitness for m in males]
        slope, _, _, _, _ = linregress(qualities, fitnesses)
        slopes.append(slope)

    plt.figure(figsize=(12, 7))
    plt.plot(range(1, len(slopes) + 1), slopes, marker='.', linestyle='-', color='darkorange')
    plt.axhline(0, color='grey', linestyle='--', lw=1)
    plt.title('Evolution of Selection on Male Quality')
    plt.xlabel('Generation')
    plt.ylabel('Slope of Fitness vs. Quality Regression')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

def plot_monogamy_evolution(all_logs, all_populations):
    slopes = []
    for gen_num, interaction_log in enumerate(all_logs):
        population = all_populations[gen_num]
        male_mating_data = defaultdict(lambda: {'total_matings': 0, 'partner_counts': defaultdict(int)})
        for (male_id, female_id), log_entry in interaction_log.items():
            count = log_entry['mating_count']
            male_mating_data[male_id]['total_matings'] += count
            male_mating_data[male_id]['partner_counts'][female_id] += count

        gen_results = []
        for male_id, data in male_mating_data.items():
            if data['total_matings'] > 0:
                max_matings_with_one_partner = max(data['partner_counts'].values())
                monogamy_index = max_matings_with_one_partner / data['total_matings']
                if male_id in population:
                    gen_results.append({'quality': population[male_id].quality, 'index': monogamy_index})

        if len(gen_results) < 2:
            slopes.append(0)
            continue

        qualities = [d['quality'] for d in gen_results]
        indices = [d['index'] for d in gen_results]
        slope, _, _, _, _ = linregress(qualities, indices)
        slopes.append(slope)

    plt.figure(figsize=(12, 7))
    plt.plot(range(1, len(slopes) + 1), slopes, marker='.', linestyle='-', color='crimson')
    plt.axhline(0, color='grey', linestyle='--', lw=1)
    plt.title('Evolution of Monogamy Strategy vs. Quality')
    plt.xlabel('Generation')
    plt.ylabel('Slope of Monogamy Index vs. Quality Regression')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

# 5. Runner
def run_full_simulation(num_generations=50, breeding_season_length=60, dispersal_rate=8.0):
    all_final_populations, all_interaction_logs, all_raw_offspring_counts = [], [], []
    population = initialize_population(INITIAL_POPULATION)

    for gen in range(num_generations):
        if not population:
            print("Population died out. Ending simulation.")
            break
        print(f"Running Generation {gen + 1}/{num_generations}...")
        if gen == 0:
            plot_grid_heatmap(population, f"Initial Population Distribution (Generation {gen + 1})")

        next_gen_pop, interaction_log, final_pop_state, raw_offspring = run_breeding_season(population, breeding_season_length, dispersal_rate)

        analyze_and_print_stats(final_pop_state, gen)

        all_final_populations.append(final_pop_state)
        all_interaction_logs.append(interaction_log)
        

        population = next_gen_pop

    if population:
        plot_grid_heatmap(population, f"Final Population Distribution (Start of Gen {num_generations + 1})")

    print("\n--- Final Analysis Across All Generations ---")
    low_q_times, high_q_times = analyze_parental_care(all_interaction_logs, all_final_populations)
    plot_parental_care(low_q_times, high_q_times)
    plot_male_quality_evolution(all_final_populations)
    plot_fitness_quality_evolution(all_final_populations)
    plot_monogamy_evolution(all_interaction_logs, all_final_populations)


if __name__ == '__main__':
    run_full_simulation(num_generations=300, dispersal_rate=20.0)