In [1]:
import os
import warnings

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

warnings.filterwarnings("ignore")
data_path = "data/persona.tsv"
embedding_path = "data/persona.npz"
df = pd.read_csv(data_path, sep="\t")
array = np.load(embedding_path)["low_dim_embeddings"]
df["x"] = array[:, 0]
df["y"] = array[:, 1]
df["id"] = np.arange(len(df))

test_indices_in_df = df[df["split_id"] == 1]["id"].values[:200]
test_indices_map = list(test_indices_in_df)
print("Number of test points: ", len(test_indices_in_df))

Number of test points:  200


In [2]:
print(test_indices_map)

[87, 114, 149, 226, 270, 278, 573, 575, 578, 585, 655, 690, 821, 833, 856, 866, 883, 928, 1045, 1057, 1221, 1230, 1318, 1322, 1332, 1349, 1402, 1414, 1430, 1550, 1572, 1728, 1761, 1966, 2029, 2104, 2129, 2174, 2180, 2280, 2360, 2495, 2514, 2742, 2937, 3231, 3253, 3393, 3789, 4050, 4151, 4162, 4221, 4232, 4293, 4337, 4367, 4510, 4624, 4645, 4688, 4727, 4790, 5019, 5073, 5169, 5215, 5223, 5324, 5438, 5464, 5904, 5937, 6148, 6175, 6234, 6267, 6729, 6748, 7129, 7144, 7315, 7335, 7677, 7748, 7841, 7858, 7895, 8018, 8028, 8093, 8257, 8285, 8311, 8452, 8525, 8569, 8786, 8866, 8950, 9093, 9161, 9162, 9423, 9494, 9600, 9660, 9887, 9937, 10049, 10501, 10550, 10571, 10618, 10817, 10899, 10980, 11014, 11040, 11060, 11141, 11147, 11317, 11372, 11405, 11446, 11468, 11650, 11712, 11855, 11863, 11895, 11930, 11968, 12042, 12085, 12250, 12256, 12387, 12396, 12469, 12533, 12560, 12750, 12896, 13121, 13343, 13413, 13551, 13579, 13705, 13735, 13797, 13799, 13914, 13924, 13928, 13935, 14038, 14106, 14149, 

In [5]:
# Generate a scatter plot of the embedding around some query point given id
def show(query_id, radius=0.1):
    # Get the embedding for the query point
    original_query_id = query_id
    query_id = test_indices_map[query_id]
    query_embedding = df.loc[query_id, ["x", "y"]].values.astype(np.float32)

    # Calculate the distance from each point to the query point
    df["distance"] = np.linalg.norm(df[["x", "y"]].values - query_embedding, axis=1)

    # Filter points within the specified radius, excluding the query point itself
    df_filtered = df[(df["distance"] <= radius) & (df["distance"] > 0)]
    df_filtered["x_relative"] = df_filtered["x"] - query_embedding[0]
    df_filtered["y_relative"] = df_filtered["y"] - query_embedding[1]
    df_filtered["distance"] = df_filtered["distance"] / radius

    # Insert <br> into persona for every 10 words
    def insert_br_into_persona(persona):
        words = persona.split(" ")
        for i in range(10, len(words), 10):
            words[i] = words[i] + "<br>"
        return " ".join(words)

    df_filtered["persona"] = df_filtered["persona"].apply(insert_br_into_persona)

    # Create a scatter plot
    fig = px.scatter(
        df_filtered,
        x="x_relative",
        y="y_relative",
        # color="distance",
        title=f"Local neighborhood around Query Point {query_id} (Test Point #{original_query_id})",
        width=1000,
        height=500,
        hover_data={"persona": True, "x_relative": False, "y_relative": False, "distance": False},
    )
    fig.update_xaxes(title="x", range=[-radius, radius])
    fig.update_yaxes(title="y", range=[-radius, radius])
    fig.update_traces(marker=dict(size=10, color="gray"))

    # Show texts for the five closest points
    closest_points = df_filtered.sort_values(by="distance", ascending=True)[:5]
    for i, row in closest_points.iterrows():
        fig.add_trace(
            go.Scatter(
                x=[row["x_relative"]],
                y=[row["y_relative"]],
                mode="markers",
                name=row["persona"],
                hoverinfo="text",
                hovertext=row["persona"],
                showlegend=True,
            )
        )
    fig.update_traces(marker=dict(size=10))


    # Add a dot at the origin
    fig.add_trace(
        go.Scatter(
            x=[0],
            y=[0],
            mode="markers",
            marker=dict(color="black", size=10, symbol="x"),
            name="What could be a suitable persona at this point?",
            hoverinfo="text",
            hovertext="What could be a suitable persona at this point?",
            showlegend=True
        )
    )

    # fig.show()
    fig.write_html(f"visualizations/vis_{original_query_id}.html")

def annotate(query_id: int, persona: str):
    os.makedirs("data/persona_annotations", exist_ok=True)
    with open(f"data/persona_annotations/{query_id}.txt", "w") as f:
        f.write(persona)

In [9]:
for i in range(200):
    show(i)
