### Imports

In [None]:
import base64
from io import BytesIO
from typing import Any

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap
import umap.plot
from bokeh.io import output_notebook
from bokeh.layouts import gridplot
from bokeh.models import (
    BasicTicker,
    ColorBar,
    ColumnDataSource,
    HoverTool,
    LinearColorMapper,
)
from bokeh.plotting import figure, show
from matplotlib.colors import to_hex
from PIL import Image
from rich.console import Console
from rich.progress import track
from scipy.spatial.distance import pdist, squareform
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.metrics.pairwise import cosine_similarity

from ariel_experiments.characterize.canonical.core.toolkit import (
    CanonicalToolKit as ctk,
)
from ariel_experiments.gui_vis.view_mujoco import view
from ariel_experiments.utils.initialize import generate_random_individual

console = Console()
output_notebook()

### Functions

In [None]:
def plot_heatmap_row(
    matrices: list[np.ndarray],
    titles: list[str] = None,
    suptitle: str = None,
    figsize_per_plot: tuple[int, int] = (5, 6),
    cmap: str = "viridis",
):
    """
    Plots a horizontal row of heatmaps with local color scaling.

    Args:
        matrices: List of matrices to plot
        titles: Optional list of titles for each subplot
        suptitle: Optional overall figure title
        figsize_per_plot: (width, height) for each subplot
        cmap: Colormap to use
    """
    num_plots = len(matrices)

    fig, axes = plt.subplots(
        nrows=1,
        ncols=num_plots,
        figsize=(figsize_per_plot[0] * num_plots, figsize_per_plot[1]),
        squeeze=False,
    )
    axes = axes.flatten()

    if suptitle:
        fig.suptitle(suptitle, fontsize=16)

    for i, matrix in enumerate(matrices):
        ax = axes[i]

        # Local color scaling for maximum contrast
        local_vmin = matrix.min()
        local_vmax = matrix.max()

        sns.heatmap(
            matrix,
            ax=ax,
            cmap=cmap,
            vmin=local_vmin,
            vmax=local_vmax,
            cbar_ax=None,
        )

        if titles and i < len(titles):
            ax.set_title(titles[i])

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [None]:
def plot_comparison_heatmaps(
    all_matrix_data: dict[str, dict[int, Any]],  # Updated type hint
    max_show_radius: int,
):
    """
    Plots a row of heatmaps for each radius.
    Row = Radius
    Column = Metric
    """
    # 1. Iterate through radii (Rows of the visual)
    for r in range(max_show_radius + 1):
        row_matrices = []
        row_titles = []

        # 2. Iterate through metrics (Columns of the visual)
        # CHANGE: We use .items() because input is now a dict, not a list of tuples
        for name, matrix_dict in all_matrix_data.items():
            # DIRECT ACCESS: Get the specific matrix for this radius
            if r in matrix_dict:
                matrix = matrix_dict[r]
            else:
                # Fallback if radius is missing
                matrix = np.zeros((1, 1))

            row_matrices.append(matrix)
            row_titles.append(f"r:{r} {name}")

        # 3. Plot the specific row
        plot_heatmap_row(
            matrices=row_matrices,
            titles=row_titles,
            suptitle=f"Comparison at Radius {r}",
        )

In [None]:
def get_cumsum_dict(matrix_dict):
    """
    Calculates the cumulative sum of matrices keyed by integer radii.
    """
    # 1. Sort the keys to ensure we process 0, then 1, then 2, etc.
    sorted_radii = sorted(matrix_dict.keys())

    cum_dict = {}
    running_sum = None

    for r in sorted_radii:
        current_matrix = matrix_dict[r]

        if running_sum is None:
            # First iteration (e.g., radius 0)
            # Use .copy() to ensure we don't accidentally modify the input
            running_sum = current_matrix.copy()
        else:
            # Add the current matrix to the accumulated total
            running_sum = running_sum + current_matrix

        # Store the result in the new dictionary
        cum_dict[r] = running_sum

    return cum_dict

In [None]:
# def get_sorted_coords_from_matrix(matrix, *, max_first=True):
#     """
#     Returns (row, col) tuples from the upper triangle, sorted by value.
#     """
#     rows, cols = np.triu_indices_from(matrix, k=1)
#     values = matrix[rows, cols]
#     sort_idx = np.argsort(values)
#     if max_first:
#         sort_idx = sort_idx[::-1]
#     return list(zip(rows[sort_idx], cols[sort_idx]))

In [None]:
# def sorted_idx_dict(matrix_dict, *, max_first=True):
#     """Return {radius: sorted_coord_list} by applying get_sorted_coords_from_matrix to each matrix."""
#     return {
#         r: get_sorted_coords_from_matrix(mat, max_first=max_first)
#         for r, mat in matrix_dict.items()
#     }

In [None]:
def sorted_idx_dict(data_dict, *, max_first=True):
    """Return {key: sorted_coords} by applying get_sorted_coords to each item."""
    return {
        k: get_sorted_coords(data, max_first=max_first)
        for k, data in data_dict.items()
    }

def get_sorted_coords(data, *, max_first=True):
    """
    If data is 2D: Returns (row, col) tuples from upper triangle, sorted by value.
    If data is 1D: Returns a list of indices, sorted by value.
    """
    # 1. Handle 1D Array
    if data.ndim == 1:
        sort_idx = np.argsort(data)
        if max_first:
            sort_idx = sort_idx[::-1]
        return sort_idx.tolist() # Returns [5, 2, 9, ...]

    # 2. Handle 2D Matrix
    elif data.ndim == 2:
        rows, cols = np.triu_indices_from(data, k=1)
        values = data[rows, cols]
        sort_idx = np.argsort(values)
        if max_first:
            sort_idx = sort_idx[::-1]
        return list(zip(rows[sort_idx], cols[sort_idx])) # Returns [(0,1), (3,4), ...]
    
    else:
        raise ValueError("Data must be 1D array or 2D matrix.")

images and interactive

In [None]:
def embeddable_image(data, scale=1.0):
    """
    Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
    Returns PNG data-url with aspect ratio AND relative scale preserved.
    """
    arr = np.asarray(data)
    
    # Normalize types
    if np.issubdtype(arr.dtype, np.floating):
        if arr.max() <= 1.0:
            arr = (arr * 255).astype(np.uint8)
        else:
            arr = arr.astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)

    # Detect Mode
    if arr.ndim == 3:
        mode = 'RGBA' if arr.shape[2] == 4 else 'RGB'
    else:
        mode = 'L' # Grayscale

    # Create Image
    img = Image.fromarray(arr, mode=mode)
    
    # Resize by a constant factor to preserve relative size differences
    if scale != 1.0:
        new_size = (int(img.width * scale), int(img.height * scale))
        img = img.resize(new_size, Image.Resampling.BICUBIC)

    buffer = BytesIO()
    img.save(buffer, format='PNG', optimize=False, compress_level=1)
    return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

def robot_image(i, scale=1.0):
    """
    Generates the image for robot i using the global POPULATION and view function.
    """
    graph = POPULATION[i].to_graph()  
    # Using remove_background=True for transparent PNGs
    img_arr = np.array(view(graph, return_img=True, tilted=True, remove_background=True))
    
    # SCALE 1.0: High Quality for Matplotlib.
    return embeddable_image(img_arr, scale=1.0)

def get_population_images(population_size, scale=1.0):
    """
    Pre-generates all images for the population to avoid re-rendering.
    """
    return [robot_image(i, scale=scale) for i in track(range(population_size), description=f"Generating {population_size} images...")]

def decode_base64_image(data_url):
    """Helper to convert base64 string back to numpy array."""
    header, encoded = data_url.split(",", 1)
    data = base64.b64decode(encoded)
    return np.array(Image.open(BytesIO(data)))

def create_thumbnails(image_list, scale=0.5):
    """
    Takes a list of base64 images (HQ) and creates a new list of scaled-down 
    thumbnails (preserving relative aspect ratio) for use in web tooltips.
    """
    thumbnails = []
    for b64_str in image_list:
        # Decode
        header, encoded = b64_str.split(",", 1)
        data = base64.b64decode(encoded)
        img = Image.open(BytesIO(data))
        
        # Resize
        new_size = (int(img.width * scale), int(img.height * scale))
        img_small = img.resize(new_size, Image.Resampling.BICUBIC)
        
        # Re-encode
        buffer = BytesIO()
        img_small.save(buffer, format='PNG')
        thumb_str = 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
        thumbnails.append(thumb_str)
    return thumbnails


def matrix_to_heatmap_source(matrix, images, metric_name, radius):
    """Converts matrix to Bokeh DataSource."""
    N = matrix.shape[0]
    x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
    x_flat = x_indices.flatten()
    y_flat = N - 1 - y_indices.flatten() 
    values = matrix.flatten()
    imgs_i = [images[r] for r in y_indices.flatten()]
    imgs_j = [images[c] for c in x_indices.flatten()]
    ids_i = [str(r) for r in y_indices.flatten()]
    ids_j = [str(c) for c in x_indices.flatten()]
    data = {
        'x': x_flat, 'y': y_flat, 'value': values,
        'img_row': imgs_i, 'img_col': imgs_j,
        'id_row': ids_i, 'id_col': ids_j,
        'metric': [metric_name] * len(values),
        'radius': [radius] * len(values)
    }
    return ColumnDataSource(data)

def plot_interactive_heatmaps(all_matrix_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=None, palette="Reds256", thumbnail_scale=0.5):
    """
    Creates a Grid of Interactive Heatmaps using Bokeh.
    Args:
        thumbnail_scale: Factor to scale images down for the tooltip (default 0.5)
    """
    # Create thumbnails specifically for this plot (leaves original list untouched)
    thumb_images = create_thumbnails(population_images, scale=thumbnail_scale)

    num_cols = len(all_matrix_data)
    if plot_width is None: plot_width = 650 if num_cols == 1 else 200
    if plot_height is None: plot_height = 600 if num_cols == 1 else 200

    grid_layout = []
    for r in range(max_show_radius + 1):
        row_plots = []
        for name, matrix_dict in all_matrix_data.items():
            if r in matrix_dict: matrix = matrix_dict[r]
            else: matrix = np.zeros((1, 1))
            
            # Use thumbnails here
            source = matrix_to_heatmap_source(matrix, thumb_images, name, r)
            vmin, vmax = matrix.min(), matrix.max()
            mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
            p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
            p.axis.visible = False; p.grid.visible = False
            p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
            color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
            p.add_layout(color_bar, 'right')
            
            hover = p.select(dict(type=HoverTool))
            # No CSS max-width constraints. We rely on the thumbnail being physically smaller (0.5x)
            # but proportional.
            hover.tooltips = """
            <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
                <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
                <div style="display: flex; flex-direction: row; gap: 10px;">
                    <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="width: auto; height: auto;"></div>
                    <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="width: auto; height: auto;"></div>
                </div>
            </div>
            """
            row_plots.append(p)
        grid_layout.append(row_plots)
    show(gridplot(grid_layout))

def plot_interactive_umap_grid(umap_data: dict, population_images: list, max_show_radius: int, follow_idx_list: list[int] | None= None, plot_width=None, plot_height=200, thumbnail_scale=0.5):
    """
    Creates a Grid of Interactive UMAP Scatter plots.
    Args:
        thumbnail_scale: Factor to scale images down for the tooltip (default 0.5)
    """
    # Create thumbnails specifically for this plot
    thumb_images = create_thumbnails(population_images, scale=thumbnail_scale)

    num_cols = len(umap_data)
    if plot_width is None: plot_width = 700 if num_cols == 1 else 200
    n = len(population_images)
    
    sizes = [4] * n
    rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n))

    line_colors = [None] * n
    if follow_idx_list:
        follow_set = set(follow_idx_list)
        sizes = [8 if i in follow_set else 3 for i in range(n)]
        line_colors = ['black' if i in follow_set else None for i in range(n)]
        
        for i in range(n):
            if i not in follow_set:
                rgba_colors[i] = [0.0, 0.0, 0.0, 0.3]
        
        
    hex_colors = [to_hex(c, keep_alpha=True) for c in rgba_colors]
    
    
    grid_layout = []
    for r in range(max_show_radius + 1):
        row_plots = []
        for name, matrix_dict in umap_data.items():
            if r in matrix_dict:
                emb = matrix_dict[r]
                if emb.ndim != 2 or emb.shape[1] != 2:
                    p = figure(title=f"r:{r} {name} (No Data)", width=plot_width, height=plot_height); row_plots.append(p); continue
            else:
                p = figure(title=f"r:{r} {name} (Missing)", width=plot_width, height=plot_height); row_plots.append(p); continue

            # Use thumb_images here
            robots_df = pd.DataFrame({"x": emb[:, 0], "y": emb[:, 1], "digit": [str(i) for i in range(n)], "image": thumb_images, "color": hex_colors, "size": sizes, 'line_color': line_colors})
            
            if follow_idx_list:
                robots_df['sort_order'] = [1 if i in set(follow_idx_list) else 0 for i in range(n)]
                robots_df = robots_df.sort_values('sort_order', ascending=True)
            source = ColumnDataSource(robots_df)
            p = figure(title=f"r:{r} {name}", width=plot_width, height=plot_height, tools="pan,wheel_zoom,reset,save", toolbar_location="above")
            # p.scatter('x', 'y', source=source, color='color', line_alpha=1, line_color='white', line_width=2, size='size')
            p.scatter('x', 'y', source=source, color='color', line_alpha=1, line_color='line_color', line_width=1, size='size')

        
            hover = HoverTool(tooltips="""<div><img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/></div><div style="font-size:12px; font-weight: bold;"><span style='color:#224499'>ID: @digit</span></div>""")
            p.add_tools(hover)
            row_plots.append(p)
        grid_layout.append(row_plots)
    show(gridplot(grid_layout))

for overview plotting

In [None]:
def _stitch_images_horizontally(images, target_height=None, gap_px=20):
    """
    Stitches images horizontally with white background (handles transparency).
    If target_height is provided, pads all images vertically to match that height (alignment: top).
    """
    if not images or all(img is None for img in images):
        return None, 0, 0
    
    valid_images = [img for img in images if img is not None]
    if not valid_images: return None, 0, 0

    # 1. Normalize Types to uint8 RGB and composite transparent images onto white
    normalized_imgs = []
    for img in valid_images:
        # Handle float 0-1
        if np.issubdtype(img.dtype, np.floating):
            img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        
        # Handle Alpha channel - composite onto white background
        if len(img.shape) == 3 and img.shape[2] == 4:
            # Extract RGB and Alpha
            rgb = img[:, :, :3]
            alpha = img[:, :, 3:4] / 255.0  # Normalize alpha to 0-1
            
            # Create white background
            white_bg = np.full_like(rgb, 255, dtype=np.uint8)
            
            # Composite: result = foreground * alpha + background * (1 - alpha)
            img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
        elif len(img.shape) == 3 and img.shape[2] >= 3:
            img = img[:, :, :3]
            
        normalized_imgs.append(img)

    # 2. Determine Canvas Height
    current_max_h = max(img.shape[0] for img in normalized_imgs)
    final_h = target_height if target_height and target_height > current_max_h else current_max_h

    # 3. Create White Gap Column
    white_gap_col = np.full((final_h, gap_px, 3), 255, dtype=np.uint8)

    # 4. Stitching Loop
    stitched = None
    
    for i, img in enumerate(normalized_imgs):
        h, w = img.shape[:2]
        
        # Pad image to final_h (fill bottom with white)
        if h < final_h:
            pad = np.full((final_h - h, w, 3), 255, dtype=np.uint8)
            img = np.vstack((img, pad))
            
        if stitched is None:
            stitched = img
        else:
            stitched = np.hstack((stitched, white_gap_col, img))
            
    return stitched, stitched.shape[1], final_h

def view_grid_of_groups(rows_of_tuples, rows_of_titles=None, col_headers=None, main_title=None):
    """
    Plots a grid of groups where ALL images are scaled equally (no auto-zoom).
    """
    if not rows_of_tuples: return

    n_rows = len(rows_of_tuples)
    n_cols = len(rows_of_tuples[0])
    ROBOT_GAP_PX = 20
    
    # --- PASS 1: Calculate Global Max Dimensions and Process Images ---
    global_max_h = 0
    global_max_w = 0
    
    grid_data = [[None for _ in range(n_cols)] for _ in range(n_rows)]
    
    for r in range(n_rows):
        for c in range(n_cols):
            group_base64_strings = rows_of_tuples[r][c]
            images = [decode_base64_image(s) for s in group_base64_strings]
            
            # Find max height in this specific group to update global max
            for img in images:
                if img is not None:
                    if img.shape[0] > global_max_h: global_max_h = img.shape[0]
            
            grid_data[r][c] = images

    # --- PASS 2: Stitch and Measure Widths ---
    processed_images = []
    
    for r in range(n_rows):
        row_imgs = []
        for c in range(n_cols):
            images = grid_data[r][c]
            # Stitch using GLOBAL height (pads bottom with white)
            stitched, w, h = _stitch_images_horizontally(images, target_height=global_max_h, gap_px=ROBOT_GAP_PX)
            
            if w > global_max_w: global_max_w = w
            row_imgs.append(stitched)
        processed_images.append(row_imgs)

    # --- PASS 3: Plot with Fixed Limits ---
    fig, axes = plt.subplots(n_rows, n_cols, 
                             figsize=(4 * n_cols, 2.5 * n_rows), 
                             squeeze=False,
                             facecolor='white')
    
    if main_title:
        fig.suptitle(main_title, fontsize=16, weight="bold", y=0.98, color='black')

    for r in range(n_rows):
        for c in range(n_cols):
            ax = axes[r, c]
            ax.set_facecolor('white')
            
            img_data = processed_images[r][c]
            
            if img_data is not None:
                ax.imshow(img_data)
            
            # Force Equal Scaling
            ax.set_xlim(0, global_max_w)
            ax.set_ylim(global_max_h, 0)
            ax.set_aspect('equal')
            ax.axis('off')

            # Titles and Headers
            if rows_of_titles and r < len(rows_of_titles):
                ax.set_title(rows_of_titles[r][c], fontsize=10, color='black', pad=0)

            if r == 0 and col_headers and c < len(col_headers):
                ax.text(0.5, 1, col_headers[c], transform=ax.transAxes, 
                        ha="center", va="bottom", fontsize=12, weight="bold", color="#224499")

    plt.subplots_adjust(wspace=0.1, hspace=0.25, top=0.95, bottom=0.02)
    plt.show()

# def plot_rows_for_radii(cumulative_data, sorted_data, population_images, max_radius: int, pair_rank: int = 0, labels: list[str] = None, main_title: str = None):
#     if labels is None: labels = list(cumulative_data.keys())
#     all_rows_robots = []
#     all_rows_titles = []
    
#     for r in range(max_radius + 1):
#         robots_row = []
#         titles_row = []
#         for name in labels:
#             coords_list = sorted_data[name].get(r, [])
#             matrix = cumulative_data[name].get(r)
#             if coords_list and pair_rank < len(coords_list):
#                 i, j = coords_list[pair_rank]
#             else:
#                 i, j = (0, 0)
#             idx_i, idx_j = int(i), int(j)
#             val = matrix[idx_i, idx_j] if matrix is not None else 0.0
            
#             robots_row.append([population_images[idx_i], population_images[idx_j]])
#             titles_row.append(f"{name}\nr:{r} <{idx_i},{idx_j}> val={val:.3f}")
            
#         all_rows_robots.append(robots_row)
#         all_rows_titles.append(titles_row)

#     view_grid_of_groups(all_rows_robots, all_rows_titles, col_headers=None, main_title=main_title)

In [None]:
def plot_rows_for_radii(cumulative_data, sorted_data, population_images, max_radius: int, 
                        pair_rank: int = 0, plot_up_to: bool = False, 
                        labels: list[str] = None, main_title: str = None):
    
    if labels is None: labels = list(cumulative_data.keys())
    
    all_rows_robots = []
    all_rows_titles = []
    
    for r in range(max_radius + 1):
        robots_row = []
        titles_row = []
        
        for name in labels:
            # 1. Get the list of coordinates/indices for this metric & radius
            coords_list = sorted_data[name].get(r, [])
            n_items = len(coords_list)
            
            # 2. Determine which items to plot
            if n_items == 0:
                items_to_process = []
            else:
                # --- LOGIC UPDATE FOR NEGATIVE RANKS ---
                if plot_up_to:
                    if pair_rank >= 0:
                        # Positive: Take from start up to rank (Top N)
                        # e.g. rank=2 -> indices [0, 1, 2]
                        end_idx = min(pair_rank + 1, n_items)
                        items_to_process = coords_list[:end_idx]
                    else:
                        # Negative: Take from rank to end (Bottom N)
                        # e.g. rank=-2 -> indices [-2, -1]
                        # Ensure we don't go out of bounds (e.g. -99 vs len 10)
                        start_idx = max(-n_items, pair_rank)
                        items_to_process = coords_list[start_idx:]
                else:
                    # Single Item Mode
                    # Python handles negative indexing (list[-1]), 
                    # but we must check bounds to prevent IndexError if rank is too large/small
                    if -n_items <= pair_rank < n_items:
                        items_to_process = [coords_list[pair_rank]]
                    else:
                        items_to_process = []

            # 3. Process the selected items (Stitch them all into one group)
            group_images = []
            title_parts = []
            
            matrix = cumulative_data[name].get(r)
            
            for k, item in enumerate(items_to_process):
                # LOGIC BRANCH: Is it a Tuple (Pair) or Scalar (Single)?
                
                # Case A: It's a Tuple/List/Array (e.g., (10, 42))
                if isinstance(item, (list, tuple, np.ndarray)) and len(item) == 2:
                    idx_i, idx_j = int(item[0]), int(item[1])
                    
                    # Fetch value from matrix if available
                    val = 0.0
                    if matrix is not None:
                        try:
                            val = matrix[idx_i, idx_j]
                        except IndexError:
                            pass

                    group_images.extend([population_images[idx_i], population_images[idx_j]])
                    
                    # Add a separator pipe '|' if this isn't the first item
                    sep = " | " if k > 0 else ""
                    title_parts.append(f"{sep}<{idx_i},{idx_j}>={val:.2f}")

                # Case B: It's a Scalar/Integer (e.g., 10)
                else:
                    # Handle if it came as a single-element array or plain int
                    idx = int(item) if np.isscalar(item) else int(item[0])
                    
                    # Fetch value (Fitness) from array if available
                    val = 0.0
                    if matrix is not None:
                         try:
                            # If matrix is 1D array
                            if matrix.ndim == 1:
                                val = matrix[idx]
                            # If matrix is 2D but we have 1 index, maybe diagonal?
                            elif matrix.ndim == 2:
                                val = matrix[idx, idx] 
                         except IndexError:
                            pass

                    group_images.append(population_images[idx])
                    
                    sep = " | " if k > 0 else ""
                    title_parts.append(f"{sep}#{idx}={val:.2f}")

            # 4. Finalize Row
            robots_row.append(group_images)
            
            # Construct title (limit length if plot_up_to included many items)
            full_title_str = "".join(title_parts)
            if len(full_title_str) > 50: 
                full_title_str = full_title_str[:47] + "..."
            
            titles_row.append(f"{name} (r:{r})\n{full_title_str}")
            
        all_rows_robots.append(robots_row)
        all_rows_titles.append(titles_row)

    view_grid_of_groups(all_rows_robots, all_rows_titles, col_headers=None, main_title=main_title)

In [None]:
def plot_histograms(data_dict, max_radius=2, bins=100, size_per_plot=4):
    """
    Plots a grid of histograms with dynamic figure sizing to keep plots square.
    
    Parameters:
    - data_dict: Dictionary containing the data (matrices or arrays).
    - max_radius: The maximum radius index to plot (rows).
    - bins: Number of histogram bins.
    - size_per_plot: Width/Height in inches for each individual subplot.
    """
    
    # 1. Setup the grid dimensions
    metric_keys = list(data_dict.keys())
    n_cols = len(metric_keys)
    n_rows = max_radius + 1 
    
    # 2. Dynamic Figure Size Calculation
    # We multiply the number of cols/rows by the desired size per plot
    dynamic_figsize = (n_cols * size_per_plot, n_rows * size_per_plot)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=dynamic_figsize, constrained_layout=True)
    
    # Ensure axes is always 2D array even if 1 row or 1 col
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1: 
        axes = axes[np.newaxis, :]
    elif n_cols == 1: 
        axes = axes[:, np.newaxis]

    # 3. Iterate through Metrics (Columns)
    for col_idx, metric_name in enumerate(metric_keys):
        
        # Get the sub-dictionary for this metric
        radius_dict = data_dict[metric_name]
        
        # 4. Iterate through Radii (Rows)
        for r in range(n_rows):
            ax = axes[r, col_idx]
            
            # Safety check: does this radius exist?
            if r not in radius_dict:
                ax.axis('off')
                continue
                
            raw_data = radius_dict[r]
            
            # --- DATA PREPROCESSING ---
            # If 2D Matrix: Flatten Upper Triangle only (k=1 excludes diagonal)
            if raw_data.ndim == 2:
                vals = raw_data[np.triu_indices_from(raw_data, k=1)]
            # If 1D Array: Use as is
            else:
                vals = raw_data.flatten()
            
            # --- PLOTTING ---
            sns.histplot(vals, bins=bins, kde=True, ax=ax, 
                         color=f"C{col_idx}", edgecolor='w', linewidth=0.5)
            
            # --- STATS ANNOTATION ---
            if len(vals) > 0:
                stats_text = (f"$\mu$: {np.mean(vals):.2f}\n"
                              f"Min: {np.min(vals):.2f}\n"
                              f"Max: {np.max(vals):.2f}")
                
                ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, 
                        fontsize=10, verticalalignment='top', horizontalalignment='right',
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

            # --- LABELS & TITLES ---
            # Title only on top row
            if r == 0:
                ax.set_title(metric_name, fontsize=14, fontweight='bold', pad=15)
            
            # Y Label only on first column
            if col_idx == 0:
                ax.set_ylabel(f"Radius {r}\nCount", fontsize=12, fontweight='bold')
            else:
                ax.set_ylabel("")
            
            # X Label only on bottom row
            if r == n_rows - 1:
                ax.set_xlabel("Value", fontsize=11)
            else:
                ax.set_xlabel("")

    fig.suptitle(f"Distribution of Values per Radius (0 to {max_radius})", fontsize=18, y=1.02, fontweight='bold')
    plt.show()

---

### GLOBAL ANALYSIS SETTINGS

In [None]:
POPULATION_SIZE = 500
NUM_OF_MODULES = 20

MAX_RADIUS = None

CONFIG = ctk.SimilarityConfig(
    max_tree_radius=MAX_RADIUS, radius_strategy=ctk.RadiusStrategy.NODE_LOCAL
)

In [None]:
POPULATION = [
    ctk.from_graph(generate_random_individual(NUM_OF_MODULES))
    for _ in range(POPULATION_SIZE)
]

SUBTREES = [
    ctk.collect_tree_hash_config_mode(individual, config=CONFIG)
    for individual in POPULATION
]

COUNT_MATRIX_DICT = ctk.get_count_matrix(SUBTREES, CONFIG)

most time consuming step

In [None]:
POPULATION_IMGS = get_population_images(POPULATION_SIZE, scale=1)

#### Matrix Helpers

In [None]:
def apply_tfidf_transformer(count_matrix):
    transformer = TfidfTransformer()
    return transformer.fit_transform(count_matrix)

def apply_umap_n2(count_matrix):
    return umap.UMAP(init='random', random_state=42, transform_seed=42,n_jobs=1, metric="cosine", n_neighbors=2).fit_transform(
        count_matrix
    )
  
def apply_umap_n10(count_matrix):
    return umap.UMAP(init='random', random_state=42, transform_seed=42,n_jobs=1, metric="cosine", n_neighbors=10).fit_transform(
        count_matrix
    )
    
def apply_umap_n20(count_matrix):
    return umap.UMAP(init='random', random_state=42, transform_seed=42,n_jobs=1, metric="cosine", n_neighbors=20).fit_transform(
        count_matrix
    )

def apply_emb_to_dist(umap_emb_matrix):
    condensed_distances = pdist(umap_emb_matrix, metric="euclidean")
    return squareform(condensed_distances)

def apply_collapse_to_fitness(matrix):
    return matrix.sum(axis=1) - matrix.diagonal()

In [None]:
# basis ---
count_cos_matrix_dict = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, cosine_similarity)

tfidf_matrix_dict = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_tfidf_transformer)
tfidf_cos_matrix_dict = ctk.matrix_dict_applier(tfidf_matrix_dict, cosine_similarity)


# umap ---
umap_dict_n2 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n2) 
umapdist_matrix_dict_n2 = ctk.matrix_dict_applier(umap_dict_n2, apply_emb_to_dist)

# tfidf
tfidf_umap_dict_n2 = ctk.matrix_dict_applier(tfidf_matrix_dict, apply_umap_n2) 
tfidf_umapdist_matrix_dict_n2 = ctk.matrix_dict_applier(tfidf_umap_dict_n2, apply_emb_to_dist)

umap_dict_n10 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n10) 
umapdist_matrix_dict_n10 = ctk.matrix_dict_applier(umap_dict_n10, apply_emb_to_dist)

# tfidf
tfidf_umap_dict_n10 = ctk.matrix_dict_applier(tfidf_matrix_dict, apply_umap_n10) 
tfidf_umapdist_matrix_dict_n10 = ctk.matrix_dict_applier(tfidf_umap_dict_n10, apply_emb_to_dist)

umap_dict_n20 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n20) 
umapdist_matrix_dict_n20 = ctk.matrix_dict_applier(umap_dict_n20, apply_emb_to_dist)


In [None]:
# for dimension reduction visualization
UMAP_EMBEDDINGS = {
    "umap_emb_n2": umap_dict_n2,
    "tfidf_umap_emb_n2": tfidf_umap_dict_n2,
    
    "umap_emb_n10": umap_dict_n10,
    "tfidf_umap_emb_n10": tfidf_umap_dict_n10,
    
    "umap_emb_n20": umap_dict_n20,
}

#### Matrix-data

In [None]:
# metrics applied per unique radius
ALL_MATRIX_DATA = {
    "umap_dist_n2": umapdist_matrix_dict_n2,
    "tfidf_umap_dist_n2" : tfidf_umapdist_matrix_dict_n2,
    
    "umap_dist_n10": umapdist_matrix_dict_n10,
    "tfidf_umap_dist_n10" : tfidf_umapdist_matrix_dict_n10,
    
    "umap_dist_n20": umapdist_matrix_dict_n20,
    
    "count_cos": count_cos_matrix_dict,
    "tfidf_cos": tfidf_cos_matrix_dict,
}

# each radius contains the cumsum of the previous radiusses
CUMULATIVE_MATRIX_DATA = {
    "umap_dist_n2": get_cumsum_dict(umapdist_matrix_dict_n2),
    "tfidf_umap_dist_n2" : get_cumsum_dict(tfidf_umapdist_matrix_dict_n2),
    
    "umap_dist_n10": get_cumsum_dict(umapdist_matrix_dict_n10),
    "tfidf_umap_dist_n10" : get_cumsum_dict(tfidf_umapdist_matrix_dict_n10),
    
    "umap_dist_n20": get_cumsum_dict(umapdist_matrix_dict_n20),
    
    "count_cos": get_cumsum_dict(count_cos_matrix_dict),
    "tfidf_cos": get_cumsum_dict(tfidf_cos_matrix_dict),
}

# sort based on the cumsums
SORTED_IDX_DATA = {
    "umap_dist_n2": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n2"], max_first=False),
    "tfidf_umap_dist_n2": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["tfidf_umap_dist_n2"], max_first=False),
    
    "umap_dist_n10": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n10"], max_first=False),
    "tfidf_umap_dist_n10": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["tfidf_umap_dist_n10"], max_first=False),
    
    "umap_dist_n20": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n20"], max_first=False),
    
    "count_cos": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["count_cos"], max_first=True),
    "tfidf_cos": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["tfidf_cos"], max_first=True),
}
 
MAX_SHOW_RADIUS = 7

#### Array Data [fitness]

TODO: entropy?
TODO: KDE on the umap?

In [None]:
# TODO

# def calculate_inverse_density_fitness(umap_embeddings, bandwidth=0.1):
#     """
#     Calculates fitness based on the inverse of local density.
#     Higher Fitness = Sparser Region = More Unique.
#     """
#     # 1. Initialize KDE Model
#     # bandwidth is crucial: smaller = more localized density, larger = smoother.
#     kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
    
#     # 2. Fit the model to the 2D/3D embeddings
#     kde.fit(umap_embeddings)
    
#     # 3. Score samples to get the log-probability density
#     # log_density is usually preferred for numerical stability
#     log_density = kde.score_samples(umap_embeddings)
    
#     # 4. Convert Log-Density to Density
#     density = np.exp(log_density)
    
#     # 5. Calculate Inverse Density Fitness
#     # Add a tiny epsilon to prevent division by zero
#     fitness_inverse_density = 1.0 / (density + 1e-6)
    
#     return fitness_inverse_density

In [None]:
# fitness (1d collapsed data) ---
FITNESS_DATA = {
    'umapdist_fitness_n2': ctk.matrix_dict_applier(umapdist_matrix_dict_n2, apply_collapse_to_fitness),
    'tfidf_umapdist_fitness_n2': ctk.matrix_dict_applier(tfidf_umapdist_matrix_dict_n2, apply_collapse_to_fitness),
    
    'umapdist_fitness_n10': ctk.matrix_dict_applier(umapdist_matrix_dict_n10, apply_collapse_to_fitness),
    'tfidf_umapdist_fitness_n10': ctk.matrix_dict_applier(tfidf_umapdist_matrix_dict_n10, apply_collapse_to_fitness),
    
    'count_cos_fitness': ctk.matrix_dict_applier(count_cos_matrix_dict, apply_collapse_to_fitness),
    'tfidf_cos_fitness': ctk.matrix_dict_applier(tfidf_cos_matrix_dict, apply_collapse_to_fitness),
}

CUMULATIVE_FITNESS_DATA = {
    'umapdist_fitness_n2': get_cumsum_dict(FITNESS_DATA['umapdist_fitness_n2']),    
    'tfidf_umapdist_fitness_n2': get_cumsum_dict(FITNESS_DATA['tfidf_umapdist_fitness_n2']),
    
    'umapdist_fitness_n10': get_cumsum_dict(FITNESS_DATA['umapdist_fitness_n10']),
    'tfidf_umapdist_fitness_n10': get_cumsum_dict(FITNESS_DATA['tfidf_umapdist_fitness_n10']),
    
    'count_cos_fitness': get_cumsum_dict(FITNESS_DATA['count_cos_fitness']),
    'tfidf_cos_fitness': get_cumsum_dict(FITNESS_DATA['tfidf_cos_fitness']),
}

SORTED_FITNESS_IDX = {
    'umapdist_fitness_n2': sorted_idx_dict(FITNESS_DATA['umapdist_fitness_n2'], max_first=False),
    'umapdist_fitness_n10': sorted_idx_dict(FITNESS_DATA['umapdist_fitness_n10'], max_first=False),
    
    'count_cos_fitness': sorted_idx_dict(FITNESS_DATA['count_cos_fitness'], max_first=True),
    'tfidf_cos_fitness': sorted_idx_dict(FITNESS_DATA['tfidf_cos_fitness'], max_first=True),
    
    'tfidf_umapdist_fitness_n2': sorted_idx_dict(FITNESS_DATA['tfidf_umapdist_fitness_n2'], max_first=False),
    'tfidf_umapdist_fitness_n10': sorted_idx_dict(FITNESS_DATA['tfidf_umapdist_fitness_n10'], max_first=False)
}

---

### HISTOGRAMS

#### Matrix per radius

In [None]:
plot_comparison_heatmaps(ALL_MATRIX_DATA, MAX_SHOW_RADIUS)

#### Cumulative matrix per radius

In [None]:
plot_comparison_heatmaps(CUMULATIVE_MATRIX_DATA, MAX_SHOW_RADIUS)

Interactive Heatmaps

In [None]:
if POPULATION_SIZE < 20:
    plot_interactive_heatmaps(ALL_MATRIX_DATA, POPULATION_IMGS, MAX_SHOW_RADIUS)
    plot_interactive_heatmaps(CUMULATIVE_MATRIX_DATA, POPULATION_IMGS, MAX_SHOW_RADIUS)

### HISTOGRAMS

In [None]:
plot_histograms(ALL_MATRIX_DATA, MAX_SHOW_RADIUS)

In [None]:
plot_histograms(CUMULATIVE_MATRIX_DATA, MAX_SHOW_RADIUS)

### 2D EMBEDDINGS

In [None]:
# idxs = np.linspace(0, POPULATION_SIZE - 1, 10, dtype=int).tolist() # if you want to follow?
idxs=None
console.print(f'following {idxs}')
plot_interactive_umap_grid(UMAP_EMBEDDINGS, POPULATION_IMGS, MAX_SHOW_RADIUS)

### DYNAMIC ROBOT POSTER VIEWER

#### Most Similar Per Radius per metric

In [None]:
similar_rank = 0

In [None]:
plot_rows_for_radii(
    cumulative_data=CUMULATIVE_MATRIX_DATA,
    sorted_data=SORTED_IDX_DATA,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=similar_rank,
    main_title=f"Comparison Rank {np.abs(similar_rank)} (Most Similar)"  
)

similar_rank += 1

#### Least Similar Per Radius per metric

In [None]:
dif_rank = -1

In [None]:
plot_rows_for_radii(
    cumulative_data=CUMULATIVE_MATRIX_DATA,
    sorted_data=SORTED_IDX_DATA,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=dif_rank,
    main_title=f"Comparison Rank {np.abs(dif_rank)} (Least Similar)",
    # plot_up_to=True 
)

dif_rank -= 1

---

### FITNESS

In [None]:
plot_histograms(FITNESS_DATA, MAX_SHOW_RADIUS, bins=POPULATION_SIZE//3)

In [None]:
plot_histograms(CUMULATIVE_FITNESS_DATA, MAX_SHOW_RADIUS, bins=POPULATION_SIZE//3)

TFIDF R3 CLUSTERING VERY COOL

In [None]:
idxs = SORTED_FITNESS_IDX['count_cos_fitness'][4][:3]
console.print(f'following {idxs}')
plot_interactive_umap_grid(UMAP_EMBEDDINGS, POPULATION_IMGS, MAX_SHOW_RADIUS, follow_idx_list=idxs)

In [None]:
idxs = SORTED_FITNESS_IDX['count_cos_fitness'][4][-5:]
console.print(f'following {idxs}')
plot_interactive_umap_grid(UMAP_EMBEDDINGS, POPULATION_IMGS, MAX_SHOW_RADIUS, follow_idx_list=idxs)

In [None]:
fit_rank = 0

In [None]:
plot_rows_for_radii(
    cumulative_data=CUMULATIVE_FITNESS_DATA,
    sorted_data=SORTED_FITNESS_IDX,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=fit_rank,
    main_title=f"Top {np.abs(fit_rank) + 1} Diverse",
    plot_up_to=True 
)

fit_rank += 1

In [None]:
unfit_rank = 0

In [None]:
plot_rows_for_radii(
    cumulative_data=CUMULATIVE_FITNESS_DATA,
    sorted_data=SORTED_FITNESS_IDX,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=unfit_rank,
    main_title=f"Least {np.abs(unfit_rank) + 1} Diverse",
    plot_up_to=True 
)

unfit_rank -= 1

---

### arbitrary tests?

In [None]:
idxs = [11, 32]
console.print(f'following {idxs}')
plot_interactive_umap_grid(UMAP_EMBEDDINGS, POPULATION_IMGS, MAX_SHOW_RADIUS, follow_idx_list=idxs)

In [None]:
console.print(COUNT_MATRIX_DICT[0][11])
console.print(COUNT_MATRIX_DICT[0][32])
console.print(COUNT_MATRIX_DICT[0][84])
console.print(COUNT_MATRIX_DICT[0][15])


console.print(SUBTREES[11])
console.print(SUBTREES[32])