In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np

ds = load_dataset("argilla/FinePersonas-v0.1-clustering-100k")
df = ds['train'].to_pandas()

In [None]:
# Sample 5 random rows to view
df.sample(5)

In [None]:
df["summary_label"] = df["summary_label"].apply(eval)

In [None]:
# count occurrences of each label in summary_label
summary_label_counts = {}
for labels in df['summary_label']:
    for label in labels:
        summary_label_counts[label] = summary_label_counts.get(label, 0) + 1

# Sort the summary_label_counts by value in descending order
summary_label_counts = sorted(summary_label_counts.items(), key=lambda x: x[1], reverse=True)

# Print the sorted summary_label_counts
print(summary_label_counts)


In [None]:
# Compute the average position of first 10 labels

frequent_labels = [label for label, count in summary_label_counts[:10]]

def compute_average_position(df, label):
    positions = df[df['summary_label'].apply(lambda x: label in x)]['projection'].to_list()
    return np.mean(positions, axis=0)

average_positions = {label: compute_average_position(df, label) for label in frequent_labels}

print(average_positions)

In [None]:
# Visualize the 2d embeddings using Plotly
import plotly.express as px
import plotly.graph_objects as go

# Extract the coordinates from the 'projection' column
df["x"] = df["projection"].apply(lambda x: x[0])
df["y"] = df["projection"].apply(lambda x: x[1])

# Create an interactive scatter plot with Plotly
fig = px.scatter(
    df,
    x="x",
    y="y",
    hover_data=["summary_label"],
    opacity=0.5,
    # title="2D Embeddings Visualization",
)

# Customize the layout
fig.update_traces(marker=dict(size=1, line=dict(width=0.5, color="DarkSlateGrey")))
fig.update_layout(
    title_font_size=24,
    xaxis_title="",
    yaxis_title="",
    width=600,
    height=600,
    plot_bgcolor="white",
    paper_bgcolor="white",
    # hovermode='closest',
    font=dict(family="Arial, sans-serif", size=14, color="black"),
)

# Improve axis appearance
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGrey", zeroline=False)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGrey", zeroline=False)

# Add a subtle border
# Remove the border shape
fig.update_layout(
    shapes=[]  # Empty list to remove any shapes
)
# Remove x and y ticks
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)

# Remove the grid
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)

# Add text annotations for the average positions
# for label, position in average_positions.items():
#     fig.add_annotation(
#         x=position[0],
#         y=position[1],
#         text=label,
#         showarrow=True,
#         arrowhead=2,
#         font=dict(color="red"),
#     )

# Uncomment to show the plot
fig.show()

In [None]:
# Create a dataframe with the following columns:
# id persona split_id

# Randomly sample 1% of the dataframe for split_id=1, others split_id=0
split_ids = df.sample(frac=0.01, random_state=42).index
df['split_id'] = df.index.isin(split_ids).astype(int)
df_split = df[['id', 'persona', 'split_id']]

# Save the dataframe to a tsv file
data_path = "../data/persona.tsv"
df_split.to_csv(data_path, sep='\t', index=False)


In [None]:
# Verify the number of samples in each split
df_split['split_id'].value_counts()

In [None]:
# Get the high-dim and 2d embeddings
high_dim_embeddings = np.array(df['embedding'].to_list())
low_dim_embeddings = np.array(df['projection'].to_list())

# Verify the shapes
print(high_dim_embeddings.shape)
print(low_dim_embeddings.shape)


In [None]:
# Save the high-dim and 2d embeddings to a numpy file
np.savez("../data/persona.npz", high_dim_embeddings=high_dim_embeddings, low_dim_embeddings=low_dim_embeddings,)