In [1]:
import base64
from io import BytesIO
from typing import Any

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap
import umap.plot
from bokeh.io import output_notebook
from bokeh.layouts import gridplot
from bokeh.models import (
    BasicTicker,
    ColorBar,
    ColumnDataSource,
    HoverTool,
    LinearColorMapper,
)
from bokeh.plotting import figure, show
from matplotlib.colors import to_hex
from PIL import Image
from rich.console import Console
from rich.progress import track
from scipy.spatial.distance import pdist, squareform
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.metrics.pairwise import cosine_similarity

from ariel_experiments.characterize.canonical.core.toolkit import (
    CanonicalToolKit as ctk,
)
from ariel_experiments.gui_vis.view_mujoco import view
from ariel_experiments.utils.initialize import generate_random_individual

console = Console()
output_notebook()



In [2]:
def plot_heatmap_row(
    matrices: list[np.ndarray],
    titles: list[str] = None,
    suptitle: str = None,
    figsize_per_plot: tuple[int, int] = (5, 6),
    cmap: str = "viridis",
):
    """
    Plots a horizontal row of heatmaps with local color scaling.

    Args:
        matrices: List of matrices to plot
        titles: Optional list of titles for each subplot
        suptitle: Optional overall figure title
        figsize_per_plot: (width, height) for each subplot
        cmap: Colormap to use
    """
    num_plots = len(matrices)

    fig, axes = plt.subplots(
        nrows=1,
        ncols=num_plots,
        figsize=(figsize_per_plot[0] * num_plots, figsize_per_plot[1]),
        squeeze=False,
    )
    axes = axes.flatten()

    if suptitle:
        fig.suptitle(suptitle, fontsize=16)

    for i, matrix in enumerate(matrices):
        ax = axes[i]

        # Local color scaling for maximum contrast
        local_vmin = matrix.min()
        local_vmax = matrix.max()

        sns.heatmap(
            matrix,
            ax=ax,
            cmap=cmap,
            vmin=local_vmin,
            vmax=local_vmax,
            cbar_ax=None,
        )

        if titles and i < len(titles):
            ax.set_title(titles[i])

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [3]:
def plot_comparison_heatmaps(
    all_matrix_data: dict[str, dict[int, Any]],  # Updated type hint
    max_show_radius: int,
):
    """
    Plots a row of heatmaps for each radius.
    Row = Radius
    Column = Metric
    """
    # 1. Iterate through radii (Rows of the visual)
    for r in range(max_show_radius + 1):
        row_matrices = []
        row_titles = []

        # 2. Iterate through metrics (Columns of the visual)
        # CHANGE: We use .items() because input is now a dict, not a list of tuples
        for name, matrix_dict in all_matrix_data.items():
            # DIRECT ACCESS: Get the specific matrix for this radius
            if r in matrix_dict:
                matrix = matrix_dict[r]
            else:
                # Fallback if radius is missing
                matrix = np.zeros((1, 1))

            row_matrices.append(matrix)
            row_titles.append(f"r:{r} {name}")

        # 3. Plot the specific row
        plot_heatmap_row(
            matrices=row_matrices,
            titles=row_titles,
            suptitle=f"Comparison at Radius {r}",
        )

In [4]:
# def view_horizontal_groups(robot_tuples, titles=None):
#     """
#     Plots horizontal groups without rescaling images of different sizes.
#     Uses pixel-perfect width_ratios.
#     """

#     # --- 1. Pre-fetch ALL images to get dimensions ---
#     # We flatten the structure but keep track of where groups split
#     processed_items = []  # Will store dicts: {'type': 'img', 'data': img} or {'type': 'gap'}

#     # Config for spacing (in pixels, approximately)
#     GROUP_GAP_PX = 100  # The big gap between tuples
#     ROBOT_GAP_PX = 20  # The small gap between robots in a tuple
#     DPI = 100  # Screen dots per inch

#     max_height = 0

#     group_start_indices = []  # To help us place titles later
#     current_index = 0

#     for i, group in enumerate(robot_tuples):
#         # Record where this group starts in the flat list
#         group_start_indices.append(current_index)

#         for j, robot in enumerate(group):
#             # Generate the image
#             img = np.array(
#                 view(
#                     robot, return_img=True, remove_background=True, tilted=True
#                 )
#             )
#             h, w = img.shape[:2]

#             # Update global max height (defines the strip height)
#             if h > max_height:
#                 max_height = h

#             processed_items.append({
#                 "type": "img",
#                 "data": img,
#                 "width": w,
#                 "height": h,
#             })
#             current_index += 1

#             # Add a small gap after every robot, EXCEPT the last one in the group
#             if j < len(group) - 1:
#                 processed_items.append({"type": "gap", "width": ROBOT_GAP_PX})
#                 current_index += 1

#         # Add a large gap after every group, EXCEPT the last group
#         if i < len(robot_tuples) - 1:
#             processed_items.append({"type": "gap", "width": GROUP_GAP_PX})
#             current_index += 1

#     # --- 2. Calculate Figure Dimensions ---
#     # Total width is sum of all image widths + sum of all gap widths
#     total_width_px = sum(item["width"] for item in processed_items)

#     # Calculate Figure size in Inches (Pixels / DPI)
#     fig_width_in = total_width_px / DPI
#     fig_height_in = max_height / DPI

#     # Add a little buffer for titles at the top (e.g., 0.5 inches)
#     title_buffer_in = 0.5
#     fig_height_total = fig_height_in + title_buffer_in

#     # --- 3. Create Figure with Exact Aspect Ratio ---
#     fig = plt.figure(figsize=(fig_width_in, fig_height_total), dpi=DPI)

#     # List of widths to tell GridSpec exactly how much space each col gets
#     widths = [item["width"] for item in processed_items]

#     # One big row, N columns (images + gaps)
#     gs = gridspec.GridSpec(
#         1, len(processed_items), figure=fig, width_ratios=widths
#     )

#     # Remove default spacing, we are handling it manually with 'gap' columns
#     plt.subplots_adjust(
#         left=0,
#         right=1,
#         bottom=0,
#         top=fig_height_in / fig_height_total,
#         wspace=0,
#         hspace=0,
#     )

#     # --- 4. Plotting ---
#     axes_map = {}  # Map index to ax object to help with titles

#     for idx, item in enumerate(processed_items):
#         if item["type"] == "gap":
#             # Skip this column, let it be empty whitespace
#             continue

#         # Create subplot
#         ax = fig.add_subplot(gs[0, idx])

#         # Display Image
#         # anchor='S' aligns image to Bottom (South) if it's shorter than max_height
#         # 'C' would center it. 'N' would align top.
#         ax.imshow(
#             item["data"], aspect="equal", interpolation="none", origin="upper"
#         )

#         # Ensure the axes limits match the max_height so alignment works
#         # This keeps the "ceiling" consistent even for short images
#         ax.set_ylim(max_height, 0)
#         ax.set_xlim(0, item["width"])

#         ax.axis("off")
#         axes_map[idx] = ax

#     # --- 5. Titles ---
#     if titles:
#         # We need to find the center of each group
#         # We iterate through the original group structure to find start/end indices
#         flat_ptr = 0

#         for i, group in enumerate(robot_tuples):
#             # Find the first ax in this group
#             start_ax = axes_map[flat_ptr]

#             # Advance pointer to find the last ax in this group
#             # Structure in flattened list: [Img, Gap, Img, Gap, Img] ... [Big Gap] ...
#             # Length of group items = (len(group) * 2) - 1
#             items_in_group = (len(group) * 2) - 1
#             end_ptr = flat_ptr + items_in_group - 1

#             # If the group has only 1 robot, start and end are same
#             if end_ptr not in axes_map:
#                 # This happens if end_ptr points to a gap (logic check),
#                 # but based on math above, end_ptr should always hit an image.
#                 end_ax = start_ax
#             else:
#                 end_ax = axes_map[end_ptr]

#             # Calculate positions in Figure Coordinates (0 to 1)
#             bbox_start = start_ax.get_position()
#             bbox_end = end_ax.get_position()

#             center_x = (bbox_start.x0 + bbox_end.x1) / 2

#             # Place Title
#             fig.text(
#                 center_x,
#                 1.0 - (0.2 / fig_height_total),
#                 titles[i],
#                 ha="center",
#                 va="top",
#                 fontsize=12,
#                 weight="bold",
#             )

#             # Advance pointer past this group AND the group gap
#             flat_ptr += items_in_group + 1

#     plt.show()

In [5]:
def get_cum_matrix(matrix_dict):
    """
    Calculates the cumulative sum of matrices keyed by integer radii.
    """
    # 1. Sort the keys to ensure we process 0, then 1, then 2, etc.
    sorted_radii = sorted(matrix_dict.keys())

    cum_dict = {}
    running_sum = None

    for r in sorted_radii:
        current_matrix = matrix_dict[r]

        if running_sum is None:
            # First iteration (e.g., radius 0)
            # Use .copy() to ensure we don't accidentally modify the input
            running_sum = current_matrix.copy()
        else:
            # Add the current matrix to the accumulated total
            running_sum = running_sum + current_matrix

        # Store the result in the new dictionary
        cum_dict[r] = running_sum

    return cum_dict

In [6]:
def get_sorted_coords_from_matrix(matrix, *, max_first=True):
    """
    Returns (row, col) tuples from the upper triangle, sorted by value.
    """
    rows, cols = np.triu_indices_from(matrix, k=1)
    values = matrix[rows, cols]
    sort_idx = np.argsort(values)
    if max_first:
        sort_idx = sort_idx[::-1]
    return list(zip(rows[sort_idx], cols[sort_idx]))

In [7]:
def sorted_idx_dict(matrix_dict, *, max_first=True):
    """Return {radius: sorted_coord_list} by applying get_sorted_coords_from_matrix to each matrix."""
    return {
        r: get_sorted_coords_from_matrix(mat, max_first=max_first)
        for r, mat in matrix_dict.items()
    }

In [8]:
# def embeddable_image(data, size=(100, 100)):
#     """
#     Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
#     Returns PNG data-url with aspect ratio preserved.
#     """
#     arr = np.asarray(data)

#     # Normalize types
#     if np.issubdtype(arr.dtype, np.floating):
#         if arr.max() <= 1.0:
#             arr = (arr * 255).astype(np.uint8)
#         else:
#             arr = arr.astype(np.uint8)
#     else:
#         arr = arr.astype(np.uint8)

#     # Detect Mode
#     if arr.ndim == 3:
#         mode = "RGBA" if arr.shape[2] == 4 else "RGB"
#     else:
#         mode = "L"  # Grayscale

#     # Create Image and Thumbnail
#     img = Image.fromarray(arr, mode=mode)
#     img.thumbnail(size, Image.Resampling.BICUBIC)  # Keeps aspect ratio

#     buffer = BytesIO()
#     img.save(buffer, format="PNG")
#     return (
#         "data:image/png;base64," + base64.b64encode(buffer.getvalue()).decode()
#     )


# def robot_image(i):
#     """
#     Generates the image for robot i using the global POPULATION and view function.
#     """
#     graph = POPULATION[i].to_graph()
#     # Using remove_background=True for transparent PNGs
#     img_arr = np.array(
#         view(graph, return_img=True, tilted=True, remove_background=True)
#     )
#     return embeddable_image(img_arr)


# def get_population_images(population_size):
#     """
#     Pre-generates all images for the population to avoid re-rendering.
#     """
#     return [
#         robot_image(i)
#         for i in track(
#             range(population_size),
#             description=f"Generating {population_size} images...",
#         )
#     ]


In [9]:
# def matrix_to_heatmap_source(matrix, images, metric_name, radius):
#     """
#     Converts an NxN matrix into a Bokeh ColumnDataSource (Long format).
#     """
#     N = matrix.shape[0]

#     # Create coordinate grids
#     # Note: Bokeh origin (0,0) is bottom-left. Matrices are top-left.
#     # We invert y so the plot looks like the matrix.
#     x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))

#     # Flatten arrays
#     x_flat = x_indices.flatten()
#     y_flat = N - 1 - y_indices.flatten()  # Invert Y for visual matrix layout
#     values = matrix.flatten()

#     # Map images
#     imgs_i = [images[r] for r in y_indices.flatten()]  # Row robot
#     imgs_j = [images[c] for c in x_indices.flatten()]  # Col robot

#     ids_i = [str(r) for r in y_indices.flatten()]
#     ids_j = [str(c) for c in x_indices.flatten()]

#     data = {
#         "x": x_flat,
#         "y": y_flat,
#         "value": values,
#         "img_row": imgs_i,  # Robot on Y axis
#         "img_col": imgs_j,  # Robot on X axis
#         "id_row": ids_i,
#         "id_col": ids_j,
#         "metric": [metric_name] * len(values),
#         "radius": [radius] * len(values),
#     }
#     return ColumnDataSource(data)

In [10]:
# def matrix_to_heatmap_source(matrix, images, metric_name, radius):
#     """
#     Converts an NxN matrix into a Bokeh ColumnDataSource (Long format).
#     """
#     N = matrix.shape[0]
#     x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
    
#     x_flat = x_indices.flatten()
#     y_flat = N - 1 - y_indices.flatten() 
#     values = matrix.flatten()
    
#     imgs_i = [images[r] for r in y_indices.flatten()]
#     imgs_j = [images[c] for c in x_indices.flatten()]
    
#     ids_i = [str(r) for r in y_indices.flatten()]
#     ids_j = [str(c) for c in x_indices.flatten()]
    
#     data = {
#         'x': x_flat, 'y': y_flat, 'value': values,
#         'img_row': imgs_i, 'img_col': imgs_j,
#         'id_row': ids_i, 'id_col': ids_j,
#         'metric': [metric_name] * len(values),
#         'radius': [radius] * len(values)
#     }
#     return ColumnDataSource(data)

# def plot_interactive_heatmaps(
#     all_matrix_data: dict, 
#     population_images: list, 
#     max_show_radius: int, 
#     plot_width=None, 
#     plot_height=None, 
#     palette="Reds256"
# ):
#     """
#     Creates a Grid of Interactive Heatmaps using Bokeh.
#     """
#     num_cols = len(all_matrix_data)
    
#     # Dynamic Sizing Defaults
#     # If width is not specified, make it large for single column, compact for multi-column
#     if plot_width is None:
#         plot_width = 650 if num_cols == 1 else 350
        
#     if plot_height is None:
#         # Keep heatmaps roughly square-ish (accounting for colorbar width in plot_width)
#         plot_height = 600 if num_cols == 1 else 300

#     grid_layout = []
#     for r in range(max_show_radius + 1):
#         row_plots = []
#         for name, matrix_dict in all_matrix_data.items():
#             if r in matrix_dict:
#                 matrix = matrix_dict[r]
#             else:
#                 matrix = np.zeros((1, 1))
            
#             source = matrix_to_heatmap_source(matrix, population_images, name, r)
#             vmin, vmax = matrix.min(), matrix.max()
#             mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
#             p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
#             p.axis.visible = False
#             p.grid.visible = False
#             p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
            
#             color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
#             p.add_layout(color_bar, 'right')
            
#             hover = p.select(dict(type=HoverTool))
#             hover.tooltips = """
#             <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
#                 <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
#                 <div style="display: flex; flex-direction: row; gap: 10px;">
#                     <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="max-width: 60px; max-height: 60px; width: auto; height: auto;"></div>
#                     <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="max-width: 60px; max-height: 60px; width: auto; height: auto;"></div>
#                 </div>
#             </div>
#             """
#             row_plots.append(p)
#         grid_layout.append(row_plots)
#     show(gridplot(grid_layout))

---

In [11]:

# def embeddable_image(data, scale=0.5):
#     """
#     Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
#     Returns PNG data-url with aspect ratio AND relative scale preserved.
#     """
#     arr = np.asarray(data)
    
#     # Normalize types
#     if np.issubdtype(arr.dtype, np.floating):
#         if arr.max() <= 1.0:
#             arr = (arr * 255).astype(np.uint8)
#         else:
#             arr = arr.astype(np.uint8)
#     else:
#         arr = arr.astype(np.uint8)

#     # Detect Mode
#     if arr.ndim == 3:
#         mode = 'RGBA' if arr.shape[2] == 4 else 'RGB'
#     else:
#         mode = 'L' # Grayscale

#     # Create Image
#     img = Image.fromarray(arr, mode=mode)
    
#     # Resize by a constant factor to preserve relative size differences
#     # (e.g., a big robot stays big, a small robot stays small)
#     if scale != 1.0:
#         new_size = (int(img.width * scale), int(img.height * scale))
#         img = img.resize(new_size, Image.Resampling.BICUBIC)

#     buffer = BytesIO()
#     img.save(buffer, format='PNG')
#     return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

# def robot_image(i):
#     """
#     Generates the image for robot i using the global POPULATION and view function.
#     """
#     graph = POPULATION[i].to_graph()  
#     # Using remove_background=True for transparent PNGs
#     img_arr = np.array(view(graph, return_img=True, tilted=True, remove_background=True))
#     # Apply a 50% scale to keep file sizes manageable while preserving relative scale
#     return embeddable_image(img_arr, scale=0.5)

# def get_population_images(population_size):
#     """
#     Pre-generates all images for the population to avoid re-rendering.
#     """
#     return [robot_image(i) for i in track(range(population_size), description=f"Generating {population_size} images...")]

# # --- 2. Interactive Heatmap Functions (Matrix Visualization) ---

# def matrix_to_heatmap_source(matrix, images, metric_name, radius):
#     """
#     Converts an NxN matrix into a Bokeh ColumnDataSource (Long format).
#     """
#     N = matrix.shape[0]
#     x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
    
#     x_flat = x_indices.flatten()
#     y_flat = N - 1 - y_indices.flatten() 
#     values = matrix.flatten()
    
#     imgs_i = [images[r] for r in y_indices.flatten()]
#     imgs_j = [images[c] for c in x_indices.flatten()]
    
#     ids_i = [str(r) for r in y_indices.flatten()]
#     ids_j = [str(c) for c in x_indices.flatten()]
    
#     data = {
#         'x': x_flat, 'y': y_flat, 'value': values,
#         'img_row': imgs_i, 'img_col': imgs_j,
#         'id_row': ids_i, 'id_col': ids_j,
#         'metric': [metric_name] * len(values),
#         'radius': [radius] * len(values)
#     }
#     return ColumnDataSource(data)

# def plot_interactive_heatmaps(
#     all_matrix_data: dict, 
#     population_images: list, 
#     max_show_radius: int, 
#     plot_width=None, 
#     plot_height=None, 
#     palette="Reds256"
# ):
#     """
#     Creates a Grid of Interactive Heatmaps using Bokeh.
#     """
#     num_cols = len(all_matrix_data)
    
#     # Dynamic Sizing Defaults
#     # If width is not specified, make it large for single column, compact for multi-column
#     if plot_width is None:
#         plot_width = 650 if num_cols == 1 else 350
        
#     if plot_height is None:
#         # Keep heatmaps roughly square-ish (accounting for colorbar width in plot_width)
#         plot_height = 600 if num_cols == 1 else 300

#     grid_layout = []
#     for r in range(max_show_radius + 1):
#         row_plots = []
#         for name, matrix_dict in all_matrix_data.items():
#             if r in matrix_dict:
#                 matrix = matrix_dict[r]
#             else:
#                 matrix = np.zeros((1, 1))
            
#             source = matrix_to_heatmap_source(matrix, population_images, name, r)
#             vmin, vmax = matrix.min(), matrix.max()
#             mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
#             p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
#             p.axis.visible = False
#             p.grid.visible = False
#             p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
            
#             color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
#             p.add_layout(color_bar, 'right')
            
#             hover = p.select(dict(type=HoverTool))
#             # UPDATED CSS: Removed max-width/max-height constraints.
#             # Using 'width: auto' respects the relative scaling baked into the image data.
#             hover.tooltips = """
#             <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
#                 <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
#                 <div style="display: flex; flex-direction: row; gap: 10px;">
#                     <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="width: auto; height: auto;"></div>
#                     <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="width: auto; height: auto;"></div>
#                 </div>
#             </div>
#             """
#             row_plots.append(p)
#         grid_layout.append(row_plots)
#     show(gridplot(grid_layout))

# # --- 3. Interactive UMAP Grid Functions (Scatter Visualization) ---

# # def plot_interactive_umap_grid(
# #     umap_data: dict, 
# #     population_images: list, 
# #     max_show_radius: int,
# #     plot_width=None, 
# #     plot_height=350
# # ):
# #     """
# #     Creates a Grid of Interactive UMAP Scatter plots.
# #     Row = Radius
# #     Column = Metric (different embedding dictionaries)
# #     Color = Rainbow by ID (Consistent across all plots)
# #     """
# #     num_cols = len(umap_data)
    
# #     # Dynamic Sizing Defaults
# #     # If width is not specified, make it large for single column, compact for multi-column
# #     if plot_width is None:
# #         plot_width = 700 if num_cols == 1 else 350
        
# #     n = len(population_images)
    
# #     # 1. Pre-calculate Colors (Rainbow by ID)
# #     rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n))
# #     hex_colors = [to_hex(c) for c in rgba_colors]
    
# #     grid_layout = []

# #     # 2. Iterate Rows (Radius)
# #     for r in range(max_show_radius + 1):
# #         row_plots = []
        
# #         # 3. Iterate Columns (Different UMAP methods)
# #         for name, matrix_dict in umap_data.items():
            
# #             # Check if embedding exists for this radius
# #             if r in matrix_dict:
# #                 emb = matrix_dict[r]
# #                 # Handle empty or malformed embeddings
# #                 if emb.ndim != 2 or emb.shape[1] != 2:
# #                      # Create empty placeholder if data is invalid
# #                     p = figure(title=f"r:{r} {name} (No Data)", width=plot_width, height=plot_height)
# #                     row_plots.append(p)
# #                     continue
# #             else:
# #                  # Create empty placeholder if radius missing
# #                 p = figure(title=f"r:{r} {name} (Missing)", width=plot_width, height=plot_height)
# #                 row_plots.append(p)
# #                 continue

# #             # Build DataSource
# #             robots_df = pd.DataFrame({
# #                 "x": emb[:, 0],
# #                 "y": emb[:, 1],
# #                 "digit": [str(i) for i in range(n)],
# #                 "image": population_images,
# #                 "color": hex_colors
# #             })
# #             source = ColumnDataSource(robots_df)
            
# #             # Create Figure
# #             p = figure(
# #                 title=f"r:{r} {name}", 
# #                 width=plot_width, 
# #                 height=plot_height, 
# #                 tools="pan,wheel_zoom,reset,save",
# #                 toolbar_location="above"
# #             )
            
# #             # Scatter Plot
# #             p.scatter(
# #                 'x', 'y',
# #                 source=source,
# #                 color='color',
# #                 line_alpha=0.6,
# #                 fill_alpha=0.7,
# #                 size=8
# #             )
            
# #             # Hover Tool with Image
# #             # CSS: Using width: auto to preserve relative size
# #             hover = HoverTool(tooltips="""
# #             <div>
# #                 <img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/>
# #             </div>
# #             <div style="font-size:12px; font-weight: bold;">
# #                 <span style='color:#224499'>ID: @digit</span>
# #             </div>
# #             """)
# #             p.add_tools(hover)
            
# #             row_plots.append(p)
            
# #         grid_layout.append(row_plots)
        
# #     # Display Grid
# #     show(gridplot(grid_layout))

In [12]:
# def embeddable_image(data, scale=1.0):
#     """
#     Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
#     Returns PNG data-url with aspect ratio AND relative scale preserved.
#     """
#     arr = np.asarray(data)
    
#     # Normalize types
#     if np.issubdtype(arr.dtype, np.floating):
#         if arr.max() <= 1.0:
#             arr = (arr * 255).astype(np.uint8)
#         else:
#             arr = arr.astype(np.uint8)
#     else:
#         arr = arr.astype(np.uint8)

#     # Detect Mode
#     if arr.ndim == 3:
#         mode = 'RGBA' if arr.shape[2] == 4 else 'RGB'
#     else:
#         mode = 'L' # Grayscale

#     # Create Image
#     img = Image.fromarray(arr, mode=mode)
    
#     # Resize by a constant factor to preserve relative size differences
#     if scale != 1.0:
#         new_size = (int(img.width * scale), int(img.height * scale))
#         img = img.resize(new_size, Image.Resampling.BICUBIC)

#     buffer = BytesIO()
#     img.save(buffer, format='PNG')
#     return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

# def robot_image(i):
#     """
#     Generates the image for robot i using the global POPULATION and view function.
#     """
#     graph = POPULATION[i].to_graph()  
#     # Using remove_background=True for transparent PNGs
#     img_arr = np.array(view(graph, return_img=True, tilted=True, remove_background=True))
#     # Apply 1.0 scale (Full Resolution) so images aren't pixelated
#     return embeddable_image(img_arr, scale=1.0)

# def get_population_images(population_size):
#     """
#     Pre-generates all images for the population to avoid re-rendering.
#     """
#     return [robot_image(i) for i in track(range(population_size), description=f"Generating {population_size} images...")]

# def decode_base64_image(data_url):
#     """
#     Helper to convert the stored base64 string back to a numpy array for matplotlib.
#     """
#     header, encoded = data_url.split(",", 1)
#     data = base64.b64decode(encoded)
#     return np.array(Image.open(BytesIO(data)))

# # --- 2. Interactive Bokeh Functions (Web/Interactive) ---

# def matrix_to_heatmap_source(matrix, images, metric_name, radius):
#     """Converts matrix to Bokeh DataSource."""
#     N = matrix.shape[0]
#     x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
#     x_flat = x_indices.flatten()
#     y_flat = N - 1 - y_indices.flatten() 
#     values = matrix.flatten()
#     imgs_i = [images[r] for r in y_indices.flatten()]
#     imgs_j = [images[c] for c in x_indices.flatten()]
#     ids_i = [str(r) for r in y_indices.flatten()]
#     ids_j = [str(c) for c in x_indices.flatten()]
#     data = {
#         'x': x_flat, 'y': y_flat, 'value': values,
#         'img_row': imgs_i, 'img_col': imgs_j,
#         'id_row': ids_i, 'id_col': ids_j,
#         'metric': [metric_name] * len(values),
#         'radius': [radius] * len(values)
#     }
#     return ColumnDataSource(data)

# def plot_interactive_heatmaps(all_matrix_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=None, palette="Reds256"):
#     """Creates a Grid of Interactive Heatmaps using Bokeh."""
#     num_cols = len(all_matrix_data)
#     if plot_width is None: plot_width = 650 if num_cols == 1 else 350
#     if plot_height is None: plot_height = 600 if num_cols == 1 else 300

#     grid_layout = []
#     for r in range(max_show_radius + 1):
#         row_plots = []
#         for name, matrix_dict in all_matrix_data.items():
#             if r in matrix_dict: matrix = matrix_dict[r]
#             else: matrix = np.zeros((1, 1))
            
#             source = matrix_to_heatmap_source(matrix, population_images, name, r)
#             vmin, vmax = matrix.min(), matrix.max()
#             mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
#             p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
#             p.axis.visible = False; p.grid.visible = False
#             p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
#             color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
#             p.add_layout(color_bar, 'right')
            
#             hover = p.select(dict(type=HoverTool))
#             hover.tooltips = """
#             <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
#                 <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
#                 <div style="display: flex; flex-direction: row; gap: 10px;">
#                     <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="width: auto; height: auto;"></div>
#                     <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="width: auto; height: auto;"></div>
#                 </div>
#             </div>
#             """
#             row_plots.append(p)
#         grid_layout.append(row_plots)
#     show(gridplot(grid_layout))

# # def plot_interactive_umap_grid(umap_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=350):
# #     """Creates a Grid of Interactive UMAP Scatter plots."""
# #     num_cols = len(umap_data)
# #     if plot_width is None: plot_width = 700 if num_cols == 1 else 350
# #     n = len(population_images)
# #     rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n))
# #     hex_colors = [to_hex(c) for c in rgba_colors]
    
# #     grid_layout = []
# #     for r in range(max_show_radius + 1):
# #         row_plots = []
# #         for name, matrix_dict in umap_data.items():
# #             if r in matrix_dict:
# #                 emb = matrix_dict[r]
# #                 if emb.ndim != 2 or emb.shape[1] != 2:
# #                     p = figure(title=f"r:{r} {name} (No Data)", width=plot_width, height=plot_height); row_plots.append(p); continue
# #             else:
# #                 p = figure(title=f"r:{r} {name} (Missing)", width=plot_width, height=plot_height); row_plots.append(p); continue

# #             robots_df = pd.DataFrame({"x": emb[:, 0], "y": emb[:, 1], "digit": [str(i) for i in range(n)], "image": population_images, "color": hex_colors})
# #             source = ColumnDataSource(robots_df)
# #             p = figure(title=f"r:{r} {name}", width=plot_width, height=plot_height, tools="pan,wheel_zoom,reset,save", toolbar_location="above")
# #             p.scatter('x', 'y', source=source, color='color', line_alpha=0.6, fill_alpha=0.7, size=8)
# #             hover = HoverTool(tooltips="""<div><img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/></div><div style="font-size:12px; font-weight: bold;"><span style='color:#224499'>ID: @digit</span></div>""")
# #             p.add_tools(hover)
# #             row_plots.append(p)
# #         grid_layout.append(row_plots)
# #     show(gridplot(grid_layout))

# # --- 3. Static Matplotlib Functions (Rows of Pairs) ---

# # def view_horizontal_groups(image_data_tuples, titles=None):
# #     """
# #     Plots horizontal groups without rescaling images of different sizes.
# #     Uses pixel-perfect width_ratios.
# #     Args:
# #         image_data_tuples: List of lists containing base64 image strings.
# #                            e.g. [[img_str_A, img_str_B], [img_str_C, img_str_D]]
# #     """
# #     # --- 1. Pre-fetch ALL images to get dimensions ---
# #     processed_items = []  # Will store dicts: {'type': 'img', 'data': img} or {'type': 'gap'}
    
# #     GROUP_GAP_PX = 100 
# #     ROBOT_GAP_PX = 20  
# #     DPI = 100  
# #     max_height = 0
    
# #     # Map from flattened index back to group index for titles
# #     axes_map = {} 
# #     current_index = 0

# #     for i, group in enumerate(image_data_tuples):
# #         for j, img_str in enumerate(group):
# #             # Decode the base64 string to numpy array for plotting
# #             img = decode_base64_image(img_str)
# #             h, w = img.shape[:2]

# #             if h > max_height: max_height = h

# #             processed_items.append({"type": "img", "data": img, "width": w, "height": h})
# #             current_index += 1

# #             if j < len(group) - 1:
# #                 processed_items.append({"type": "gap", "width": ROBOT_GAP_PX})
# #                 current_index += 1

# #         if i < len(image_data_tuples) - 1:
# #             processed_items.append({"type": "gap", "width": GROUP_GAP_PX})
# #             current_index += 1

# #     # --- 2. Calculate Figure Dimensions ---
# #     total_width_px = sum(item.get("width", 0) for item in processed_items)
# #     fig_width_in = total_width_px / DPI
# #     fig_height_in = max_height / DPI
# #     title_buffer_in = 0.5
# #     fig_height_total = fig_height_in + title_buffer_in

# #     # --- 3. Create Figure ---
# #     fig = plt.figure(figsize=(fig_width_in, fig_height_total), dpi=DPI)
# #     widths = [item.get("width", 0) for item in processed_items]
    
# #     gs = gridspec.GridSpec(1, len(processed_items), figure=fig, width_ratios=widths)
# #     plt.subplots_adjust(left=0, right=1, bottom=0, top=fig_height_in / fig_height_total, wspace=0, hspace=0)

# #     # --- 4. Plotting ---
# #     flat_ptr = 0 # To track position in processed_items list

# #     for idx, item in enumerate(processed_items):
# #         if item["type"] == "gap":
# #             continue

# #         ax = fig.add_subplot(gs[0, idx])
# #         ax.imshow(item["data"], aspect="equal", interpolation="none", origin="upper")
# #         ax.set_ylim(max_height, 0)
# #         ax.set_xlim(0, item["width"])
# #         ax.axis("off")
        
# #         # Store ax in map for title logic
# #         axes_map[idx] = ax

# #     # --- 5. Titles ---
# #     if titles:
# #         # We assume the processed_items structure directly maps to groups
# #         # Re-iterate to find the start and end axis for each group
# #         ptr = 0
# #         for i, group in enumerate(image_data_tuples):
# #             # Calculate how many items correspond to this group in the flat list
# #             # Each robot (except last) adds 1 image + 1 gap. Last robot adds 1 image.
# #             # Then there is 1 group gap (except last group).
            
# #             # Start of group
# #             start_idx = ptr
# #             start_ax = axes_map[start_idx]
            
# #             # Find end of group (index of the last robot image in this group)
# #             # The number of slots used by robots = (len(group) * 2) - 1
# #             # e.g., 2 robots -> [Img, Gap, Img] -> offset is 2
# #             offset = (len(group) * 2) - 1
# #             end_idx = ptr + offset - 1 
# #             end_ax = axes_map[end_idx]
            
# #             # Calculate Center
# #             bbox_start = start_ax.get_position()
# #             bbox_end = end_ax.get_position()
# #             center_x = (bbox_start.x0 + bbox_end.x1) / 2

# #             fig.text(center_x, 1.0 - (0.2 / fig_height_total), titles[i], ha="center", va="top", fontsize=10, weight="bold")
            
# #             # Move pointer: robots + group gap
# #             ptr += offset
# #             if i < len(image_data_tuples) - 1:
# #                 ptr += 1 # Skip group gap

# #     plt.show()

In [13]:
# def embeddable_image(data, scale=1.0):
#     """
#     Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
#     Returns PNG data-url with aspect ratio AND relative scale preserved.
#     """
#     arr = np.asarray(data)
    
#     # Normalize types
#     if np.issubdtype(arr.dtype, np.floating):
#         if arr.max() <= 1.0:
#             arr = (arr * 255).astype(np.uint8)
#         else:
#             arr = arr.astype(np.uint8)
#     else:
#         arr = arr.astype(np.uint8)

#     # Detect Mode
#     if arr.ndim == 3:
#         mode = 'RGBA' if arr.shape[2] == 4 else 'RGB'
#     else:
#         mode = 'L' # Grayscale

#     # Create Image
#     img = Image.fromarray(arr, mode=mode)
    
#     # Resize by a constant factor to preserve relative size differences
#     if scale != 1.0:
#         new_size = (int(img.width * scale), int(img.height * scale))
#         img = img.resize(new_size, Image.Resampling.BICUBIC)

#     buffer = BytesIO()
#     img.save(buffer, format='PNG')
#     return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

# def robot_image(i):
#     """
#     Generates the image for robot i using the global POPULATION and view function.
#     """
#     graph = POPULATION[i].to_graph()  
#     # Using remove_background=True for transparent PNGs
#     img_arr = np.array(view(graph, return_img=True, tilted=True, remove_background=True))
#     # Apply 0.6 scale: Consistent global scaling.
#     # Keeps relative sizes correct (cores stay same size) but prevents giant tooltips.
#     return embeddable_image(img_arr, scale=0.6)

# def get_population_images(population_size):
#     """
#     Pre-generates all images for the population to avoid re-rendering.
#     """
#     return [robot_image(i) for i in track(range(population_size), description=f"Generating {population_size} images...")]

# def decode_base64_image(data_url):
#     """
#     Helper to convert the stored base64 string back to a numpy array for matplotlib.
#     """
#     header, encoded = data_url.split(",", 1)
#     data = base64.b64decode(encoded)
#     return np.array(Image.open(BytesIO(data)))

# # --- 2. Interactive Bokeh Functions (Web/Interactive) ---

# def matrix_to_heatmap_source(matrix, images, metric_name, radius):
#     """Converts matrix to Bokeh DataSource."""
#     N = matrix.shape[0]
#     x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
#     x_flat = x_indices.flatten()
#     y_flat = N - 1 - y_indices.flatten() 
#     values = matrix.flatten()
#     imgs_i = [images[r] for r in y_indices.flatten()]
#     imgs_j = [images[c] for c in x_indices.flatten()]
#     ids_i = [str(r) for r in y_indices.flatten()]
#     ids_j = [str(c) for c in x_indices.flatten()]
#     data = {
#         'x': x_flat, 'y': y_flat, 'value': values,
#         'img_row': imgs_i, 'img_col': imgs_j,
#         'id_row': ids_i, 'id_col': ids_j,
#         'metric': [metric_name] * len(values),
#         'radius': [radius] * len(values)
#     }
#     return ColumnDataSource(data)

# def plot_interactive_heatmaps(all_matrix_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=None, palette="Reds256"):
#     """Creates a Grid of Interactive Heatmaps using Bokeh."""
#     num_cols = len(all_matrix_data)
#     if plot_width is None: plot_width = 650 if num_cols == 1 else 350
#     if plot_height is None: plot_height = 600 if num_cols == 1 else 300

#     grid_layout = []
#     for r in range(max_show_radius + 1):
#         row_plots = []
#         for name, matrix_dict in all_matrix_data.items():
#             if r in matrix_dict: matrix = matrix_dict[r]
#             else: matrix = np.zeros((1, 1))
            
#             source = matrix_to_heatmap_source(matrix, population_images, name, r)
#             vmin, vmax = matrix.min(), matrix.max()
#             mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
#             p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
#             p.axis.visible = False; p.grid.visible = False
#             p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
#             color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
#             p.add_layout(color_bar, 'right')
            
#             hover = p.select(dict(type=HoverTool))
#             # REMOVED max-width/height. 
#             # Images are now naturally scaled by 'robot_image' (0.6x), so relative sizes are correct.
#             hover.tooltips = """
#             <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
#                 <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
#                 <div style="display: flex; flex-direction: row; gap: 10px;">
#                     <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="width: auto; height: auto;"></div>
#                     <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="width: auto; height: auto;"></div>
#                 </div>
#             </div>
#             """
#             row_plots.append(p)
#         grid_layout.append(row_plots)
#     show(gridplot(grid_layout))

# # def plot_interactive_umap_grid(umap_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=350):
# #     """Creates a Grid of Interactive UMAP Scatter plots."""
# #     num_cols = len(umap_data)
# #     if plot_width is None: plot_width = 700 if num_cols == 1 else 350
# #     n = len(population_images)
# #     rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n))
# #     hex_colors = [to_hex(c) for c in rgba_colors]
    
# #     grid_layout = []
# #     for r in range(max_show_radius + 1):
# #         row_plots = []
# #         for name, matrix_dict in umap_data.items():
# #             if r in matrix_dict:
# #                 emb = matrix_dict[r]
# #                 if emb.ndim != 2 or emb.shape[1] != 2:
# #                     p = figure(title=f"r:{r} {name} (No Data)", width=plot_width, height=plot_height); row_plots.append(p); continue
# #             else:
# #                 p = figure(title=f"r:{r} {name} (Missing)", width=plot_width, height=plot_height); row_plots.append(p); continue

# #             robots_df = pd.DataFrame({"x": emb[:, 0], "y": emb[:, 1], "digit": [str(i) for i in range(n)], "image": population_images, "color": hex_colors})
# #             source = ColumnDataSource(robots_df)
# #             p = figure(title=f"r:{r} {name}", width=plot_width, height=plot_height, tools="pan,wheel_zoom,reset,save", toolbar_location="above")
# #             p.scatter('x', 'y', source=source, color='color', line_alpha=0.6, fill_alpha=0.7, size=8)
# #             # REMOVED max-width/height here as well
# #             hover = HoverTool(tooltips="""<div><img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/></div><div style="font-size:12px; font-weight: bold;"><span style='color:#224499'>ID: @digit</span></div>""")
# #             p.add_tools(hover)
# #             row_plots.append(p)
# #         grid_layout.append(row_plots)
# #     show(gridplot(grid_layout))

# # --- 3. Static Matplotlib Functions (Rows of Pairs) ---

# # def view_horizontal_groups(image_data_tuples, titles=None):
# #     """
# #     Plots horizontal groups without rescaling images of different sizes.
# #     Uses pixel-perfect width_ratios.
# #     Args:
# #         image_data_tuples: List of lists containing base64 image strings.
# #                            e.g. [[img_str_A, img_str_B], [img_str_C, img_str_D]]
# #     """
# #     # --- 1. Pre-fetch ALL images to get dimensions ---
# #     processed_items = []  # Will store dicts: {'type': 'img', 'data': img} or {'type': 'gap'}
    
# #     GROUP_GAP_PX = 100 
# #     ROBOT_GAP_PX = 20  
# #     # LOWERED DPI to 60. 
# #     # This makes the images appear physically larger in the figure layout, 
# #     # preventing the text (titles) from overpowering the visual content.
# #     DPI = 60  
# #     max_height = 0
    
# #     # Map from flattened index back to group index for titles
# #     axes_map = {} 
# #     current_index = 0

# #     for i, group in enumerate(image_data_tuples):
# #         for j, img_str in enumerate(group):
# #             # Decode the base64 string to numpy array for plotting
# #             img = decode_base64_image(img_str)
# #             h, w = img.shape[:2]

# #             if h > max_height: max_height = h

# #             processed_items.append({"type": "img", "data": img, "width": w, "height": h})
# #             current_index += 1

# #             if j < len(group) - 1:
# #                 processed_items.append({"type": "gap", "width": ROBOT_GAP_PX})
# #                 current_index += 1

# #         if i < len(image_data_tuples) - 1:
# #             processed_items.append({"type": "gap", "width": GROUP_GAP_PX})
# #             current_index += 1

# #     # --- 2. Calculate Figure Dimensions ---
# #     total_width_px = sum(item.get("width", 0) for item in processed_items)
# #     fig_width_in = total_width_px / DPI
# #     fig_height_in = max_height / DPI
# #     title_buffer_in = 0.5
# #     fig_height_total = fig_height_in + title_buffer_in

# #     # --- 3. Create Figure ---
# #     fig = plt.figure(figsize=(fig_width_in, fig_height_total), dpi=DPI)
# #     widths = [item.get("width", 0) for item in processed_items]
    
# #     gs = gridspec.GridSpec(1, len(processed_items), figure=fig, width_ratios=widths)
# #     plt.subplots_adjust(left=0, right=1, bottom=0, top=fig_height_in / fig_height_total, wspace=0, hspace=0)

# #     # --- 4. Plotting ---
# #     flat_ptr = 0 # To track position in processed_items list

# #     for idx, item in enumerate(processed_items):
# #         if item["type"] == "gap":
# #             continue

# #         ax = fig.add_subplot(gs[0, idx])
# #         ax.imshow(item["data"], aspect="equal", interpolation="none", origin="upper")
# #         ax.set_ylim(max_height, 0)
# #         ax.set_xlim(0, item["width"])
# #         ax.axis("off")
        
# #         # Store ax in map for title logic
# #         axes_map[idx] = ax

# #     # --- 5. Titles ---
# #     if titles:
# #         # We assume the processed_items structure directly maps to groups
# #         # Re-iterate to find the start and end axis for each group
# #         ptr = 0
# #         for i, group in enumerate(image_data_tuples):
# #             # Calculate how many items correspond to this group in the flat list
# #             # Each robot (except last) adds 1 image + 1 gap. Last robot adds 1 image.
# #             # Then there is 1 group gap (except last group).
            
# #             # Start of group
# #             start_idx = ptr
# #             start_ax = axes_map[start_idx]
            
# #             # Find end of group (index of the last robot image in this group)
# #             # The number of slots used by robots = (len(group) * 2) - 1
# #             # e.g., 2 robots -> [Img, Gap, Img] -> offset is 2
# #             offset = (len(group) * 2) - 1
# #             end_idx = ptr + offset - 1 
# #             end_ax = axes_map[end_idx]
            
# #             # Calculate Center
# #             bbox_start = start_ax.get_position()
# #             bbox_end = end_ax.get_position()
# #             center_x = (bbox_start.x0 + bbox_end.x1) / 2

# #             fig.text(center_x, 1.0 - (0.2 / fig_height_total), titles[i], ha="center", va="top", fontsize=10, weight="bold")
            
# #             # Move pointer: robots + group gap
# #             ptr += offset
# #             if i < len(image_data_tuples) - 1:
# #                 ptr += 1 # Skip group gap

# #     plt.show()


In [14]:
def embeddable_image(data, scale=1.0):
    """
    Simplified version: Accepts HxWx4 (RGBA) or HxWx3 (RGB).
    Returns PNG data-url with aspect ratio AND relative scale preserved.
    """
    arr = np.asarray(data)
    
    # Normalize types
    if np.issubdtype(arr.dtype, np.floating):
        if arr.max() <= 1.0:
            arr = (arr * 255).astype(np.uint8)
        else:
            arr = arr.astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)

    # Detect Mode
    if arr.ndim == 3:
        mode = 'RGBA' if arr.shape[2] == 4 else 'RGB'
    else:
        mode = 'L' # Grayscale

    # Create Image
    img = Image.fromarray(arr, mode=mode)
    
    # Resize by a constant factor to preserve relative size differences
    if scale != 1.0:
        new_size = (int(img.width * scale), int(img.height * scale))
        img = img.resize(new_size, Image.Resampling.BICUBIC)

    buffer = BytesIO()
    img.save(buffer, format='PNG')
    return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

def robot_image(i):
    """
    Generates the image for robot i using the global POPULATION and view function.
    """
    graph = POPULATION[i].to_graph()  
    # Using remove_background=True for transparent PNGs
    img_arr = np.array(view(graph, return_img=True, tilted=True, remove_background=True))
    
    # SCALE 1.0: High Quality for Matplotlib.
    return embeddable_image(img_arr, scale=1.0)

def get_population_images(population_size):
    """
    Pre-generates all images for the population to avoid re-rendering.
    """
    return [robot_image(i) for i in track(range(population_size), description=f"Generating {population_size} images...")]

def decode_base64_image(data_url):
    """Helper to convert base64 string back to numpy array."""
    header, encoded = data_url.split(",", 1)
    data = base64.b64decode(encoded)
    return np.array(Image.open(BytesIO(data)))

def create_thumbnails(image_list, scale=0.5):
    """
    Takes a list of base64 images (HQ) and creates a new list of scaled-down 
    thumbnails (preserving relative aspect ratio) for use in web tooltips.
    """
    thumbnails = []
    for b64_str in image_list:
        # Decode
        header, encoded = b64_str.split(",", 1)
        data = base64.b64decode(encoded)
        img = Image.open(BytesIO(data))
        
        # Resize
        new_size = (int(img.width * scale), int(img.height * scale))
        img_small = img.resize(new_size, Image.Resampling.BICUBIC)
        
        # Re-encode
        buffer = BytesIO()
        img_small.save(buffer, format='PNG')
        thumb_str = 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()
        thumbnails.append(thumb_str)
    return thumbnails

# --- 2. Interactive Bokeh Functions (Web/Interactive) ---

def matrix_to_heatmap_source(matrix, images, metric_name, radius):
    """Converts matrix to Bokeh DataSource."""
    N = matrix.shape[0]
    x_indices, y_indices = np.meshgrid(np.arange(N), np.arange(N))
    x_flat = x_indices.flatten()
    y_flat = N - 1 - y_indices.flatten() 
    values = matrix.flatten()
    imgs_i = [images[r] for r in y_indices.flatten()]
    imgs_j = [images[c] for c in x_indices.flatten()]
    ids_i = [str(r) for r in y_indices.flatten()]
    ids_j = [str(c) for c in x_indices.flatten()]
    data = {
        'x': x_flat, 'y': y_flat, 'value': values,
        'img_row': imgs_i, 'img_col': imgs_j,
        'id_row': ids_i, 'id_col': ids_j,
        'metric': [metric_name] * len(values),
        'radius': [radius] * len(values)
    }
    return ColumnDataSource(data)

def plot_interactive_heatmaps(all_matrix_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=None, palette="Reds256", thumbnail_scale=0.5):
    """
    Creates a Grid of Interactive Heatmaps using Bokeh.
    Args:
        thumbnail_scale: Factor to scale images down for the tooltip (default 0.5)
    """
    # Create thumbnails specifically for this plot (leaves original list untouched)
    thumb_images = create_thumbnails(population_images, scale=thumbnail_scale)

    num_cols = len(all_matrix_data)
    if plot_width is None: plot_width = 650 if num_cols == 1 else 350
    if plot_height is None: plot_height = 600 if num_cols == 1 else 300

    grid_layout = []
    for r in range(max_show_radius + 1):
        row_plots = []
        for name, matrix_dict in all_matrix_data.items():
            if r in matrix_dict: matrix = matrix_dict[r]
            else: matrix = np.zeros((1, 1))
            
            # Use thumbnails here
            source = matrix_to_heatmap_source(matrix, thumb_images, name, r)
            vmin, vmax = matrix.min(), matrix.max()
            mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
            
            p = figure(title=f"r:{r} {name}", x_range=(-0.5, matrix.shape[1]-0.5), y_range=(-0.5, matrix.shape[0]-0.5), width=plot_width, height=plot_height, tools="hover,save,reset", toolbar_location="above")
            p.axis.visible = False; p.grid.visible = False
            p.rect(x='x', y='y', width=1, height=1, source=source, fill_color={'field': 'value', 'transform': mapper}, line_color=None)
            color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(), label_standoff=8, border_line_color=None, location=(0,0), width=8)
            p.add_layout(color_bar, 'right')
            
            hover = p.select(dict(type=HoverTool))
            # No CSS max-width constraints. We rely on the thumbnail being physically smaller (0.5x)
            # but proportional.
            hover.tooltips = """
            <div style="display: flex; flex-direction: column; align-items: center; background: white; padding: 5px;">
                <div style="font-weight: bold; margin-bottom: 5px;">@metric (r=@radius) Val: @value{0.000}</div>
                <div style="display: flex; flex-direction: row; gap: 10px;">
                    <div style="text-align: center;"><span style="font-size: 10px;">Row: @id_row</span><br><img src="@img_row" style="width: auto; height: auto;"></div>
                    <div style="text-align: center;"><span style="font-size: 10px;">Col: @id_col</span><br><img src="@img_col" style="width: auto; height: auto;"></div>
                </div>
            </div>
            """
            row_plots.append(p)
        grid_layout.append(row_plots)
    show(gridplot(grid_layout))

def plot_interactive_umap_grid(umap_data: dict, population_images: list, max_show_radius: int, plot_width=None, plot_height=350, thumbnail_scale=0.5):
    """
    Creates a Grid of Interactive UMAP Scatter plots.
    Args:
        thumbnail_scale: Factor to scale images down for the tooltip (default 0.5)
    """
    # Create thumbnails specifically for this plot
    thumb_images = create_thumbnails(population_images, scale=thumbnail_scale)

    num_cols = len(umap_data)
    if plot_width is None: plot_width = 700 if num_cols == 1 else 350
    n = len(population_images)
    rgba_colors = plt.cm.rainbow(np.linspace(0, 1, n))
    hex_colors = [to_hex(c) for c in rgba_colors]
    
    grid_layout = []
    for r in range(max_show_radius + 1):
        row_plots = []
        for name, matrix_dict in umap_data.items():
            if r in matrix_dict:
                emb = matrix_dict[r]
                if emb.ndim != 2 or emb.shape[1] != 2:
                    p = figure(title=f"r:{r} {name} (No Data)", width=plot_width, height=plot_height); row_plots.append(p); continue
            else:
                p = figure(title=f"r:{r} {name} (Missing)", width=plot_width, height=plot_height); row_plots.append(p); continue

            # Use thumb_images here
            robots_df = pd.DataFrame({"x": emb[:, 0], "y": emb[:, 1], "digit": [str(i) for i in range(n)], "image": thumb_images, "color": hex_colors})
            source = ColumnDataSource(robots_df)
            p = figure(title=f"r:{r} {name}", width=plot_width, height=plot_height, tools="pan,wheel_zoom,reset,save", toolbar_location="above")
            p.scatter('x', 'y', source=source, color='color', line_alpha=0.6, fill_alpha=0.7, size=8)
            hover = HoverTool(tooltips="""<div><img src='@image' style='float:left; margin:5px; width:auto; height:auto;'/></div><div style="font-size:12px; font-weight: bold;"><span style='color:#224499'>ID: @digit</span></div>""")
            p.add_tools(hover)
            row_plots.append(p)
        grid_layout.append(row_plots)
    show(gridplot(grid_layout))

# --- 3. Static Matplotlib Functions (Rows of Pairs) ---

# def view_horizontal_groups(image_data_tuples, titles=None):
#     """
#     Plots horizontal groups without rescaling images of different sizes.
#     Uses pixel-perfect width_ratios.
#     Args:
#         image_data_tuples: List of lists containing base64 image strings (HQ).
#     """
#     # --- 1. Pre-fetch ALL images to get dimensions ---
#     processed_items = []  # Will store dicts: {'type': 'img', 'data': img} or {'type': 'gap'}
    
#     GROUP_GAP_PX = 100 
#     ROBOT_GAP_PX = 20  
#     DPI = 100 # High DPI for Sharp Text/Images
#     max_height = 0
    
#     # Map from flattened index back to group index for titles
#     axes_map = {} 
#     current_index = 0

#     for i, group in enumerate(image_data_tuples):
#         for j, img_str in enumerate(group):
#             # Decode the base64 string to numpy array for plotting
#             img = decode_base64_image(img_str)
#             h, w = img.shape[:2]

#             if h > max_height: max_height = h

#             processed_items.append({"type": "img", "data": img, "width": w, "height": h})
#             current_index += 1

#             if j < len(group) - 1:
#                 processed_items.append({"type": "gap", "width": ROBOT_GAP_PX})
#                 current_index += 1

#         if i < len(image_data_tuples) - 1:
#             processed_items.append({"type": "gap", "width": GROUP_GAP_PX})
#             current_index += 1

#     # --- 2. Calculate Figure Dimensions ---
#     total_width_px = sum(item.get("width", 0) for item in processed_items)
#     fig_width_in = total_width_px / DPI
#     fig_height_in = max_height / DPI
#     title_buffer_in = 0.5
#     fig_height_total = fig_height_in + title_buffer_in

#     # --- 3. Create Figure ---
#     fig = plt.figure(figsize=(fig_width_in, fig_height_total), dpi=DPI)
#     widths = [item.get("width", 0) for item in processed_items]
    
#     gs = gridspec.GridSpec(1, len(processed_items), figure=fig, width_ratios=widths)
#     plt.subplots_adjust(left=0, right=1, bottom=0, top=fig_height_in / fig_height_total, wspace=0, hspace=0)

#     # --- 4. Plotting ---
#     flat_ptr = 0 # To track position in processed_items list

#     for idx, item in enumerate(processed_items):
#         if item["type"] == "gap":
#             continue

#         ax = fig.add_subplot(gs[0, idx])
#         ax.imshow(item["data"], aspect="equal", interpolation="none", origin="upper")
#         ax.set_ylim(max_height, 0)
#         ax.set_xlim(0, item["width"])
#         ax.axis("off")
        
#         # Store ax in map for title logic
#         axes_map[idx] = ax

#     # --- 5. Titles ---
#     if titles:
#         # We assume the processed_items structure directly maps to groups
#         # Re-iterate to find the start and end axis for each group
#         ptr = 0
#         for i, group in enumerate(image_data_tuples):
#             # Calculate how many items correspond to this group in the flat list
#             # Each robot (except last) adds 1 image + 1 gap. Last robot adds 1 image.
#             # Then there is 1 group gap (except last group).
            
#             # Start of group
#             start_idx = ptr
#             start_ax = axes_map[start_idx]
            
#             # Find end of group (index of the last robot image in this group)
#             # The number of slots used by robots = (len(group) * 2) - 1
#             # e.g., 2 robots -> [Img, Gap, Img] -> offset is 2
#             offset = (len(group) * 2) - 1
#             end_idx = ptr + offset - 1 
#             end_ax = axes_map[end_idx]
            
#             # Calculate Center
#             bbox_start = start_ax.get_position()
#             bbox_end = end_ax.get_position()
#             center_x = (bbox_start.x0 + bbox_end.x1) / 2

#             fig.text(center_x, 1.0 - (0.2 / fig_height_total), titles[i], ha="center", va="top", fontsize=10, weight="bold")
            
#             # Move pointer: robots + group gap
#             ptr += offset
#             if i < len(image_data_tuples) - 1:
#                 ptr += 1 # Skip group gap

#     plt.show()

In [15]:
def _stitch_images_horizontally(images, target_height=None, gap_px=20):
    """
    Stitches images horizontally with white background (handles transparency).
    If target_height is provided, pads all images vertically to match that height (alignment: top).
    """
    if not images or all(img is None for img in images):
        return None, 0, 0
    
    valid_images = [img for img in images if img is not None]
    if not valid_images: return None, 0, 0

    # 1. Normalize Types to uint8 RGB and composite transparent images onto white
    normalized_imgs = []
    for img in valid_images:
        # Handle float 0-1
        if np.issubdtype(img.dtype, np.floating):
            img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        
        # Handle Alpha channel - composite onto white background
        if len(img.shape) == 3 and img.shape[2] == 4:
            # Extract RGB and Alpha
            rgb = img[:, :, :3]
            alpha = img[:, :, 3:4] / 255.0  # Normalize alpha to 0-1
            
            # Create white background
            white_bg = np.full_like(rgb, 255, dtype=np.uint8)
            
            # Composite: result = foreground * alpha + background * (1 - alpha)
            img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
        elif len(img.shape) == 3 and img.shape[2] >= 3:
            img = img[:, :, :3]
            
        normalized_imgs.append(img)

    # 2. Determine Canvas Height
    current_max_h = max(img.shape[0] for img in normalized_imgs)
    final_h = target_height if target_height and target_height > current_max_h else current_max_h

    # 3. Create White Gap Column
    white_gap_col = np.full((final_h, gap_px, 3), 255, dtype=np.uint8)

    # 4. Stitching Loop
    stitched = None
    
    for i, img in enumerate(normalized_imgs):
        h, w = img.shape[:2]
        
        # Pad image to final_h (fill bottom with white)
        if h < final_h:
            pad = np.full((final_h - h, w, 3), 255, dtype=np.uint8)
            img = np.vstack((img, pad))
            
        if stitched is None:
            stitched = img
        else:
            stitched = np.hstack((stitched, white_gap_col, img))
            
    return stitched, stitched.shape[1], final_h

def view_grid_of_groups(rows_of_tuples, rows_of_titles=None, col_headers=None, main_title=None):
    """
    Plots a grid of groups where ALL images are scaled equally (no auto-zoom).
    """
    if not rows_of_tuples: return

    n_rows = len(rows_of_tuples)
    n_cols = len(rows_of_tuples[0])
    ROBOT_GAP_PX = 20
    
    # --- PASS 1: Calculate Global Max Dimensions and Process Images ---
    global_max_h = 0
    global_max_w = 0
    
    grid_data = [[None for _ in range(n_cols)] for _ in range(n_rows)]
    
    for r in range(n_rows):
        for c in range(n_cols):
            group_base64_strings = rows_of_tuples[r][c]
            images = [decode_base64_image(s) for s in group_base64_strings]
            
            # Find max height in this specific group to update global max
            for img in images:
                if img is not None:
                    if img.shape[0] > global_max_h: global_max_h = img.shape[0]
            
            grid_data[r][c] = images

    # --- PASS 2: Stitch and Measure Widths ---
    processed_images = []
    
    for r in range(n_rows):
        row_imgs = []
        for c in range(n_cols):
            images = grid_data[r][c]
            # Stitch using GLOBAL height (pads bottom with white)
            stitched, w, h = _stitch_images_horizontally(images, target_height=global_max_h, gap_px=ROBOT_GAP_PX)
            
            if w > global_max_w: global_max_w = w
            row_imgs.append(stitched)
        processed_images.append(row_imgs)

    # --- PASS 3: Plot with Fixed Limits ---
    fig, axes = plt.subplots(n_rows, n_cols, 
                             figsize=(4 * n_cols, 2.5 * n_rows), 
                             squeeze=False,
                             facecolor='white')
    
    if main_title:
        fig.suptitle(main_title, fontsize=16, weight="bold", y=0.98, color='black')

    for r in range(n_rows):
        for c in range(n_cols):
            ax = axes[r, c]
            ax.set_facecolor('white')
            
            img_data = processed_images[r][c]
            
            if img_data is not None:
                ax.imshow(img_data)
            
            # Force Equal Scaling
            ax.set_xlim(0, global_max_w)
            ax.set_ylim(global_max_h, 0)
            ax.set_aspect('equal')
            ax.axis('off')

            # Titles and Headers
            if rows_of_titles and r < len(rows_of_titles):
                ax.set_title(rows_of_titles[r][c], fontsize=10, color='black', pad=0)

            if r == 0 and col_headers and c < len(col_headers):
                ax.text(0.5, 1, col_headers[c], transform=ax.transAxes, 
                        ha="center", va="bottom", fontsize=12, weight="bold", color="#224499")

    plt.subplots_adjust(wspace=0.1, hspace=0.25, top=0.95, bottom=0.02)
    plt.show()

def plot_rows_for_radii(cumulative_data, sorted_data, population_images, max_radius: int, pair_rank: int = 0, labels: list[str] = None, main_title: str = None):
    if labels is None: labels = list(cumulative_data.keys())
    all_rows_robots = []
    all_rows_titles = []
    
    for r in range(max_radius + 1):
        robots_row = []
        titles_row = []
        for name in labels:
            coords_list = sorted_data[name].get(r, [])
            matrix = cumulative_data[name].get(r)
            if coords_list and pair_rank < len(coords_list):
                i, j = coords_list[pair_rank]
            else:
                i, j = (0, 0)
            idx_i, idx_j = int(i), int(j)
            val = matrix[idx_i, idx_j] if matrix is not None else 0.0
            
            robots_row.append([population_images[idx_i], population_images[idx_j]])
            titles_row.append(f"{name}\nr:{r} ({idx_i},{idx_j}) val={val:.3f}")
            
        all_rows_robots.append(robots_row)
        all_rows_titles.append(titles_row)

    view_grid_of_groups(all_rows_robots, all_rows_titles, col_headers=None, main_title=main_title)

---

## GLOBAL ANALYSIS SETTINGS

In [16]:
POPULATION_SIZE = 100
NUM_OF_MODULES = 20

MAX_RADIUS = None

CONFIG = ctk.SimilarityConfig(
    max_tree_radius=MAX_RADIUS, radius_strategy=ctk.RadiusStrategy.NODE_LOCAL
)

In [17]:
POPULATION = [
    ctk.from_graph(generate_random_individual(NUM_OF_MODULES))
    for _ in range(POPULATION_SIZE)
]

SUBTREES = [
    ctk.collect_tree_hash_config_mode(individual, config=CONFIG)
    for individual in POPULATION
]

COUNT_MATRIX_DICT = ctk.get_count_matrix(SUBTREES, CONFIG)

In [None]:
POPULATION_IMGS = get_population_images(POPULATION_SIZE)

Output()

MESA: error: ZINK: failed to choose pdev
glx: failed to create drisw screen
Dropped Escape call with ulEscapeCode : 0x03007703


In [None]:
def apply_tfidf_transformer(count_matrix):
    transformer = TfidfTransformer()
    tfidf_matrix = transformer.fit_transform(count_matrix)
    return cosine_similarity(tfidf_matrix)


def apply_umap_n2(count_matrix):
    return umap.UMAP(metric="cosine", n_neighbors=2).fit_transform(
        count_matrix
    )
    
def apply_umap_n20(count_matrix):
    return umap.UMAP(metric="cosine", n_neighbors=20).fit_transform(
        count_matrix
    )

def apply_umap_n40(count_matrix):
    return umap.UMAP(metric="cosine", n_neighbors=40).fit_transform(
        count_matrix
    )

def apply_emb_to_dist(umap_emb_matrix):
    condensed_distances = pdist(umap_emb_matrix, metric="euclidean")
    return squareform(condensed_distances)

In [None]:
cos_matrix_dict = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, cosine_similarity)
tfidf_matrix_dict = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_tfidf_transformer)


umap_dict_n2 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n2) 
umapdist_matrix_dict_n2 = ctk.matrix_dict_applier(umap_dict_n2, apply_emb_to_dist)

umap_dict_n20 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n20) 
umapdist_matrix_dict_n20 = ctk.matrix_dict_applier(umap_dict_n20, apply_emb_to_dist)

umap_dict_n40 = ctk.matrix_dict_applier(COUNT_MATRIX_DICT, apply_umap_n40) 
umapdist_matrix_dict_n40 = ctk.matrix_dict_applier(umap_dict_n40, apply_emb_to_dist)

In [None]:
LABELS = ["umap_dist_n2", "umap_dist_n20", "umap_dist_n40", "cosine", "tfidf"]

ALL_MATRIX_DATA = {
    "umap_dist_n2": umapdist_matrix_dict_n2,
    "umap_dist_n20": umapdist_matrix_dict_n20,
    "umap_dist_n40": umapdist_matrix_dict_n40,
    "cosine": cos_matrix_dict,
    "tfidf": tfidf_matrix_dict,
}

CUMULATIVE_MATRIX_DATA = {
    "umap_dist_n2": get_cum_matrix(umapdist_matrix_dict_n2),
    "umap_dist_n20": get_cum_matrix(umapdist_matrix_dict_n20),
    "umap_dist_n40": get_cum_matrix(umapdist_matrix_dict_n40),
    "cosine": get_cum_matrix(cos_matrix_dict),
    "tfidf": get_cum_matrix(tfidf_matrix_dict),
}

SORTED_IDX_DATA = {
    "umap_dist_n2": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n2"], max_first=False),
    "umap_dist_n20": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n20"], max_first=False),
    "umap_dist_n40": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["umap_dist_n40"], max_first=False),
    "cosine": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["cosine"], max_first=True),
    "tfidf": sorted_idx_dict(CUMULATIVE_MATRIX_DATA["tfidf"], max_first=True),
}

UMAP_EMBEDDINGS = {
    "umap_emb_n2": umap_dict_n2,
    "umap_emb_n20": umap_dict_n20,
    "umap_emb_n40": umap_dict_n40,
}

MAX_SHOW_RADIUS = 9

---

### Matrix per radius

In [None]:
plot_comparison_heatmaps(ALL_MATRIX_DATA, MAX_SHOW_RADIUS)

### Cumulative matrix per radius

In [None]:
plot_comparison_heatmaps(CUMULATIVE_MATRIX_DATA, MAX_SHOW_RADIUS)

Interactive Heatmaps

In [None]:
plot_interactive_heatmaps(CUMULATIVE_MATRIX_DATA, POPULATION_IMGS, MAX_SHOW_RADIUS)

### Show 2d UMAP embeddings

In [None]:
plot_interactive_umap_grid(UMAP_EMBEDDINGS, POPULATION_IMGS, MAX_SHOW_RADIUS)

### Most Similar Per Radius per metric

In [None]:
rank = 0

In [None]:
print(f'rank {rank}')

plot_rows_for_radii(
    cumulative_data=CUMULATIVE_MATRIX_DATA,
    sorted_data=SORTED_IDX_DATA,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=rank,
    labels=LABELS,
    main_title=f"Comparison Rank {np.abs(rank)} (Most Similar)"  
)

rank += 1

### Most Similar Per Radius per metric

In [None]:
rank = -1

In [None]:
print(f'rank {rank}')

plot_rows_for_radii(
    cumulative_data=CUMULATIVE_MATRIX_DATA,
    sorted_data=SORTED_IDX_DATA,
    population_images=POPULATION_IMGS,
    max_radius=MAX_SHOW_RADIUS,
    pair_rank=rank,
    labels=LABELS,
    main_title=f"Comparison Rank {np.abs(rank)} (Least Similar)" 
)

rank -= 1