In [5]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, Dropdown, IntSlider, VBox, Button, GridBox, Layout, HBox
from typing import Tuple, Optional
import numpy.typing as npt
from IPython.display import display, clear_output
import random
from classes import load_all
from ipywidgets import Output
import matplotlib.colors as mcolors


# Type aliases
Array2D = npt.NDArray[np.float64]

# Define constant color mapping
COLOR_MAP = {
    0: 'black', 1: 'blue', 2: 'red', 3: 'green', 4: 'yellow',
    5: 'grey', 6: 'pink', 7: 'orange', 8: 'lightblue', 9: 'darkred'
}

def get_custom_colormap():
    """Create a custom colormap from our color mapping."""
    colors = [COLOR_MAP[i] for i in range(10)]
    return mcolors.ListedColormap(colors)

def shannon_entropy(arr: Array2D) -> float:
    """Calculate the Shannon entropy of a 2D array."""
    value, counts = np.unique(arr, return_counts=True)
    probs = counts / counts.sum()
    if len(probs) == 1: # avoid div by 0
        return 0.0
    entropy = -np.sum(probs * np.log2(probs))
    return entropy

def convolve2d_entropy(input_arr: Array2D, kernel: Array2D, 
               stride: int = 1, padding: int = 0) -> Array2D:
    """Perform 2D convolution with specified stride and padding."""
    if padding > 0:
        input_arr = np.pad(input_arr, padding, mode='constant', constant_values=-1)
    
    input_h, input_w = input_arr.shape
    kernel_h, kernel_w = kernel.shape
    
    output_h = (input_h - kernel_h) // stride + 1
    output_w = (input_w - kernel_w) // stride + 1
    
    output = np.zeros((output_h, output_w))
    
    for i in range(output_h):
        for j in range(output_w):
            start_i, start_j = i * stride, j * stride
            region = input_arr[start_i:start_i + kernel_h, start_j:start_j + kernel_w]
            
            # mask out where kernel == 0
            masked = region[kernel != 0]
            
            if masked.size > 0:
                output[i, j] = shannon_entropy(masked)
            else:
                output[i, j] = 0.0
    
    return output

class ClickableFilterWidget:
    """Interactive clickable filter grid."""
    
    def __init__(self):
        self.filter_values = np.zeros((current_filt_size, current_filt_size), dtype=int)
        self.buttons = []
        self.output_widget = None
        self.create_buttons()
    
    def create_buttons(self) -> None:
        """Create grid of clickable buttons sized to input."""
        self.buttons = []
        for i in range(current_filt_size):
            row = []
            for j in range(current_filt_size):
                btn = Button(
                    description='0',
                    layout=Layout(width='60px', height='60px'),
                    button_style='',
                    style={'font_weight': 'bold', 'font_size': '16px'}
                )
                btn.on_click(lambda x, row=i, col=j: self.toggle_cell(row, col))
                row.append(btn)
            self.buttons.append(row)
    
    def toggle_cell(self, i: int, j: int) -> None:
        """Toggle filter cell between 0 and 1."""
        self.filter_values[i, j] = 1 - self.filter_values[i, j]
        self.buttons[i][j].description = str(self.filter_values[i, j])
        self.buttons[i][j].button_style = 'info' if self.filter_values[i, j] else ''
        self.update_visualization()
    
    def get_widget(self):
        """Return the widget for display."""
        grid_buttons = []
        for row in self.buttons:
            grid_buttons.extend(row)
        
        return GridBox(
            children=grid_buttons,
            layout=Layout(grid_template_columns=f'repeat({current_filt_size}, 60px)', grid_gap='2px')
        )
    
    def set_output_widget(self, output_widget):
        """Set the output widget for updates."""
        self.output_widget = output_widget
    
    def _reduced_kernel(self) -> Array2D:
        """Reduce filter to smallest rectangle containing all 1's."""
        ones = np.argwhere(self.filter_values == 1)
        if ones.size == 0:
            return np.zeros((1,1), dtype=float)  # empty kernel
        (rmin, cmin), (rmax, cmax) = ones.min(0), ones.max(0)
        return self.filter_values[rmin:rmax+1, cmin:cmax+1].astype(np.float64)
    
    def update_visualization(self):
        """Update the visualization when filter changes."""
        if self.output_widget:
            with self.output_widget:
                clear_output(wait=True)
                visualize_convolution_static(
                    current_input_name,
                    current_input_num,
                    self._reduced_kernel(),
                    current_stride,
                    current_padding
                )

# Global variables for current state
current_input_name = ''
current_input_num = 1
current_filt_size = 3
current_stride = 1
current_padding = 0
filter_widget = ClickableFilterWidget()


arc = load_all()
input_list = {item.name: item.train_pairs[0].input for item in (arc.training + arc.evaluation)}

def visualize_convolution_static(
        input_name: str,
        input_num: int,
        kernel: Array2D, 
        stride: int,
        padding: int,
    ) -> None:

    """Static version of convolution visualization."""
    # Use input_list if available, otherwise fall back to generated samples
    if input_list and input_name in input_list:
        input_arr = arc.get_problem(input_name).train_pairs[input_num - 1].input
    else:
        return
    
    output = convolve2d_entropy(input_arr, kernel, stride, padding)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    
    # Input - use custom colormap
    custom_cmap = get_custom_colormap()
    im1 = axes[0].imshow(input_arr, cmap=custom_cmap, interpolation='nearest', vmin=0, vmax=9)
    axes[0].set_title(f'Input: {input_name}')
    axes[0].set_xticks(range(input_arr.shape[1]))
    axes[0].set_yticks(range(input_arr.shape[0]))
    
    # # Add input values as text
    # for i in range(input_arr.shape[0]):
    #     for j in range(input_arr.shape[1]):
    #         axes[0].text(j, i, f'{int(input_arr[i,j])}', 
    #                     ha='center', va='center', color='white', fontweight='bold')
    
    # Output
    im3 = axes[1].imshow(output, cmap='viridis', interpolation='nearest')
    axes[1].set_title(f'Output (stride={stride}, pad={padding})')
    axes[1].set_xticks(range(output.shape[1]))
    axes[1].set_yticks(range(output.shape[0]))
    
    # # Add output values as text
    # for i in range(output.shape[0]):
    #     for j in range(output.shape[1]):
    #         axes[1].text(j, i, f'{output[i,j]:.1f}', 
    #                     ha='center', va='center', color='white', fontweight='bold', fontsize=10)
    # plt.colorbar(im3, ax=axes[1])
    
    plt.tight_layout()
    plt.show()


def update_parameters(
        input_name: str,
        input_num: int,
        filt_size: int,
        stride: int,
        padding: int
) -> None:

    """Update global parameters and refresh visualization."""
    global current_input_name, current_input_num, current_filt_size, current_stride, current_padding, filter_widget
    current_input_name = input_name
    current_input_num = input_num
    current_filt_size = filt_size
    current_stride = stride  
    current_padding = padding

    # Recreate filter widget with new size
    filter_widget = ClickableFilterWidget()
    filter_widget.set_output_widget(plot_output)
    
    # Update the filter display
    with filter_display:
        clear_output(wait=True)
        display(filter_widget.get_widget())

    filter_widget.update_visualization()

# Create output widgets
plot_output = Output()
filter_display = Output()
filter_widget.set_output_widget(plot_output)

# Get available inputs
available_inputs = list(input_list.keys())

# Create widgets manually for better control
input_dropdown = Dropdown(options=available_inputs, value=available_inputs[0], description='Input:')
input_num_slider=IntSlider(min=1, max=len(arc.get_problem(input_dropdown.value).train_pairs), value=1, description='Input #')
num_colors_slider = IntSlider(min=2, max=5, value=2, description='Colors')
size_slider = IntSlider(min=2, max=10, value=3, description='Input Size')
stride_slider = IntSlider(min=1, max=3, value=1, description='Stride')
padding_slider = IntSlider(min=0, max=2, value=1, description='Padding')

# Create parameter controls with the combined input selector
param_controls = interact(
    update_parameters,
    input_name=input_dropdown,  # Reference the dropdown in the HBox
    input_num=input_num_slider,
    filt_size=size_slider,
    stride=stride_slider,
    padding=padding_slider
)

# Display the filter widget initially
with filter_display:
    display(filter_widget.get_widget())

display(filter_display)
display(plot_output)

# Initial visualization
filter_widget.update_visualization()

interactive(children=(Dropdown(description='Input:', options=('d037b0a7', 'caa06a1f', 'a5f85a15', '22168020', â€¦

Output()

Output()