## IMPORTS

In [None]:

import warnings
import logging

# Suppress Bokeh warnings about missing renderers
warnings.filterwarnings('ignore', message='.*MISSING_RENDERERS.*')
warnings.filterwarnings('ignore', category=UserWarning, module='bokeh')

# Suppress Bokeh logger warnings
logging.getLogger('bokeh').setLevel(logging.ERROR)

import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List
# Import the functions from view_mujoco
from ariel_experiments.gui_vis.view_mujoco import (
    load_or_generate_cache,  # Convenience function
)

from sklearn.feature_extraction import FeatureHasher

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap
import umap.plot
from bokeh.io import output_notebook
from bokeh.layouts import gridplot
from bokeh.models import (
    BasicTicker,
    ColorBar,
    ColumnDataSource,
    HoverTool,
    LinearColorMapper,
)
from bokeh.plotting import figure, show
from matplotlib.colors import to_hex
from PIL import Image
from rich.console import Console
from rich.progress import track
from scipy.spatial.distance import pdist, squareform
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import KernelDensity
from sqlalchemy import create_engine, text

import canonical_toolkit as ctk
from ariel_experiments.utils.io_canon_pop import load_ctk_strings_from_database

# from ariel_experiments.gui_vis.view_mujoco import view
from ariel_experiments.utils.initialize import generate_random_individual
from dataclasses import dataclass
from PIL import ImageDraw, ImageFont
from bokeh.models import Div
from bokeh.layouts import column

console = Console()
output_notebook()


#### FUNCTION CONFIG CLASSES

In [None]:
@dataclass
class RobotSubplot:
    title: str
    under_title: str     # smaller title which is <br> under it
    idxs: list[int]      # idxs to plot in this order
    
    img_under_title: list[str] | None = None
    img_under_title_fontsize: int = 10

    # Font sizes (matplotlib uses plain numbers, not "10pt")
    axis_label_fontsize: int | str = 8     # int or named size like 'small'
    tick_fontsize: int | str = 6
    title_fontsize: int | str = 10
    under_title_fontsize: int | str = 20

    title_fontweight: str = 'normal'       # 'normal' not 'regular'
    under_title_fontweight: str = 'normal'

In [None]:
@dataclass
class EmbedSubplot:
    title: str
    embeddings: list[Any]  # idk what umap returns actually
    idxs: list[int]  # idxs that match the embeddings
    hover_data: list[Any]  # additional data i want to show for my tooltip

    # Dot size settings
    default_dot_size: int = 8  # Size for regular (non-highlighted) dots
    highlight_dot_size: int = 8  # Size for highlighted/followed dots
    # Font size settings
    axis_label_fontsize: str = "8pt"  # Font size for axis labels
    tick_fontsize: str = "6pt"  # Font size for tick labels
    title_fontsize: str = "10pt"  # Font size for subplot title
    hover_fontsize: str = "10px"  # Font size for hover tooltip (use px for HTML)


## FUNCTIONS

#### plot robot plotter

In [None]:

def _stitch_images_with_fixed_spacing(images: list[np.ndarray], target_height: int, gap_px: int = 15, img_under_titles: list[str] | None = None, text_height_px: int = 25, text_fontsize: int = 14) -> np.ndarray:
    """
    Stitch images horizontally with FIXED white gaps.
    Images keep their ORIGINAL widths, only HEIGHT is padded to target_height.
    Optionally adds text under each image (aligned at the bottom of tallest image).
    Handles transparency by compositing onto white background.
    """
    if not images:
        return np.zeros((100, 100, 3), dtype=np.uint8)
    
    # Calculate new target height if we're adding text
    actual_target_height = target_height + text_height_px if img_under_titles else target_height
    
    # Normalize all images to RGB uint8
    normalized_imgs = []
    
    for i, img in enumerate(images):
        # Handle float 0-1
        if np.issubdtype(img.dtype, np.floating):
            img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        
        # Handle Alpha channel - composite onto white background
        if len(img.shape) == 3 and img.shape[2] == 4:
            rgb = img[:, :, :3]
            alpha = img[:, :, 3:4] / 255.0
            white_bg = np.full_like(rgb, 255, dtype=np.uint8)
            img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
        elif len(img.shape) == 3 and img.shape[2] >= 3:
            img = img[:, :, :3]
        
        normalized_imgs.append(img)
    
    # Process images and add text AFTER padding to target_height
    processed_imgs = []
    
    for i, img in enumerate(normalized_imgs):
        h, w = img.shape[:2]
        
        # First, pad the image to target_height (all images same height)
        if h < target_height:
            pad = np.full((target_height - h, w, 3), 255, dtype=np.uint8)
            img = np.vstack((img, pad))
        
        # Convert to PIL
        img_pil = Image.fromarray(img)
        
        # Now add text under the padded image (so all text aligns at same height)
        if img_under_titles and i < len(img_under_titles):
            # Create canvas with extra space for text
            canvas = Image.new('RGB', (w, target_height + text_height_px), (255, 255, 255))
            canvas.paste(img_pil, (0, 0))
            
            # Draw text at the bottom
            draw = ImageDraw.Draw(canvas)
            text = img_under_titles[i]
            
            # Try to use a default font with custom size, fallback to PIL default
            try:
                font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", text_fontsize)
            except:
                font = ImageFont.load_default()
            
            # Get text bounding box for centering
            bbox = draw.textbbox((0, 0), text, font=font)
            text_w = bbox[2] - bbox[0]
            text_x = (w - text_w) // 2
            
            # Place text right after target_height
            draw.text((text_x, target_height + 5), text, fill=(0, 0, 0), font=font)
            img = np.array(canvas)
        else:
            img = np.array(img_pil)
        
        processed_imgs.append(img)
    
    # Create white gap column with actual target height
    white_gap = np.full((actual_target_height, gap_px, 3), 255, dtype=np.uint8)
    
    # Stitch with FIXED gaps
    stitched = None
    for img in processed_imgs:
        if stitched is None:
            stitched = img
        else:
            stitched = np.hstack((stitched, white_gap, img))
    
    return stitched

In [None]:
def plot_robot_grid(
    sub_plots: list[list[RobotSubplot]],
    cache_dir: str | Path = "__data__/img",
    max_full_width: int = 16, 
    subplot_height: int = 2.5,   
    main_title: str | None = None,
    robot_gap_px: int = 20,    
    dpi: int = 300
):
    cache_path = Path(cache_dir)

    # Step 1: Collect all unique robot indices
    all_robot_idxs = set()
    for row_subplots in sub_plots:
        for subplot in row_subplots:
            all_robot_idxs.update(subplot.idxs)

    # Step 2: Load images
    robot_images = {}
    global_max_h = 0

    for robot_idx in all_robot_idxs:
        try:
            img = Image.open(cache_path / f"robot_{robot_idx:04d}.png")
            img_array = np.array(img)
            robot_images[robot_idx] = img_array
            global_max_h = max(global_max_h, img_array.shape[0])
        except FileNotFoundError:
            # console.print(f"[yellow]Warning: robot_{robot_idx:04d}.png not found[/yellow]")
            robot_images[robot_idx] = np.zeros((100, 100, 4), dtype=np.uint8)
            global_max_h = max(global_max_h, 100)

    # Step 3: Prepare stitched images
    n_rows = len(sub_plots)
    n_cols = len(sub_plots[0]) if n_rows > 0 else 0

    stitched_images = []
    global_max_stitched_w = 0

    for row_subplots in sub_plots:
        row_stitched = []
        for subplot in row_subplots:
            images = [robot_images[robot_idx] for robot_idx in subplot.idxs]

            if images:
                stitched = _stitch_images_with_fixed_spacing(
                    images,
                    global_max_h,
                    robot_gap_px,
                    img_under_titles=subplot.img_under_title,
                    text_fontsize=subplot.img_under_title_fontsize
                )
                row_stitched.append(stitched)
                global_max_stitched_w = max(global_max_stitched_w, stitched.shape[1])
            else:
                row_stitched.append(None)
        stitched_images.append(row_stitched)

    # Step 4: Create grid
    fig, axes = plt.subplots(
        n_rows, n_cols,
        figsize=(max_full_width, subplot_height * n_rows),
        squeeze=False,
        facecolor='white',
        dpi=dpi
    )

    if main_title:
        fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.9, color='black')

    # Step 5: Plot
    for row_idx, row_subplots in enumerate(sub_plots):
        for col_idx, subplot in enumerate(row_subplots):
            ax = axes[row_idx, col_idx]
            ax.set_facecolor('white')

            stitched = stitched_images[row_idx][col_idx]

            if stitched is not None:
                actual_width = stitched.shape[1]
                x_offset = (global_max_stitched_w - actual_width) / 2

                ax.imshow(stitched, extent=[
                    x_offset,
                    x_offset + actual_width,
                    global_max_h,
                    0
                ])

            ax.set_xlim(0, global_max_stitched_w)
            ax.set_ylim(global_max_h, 0) 
            ax.set_aspect('equal') 
            ax.axis('off')

            # Add title and under_title
            title_text = f"{subplot.title}\n{subplot.under_title}"
            ax.set_title(
                title_text,
                fontsize=subplot.title_fontsize,
                fontweight=subplot.title_fontweight,
                pad=10,          # <--- INCREASED from 10 to 25 (Space between title and image)
                color='black'
            )

    plt.subplots_adjust(
        wspace=0.1, 
        hspace=0.4,
        # CHANGE 2: top=0.85 (was 0.90). Pushes the actual charts down to make room for the lower title.
        top=0.87 if main_title else 0.94, 
        bottom=0.04,
        left=0.1, 
        right=0.9 
    )
    plt.show()

#### plot UMAP

In [None]:
def plot_interactive_umap_grid(
    sub_plots: list[list[EmbedSubplot]], 
    population_thumbnails: list[str],
    follow_idx_list: list[int] | None = None,
    plot_width: int | None = None, 
    plot_height: int = 200,
    super_title: str | None = None
):
    """
    Creates a Grid of Interactive UMAP Scatter plots.
    
    Args:
        sub_plots: 2D list of EmbedSubplot objects defining the grid
        population_thumbnails: Pre-generated base64 thumbnail strings
        follow_idx_list: Indices to highlight (black border, opaque). Others are faded.
        plot_width: Width of each subplot (auto if None)
        plot_height: Height of each subplot
        super_title: Overall title for the grid
    """
    n_robots = len(population_thumbnails)
    n_rows = len(sub_plots)
    n_cols = len(sub_plots[0]) if n_rows > 0 else 0

    if plot_width is None:
        plot_width = 700 if n_cols == 1 else 200

    # Setup colors and sizes (rainbow with transparency)
    rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n_robots))
    rgba_colors[:, 3] = 0.4  # Base transparency

    sizes = [8] * n_robots
    line_colors = [None] * n_robots

    # Apply follow_idx highlighting
    if follow_idx_list:
        follow_set = set(follow_idx_list)
        line_colors = ['black' if i in follow_set else None for i in range(n_robots)]

        for i in range(n_robots):
            if i in follow_set:
                rgba_colors[i, 3] = 1.0  # Fully opaque
            else:
                rgba_colors[i] = [0.0, 0.0, 0.0, 0.2]  # Faded grey

    hex_colors = [to_hex(c, keep_alpha=True) for c in rgba_colors]

    # Build grid
    grid_layout = []

    for row_subplots in sub_plots:
        row_plots = []

        for subplot in row_subplots:
            emb = np.array(subplot.embeddings)

            # Validation
            if emb.ndim != 2 or emb.shape[1] != 2:
                p = figure(
                    title=f"{subplot.title} (Invalid Data)",
                    width=plot_width,
                    height=plot_height
                )
                row_plots.append(p)
                continue

            n_points = len(subplot.idxs)

            # Build DataFrame with data for this subplot
            robots_df = pd.DataFrame({
                "x": emb[:, 0],
                "y": emb[:, 1],
                "digit": [str(idx) for idx in subplot.idxs],
                "image": [population_thumbnails[idx] for idx in subplot.idxs],
                "color": [hex_colors[idx] for idx in subplot.idxs],
                "size": [sizes[idx] for idx in subplot.idxs],
                "line_color": [line_colors[idx] for idx in subplot.idxs],
                "hover_info": subplot.hover_data if subplot.hover_data else [""] * n_points
            })

            # Sort so highlighted points render on top
            if follow_idx_list:
                robots_df['sort_order'] = [
                    1 if idx in set(follow_idx_list) else 0
                    for idx in subplot.idxs
                ]
                robots_df = robots_df.sort_values('sort_order', ascending=True)

            source = ColumnDataSource(robots_df)

            # Create figure
            p = figure(
                title=subplot.title,
                width=plot_width,
                height=plot_height,
                tools="pan,wheel_zoom,reset,save",
                toolbar_location="above"
            )

            p.scatter(
                'x', 'y',
                source=source,
                color='color',
                line_alpha=1,
                line_color='line_color',
                line_width=1,
                size='size'
            )

            # Hover tooltip
            hover = HoverTool(tooltips="""
                <div>
                    <img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/>
                </div>
                <div style="font-size:12px; font-weight: bold;">
                    <span style='color:#224499'>ID: @digit</span><br>
                    <span style='color:#333'>@hover_info</span>
                </div>
            """)
            p.add_tools(hover)
            row_plots.append(p)

        grid_layout.append(row_plots)

    # Show with optional super title
    if super_title:
        # Note: Bokeh doesn't have native super titles, but you can print it
        console.print(f"[bold blue]{super_title}[/bold blue]")

    show(gridplot(grid_layout))

#### umap2?

In [None]:
def plot_interactive_umap_grid(
    sub_plots: list[list[EmbedSubplot]], 
    population_thumbnails: list[str],
    follow_idx_list: list[int] | None = None,
    plot_width: int | None = None, 
    plot_height: int = 200,
    max_full_width: int = 800,
    super_title: str | None = None,
    upscale: int | None = None,
    global_axis: bool = True
):
    """
    Creates a Grid of Interactive UMAP Scatter plots.
    
    Args:
        sub_plots: 2D list of UmapSubplot objects defining the grid
        population_thumbnails: Pre-generated base64 thumbnail strings
        follow_idx_list: Indices to highlight (black border, opaque). Others are faded.
        plot_width: Width of each subplot. If None, auto-calculated from max_full_width.
        plot_height: Height of each subplot
        max_full_width: Maximum total width of the grid (used when plot_width=None)
        super_title: Overall title for the grid
        upscale: If provided, upscales tooltip images by this factor (e.g., 2 = 2x larger)
        global_axis: If True, all subplots share the same axis range (default: True)
    """
    n_robots = len(population_thumbnails)
    n_rows = len(sub_plots)
    n_cols = len(sub_plots[0]) if n_rows > 0 else 0

    # Auto-calculate plot_width if not specified
    if plot_width is None:
        spacing_per_plot = 20  # Approximate spacing between plots
        available_width = max_full_width - (spacing_per_plot * (n_cols - 1))
        plot_width = max(150, available_width // n_cols)  # Min 150px per plot

    # Upscale thumbnails if requested
    if upscale is not None and upscale != 1:
        def upscale_thumbnail(b64_str: str, scale: float) -> str:
            """Upscale a base64 thumbnail by a factor."""
            header, encoded = b64_str.split(",", 1)
            data = base64.b64decode(encoded)
            img = Image.open(BytesIO(data))

            new_size = (int(img.width * scale), int(img.height * scale))
            img_scaled = img.resize(new_size, Image.Resampling.BICUBIC)

            buffer = BytesIO()
            img_scaled.save(buffer, format='PNG', optimize=False, compress_level=1)
            return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

        display_thumbnails = [upscale_thumbnail(thumb, upscale) for thumb in track(population_thumbnails, description=f"Upscaling thumbnails {upscale}x...", disable=True)]
    else:
        display_thumbnails = population_thumbnails

    # Setup colors (rainbow with transparency)
    rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n_robots))
    rgba_colors[:, 3] = 0.4  # Base transparency

    # Calculate global axis ranges if requested
    global_x_range = None
    global_y_range = None
    
    if global_axis:
        x_min, x_max = float('inf'), float('-inf')
        y_min, y_max = float('inf'), float('-inf')
        
        for row_subplots in sub_plots:
            for subplot in row_subplots:
                emb = np.array(subplot.embeddings)
                if emb.size > 0 and emb.ndim == 2 and emb.shape[1] == 2 and emb.shape[0] > 0:
                    x_min = min(x_min, emb[:, 0].min())
                    x_max = max(x_max, emb[:, 0].max())
                    y_min = min(y_min, emb[:, 1].min())
                    y_max = max(y_max, emb[:, 1].max())
        
        # Add padding (10% on each side)
        if x_min != float('inf'):
            x_padding = (x_max - x_min) * 0.1
            y_padding = (y_max - y_min) * 0.1
            global_x_range = (x_min - x_padding, x_max + x_padding)
            global_y_range = (y_min - y_padding, y_max + y_padding)

    # Build grid
    grid_layout = []

    for row_subplots in sub_plots:
        row_plots = []

        for subplot in row_subplots:
            emb = np.array(subplot.embeddings)

            # Validation - create clean white placeholder for empty/invalid data
            if emb.size == 0 or emb.ndim != 2 or emb.shape[1] != 2 or emb.shape[0] == 0:
                p = figure(
                    title=subplot.title,
                    width=plot_width,
                    height=plot_height,
                    toolbar_location=None,
                    x_range=global_x_range if global_x_range else None,
                    y_range=global_y_range if global_y_range else None
                )
                # Remove grid lines and axes for clean white appearance
                p.xgrid.visible = False
                p.ygrid.visible = False
                p.xaxis.visible = False
                p.yaxis.visible = False
                p.outline_line_color = None
                p.background_fill_color = "white"
                p.border_fill_color = "white"
                
                # Apply font size settings for title consistency
                p.title.text_font_size = subplot.title_fontsize
                
                row_plots.append(p)
                continue

            n_points = len(subplot.idxs)

            # Setup sizes and line colors per-subplot based on settings
            sizes = [subplot.default_dot_size] * n_robots
            line_colors = [None] * n_robots

            # Apply follow_idx highlighting
            if follow_idx_list:
                follow_set = set(follow_idx_list)
                sizes = [subplot.highlight_dot_size if i in follow_set else subplot.default_dot_size for i in
range(n_robots)]
                line_colors = ['black' if i in follow_set else None for i in range(n_robots)]

                # Update colors for highlighting
                for i in range(n_robots):
                    if i in follow_set:
                        rgba_colors[i, 3] = 1.0  # Fully opaque
                    else:
                        rgba_colors[i] = [0.0, 0.0, 0.0, 0.2]  # Faded grey

            hex_colors = [to_hex(c, keep_alpha=True) for c in rgba_colors]

            # Build DataFrame with data for this subplot
            robots_df = pd.DataFrame({
                "x": emb[:, 0],
                "y": emb[:, 1],
                "digit": [str(idx) for idx in subplot.idxs],
                "image": [display_thumbnails[idx] for idx in subplot.idxs],
                "color": [hex_colors[idx] for idx in subplot.idxs],
                "size": [sizes[idx] for idx in subplot.idxs],
                "line_color": [line_colors[idx] for idx in subplot.idxs],
                "hover_info": subplot.hover_data if subplot.hover_data else [""] * n_points
            })

            # Sort so highlighted points render on top
            if follow_idx_list:
                robots_df['sort_order'] = [
                    1 if idx in set(follow_idx_list) else 0
                    for idx in subplot.idxs
                ]
                robots_df = robots_df.sort_values('sort_order', ascending=True)

            source = ColumnDataSource(robots_df)

            # Create figure with global axis ranges if enabled
            p = figure(
                title=subplot.title,
                width=plot_width,
                height=plot_height,
                tools="pan,wheel_zoom,reset,save",
                toolbar_location="above",
                x_range=global_x_range if global_x_range else None,
                y_range=global_y_range if global_y_range else None
            )

            p.scatter(
                'x', 'y',
                source=source,
                color='color',
                line_alpha=1,
                line_color='line_color',
                line_width=1,
                size='size'
            )

            # Apply font size settings from subplot
            p.title.text_font_size = subplot.title_fontsize
            p.xaxis.axis_label_text_font_size = subplot.axis_label_fontsize
            p.yaxis.axis_label_text_font_size = subplot.axis_label_fontsize
            p.xaxis.major_label_text_font_size = subplot.tick_fontsize
            p.yaxis.major_label_text_font_size = subplot.tick_fontsize

            # Hover tooltip
            hover = HoverTool(tooltips=f"""
                  <div>
                      <img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/>
                  </div>
                  <div style="font-size:{subplot.hover_fontsize}; font-weight: bold;">
                      <span style='color:#224499'>ID: @digit</span><br>
                      <span style='color:#333'>@hover_info</span>
                  </div>
              """)
            p.add_tools(hover)
            row_plots.append(p)

        grid_layout.append(row_plots)

    grid = gridplot(grid_layout)
    # Show with optional super title
    if super_title:
        title_div = Div(
            text=f"<h2 style='text-align: center; font-weight: bold; color: #224499; margin-bottom:10px;'>{super_title}</h2>",
            width=max_full_width,
            height=40
        )
        show(column(title_div, grid))
    else:
        show(grid)


#### sequential robots

In [None]:
def plot_sequential_robots(total=100, group_size=5, cols=4, title="Robot Analysis"):
    # 1. Generate flat list of subplots
    flat_subplots = [
        RobotSubplot(
            title="", 
            under_title="", 
            idxs=[x for x in range(i, i + group_size)]
        )
        for i in range(0, total, group_size)
    ]

    # 2. Reshape into grid (list of lists)
    grid = [flat_subplots[i : i + cols] for i in range(0, len(flat_subplots), cols)]

    # 3. Plot
    plot_robot_grid(grid, main_title=title)

---

## PARAMETERS

#### CREATE population

In [None]:
# Hinge types
made_up_robots = {
    'robot_hinge_long' : ctk.node_from_string('C[l(HHHHHH)r(HHHHHH)f(HHHHHH)b(HHHHHH)]'),
    'robot_hinge_short' : ctk.node_from_string('C[l(H)r(H)f(H)b(H)]'),
    'robot_hinge_2' : ctk.node_from_string('C[f(H)b(H)]'),
    'robot_hinge_1' : ctk.node_from_string('C[f(H)]'),
}

population = list(made_up_robots.values())

#### static globals

In [None]:
USE_DB_POPULATION = True

CACHE_DIR = "__data__/img"
N_JOBS = 1
THUMBNAIL_SCALE = 0.3 


#### loaded globals

In [None]:
if USE_DB_POPULATION:
    string_population = load_ctk_strings_from_database()
    population = [ctk.node_from_string(ind) for ind in string_population][:100]

POPULATION_NODES = population #type: ignore

POPULATION_GRAPHS = [ind.to_graph() for ind in POPULATION_NODES]
POPULATION_STRINGS = [ind.to_string() for ind in POPULATION_NODES]

POPULATION_THUMBNAILS, index_df = load_or_generate_cache(
    POPULATION_GRAPHS, 
    robot_names=POPULATION_STRINGS, 
    cache_dir=CACHE_DIR,
    scale=THUMBNAIL_SCALE,
    parallel=True,
    max_workers=N_JOBS
)

# assert set(index_df['robot_name']) == set(POPULATION_STRINGS)

#### show robots

In [None]:
plot_sequential_robots(total=100, group_size=5, cols=4, title='first 100')

#### collect the subtrees

In [None]:
SIM_CONFIG = ctk.SimilarityConfig()

subtrees = {}

subtrees['core'] = [
    ctk.collect_hash_fingerprint(individual, config=SIM_CONFIG)
    for individual in POPULATION_NODES
]

In [None]:
def process_limbs(population, part_name, prefix):
    results = []
    for ind in population:
        limb = ind.get(part_name)
        if limb:
            limb.detatch_from_parent()
            # Process immediately while we have the limb
            data = ctk.collect_hash_fingerprint(limb, config=SIM_CONFIG, hash_prefix=prefix)
            results.append(data)
        else:
            results.append(None)
    return results


In [None]:
population_copy = [ind.copy() for ind in POPULATION_NODES]

parts_config = {
    'front': 'f_',
    'left':  'l_',
    'back':  'b_',
    'right': 'r_'
}

subtrees.update({
    part: process_limbs(population_copy, part, prefix) 
    for part, prefix in parts_config.items()
})

console.print(subtrees)

#### process dicts to matrixes

In [None]:
from sklearn.feature_extraction import FeatureHasher

# 1. Setup
hasher = FeatureHasher(
    n_features=2**20,
    input_type="string", 
)

matrixes = {}

# We need to track these to build the 'all' matrices later
global_radii = set()
n_robots = len(next(iter(subtrees.values()))) if subtrees else 0

# ---------------------------------------------------------
# PHASE 1: Individual Body Parts (Core, Front, Left, etc.)
# ---------------------------------------------------------
for part_name, population_list in subtrees.items():
    matrixes[part_name] = {}
    
    part_radii = set()
    agg_corpus = []  # All tokens for this specific part
    
    for ind in population_list:
        if ind is not None:
            part_radii.update(ind.keys())
            global_radii.update(ind.keys()) # Track for global usage
            
            # Flatten all tokens for this part
            tokens = [t for token_list in ind.values() for t in token_list]
            agg_corpus.append(tokens)
        else:
            agg_corpus.append([])

    # 1a. Aggregated Matrix for this part
    matrixes[part_name]['all'] = hasher.fit_transform(agg_corpus)

    # 1b. Radius-specific Matrices for this part
    for r in sorted(part_radii):
        layer_corpus = []
        for ind in population_list:
            if ind is not None and r in ind:
                layer_corpus.append(ind[r])
            else:
                layer_corpus.append([])
        matrixes[part_name][r] = hasher.fit_transform(layer_corpus)

# ---------------------------------------------------------
# PHASE 2: The 'all' Super-Part (Everything Combined)
# ---------------------------------------------------------
matrixes['all'] = {}

# 2a. The Grand Aggregation (Every token, every part, every radius)
grand_corpus = []
for i in range(n_robots):
    robot_tokens = []
    # Collect data from every part for this specific robot index
    for part_name in subtrees:
        ind = subtrees[part_name][i]
        if ind:
            # Add all tokens from all radii in this part
            for r_tokens in ind.values():
                robot_tokens.extend(r_tokens)
    grand_corpus.append(robot_tokens)

matrixes['all']['all'] = hasher.fit_transform(grand_corpus)

# 2b. The Radius Aggregation (e.g. r0 from Core + r0 from Front + ...)
for r in sorted(global_radii):
    layer_corpus = []
    for i in range(n_robots):
        robot_layer_tokens = []
        for part_name in subtrees:
            ind = subtrees[part_name][i]
            # If this part exists on this robot AND has this radius
            if ind and r in ind:
                robot_layer_tokens.extend(ind[r])
        layer_corpus.append(robot_layer_tokens)
        
    matrixes['all'][r] = hasher.fit_transform(layer_corpus)

# --- Verification ---
print("Matrices created:")
for part in matrixes:
    # Sort for display: 'all' first, then numbers
    keys = sorted(matrixes[part].keys(), key=lambda x: -1 if x == 'all' else x)
    print(f"  {part.upper()}: {keys}")
    print(f"    - Shape of 'all': {matrixes[part]['all'].shape}")
    if 0 in matrixes[part]:
        print(f"    - Shape of 'r0' : {matrixes[part][0].shape}")

In [None]:
# from sklearn.metrics.pairwise import cosine_similarity
# import numpy as np

# 1. Container for the results
# Structure: cosine_matrices['core'][0] -> (N, N) numpy array
cosine_matrices = {}

# 2. Iterate through Parts (Core, Front, All, etc.)
for part_name, radius_dict in matrixes.items():
    cosine_matrices[part_name] = {}
    
    # 3. Iterate through Radii (0, 1, 'all', etc.)
    for r, feature_matrix in radius_dict.items():
        
        # Calculate Cosine Similarity
        # Input: (n_samples, n_features) sparse matrix
        # Output: (n_samples, n_samples) dense matrix
        sim_matrix = cosine_similarity(feature_matrix)
        
        # Store it
        cosine_matrices[part_name][r] = sim_matrix


In [None]:

n_neighbors_list = [2, 5, 10, 15, 20, 50]

# 1. Prepare container and Initialize structure
# Structure: [n][part_name][r]
# We use a nested dictionary comprehension to pre-fill 'n' and 'part_name' keys
umap_embeddings = {
    n: {part: {} for part in matrixes.keys()} 
    for n in n_neighbors_list
}

# 2. Flatten the work into a single list
tasks = []
for part_name, radius_dict in matrixes.items():
    for r, feature_matrix in radius_dict.items():
        for n in n_neighbors_list:
            tasks.append((part_name, r, n, feature_matrix))

# 3. Iterate with rich.track
for part_name, r, n, feature_matrix in track(tasks, description="[green]Calculating UMAP variants..."):
    
    # Initialize UMAP with the specific 'n'
    reducer = umap.UMAP(
        n_components=2, 
        n_neighbors=n, 
        min_dist=0.01, 
        metric='cosine', 
        random_state=42
    )
    
    # Calculate Embedding
    embedding = reducer.fit_transform(feature_matrix)
    
    # Center the embedding around 0.0
    embedding = embedding - embedding.mean(axis=0)
    
    # Store it: n is now the primary key
    umap_embeddings[n][part_name][r] = embedding

print("Dimensions reduced successfully.")


In [None]:
# import base64
# import io
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# from matplotlib.colors import to_hex
# from PIL import Image
# from io import BytesIO
# from dataclasses import dataclass, field
# from bokeh.plotting import figure, show, gridplot
# from bokeh.models import ColumnDataSource, HoverTool, Div
# from bokeh.layouts import column


# # --- 3. Prepare the Data for the first 'n' ---

# # Select the first N from your list
# target_n = n_neighbors_list[0] 
# print(f"Preparing plot for N_Neighbors = {target_n}")

# # Get the data subset: {part_name: {r: embedding}}
# subset_data = umap_embeddings[target_n]

# # Sort keys to ensure the grid is ordered
# parts = list(subset_data.keys())         # Rows
# radii = list(subset_data[parts[0]].keys()) # Columns

# # Determine population size from the first valid embedding found
# # (Assuming all matrices have the same number of rows/robots)
# sample_emb = subset_data[parts[0]][radii[0]]
# n_population = sample_emb.shape[0]

# # Generate placeholder thumbnails

# # Build the 2D List of Subplots
# grid_subplots = []

# for part in parts:
#     row_list = []
#     for r in radii:
        
#         emb = subset_data[part].get(r, None)
        
#         # Create the configuration object for this specific plot
#         subplot_obj = UmapSubplot(
#             title=f"Part: {part} | R: {r}",
#             embeddings=emb,
#             idxs=list(range(n_population)), # 0 to N-1
#             hover_data=[f"Robot {i}" for i in range(n_population)]
#         )
#         row_list.append(subplot_obj)
    
#     grid_subplots.append(row_list)

# # --- 4. Call your plotting function ---

# # Note: Ensure 'plot_interactive_umap_grid' is defined in your scope 
# # (or paste the definition you provided before running this block)

# plot_interactive_umap_grid(
#     sub_plots=grid_subplots,
#     population_thumbnails=POPULATION_THUMBNAILS,
#     super_title=f"UMAP Structure Analysis (Neighbors={target_n})",
#     # max_full_width=1200,
#     # plot_height=300,
#     # plot_width=300
# )

In [None]:
# --- Helper for Sorting Radii ---
def radius_sort_key(val):
    """
    Sorts numbers first (0, 1, 10...), then strings ('all') at the end.
    """
    # Group 0: Numbers
    if isinstance(val, (int, float)):
        return (0, val)
    # Group 1: Strings (like 'all')
    return (1, str(val))


# --- Configuration ---
target_n = n_neighbors_list[0] 
subset_data = umap_embeddings[target_n]

# --- 1. Get Keys ---
# Parts: Keep exactly as they are (Insertion Order)
parts = list(subset_data.keys()) 

# Radii: Sort so numbers are first, 'all' is appended at the end
radii = sorted(list(subset_data[parts[0]].keys()), key=radius_sort_key)

# --- 2. Build the Grid (Rows=Radii, Cols=Parts) ---
grid_subplots = []

for r in radii:
    row_list = []
    
    for part in parts:
        # Get embedding safely
        emb = subset_data[part].get(r, None)
        
        # Handle cases where embedding might be missing
        if emb is None:
            emb = np.array([]) 

        # Create the plot object
        subplot_obj = EmbedSubplot(
            title=f"R: {r} | Part: {part}",
            embeddings=emb,
            idxs=[i for i in range(len(POPULATION_NODES))], 
            hover_data=[],
            title_fontsize="10pt"
        )
        row_list.append(subplot_obj)
    
    grid_subplots.append(row_list)

# --- 3. Plot ---
plot_interactive_umap_grid(
    sub_plots=grid_subplots,
    population_thumbnails=POPULATION_THUMBNAILS, 
    super_title=f"UMAP Grid (Neighbors={target_n}) | Rows=Radii, Cols=Parts",
    # match_global_scale=True
    max_full_width=800,
    global_axis=True
)

In [None]:
# # --- Configuration ---
# target_n = n_neighbors_list[0] 
# subset_data = umap_embeddings[target_n]

# # Get sorted keys
# parts = list(subset_data.keys())           # Will be COLUMNS
# radii = list(subset_data[parts[0]].keys()) # Will be ROWS

# # --- Build the Grid ---
# grid_subplots = []

# # Outer Loop: Iterate over Radii (Rows)
# for r in radii:
#     row_list = []
    
#     # Inner Loop: Iterate over Parts (Columns)
#     for part in parts:
        
#         # Retrieve the specific embedding
#         # Note: subset_data structure is [part][r], so we access it the same way, 
#         # just inside a different loop order.
#         emb = subset_data[part].get(r, None)
        
#         # Handle missing data
#         if emb is None:
#             emb = np.array([]) 

#         # Create the plot object
#         subplot_obj = UmapSubplot(
#             title=f"R: {r} | Part: {part}",
#             embeddings=emb,
#             idxs=list(range(n_population)), 
#             hover_data=[f"ID: {i}" for i in range(n_population)],
#             title_fontsize="10pt"
#         )
#         row_list.append(subplot_obj)
    
#     # Add the full row (all parts for this specific radius)
#     grid_subplots.append(row_list)

# # --- Plot ---
# plot_interactive_umap_grid(
#     sub_plots=grid_subplots,
#     population_thumbnails=POPULATION_THUMBNAILS, 
#     super_title=f"UMAP Grid (Neighbors={target_n}) | Rows=Radii, Cols=Parts",
#     max_full_width=100,
#     plot_height=250, 
#     plot_width=250
# )

In [None]:
# print cumulative data matrix???

In [None]:
# from sklearn.feature_extraction import FeatureHasher

# # 1. Setup the hasher
# hasher = FeatureHasher(
#     n_features=2**20,
#     input_type="string", # Expects lists of strings
# )

# # 2. The data container
# # Structure will be: matrixes['core'][0], matrixes['core'][1], etc.
# matrixes = {} 

# # 3. The Loop
# for part_name, population_list in subtrees.items():
#     matrixes[part_name] = {}
    
#     # Step A: Find all unique radii present in this specific body part across the whole population
#     # (e.g., maybe 'core' goes up to radius 12, but 'front' only goes to radius 5)
#     all_radii = set()
#     for ind in population_list:
#         if ind is not None:
#             all_radii.update(ind.keys())
            
#     # Step B: Create a matrix for each radius
#     for r in sorted(all_radii):
#         corpus = []
        
#         for ind in population_list:
#             # If the robot has this limb AND has data for this radius
#             if ind is not None and r in ind:
#                 corpus.append(ind[r]) # Add the list of tokens (e.g. ['r0__C', ...])
#             else:
#                 # Robot is missing limb or this radius depth -> Empty features
#                 corpus.append([])

#         # Step C: Transform and store
#         # Result is a (n_samples, 2**20) sparse matrix
#         matrixes[part_name][r] = hasher.transform(corpus)

# # --- Verification ---
# print("Matrices created:")
# for part in matrixes:
#     print(f"  {part}: Radii {list(matrixes[part].keys())}")
#     # Example: Check shape of Core Radius 0
#     if 0 in matrixes[part]:
#         print(f"    Shape of {part} r0: {matrixes[part][0].shape}")

In [None]:
# from rich.progress import track
# import umap

# # 1. Prepare container and Initialize structure
# # We pre-fill the keys so we don't hit KeyErrors inside the flat loop
# umap_embeddings = {k: {} for k in matrixes.keys()}

# # 2. Flatten the work into a single list
# # This creates a list of tuples: [('core', 0, matrix), ('core', 1, matrix), ...]
# tasks = []
# for part_name, radius_dict in matrixes.items():
#     for r, feature_matrix in radius_dict.items():
#         tasks.append((part_name, r, feature_matrix))

# # 3. Iterate with rich.track
# # We unwrap the tuple directly in the for-loop
# for part_name, r, feature_matrix in track(tasks, description="[green]Calculating UMAP..."):
    
#     # Initialize UMAP
#     reducer = umap.UMAP(
#         n_components=2, 
#         n_neighbors=15, 
#         min_dist=0.1, 
#         metric='cosine', 
#         random_state=42
#     )
    
#     # Calculate Embedding
#     embedding = reducer.fit_transform(feature_matrix)
    
#     # Store it (using the references we set up in step 1)
#     umap_embeddings[part_name][r] = embedding

# print("Dimensions reduced successfully.")

In [None]:
# # for every key in subtrees
# # for every radius in in the data list
# # create a featurehasher matrix

# hasher = FeatureHasher(
#     n_features=2**20,
#     input_type="string",
# )

# matrixes = {}

# # matrixes['core'] =

COSINE MATRIXES

UMAPS

In [None]:
# CALC TILL MAX RADIUS

In [None]:

def evaluate_diversity_cum_cos(population: Population) -> Population:
    # 1. Collect subtrees data
    for ind in population:
        if ind.requires_eval:
            ind.tags['ctk_string'] = decode_genotype_to_string(ind.genotype)
    
    subtrees_dicts = [
        ctk.collect_tree_hash_config_mode(
            ctk.from_string(ind.tags['ctk_string']),
            config=SIM_CONFIG,
        ) for ind in population
    ]

    # console.print(subtrees_dicts)

    n_pop = len(population)
    keys = list(range(SIM_CONFIG.max_tree_radius + 1))

    # console.print(keys)

    # Initialize the accumulator matrix (N x N)
    cumulative_sim_matrix = np.zeros((n_pop, n_pop))

    current_hasher = FeatureHasher(
        n_features=2**20,
        input_type="string",
    )

    # 2. Iterate keys and accumulate similarity matrices
    for key in keys:
        specific_corpus = [d.get(key, []) for d in subtrees_dicts]

        # console.print(specific_corpus)

        # Transform is called repeatedly on the same object (Efficient)
        count_matrix = current_hasher.fit_transform(specific_corpus)

        # console.print(count_matrix.toarray())

        # Calculate Similarity
        sim_matrix = cosine_similarity(count_matrix)

        # console.print(sim_matrix)

        # Accumulate
        cumulative_sim_matrix += sim_matrix

    # 3. Average the matrix across all keys (radii)
    final_sim_matrix = cumulative_sim_matrix / len(keys)
    np.fill_diagonal(final_sim_matrix, 0)

    # console.print(final_sim_matrix)

    row_sums = final_sim_matrix.sum(axis=1)

    # console.print('row sums', row_sums)

    mean_similarity_to_others = row_sums / (n_pop - 1)

    # console.print('dev', (n_pop - 1))

    # console.print('mean similarity to others', mean_similarity_to_others)

    diversity_scores = 1.0 - mean_similarity_to_others

    # console.print('diversity scores', diversity_scores)

    # 7. Assign to individuals
    for i, ind in enumerate(population):
        ind.fitness = float(diversity_scores[i])
        ind.requires_eval = False

        # nde = NeuralDevelopmentalEncoding(number_of_modules=NUM_OF_MODULES)
        # hpd = HighProbabilityDecoder(num_modules=NUM_OF_MODULES)
        # matrixes = nde.forward(np.array(ind.genotype))
        # ind_graph = hpd.probability_matrices_to_graph(
        #     matrixes[0], matrixes[1], matrixes[2],
        # )
        # ind.tags["ctk_string"] = ctk.to_string(ctk.from_graph(ind_graph))

    return population

---

## SHOW EVOLUTION