In [5]:
import pickle
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
from ipywidgets import HBox, VBox, widgets
from PIL import Image as PILImage
import io


# These are the names that got shuffled.
names = [
    "design_style_grid",
    "design_style_grid_anne",
    "design_style_grid_hybrid",
    "design_style_grid_hybrid2",
]

# read in randomized names
with open("randomized_names", "rb") as f:
    names = pickle.load(f)

datasets = {
    name: pd.read_csv(f"{name}.csv", index_col="style")
    for name in names
}

wgts = datasets[names[0]]["weight"].tolist()
styles = datasets[names[0]].index.tolist()
all_spectrums = datasets[names[0]].drop(columns="weight").columns.tolist()

base_width = 250
base_height = None
fixed_height = f"{base_width * 1.5:.0f}px"


def resize_image_aspect_ratio(image_path, base_width=None, base_height=None):
    with PILImage.open(image_path) as img:
        # If base_width is provided, calculate the height based on the aspect ratio
        if base_width is not None:
            w_percent = base_width / float(img.size[0])
            h_size = int((float(img.size[1]) * float(w_percent)))
            img = img.resize((base_width, h_size), PILImage.LANCZOS)

        # If base_height is provided, calculate the width based on the aspect ratio
        elif base_height is not None:
            h_percent = base_height / float(img.size[1])
            w_size = int((float(img.size[0]) * float(h_percent)))
            img = img.resize((w_size, base_height), PILImage.LANCZOS)

        # Convert to bytes
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format="PNG")
        img_byte_arr = img_byte_arr.getvalue()

    return img_byte_arr


# Helper function to get the image bytes
def load_image_bytes(file_path):
    with open(file_path, "rb") as f:
        return f.read()


def render_thumbnails():
    global thumbnail_widgets

    thumbnail_widgets = []
    for i, name in enumerate(names):
        curr_widgets = []
        for cs in chosen_styles[i]:
            widget = widgets.Image(
                value=resize_image_aspect_ratio(
                    f"thumbnails/{cs}.jpg", base_width=base_width, base_height=base_height
                ),
                format="jpg",
            )

            label = widgets.Label(value=cs)

            # Set the label layout to match the image width and center the text
            label.layout = widgets.Layout(width=f"{widget.width}px", text_align="center")

            # Use the style property to set a larger font size
            label.style = {"font_size": "20px"}  # , 'font_weight': 'bold'}

            curr_vbox = VBox(
                [label, widget],
                layout=widgets.Layout(align_items="center", height=fixed_height),
            )  # align_items centers the VBox content

            curr_widgets.append(curr_vbox)
        thumbnail_widgets.append(curr_widgets)


def update_thumbnails(vals, X):
    global thumbnail_widgets

    for i, name in enumerate(names):
        # Update chosen styles
        chosen_styles[i] = get_styles(X[i], vals, styles, wgts)

        for widget, cs in zip(thumbnail_widgets[i], chosen_styles[i]):
            widget.children[0].value = cs
            widget.children[1].value = resize_image_aspect_ratio(
                f"thumbnails/{cs}.jpg", base_width=base_width, base_height=base_height
            )


def get_styles(
    X: list[list[int]], usr_pnt: list[int], styles: list[str], weights: list[float]
):

    errors = []

    # Iterate through matrix, getting sum of absolute differences for each styl
    for i in range(len(X)):
        err = 0
        for j in range(len(all_spectrums)):
            # Difference between sliders and styles for style i and spectrum j
            curr_err = abs(X[i][j] - usr_pnt[j])

            # If error on traditional-modern spectrum > 25, remove from consideration
            if (all_spectrums[j] == "Traditional--Modern") & (curr_err > 25):
                curr_err += 1000

            err += curr_err
        err = err / weights[i]
        errors.append(err)

    # Find the lowest 3 error values
    curr_styles = []
    for _ in range(3):
        # Location of minimum value
        idx = errors.index(min(errors))

        # Add style
        curr_styles.append(styles[idx])

        # Replace chosen style with high error value
        errors[idx] += 1000

    return curr_styles


def initial_plot():
    global chosen_styles
    global disc1
    global disc2
    global vboxes
    global X

    # Start by creating sliders
    vboxes = []
    for spec in all_spectrums:
        # Parse spectrum name
        s1, s2 = spec.split("--")

        # Create slider
        slider = widgets.FloatSlider(
            value=50,  # Initial value
            min=0,  # Minimum value
            max=100,  # Maximum value
            step=1,
            description="",
            disabled=False,
            continuous_update=True,
            orientation="horizontal",
            readout=False,  # Hide the numeric value
            layout=widgets.Layout(width="600px"),  # Adjust the width as needed
        )

        # Add slider to visual
        spacer = widgets.Box(layout=widgets.Layout(width=f"{base_width / 2:.0f}px"))
        vboxes.append(
            HBox(
                [
                    spacer,
                    widgets.Label(
                        s1,
                        layout=widgets.Layout(
                            width="100px",
                            display="flex",
                            justify_content="flex-end",
                        ),
                        style={"font_size": "20px"},
                    ),
                    slider,
                    widgets.Label(s2, style={"font_size": "20px"}),
                ],
                layout={"display": ""},
            )
        )
        # sliders = [row[1] for row in vboxes]

    # Create disclaimers
    disc1 = widgets.Label("Note: This tool is for exploration only and is not exact.")
    disc2 = widgets.Label()

    # Find closest styles
    X = []
    chosen_styles = []
    for name in names:
        curr_x = datasets[name].reindex(columns=all_spectrums).to_numpy().tolist()
        X.append(curr_x)

        usr_pnt = [50, 50, 50, 50]

        chosen_styles.append(get_styles(curr_x, usr_pnt, styles, wgts))

    # Plot chosen styles
    render_thumbnails()


# Generate initial plot
initial_plot()


# Observe function to wrap update
def on_slider_change(change):
    update_thumbnails(
        [hb.children[2].value for hb in vboxes if hb.layout.display == ""], X
    )


# Link sliders to observe function
for hb in vboxes:
    hb.children[2].observe(on_slider_change, names="value")

# Combine everything into a final layout
spacer = widgets.HTML(value='&nbsp;', layout=widgets.Layout(width='50px'))
thumbnail_vert = [
    HBox(thumbnail_widgets[0] + [spacer] + thumbnail_widgets[1]),
    HBox(thumbnail_widgets[2] + [spacer] + thumbnail_widgets[3]),
]

final_layout = widgets.VBox(
    thumbnail_vert + vboxes
)  # , layout=widgets.Layout(align_items="center"))

# Display the final layout
display(final_layout)

VBox(children=(HBox(children=(VBox(children=(Label(value='Transitional', layout=Layout(width='px'), style=Labe…