In [2]:
import polars as pl

df = pl.read_parquet("./messengers_by_play_and_speaker_grouped_text.parquet")

In [3]:
sentences_df = df.with_columns(joined_text=pl.col("text").list.join(" ")).with_columns(
    sentences=pl.col("joined_text")
    .str.replace("·", ".")
    .str.strip_chars()
    .str.split(".")
)

In [4]:
from sentence_transformers import SentenceTransformer

model_name = "bowphs/SPhilBerta"
model = SentenceTransformer(model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def embed(sentences):
    return model.encode(sentences)


embeddings_df = sentences_df.with_columns(
    embeddings=pl.col("sentences").map_batches(embed, return_dtype=pl.Array(pl.Float32, (768,))),
    row_id=pl.concat_str(
        pl.col("dramatist"), 
        pl.col("title"), 
        pl.col("speaker"), 
        pl.col("n").list.first(),
        separator=" "
    )
).sort(pl.col("year"), pl.col("n").list.first())

In [41]:
similarities = model.similarity(
    embeddings_df["embeddings"].to_numpy(), embeddings_df["embeddings"].to_numpy()
)


def to_comparison_dataframe():
    rows = {}

    rows['title'] = embeddings_df['row_id'].to_list()

    for row_idx, similarity in enumerate(similarities):
        row = embeddings_df.row(row_idx, named=True)

        rows[row['row_id']] = similarity

    comparisons = pl.DataFrame(rows)

    return comparisons

def to_plotable_comparison():
    raw: dict[str, list[str | float]] = {'speech_1': [], 'speech_2': [], 'similarity': []}

    for row_idx, similarity in enumerate(similarities):
        row = embeddings_df.row(row_idx, named=True)

        for row_idx2, score in enumerate(similarity):
            corresp_row = embeddings_df.row(row_idx2, named=True)

            raw['speech_1'].append(row['row_id'])
            raw['speech_2'].append(corresp_row['row_id'])
            raw['similarity'].append(float(score))

    return pl.DataFrame(raw)


In [50]:
import altair as alt

plotable = to_plotable_comparison()

alt.Chart(plotable).mark_rect().encode(
    x=alt.X("speech_2:N", sort=None),
    y=alt.Y("speech_1:N", sort=None),
    color="similarity:Q",
    tooltip=["speech_1:N", "speech_2:N", "similarity:Q"],
).properties(
    title="Messenger speech similarity"
)

In [49]:
alt.Chart(plotable).mark_rect().encode(
    x=alt.X("speech_2:N"),
    y=alt.Y("speech_1:N"),
    color="similarity:Q",
    tooltip=["speech_1:N", "speech_2:N", "similarity:Q"],
).properties(
    title="Messenger speech similarity, alphabetical order"
)