---

In [None]:
import json          
import pandas as pd
import matplotlib.pyplot as plt
import sqlite3
import numpy as np
import sys
from pathlib import Path
from rich.console import Console

import warnings
import canonical_toolkit as ctk

In [None]:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DATA_FOLDER = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In [None]:
warnings.filterwarnings("ignore", message="n_jobs value.*overridden.*")                                       
console = Console()                                                                                           
                                                                                                            
# 1. Check if DATA_FOLDER is already set                                                                      
if DATA_FOLDER:                                                                                               
    data_path = Path(DATA_FOLDER)                                                                             
    ea_folder = data_path.parent.parent  # __data__/run_xxx -> ea/                                            
                                                                                                            
# 2. Try to find run_history.csv in current directory                                                         
elif (Path.cwd() / "run_history.csv").exists():                                                               
    df = pd.read_csv("run_history.csv")                                                                       
    data_path = Path(df["output_folder"].iloc[-1])                                                            
    ea_folder = Path.cwd()                                                                                    
                                                                                                            
# 3. Assume notebook is inside output folder, walk up to find config.py                                       
else:                                                                                                         
    ea_folder = Path.cwd()                                                                                    
    while ea_folder != ea_folder.parent:                                                                      
        if (ea_folder / "config.py").exists():                                                                
            break                                                                                             
        ea_folder = ea_folder.parent                                                                          
    else:                                                                                                     
        raise FileNotFoundError("Could not find ea/config.py in any parent directory")                        
    data_path = Path.cwd()                                                                                    
                                                                                                            
sys.path.insert(0, str(ea_folder.parent))                                                                     
from ea.config import Config                                                                                  
                                                                                                            
print(data_path)                                                                                              
data = pd.read_sql("SELECT * FROM individual", sqlite3.connect(data_path / "database.db"))                    
config = Config.load(data_path)                                                                               
config.large_description()

In [None]:
tags_expanded = data['tags_'].apply(lambda x: json.loads(x) if isinstance(x, str) else x).apply(pd.Series)                                        
data = pd.concat([data, tags_expanded], axis=1)                                                                                                   
                                                                                                                                                                                                                                                     
data['gen'] = data.apply(                                                                                                                         
    lambda row: list(range(int(row['time_of_birth']), int(row['time_of_death']) + 1)),                                                            
    axis=1                                                                                                                                        
)   
                                                                                                                                              
gen_df = (data                                                                                                                                                     
    .explode('gen')                                                                                                                                                
    .rename(columns={'fitness_': 'fitness', 'genotype_': 'genotype', 'tags_': 'tags'})                                                                             
    .sort_values(['gen', 'ctk_string'], ascending=[True, True])                                                                                                    
)                                                                                                                                               
                                                                                                                                                
# Add rank within each generation                                                                                                                 
gen_df['rank'] = gen_df.groupby('gen').cumcount()                                                                                                                  
gen_df = gen_df.set_index(['gen', 'rank'])
gen_df.head()

In [None]:
survivors_df = gen_df.reset_index()
survivors_df = survivors_df[survivors_df['time_of_death'] > survivors_df['gen']]
survivors_df = survivors_df.set_index(['gen', 'rank'])
survivors_df.head()

In [None]:
killed_df = gen_df.reset_index()
killed_df = killed_df[killed_df['time_of_death'] == killed_df['gen']]
killed_df = killed_df.set_index(['gen', 'rank'])
killed_df.head()

In [None]:
def plot_fittest_robots(gen_df, by, amt, config):
    # 1. Identify the last generation
    last_gen = gen_df['gen'].max()
    
    # 2. Filter for only the last generation and sort by fitness
    # We sort descending if IS_MAXIMISATION is True to get the "fittest" at the top
    fittest_df = gen_df[gen_df['gen'] == last_gen].sort_values(
        by=by, 
        ascending=not config.IS_MAXIMISATION
    ).head(amt)

    # 3. Prepare data for the plotter
    # GridPlotter usually expects a 2D list [rows][columns]
    images = []
    titles = []

    for row in fittest_df.itertuples():
        img = ctk.quick_view(
            ctk.node_from_string(row.ctk_string).to_graph(),
            return_img=True,
            white_background=True
        )
        images.append(img)
        
        metric_val = getattr(row, by)
        titles.append(f"ID {row.id} | {by}={metric_val:.3f}")

    # 4. Initialize and configure the plotter
    plotter = ctk.GridPlotter()
    plotter.config.title_size = 10
    plotter.config.margin = (0.3, 0, 0, 0)
    plotter.config.col_space = 0.23
    plotter.config.dpi = 300
    
    # Wrap in lists to create a single-row 2D structure for the plotter
    plotter.add_2D_image_data([images], titles_2d=[titles])
    
    plotter.suptitle(f"Top {amt} Fittest Robots (Gen {last_gen}) by {by.capitalize()}", font_size=8)
    
    return plotter

---

### Plot fitness + 'stored values'

In [None]:
def plot_metrics(gen_df, metrics=list[str], is_max=True):                                                                                          
    gen_grouped = gen_df.groupby(level='gen')                                                                                                                            
    x = np.arange(gen_grouped.ngroups)                                                                                                                                                                                                                                                                                       
    fig, axes = plt.subplots(len(metrics), 1, figsize=(10, 3*len(metrics)), sharex=True)                                                                                                                                                                                                                                    
    for ax, key in zip(axes, metrics):                                                                                                                                   
        mean, std = gen_grouped[key].mean().values, gen_grouped[key].std().values                                                                                        
        best = gen_grouped[key].max().values if is_max else gen_grouped[key].min().values                                                                                
                                                                                                                                                                        
        ax.plot(x, mean, 'b-', lw=2, label='Mean')                                                                                                                       
        ax.fill_between(x, mean - std, mean + std, alpha=0.2, color='blue')                                                                                              
        ax.plot(x, best, 'g--', lw=1.5, label='Best')                                                                                                                    
        ax.set_ylabel(key.capitalize())                                                                                                                                  
        ax.legend(loc='upper right')                                                                                                                                     
        # ax.grid(alpha=0.3)                                                                                                                                               
                                                                                                                                                                        
    axes[-1].set_xlabel('Generation')                                                                                                                                    
    plt.tight_layout()     
    fig.dpi = 300                                                                                                                                              
    plt.show()                                                                                                                                                           
                

In [None]:
to_plot = ['fitness']
if config.STORE_NOVELTY:
    to_plot.append('novelty')
if config.STORE_SPEED:
    to_plot.append('speed')
                                                                                                                                                                
plot_metrics(survivors_df, metrics=to_plot, is_max=config.IS_MAXIMISATION)  

### Livespan Fittest Individuals

In [None]:
import matplotlib.pyplot as plt                                                                                                                   
import matplotlib.cm as cm                                                                                                                        
import numpy as np                                                                                                                                
                    

def plot_top_lifespans(gen_df,*, is_maximalisation: bool = True, column='fitness', top_x=5, title=None):                                                                    
    """Plot lifespan of individuals who were ever in top X of any generation."""                                                                                                                                                                                                          
    df = gen_df.reset_index()                                                                                                                     
                                                                                                                                                
    # For each generation, get the top X individual IDs                                                                                           
    top_per_gen = (df                                                                                                                             
        .sort_values([column], ascending=not is_maximalisation)                                                                              
        .groupby('gen')                                                                                                                           
        .head(top_x)                                                                                                                              
    )                                                                                                                                             
                                                                                                                                                
    # Get all unique individuals who were ever in top X                                                                                           
    top_individuals = top_per_gen['id'].unique()                                                                                       
                                                                                                                                                
    # Get their full lifespan data                                                                                                                
    lifespan_data = (df[df['id'].isin(top_individuals)]                                                                                
        .groupby('id')                                                                                                                 
        .agg({                                                                                                                                    
            column: 'first',                                                                                                                      
            'gen': ['min', 'max']                                                                                                                 
        })                                                                                                                                        
    )                                                                                                                                             
    lifespan_data.columns = [column, 'birth', 'death']                                                                                            
    lifespan_data = lifespan_data.sort_values(column, ascending=not is_maximalisation)                                                       
                                                                                                                                                
    # Plot                                                                                                                                        
    fig, ax = plt.subplots(figsize=(20, 8))                                                                                                       
                                                                                                                                                
    # Color map                                                                                                                                   
    n = len(lifespan_data)                                                                                                                        
    colors = cm.viridis(np.linspace(0, 1, n))                                                                                                     
                                                                                                                                                
    for i, (ind_id, row) in enumerate(lifespan_data.iterrows()):                                                                                  
        # Horizontal line from birth to death                                                                                                     
        ax.hlines(y=row[column], xmin=row['birth'], xmax=row['death'],                                                                            
                color=colors[i], linewidth=2, alpha=0.7)                                                                                        
                                                                                                                                                
        # Markers                                                                                                                                 
        # ax.scatter(row['birth'], row[column], color=colors[i], s=20, marker='o', zorder=0)                                                        
        # ax.scatter(row['death'], row[column], color=colors[i], s=20, marker='X', zorder=0)                                                        
                                                                                                                                                
        # Label                                                                                                                                   
        ax.annotate(f'{int(ind_id)}', (row['death'] + 0.2, row[column]), fontsize=7, va='center')                                                 
                                                                                                                                                
    # Mark which generations each was in top X                                                                                                    
    for gen in df['gen'].unique():                                                                                                                
        gen_top = top_per_gen[top_per_gen['gen'] == gen]['id'].values                                                                  
        for ind_id in gen_top:                                                                                                                    
            fit = lifespan_data.loc[ind_id, column]                                                                                               
            ax.scatter(gen, fit, color='red', s=20, marker='s', alpha=0.5, zorder=4)                                                              
                                                                                                                                                
    ax.set_xlabel('Generation')                                                                                                                   
    ax.set_ylabel(column)                                                                                                                         
    ax.set_title(title or f'Individuals Ever in Top {top_x} (in top {top_x} that gen)')                                                     
    ax.grid(True, alpha=0.3)                                                                                                                      
                                                                                                                                                
    from matplotlib.lines import Line2D                                                                                                           
    legend_elements = [                                                                                                                           
        # Line2D([0], [0], marker='o', color='gray', label='Birth', markersize=8, linestyle=''),                                                    
        # Line2D([0], [0], marker='X', color='gray', label='Death', markersize=8, linestyle=''),                                                    
        Line2D([0], [0], marker='s', color='red', label=f'In top {top_x}', markersize=8, linestyle='', alpha=0.5),                                
    ]                                                                                                                                             
    ax.legend(handles=legend_elements, loc='lower right')                                                                                         
    fig.dpi = 300                                                                                                                                                
    plt.tight_layout()                                                                                                                            
    return fig, ax   

In [None]:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
column = 'fitness'
top_x = 10
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In [None]:
plot_top_lifespans(survivors_df, is_maximalisation=config.IS_MAXIMISATION, column=column, top_x=top_x)

In [None]:
import matplotlib.pyplot as plt                                                                                                                   
                                                                                                                             
def plot_lifespan_analysis(gen_df, fitness_xlim=None):
    # 1. Prepare the data (handling the MultiIndex by resetting)
    temp_df = gen_df.reset_index()
    
    # 2. Identify the max generation present in the data
    max_gen = temp_df['gen'].max()
    
    # 3. Aggregate individual history
    individuals = temp_df.groupby('id').agg({
        'fitness': 'first',
        'gen': ['min', 'max']
    })
    
    individuals.columns = ['fitness', 'birth', 'death']
    individuals['lifespan'] = individuals['death'] - individuals['birth']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # --- Plot 1: Lifespan distribution ---
    ax1.hist(individuals['lifespan'], bins=range(int(max_gen) + 2), 
             edgecolor='black', color='skyblue', alpha=0.7)
    ax1.set_xlabel('Lifespan (generations)')
    ax1.set_ylabel('Count (Number of Individuals)')
    ax1.set_title('Distribution of Individual Longevity')
    
    # --- Plot 2: Fitness vs lifespan ---
    ax2.scatter(individuals['fitness'], individuals['lifespan'], 
                alpha=0.4, s=20, label='Individual Data')
    
    # Add horizontal lines for Gen 0 and Max Gen
    ax2.axhline(y=0, color='red', linestyle='--', linewidth=1, label='Birth (Gen 0)', zorder=0)
    ax2.axhline(y=max_gen, color='green', linestyle='--', linewidth=1, label=f'End (Gen {max_gen})', zorder=0)
    
    ax2.set_xlabel('Fitness')
    ax2.set_ylabel('Lifespan (Total Generations Alive)')
    ax2.set_title('Fitness vs Lifespan')
    
    # Place legend to the side or inside
    ax2.legend(loc='upper left', frameon=True)
    
    if fitness_xlim:
        ax2.set_xlim(fitness_xlim)
        
    plt.tight_layout()
    plt.show()


In [None]:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fitness_lim = (
    0, 
    1
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In [None]:
plot_lifespan_analysis(gen_df, fitness_xlim=fitness_lim)

In [None]:
def high_res_robot_gens(gen_df, config, top_n=3, n_samples=5, by='fitness') -> ctk.GridPlotter:
    """
    Plot top/bottom robots across sampled generations based on a specific metric.

    Args:
        gen_df: DataFrame with MultiIndex (gen, rank)
        config: Config object with IS_MAXIMISATION, NUM_GENERATIONS
        top_n: Positive for best, negative for worst
        n_samples: Number of generations to sample
        by: The column name to sort and display (e.g., 'fitness', 'speed', 'novelty')

    Returns:
        GridPlotter object
    """
    generations = np.linspace(0, config.NUM_GENERATIONS, n_samples, dtype=int)
    n = abs(top_n)

    data_2d = [[] for _ in range(n)]
    titles_2d = [[] for _ in range(n)]

    # Sort by generation and the chosen metric (by)
    df_sorted = gen_df.sort_values(
        by=['gen', by],
        ascending=[True, not config.IS_MAXIMISATION]
    )

    for gen in generations:
        gen_data = df_sorted.loc[gen]
        selection = gen_data.head(n) if top_n > 0 else gen_data.tail(n)

        for j, row in enumerate(selection.itertuples()):
            img = ctk.quick_view(
                ctk.node_from_string(row.ctk_string).to_graph(),
                return_img=True,
                white_background=True
            )
            data_2d[j].append(img)

            # Dynamically get the value for the chosen metric
            metric_val = getattr(row, by)
            titles_2d[j].append(f"Gen {gen} | ID {row.id} | {by}={metric_val:.3f}")

    plotter = ctk.GridPlotter()
    plotter.config.title_size = 10
    plotter.config.margin = (0.3, 0, 0, 0)
    plotter.config.col_space = 0.23
    plotter.config.dpi = 300
    plotter.add_2D_image_data(data_2d, titles_2d=titles_2d)

    label = "Best" if (top_n > 0) == config.IS_MAXIMISATION else "Worst"
    plotter.suptitle(f"{label} {n} Robots by {by.capitalize()} Across Generations", font_size=8)

    return plotter

In [None]:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
amt = 3
gen_samples = 5
# or fitness/ novelty
col_name = 'fitness'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In [None]:
plotter = high_res_robot_gens(gen_df, config, top_n=amt, n_samples=gen_samples, by=col_name)                                                                                             
plotter.show()