In [4]:
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import pandas as pd
from ipywidgets import HBox, VBox, widgets
from PIL import Image as PILImage
import io


data = pd.read_csv("design_style_grid2.csv", index_col="style")
styles = data.index.tolist()
all_spectrums = [
    "Industrial--Rustic",
    "Formal--Informal",
    "Traditional--Modern",
    "Maximalist--Minimalist",
]

base_width = 500
base_height = None
fixed_height = f'{base_width * 1.5:.0f}px'


def resize_image_aspect_ratio(image_path, base_width=None, base_height=None):
    with PILImage.open(image_path) as img:
        # If base_width is provided, calculate the height based on the aspect ratio
        if base_width is not None:
            w_percent = base_width / float(img.size[0])
            h_size = int((float(img.size[1]) * float(w_percent)))
            img = img.resize((base_width, h_size), PILImage.LANCZOS)

        # If base_height is provided, calculate the width based on the aspect ratio
        elif base_height is not None:
            h_percent = base_height / float(img.size[1])
            w_size = int((float(img.size[0]) * float(h_percent)))
            img = img.resize((w_size, base_height), PILImage.LANCZOS)

        # Convert to bytes
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format="PNG")
        img_byte_arr = img_byte_arr.getvalue()

    return img_byte_arr


# Helper function to get the image bytes
def load_image_bytes(file_path):
    with open(file_path, "rb") as f:
        return f.read()


def render_thumbnails():
    global thumbnail_widgets

    thumbnail_widgets = []
    for cs in chosen_styles:
        widget = widgets.Image(
            value=resize_image_aspect_ratio(
                f"thumbnails/{cs}.jpg", base_width=base_width, base_height=base_height
            ),
            format="jpg",
        )

        label = widgets.Label(value=cs)

        # Set the label layout to match the image width and center the text
        label.layout = widgets.Layout(width=f"{widget.width}px", text_align="center")

        # Use the style property to set a larger font size
        label.style = {"font_size": "20px"}  # , 'font_weight': 'bold'}

        curr_vbox = VBox(
            [label, widget], layout=widgets.Layout(align_items="center", height=fixed_height)
        )  # align_items centers the VBox content

        thumbnail_widgets.append(curr_vbox)


def update_thumbnails(vals, X):
    global thumbnail_widgets

    # Update chosen styles
    chosen_styles = get_styles(X, vals, styles)

    for widget, cs in zip(thumbnail_widgets, chosen_styles):
        widget.children[0].value = cs
        widget.children[1].value = resize_image_aspect_ratio(
            f"thumbnails/{cs}.jpg", base_width=base_width, base_height=base_height
        )


def get_styles(X, vals, styles):
    errors = np.abs(X - vals)
    err = errors.sum(axis=1)
    err[errors[:, 2] > 25] = 10000

    idx = err.argsort()[:3]
    return [styles[i] for i in idx]


def initial_plot():
    global chosen_styles
    global disc1
    global disc2
    global vboxes
    global X

    # Start by creating sliders
    vboxes = []
    for spec in all_spectrums:
        # Parse spectrum name
        s1, s2 = spec.split("--")

        # Create slider
        slider = widgets.FloatSlider(
            value=50,  # Initial value
            min=0,  # Minimum value
            max=100,  # Maximum value
            step=1,
            description="",
            disabled=False,
            continuous_update=True,
            orientation="horizontal",
            readout=False,  # Hide the numeric value
            layout=widgets.Layout(width="600px"),  # Adjust the width as needed
        )

        # Add slider to visual
        vboxes.append(
            HBox(
                [
                    widgets.Label(
                        s1,
                        layout=widgets.Layout(
                            width="100px",
                            display="flex",
                            justify_content="flex-end",
                        ),
                    ),
                    slider,
                    widgets.Label(s2),
                ],
                layout={"display": ""},
            )
        )
        # sliders = [row[1] for row in vboxes]

    # Create disclaimers
    disc1 = widgets.Label("Note: This tool is for exploration only and is not exact.")
    disc2 = widgets.Label()

    # Find closest styles
    X_df = data.reindex(columns=all_spectrums)
    X = X_df.to_numpy()

    usr_pnt = np.array([50, 50, 50, 50])

    chosen_styles = get_styles(X, usr_pnt, styles)

    # Plot chosen styles
    render_thumbnails()


# Generate initial plot
initial_plot()


# Observe function to wrap update
def on_slider_change(change):
    update_thumbnails(
        [hb.children[1].value for hb in vboxes if hb.layout.display == ""], X
    )


# Link sliders to observe function
for hb in vboxes:
    hb.children[1].observe(on_slider_change, names="value")

# Combine everything into a final layout
thumbnails_hbox = HBox(thumbnail_widgets)
final_layout = widgets.VBox([thumbnails_hbox] + vboxes, layout=widgets.Layout(align_items="center"))

# Display the final layout
display(final_layout)

VBox(children=(HBox(children=(VBox(children=(Label(value='Modern Traditional', layout=Layout(width='px'), styl…