In [None]:
%load_ext autoreload
%autoreload 2

%cd '..'

In [None]:
import numpy as np

from sentence_transformers import SentenceTransformer

from preprocessing.utils import (
    load_event_comments,
    save_event_comments,
    normalize,
)

from load.utils import save_df_as_parquet


In [None]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("data/logs/comment_embeddings.log"),
        logging.StreamHandler(stream=sys.stdout)
    ]
)

In [None]:
model = SentenceTransformer("all-MiniLM-L6-v2")

model.max_seq_length = 256

logging.info(f"Model: {model}")


In [None]:
EVENT_NAMES = [
    ("gun_control", "mass_shootings_gun_control"),
    ("gun_control", "mass_shootings"),
    ("elections", "us_elections_2012"),
    ("elections", "us_elections_2016"),
    ("elections", "us_midterms_2014"),
    ("elections", "us_midterms_2018"),
    ("abortion", "abortion"),
]


In [None]:
for event_theme, event_name in EVENT_NAMES:
    logging.info(f"Loading comments of event {event_name}")

    event_comments = load_event_comments(theme=event_theme, event_name=event_name)  # .sample(100)

    logging.info(f"Computing embeddings for {len(event_comments)} comments...")

    embeddings = model.encode(
        event_comments["body_cleaned"].values,  # type: ignore
        show_progress_bar=True,
        normalize_embeddings=True,
        convert_to_numpy=False,
    )

    event_comments["embedding"] = embeddings

    event_comments["embedding"] = event_comments["embedding"].apply(np.array)  # type: ignore

    event_comments_emb = event_comments

    logging.info(f"Saving embeddings for comments...")

    save_event_comments(event_comments_emb, f"{event_name}_with_embeddings")

    event_comments_emb = load_event_comments(
        theme=event_theme, event_name=f"{event_name}_with_embeddings"
    )[["author", "embedding"]]

    logging.info("Computing user embeddings...")

    user_embeddings = event_comments_emb.groupby(by=["author", "party"]).agg(
        count=("author", lambda x: len(x)),
        sample=("embedding", lambda x: x.sample(1)),
        mean=("embedding", lambda x: normalize(np.vstack(x).mean(axis=0))),
        max=("embedding", lambda x: normalize(np.vstack(x).max(axis=0))),
    ).reset_index()

    logging.info(f"Saving {len(user_embeddings)} user embeddings...")

    save_df_as_parquet(
        data=user_embeddings,
        target_file=f"{event_name}_user_embeddings.parquet",
    )

    logging.info("Done!")
