In [49]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, widgets


data = pd.read_csv("design_style_grid.csv")
styles = data["style"].tolist()
imp = data.important.to_numpy()
all_spectrums = [
    "Modernity: Traditional to Transitional to Modern",
    "Aesthetic: Minimalist to Maximalist",
    "Atmosphere: Rustic to Industrial",
    "Atmosphere2: Informal to Formal",
    "Culture and Location: Global to Local",
    "Materials: Nature-Inspired to High Tech",
]
starting_spectrums = [
    "Modernity: Traditional to Transitional to Modern",
    "Aesthetic: Minimalist to Maximalist",
    "Atmosphere: Rustic to Industrial",
    "Atmosphere2: Informal to Formal",
]

# Create a Plotly figure
fig = go.FigureWidget()


# Function to update the plot
def update_point(vals):
    usr_pnt = np.array(vals).reshape([1, -1])
    usr_pnt = (usr_pnt - 50) / s_std
    usr_pnt = usr_pnt @ eigenvectors[:, :2]
    with fig.batch_update():
        fig.data[0].x = [usr_pnt[0, 0]]
        fig.data[0].y = [usr_pnt[0, 1]]

    # Stretch plot if outside range
    curr_max = max(max_val, np.abs(usr_pnt).max() + buf)
    fig.update_layout(
        xaxis={"range": [-curr_max, curr_max]},
        yaxis={"range": [-curr_max, curr_max + 2 * buf]},
    )


def initial_plot():
    global eigenvectors
    global max_val
    global buf
    global s_std
    global spectrum_names
    global disc1
    global disc2
    # global sliders
    global vboxes
    spectrum_names = [cb.description for cb in cboxes if cb.value]

    # Start by creating sliders
    vboxes = []
    for spec in all_spectrums:
        # Parse spectrum name
        s1 = spec.split(": ")[1].split(" to ")[0]
        s2 = spec.split(": ")[1].split(" to ")[-1]

        # Create slider
        slider = widgets.FloatSlider(
            value=50,  # Initial value
            min=0,  # Minimum value
            max=100,  # Maximum value
            step=1,
            description="",
            disabled=False,
            continuous_update=True,
            orientation="horizontal",
            readout=False,  # Hide the numeric value
            layout=widgets.Layout(width="300px"),  # Adjust the width as needed
        )

        # Add slider to visual
        vboxes.append(
            HBox(
                [
                    widgets.Label(
                        s1,
                        layout=widgets.Layout(
                            width="100px",
                            display="flex",
                            justify_content="flex-end",
                        ),
                    ),
                    slider,
                    widgets.Label(s2),
                ],
                layout={"display": "" if spec in spectrum_names else "none"},
            )
        )
        # sliders = [row[1] for row in vboxes]

    # Create disclaimers
    disc1 = widgets.Label("Note: This tool is for exploration only and is not exact.")
    disc2 = widgets.Label()

    X_df = data.reindex(columns=spectrum_names)
    X = X_df.to_numpy()

    # s_means = X.mean(axis=0)
    s_std = X.std(axis=0)
    X_norm = (X - 50) / s_std
    if len(spectrum_names) > 2:
        cov = X_norm.T @ X_norm
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
        eigenvalues, eigenvectors = np.flip(eigenvalues), np.flip(eigenvectors, axis=1)
        disc2.value = (
            "Roughly"
            f" {eigenvalues[:2].sum() / eigenvalues.sum():.0%} of the information"
            f" from the {len(spectrum_names)} spectrums is preserved in this"
            " 2D visual."
        )
    elif len(spectrum_names) == 2:
        eigenvectors = np.identity(2)
        eigenvalues = np.ones(2)
        disc2.value = ""
    else:
        raise ValueError("Must have at least 2 spectrum names")

    X_trans = X_norm @ eigenvectors[:, :2]

    # Add noise for plotting
    jit = 0.05
    X_trans = X_trans + np.random.uniform(low=-jit, high=jit, size=X_trans.shape)

    # Add scatter plot for the special point
    user_point = fig.add_scatter(
        x=[0],
        y=[0],
        mode="markers",
        marker=dict(size=40, color="#008080"),
        text=["Your Style"],
        hoverinfo="text",
    )

    # Add scatter plot for normal points
    colors = ["#a2d2ff", "#ffb5a7", "#bde0fe", "#c3bef0"]
    color_sq = [colors[i % len(colors)] for i in range(len(X_trans))]
    fig.add_scatter(
        x=X_trans[imp, 0],
        y=X_trans[imp, 1],
        mode="markers",
        # marker=dict(size=20, color="#c3bef0"),
        marker=dict(size=20, color=color_sq[: imp.sum()]),
        text=[s for i, s in enumerate(styles) if imp[i]],
        hoverinfo="text",
    )
    fig.add_scatter(
        x=X_trans[~imp, 0],
        y=X_trans[~imp, 1],
        mode="markers",
        # marker=dict(size=10, color="#bde0fe"),
        marker=dict(size=20, color=color_sq[imp.sum() :]),
        text=[s for i, s in enumerate(styles) if not imp[i]],
        hoverinfo="text",
    )

    # Get plot bounds
    max_val = np.abs(X_trans).max()
    buf = 0.05 * max_val
    max_val += buf

    # Customize layout
    fig.update_layout(
        autosize=True,
        # width=600,
        # height=600,
        showlegend=False,
        plot_bgcolor="#f5f5f5",  # Transparent background
        paper_bgcolor="#f5f5f5",  # Transparent paper
        xaxis={"visible": False, "range": [-max_val, max_val]},  # Hide the x-axis
        yaxis={
            "visible": False,
            "range": [-max_val, max_val + 2 * buf],
        },  # Hide the y-axis
        margin={"l": 0, "r": 0, "t": 0, "b": 0},  # Reduce margins
        title={
            "text": "Discover Your Style",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
            "font": {
                "family": "Helvetica",
                "size": 35,
            },
        },
    )


def replot(b):
    # Declare use of global variables
    global eigenvectors
    global max_val
    global buf
    global s_std

    spectrum_names = [cb.description for cb in cboxes if cb.value]

    # Update which sliders are visible
    for i, hb in enumerate(vboxes):
        hb.layout.display = "" if all_spectrums[i] in spectrum_names else "none"

    # Redo PCA
    X_df = data.reindex(columns=spectrum_names)
    X = X_df.to_numpy()

    # s_means = X.mean(axis=0)
    s_std = X.std(axis=0)
    X_norm = (X - 50) / s_std
    if len(spectrum_names) > 2:
        cov = X_norm.T @ X_norm
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
        eigenvalues, eigenvectors = np.flip(eigenvalues), np.flip(eigenvectors, axis=1)
        disc2.value = (
            "Roughly"
            f" {eigenvalues[:2].sum() / eigenvalues.sum():.0%} of the information"
            f" from the {len(spectrum_names)} spectrums is preserved in this"
            " 2D visual."
        )
    elif len(spectrum_names) == 2:
        eigenvectors = np.identity(2)
        eigenvalues = np.ones(2)
        disc2.value = ""
    else:
        raise ValueError("Must have at least 2 spectrum names")

    X_trans = X_norm @ eigenvectors[:, :2]

    # Add noise for plotting
    jit = 0.05
    X_trans = X_trans + np.random.uniform(low=-jit, high=jit, size=X_trans.shape)

    # Add scatter plot for the special point
    update_point([hb.children[1].value for hb in vboxes if hb.layout.display == ""])

    # Update scatter plots
    with fig.batch_update():
        fig.data[1].x = X_trans[imp, 0]
        fig.data[1].y = X_trans[imp, 1]
        fig.data[2].x = X_trans[~imp, 0]
        fig.data[2].y = X_trans[~imp, 1]

    # Get plot bounds
    max_val = np.abs(X_trans).max()
    buf = 0.05 * max_val
    max_val += buf

    # Customize layout
    fig.update_layout(
        xaxis={"range": [-max_val, max_val]},  # Hide the x-axis
        yaxis={
            "range": [-max_val, max_val + 2 * buf],
        },
    )


# Create Checkboxes
cboxes = [
    widgets.Checkbox(
        value=opt in starting_spectrums,
        description=opt,
        style={"description_width": "initial"},
    )
    for opt in all_spectrums
]
confirm = widgets.Button(description="Confirm Selection")

# Attach the event handler to the button
confirm.on_click(replot)

# Generate initial plot
initial_plot()


# Observe function to wrap update
def on_slider_change(change):
    update_point([hb.children[1].value for hb in vboxes if hb.layout.display == ""])


# Link sliders to observe function
for hb in vboxes:
    hb.children[1].observe(on_slider_change, names="value")

# Create the widgets
ui = VBox(
    vboxes
    + [VBox(cboxes + [confirm])]
    + [HBox([VBox([widgets.Label(), disc1, disc2])])]
)

# Display the interactive plot with sliders
display(fig, ui)

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#008080', 'size': 40},
              'mode': 'markers',
              'text': [Your Style],
              'type': 'scatter',
              'uid': '6d6fa837-5417-4586-879a-5b1dd25f1310',
              'x': [0],
              'y': [0]},
             {'hoverinfo': 'text',
              'marker': {'color': ['#a2d2ff', '#ffb5a7', '#bde0fe'], 'size': 20},
              'mode': 'markers',
              'text': [Modern, Traditional, Transitional],
              'type': 'scatter',
              'uid': '17ea14c5-fbbe-44bc-9cdc-b23dad8d4872',
              'x': array([-1.36920411,  2.51914291, -0.00945885]),
              'y': array([0.87401409, 1.05856061, 0.83084499])},
             {'hoverinfo': 'text',
              'marker': {'color': [#c3bef0, #a2d2ff, #ffb5a7, #bde0fe, #c3bef0,
                                   #a2d2ff, #ffb5a7, #bde0fe, #c3bef0, #a2d2ff,
                                   #ffb5a7, #bde0f

VBox(children=(HBox(children=(Label(value='Traditional', layout=Layout(display='flex', justify_content='flex-e…