In [None]:
%load_ext autoreload
%autoreload 2

%cd '..'

In [None]:
import numpy as np

from sentence_transformers import SentenceTransformer

from preprocessing.utils import (
    load_event_comments,
    save_event_comments,
)

from load.utils import save_df_as_parquet


In [None]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("data/logs/comment_embeddings.log"),
        logging.StreamHandler(stream=sys.stdout)
    ]
)

In [None]:
model = SentenceTransformer("all-MiniLM-L6-v2")

model.max_seq_length = 256

logging.info(f"Model: {model}")


In [None]:
# Constants
EVENT_KEY = "mass_shootings"


In [None]:

logging.info(f"Loading comments of event {EVENT_KEY}")

event_comments = load_event_comments(EVENT_KEY)  # .sample(100)

In [None]:
logging.info(f"Computing embeddings for {len(event_comments)} comments...")

embeddings = model.encode(
    event_comments["body_cleaned"].values,  # type: ignore
    show_progress_bar=True,
    normalize_embeddings=True,
    convert_to_numpy=False,
)


In [None]:
event_comments["embedding"] = embeddings

event_comments["embedding"] = event_comments["embedding"].apply(np.array)  # type: ignore


event_comments_emb = event_comments

In [None]:
logging.info(f"Saving embeddings for comments...")

save_event_comments(event_comments_emb, f"{EVENT_KEY}_with_embeddings")


In [None]:
# event_comments_emb = load_event_comments(f"{EVENT_KEY}_with_embeddings")[
#     ["author", "embedding"]
# ]


In [None]:
def normalize(x):
    return x / np.linalg.norm(x)

In [None]:
logging.info("Computing user embeddings...")

user_embeddings = event_comments_emb.groupby(by="author").agg(
    count=("author", lambda x: len(x)),
    mean=("embedding", lambda x: normalize(np.vstack(x).mean(axis=0))),
    max=("embedding", lambda x: normalize(np.vstack(x).max(axis=0))),
)


In [None]:
logging.info(f"Saving {len(user_embeddings)} user embeddings...")

save_df_as_parquet(
    data=user_embeddings,
    target_file=f"{EVENT_KEY}_user_embeddings.parquet",
)

logging.info("Done!")
