In [None]:
%load_ext autoreload
%autoreload 2

%cd '..'

In [None]:
import json

import numpy as np
from tqdm import tqdm
tqdm.pandas()
import seaborn as sns
sns.set_theme()
import matplotlib.pyplot as plt

from word_partisanship.utils import (
    logodds_with_prior,
)
from preprocessing.utils import (
    split_by_party,
    load_event_comments,
    build_term_vector,
    load_event_vocab,
)
from preprocessing.constants import OUTPUT_DIR


In [None]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("data/logs/word_partisanship.log"),
        logging.StreamHandler(stream=sys.stdout)
    ]
)

In [None]:
EVENT_NAMES = [
    ("gun_control", "mass_shootings_gun_control"),
    ("gun_control", "mass_shootings"),
    ("elections", "us_elections_2012"),
    ("elections", "us_elections_2016"),
    ("elections", "us_midterms_2014"),
    ("elections", "us_midterms_2018"),
    ("abortion", "abortion"),
]

N_SAVE = 50
N_DISPLAY = 20

NGRAM_RANGE = (1, 2)

In [None]:
for theme, event_name in EVENT_NAMES:
    # Read event data
    logging.info(f"Loading event data {event_name}...")
    event_comments = load_event_comments(theme=theme, event_name=event_name)
    event_vocab = load_event_vocab(theme=theme, event_name=event_name)

    logging.info("Vocab length:", len(event_vocab))

    # sort vocabulary tokens by index
    ind_to_token = {v: k for k, v in event_vocab.items()}

    dem_comments, rep_comments = split_by_party(
        comments=event_comments,
    )

    logging.info(dem_comments.shape)
    logging.info(rep_comments.shape)

    tokens = event_comments["tokens"]
    dem_tokens = dem_comments["tokens"]
    rep_tokens = rep_comments["tokens"]

    logging.info("Building overall term vector...")
    overall_term_vec = build_term_vector(
        tokens, ngram_range=NGRAM_RANGE, vocab=event_vocab
    )

    logging.info("Building dem term vector...")
    dem_term_vec = build_term_vector(
        dem_tokens, ngram_range=NGRAM_RANGE, vocab=event_vocab
    )
    logging.info("Building rep term vector...")
    rep_term_vec = build_term_vector(
        rep_tokens, ngram_range=NGRAM_RANGE, vocab=event_vocab
    )

    logging.info("Calculating loggodds...")
    logodds = logodds_with_prior(
        overall_term_vec,
        rep_term_vec,
        dem_term_vec,
        zscore=True,
    )

    # Get top token indices
    sorted_logodds_indices = np.argsort(logodds)

    idiosyncratic_tokens = {}

    logging.info("Republican tokens")
    rep_idiosyncratic_tokens = [
        (ind_to_token[index], logodds[index])
        for index in sorted_logodds_indices[-N_SAVE:]
    ]
    logging.info(rep_idiosyncratic_tokens)

    idiosyncratic_tokens["rep"] = rep_idiosyncratic_tokens

    logging.info("Democrat tokens")
    dem_idiosyncratic_tokens = [
        (ind_to_token[index], logodds[index])
        for index in reversed(sorted_logodds_indices[:N_SAVE])
    ]
    logging.info(dem_idiosyncratic_tokens)

    idiosyncratic_tokens["dem"] = dem_idiosyncratic_tokens

    with open(f"{OUTPUT_DIR}/{event_name}_idiosyncratic_tokens.json", "w") as f:
        json.dump(idiosyncratic_tokens, f)

    # # Load idiosyncratic tokens

    # with open(f"{OUTPUT_DIR}/{event_name}_idiosyncratic_tokens.json", "r") as f:
    #     idiosyncratic_tokens = json.load(f)

    # plot barplot for top tokens

    plt.figure(figsize=(5, 12))

    sns.barplot(
        orient="h",
        x=(
            [logodds for _, logodds in idiosyncratic_tokens["rep"][-N_DISPLAY:][::-1]]
            + [logodds for _, logodds in idiosyncratic_tokens["dem"][-N_DISPLAY:]]
        ),
        y=[i for i in range(2 * N_DISPLAY)],
        palette=["#E81B23"] * N_DISPLAY + ["#00AEF3"] * N_DISPLAY,
    )
    # add token to yticks
    plt.yticks(
        ticks=[i for i in range(2 * N_DISPLAY)],
        labels=[token for token, _ in idiosyncratic_tokens["rep"][-N_DISPLAY:][::-1]]
        + [token for token, _ in idiosyncratic_tokens["dem"][-N_DISPLAY:]],
    )

    # add xlabel
    plt.xlabel("Weighted log-odds ratio")

    plt.savefig(
        fname=f"data/figures/wp/{event_name}_idiosyncratic_tokens.pdf",
        bbox_inches="tight",
        pad_inches=0,
    )

    plt.show()
