# Trageding embeddings

## Setup

In [1]:
import polars as pl 

df = pl.read_parquet("./tragedy-with-years.parquet")

In [2]:
by_play_grouped_text = df.group_by(
    pl.col("dramatist"), 
    pl.col("title"), 
).agg(pl.col("text"), pl.col("year").first())

In [3]:
sentences_df = by_play_grouped_text.with_columns(joined_text=pl.col("text").list.join(" ")).with_columns(
    sentences=pl.col("joined_text")
    .str.replace("Â·", ".")
    .str.replace(";", ".")
    .str.strip_chars()
    .str.split(".")
)

In [4]:
from sentence_transformers import SentenceTransformer

model_name = "bowphs/SPhilBerta"
model = SentenceTransformer(model_name)

In [5]:
import json

from pathlib import Path

def calculate_embeddings():
    embeddings_file = Path("./tragedy-embeddings.json")

    if embeddings_file.exists():
        with open(embeddings_file) as f:
            return json.load(f)
    
    tragedy_embeddings = []

    for row in sentences_df.iter_rows(named=True):
        sentences = [s for s in row["sentences"] if s.strip() != ""]

        try:
            embeddings = model.encode(sentences)
            tragedy_embeddings.append(
                {"id": f"{row['dramatist']} {row['title']}", "embeddings": embeddings.tolist(), "year": row['year']}
            )
        except IndexError:
            print(f"Index error on {row['dramatist']} {row['title']}")
            continue
    

    with open(embeddings_file, "w") as f:
        json.dump(tragedy_embeddings, f, ensure_ascii=False)

    return tragedy_embeddings

In [6]:
tragedy_embeddings = calculate_embeddings()

In [7]:
import numpy as np

all_embeddings = []

for t in tragedy_embeddings:
    all_embeddings.append({
        "id": t["id"],
        "embeddings": np.asarray(t['embeddings'], dtype=np.float32),
        "year": t["year"]
    })

In [8]:
comparisons = []

sorted_embeddings = sorted(all_embeddings, key=lambda x: x['year'])

for emb in sorted_embeddings:
    for i, _ in enumerate(sorted_embeddings):
        comparisons.append([emb, sorted_embeddings[i]])

## Method 1: Tensor cosine similarity

In [16]:
def embed(sentences):
    return model.encode(sentences)


def calculate_tensor_cos_sim():
    embeddings_df = sentences_df.with_columns(
        embeddings=pl.col("sentences").map_batches(
            embed, return_dtype=pl.Array(pl.Float32, (768,))
        ),
        row_id=pl.concat_str(pl.col("dramatist"), pl.col("title"), separator=" "),
    ).sort(pl.col("year")).remove(pl.col("title") == "Cyclops")

    similarities = model.similarity(
        embeddings_df["embeddings"].to_numpy(), embeddings_df["embeddings"].to_numpy()
    )

    raw: dict[str, list[str | float]] = {
        "play_1": [],
        "play_2": [],
        "similarity": [],
    }

    for row_idx, similarity in enumerate(similarities):
        row = embeddings_df.row(row_idx, named=True)

        for row_idx2, score in enumerate(similarity):
            corresp_row = embeddings_df.row(row_idx2, named=True)

            raw["play_1"].append(row["row_id"])
            raw["play_2"].append(corresp_row["row_id"])
            raw["similarity"].append(float(score))

    return pl.DataFrame(raw)

In [17]:
import altair as alt

plotable = calculate_tensor_cos_sim()

alt.Chart(plotable).mark_rect().encode(
    x=alt.X("play_2:N", sort=None),
    y=alt.Y("play_1:N", sort=None),
    color=alt.Color("similarity:Q").scale(scheme="yelloworangebrown"),
    tooltip=["play_1:N", "play_2:N", "similarity:Q"],
).properties(
    title="Tragedy similarity"
)

## Method 2: Max Means

In [9]:
import torch

def calculate_max_mean_similarities():
    similarities: dict[str, list[str | float]] = {'play_1': [], 'play_2': [], 'similarity': []}

    model.max_seq_length = 256

    for comparison in comparisons:
        similarity_by_sentence = model.similarity(comparison[0]['embeddings'], comparison[1]['embeddings'])

        max_similarities = [row.max() for row in similarity_by_sentence]
        mean_similarity = torch.Tensor(max_similarities).mean()

        similarities['play_1'].append(comparison[0]['id'])
        similarities['play_2'].append(comparison[1]['id'])
        similarities['similarity'].append(float(mean_similarity))

    return similarities

In [10]:
import polars as pl

similarities = calculate_max_mean_similarities()

similarities_df = pl.DataFrame(similarities)

In [11]:
import altair as alt

alt.Chart(similarities_df).mark_rect().encode(
    x=alt.X("play_2:N", sort=None),
    y=alt.Y("play_1:N", sort=None),
    color=alt.Color("similarity:Q").scale(scheme="yelloworangebrown"),
    tooltip=["play_1:N", "play_2:N", "similarity:Q"],
).properties(
    title="Tragedy similarity"
)