In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np

ds = load_dataset("argilla/FinePersonas-v0.1-clustering-100k")
df = ds['train'].to_pandas()

In [None]:
# Sample 5 random rows to view
df.sample(5)

In [None]:
# Visualize the 2d embeddings using Plotly
import plotly.express as px
import plotly.graph_objects as go

# Extract the coordinates from the 'projection' column
df['x'] = df['projection'].apply(lambda x: x[0])
df['y'] = df['projection'].apply(lambda x: x[1])

# Create an interactive scatter plot with Plotly
fig = px.scatter(df, x='x', y='y',
                 hover_data=['persona', 'summary_label'],
                 opacity=0.7, title='2D Embeddings Visualization')

# Customize the layout
fig.update_traces(marker=dict(size=1, line=dict(width=0.5, color='DarkSlateGrey')))
fig.update_layout(
    title_font_size=24,
    xaxis_title='X Coordinate',
    yaxis_title='Y Coordinate',
    width=1200,
    height=1200,
    plot_bgcolor='white',
    paper_bgcolor='white',
    hovermode='closest',
    font=dict(family="Arial, sans-serif", size=14, color="black")
)

# Improve axis appearance
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey', zeroline=False)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey', zeroline=False)

# Add a subtle border
fig.update_layout(
    shapes=[go.layout.Shape(
        type="rect",
        xref="paper", yref="paper",
        x0=0, y0=0, x1=1, y1=1,
        line=dict(color="Black", width=1)
    )]
)

# Uncomment to show the plot
fig.show()

In [None]:
# Create a dataframe with the following columns:
# id persona split_id

# Randomly sample 1% of the dataframe for split_id=1, others split_id=0
split_ids = df.sample(frac=0.01, random_state=42).index
df['split_id'] = df.index.isin(split_ids).astype(int)
df_split = df[['id', 'persona', 'split_id']]

# Save the dataframe to a tsv file
data_path = "../data/persona.tsv"
df_split.to_csv(data_path, sep='\t', index=False)


In [None]:
# Verify the number of samples in each split
df_split['split_id'].value_counts()

In [None]:
# Get the high-dim and 2d embeddings
high_dim_embeddings = np.array(df['embedding'].to_list())
low_dim_embeddings = np.array(df['projection'].to_list())

# Verify the shapes
print(high_dim_embeddings.shape)
print(low_dim_embeddings.shape)


In [None]:
# Save the high-dim and 2d embeddings to a numpy file
np.savez("../data/persona.npz", high_dim_embeddings=high_dim_embeddings, low_dim_embeddings=low_dim_embeddings,)