In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, widgets


data = pd.read_csv("design_style_grid.csv")
styles = data["style"].tolist()
imp = data.important.to_numpy()
spectrum_names = [
    "Modernity: Traditional to Transitional to Modern",
    "Aesthetic: Minimalist to Maximalist",
    "Atmosphere: Rustic to Industrial",
    "Atmosphere2: Informal to Formal",
]

X_df = data.reindex(columns=spectrum_names)
X = X_df.to_numpy()
features = X_df.columns

pc1, pc2 = 0, 1

s_means = X.mean(axis=0)
s_std = X.std(axis=0)
X_norm = (X - 50) / s_std
cov = X_norm.T @ X_norm
eigenvalues, eigenvectors = np.linalg.eigh(cov)
eigenvalues, eigenvectors = np.flip(eigenvalues), np.flip(eigenvectors, axis=1)
# print("Explained Variance:")
# print(
#     f"Principal Component {pc1 + 1}: {100 * eigenvalues[pc1] / eigenvalues.sum():.2f}%"
# )
# print(
#     f"Principal Component {pc2 + 1}:"
#     f" {100 * eigenvalues[pc2] / eigenvalues.sum():.2f}%\n"
# )

k = pc2 + 1
X_trans = X_norm @ eigenvectors[:, :k]

# Add noise for plotting
X_trans = X_trans + np.random.normal(scale=0.1, size=X_trans.shape)

# Initial data
x = [1, 2, 3, 4, 5]
y = [1, 2, 3, 4, 5]
x_special = [2.5]
y_special = [2.5]

# Create a Plotly figure
fig = go.FigureWidget()

# Add scatter plot for the special point
user_point = fig.add_scatter(
    x=[0],
    y=[0],
    mode="markers",
    marker=dict(size=40, color="#008080"),
    text=["Your Style"],
    hoverinfo="text",
)

# Add scatter plot for normal points
colors = ["#a2d2ff", "#ffb5a7", "#bde0fe", "#c3bef0"]
color_sq = [colors[i % len(colors)] for i in range(len(X_trans))]
fig.add_scatter(
    x=X_trans[imp, pc1],
    y=X_trans[imp, pc2],
    mode="markers",
    # marker=dict(size=20, color="#c3bef0"),
    marker=dict(size=20, color=color_sq[: imp.sum()]),
    text=[s for i, s in enumerate(styles) if imp[i]],
    hoverinfo="text",
)
fig.add_scatter(
    x=X_trans[~imp, pc1],
    y=X_trans[~imp, pc2],
    mode="markers",
    # marker=dict(size=10, color="#bde0fe"),
    marker=dict(size=20, color=color_sq[imp.sum() :]),
    text=[s for i, s in enumerate(styles) if not imp[i]],
    hoverinfo="text",
)

# Customize layout
fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    # xaxis=dict(showgrid=False),
    # yaxis=dict(showgrid=False),
    showlegend=False,
    plot_bgcolor="#f5f5f5",  # Transparent background
    paper_bgcolor="#f5f5f5",  # Transparent paper
    xaxis={"visible": False},  # Hide the x-axis
    yaxis={"visible": False},  # Hide the y-axis
    margin={"l": 0, "r": 0, "t": 0, "b": 0},  # Reduce margins
)


# Function to update the plot
def update_point(modern, minimalist, industrial, formal):
    usr_pnt = np.array([modern, minimalist, industrial, formal]).reshape([1, -1])
    usr_pnt = (usr_pnt - 50) / s_std
    usr_pnt = usr_pnt @ eigenvectors[:, :k]
    with fig.batch_update():
        fig.data[0].x = [usr_pnt[0, 0]]
        fig.data[0].y = [usr_pnt[0, 1]]


vboxes = []
desc_width = "150px"
for spec in spectrum_names:
    # Parse spectrum name
    s1 = spec.split(": ")[1].split(" to ")[0]
    s2 = spec.split(": ")[1].split(" to ")[-1]

    # Create slider
    slider = widgets.FloatSlider(
        value=50,  # Initial value
        min=0,  # Minimum value
        max=100,  # Maximum value
        step=1,
        description="",
        disabled=False,
        continuous_update=True,
        orientation="horizontal",
        readout=False,  # Hide the numeric value
        layout=widgets.Layout(width="300px"),  # Adjust the width as needed
    )

    # Add slider to visual
    vboxes.append(
        [
            widgets.Label(
                s1,
                layout=widgets.Layout(
                    width="100px", display="flex", justify_content="flex-end"
                ),
            ),
            slider,
            widgets.Label(s2),
        ]
    )


# Observe function to wrap update
def on_slider_change(change):
    update_point(*[row[1].value for row in vboxes])


# Link sliders to observe function
for row in vboxes:
    row[1].observe(on_slider_change, names="value")

# Initial call

ui = VBox([HBox(row) for row in vboxes])

# Display the interactive plot with sliders
display(fig, ui)

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#008080', 'size': 40},
              'mode': 'markers',
              'text': [Your Style],
              'type': 'scatter',
              'uid': '0d150a2a-00f5-4d06-98bc-633e56676288',
              'x': [0],
              'y': [0]},
             {'hoverinfo': 'text',
              'marker': {'color': ['#a2d2ff', '#ffb5a7', '#bde0fe'], 'size': 20},
              'mode': 'markers',
              'text': [Modern, Traditional, Transitional],
              'type': 'scatter',
              'uid': '945e0713-a7bd-4a29-87e1-caa254264ac9',
              'x': array([-1.33495687,  2.57915828, -0.23210983]),
              'y': array([0.86581112, 0.94024626, 0.84124545])},
             {'hoverinfo': 'text',
              'marker': {'color': [#c3bef0, #a2d2ff, #ffb5a7, #bde0fe, #c3bef0,
                                   #a2d2ff, #ffb5a7, #bde0fe, #c3bef0, #a2d2ff,
                                   #ffb5a7, #bde0f

VBox(children=(HBox(children=(Label(value='Traditional', layout=Layout(display='flex', justify_content='flex-e…