In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
import json

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from utils import load_data
from constants import DATA_DIR, EVENTS_DIR, DEM_SUBREDDITS, REP_SUBREDDITS, PARTISAN_SUBREDDITS

In [None]:
dev = False

### User Affiliation

In [None]:
reddit = load_data(DATA_DIR, year=2016, tokenize=False, comp="parquet", dev=False)

In [None]:
if dev:
    reduction_frac = 0.15
    print(f'Reduce data to fraction {reduction_frac}')

    reddit = reddit.sample(frac=reduction_frac)

In [None]:
users = defaultdict(list)

for index, row in reddit[["author", "subreddit"]].iterrows():
    if row['subreddit'] in PARTISAN_SUBREDDITS:
        users[row['author']].append(row['subreddit'])

In [None]:
users_n_posts_df = pd.DataFrame([[user, len(users[user])] for user in users.keys()],
                                columns = ['user', 'n_posts'])
users_n_posts_df.to_csv(f'{DATA_DIR}/users_n_posts.csv')

In [None]:
def calculate_partisan_score(subreddits_list):
    score = 0
    for subreddit in subreddits_list:
        if subreddit in DEM_SUBREDDITS:
            score += 1
        elif subreddit in REP_SUBREDDITS:
            score -= 1
    return score


def is_rep_or_dem(score):
    if score >= 1:
        return 'D'
    elif score <= -1:
        return 'R'
    return 'N'

In [None]:
user_affiliation = {}
users_affiliation_data = []

for key in users.keys():
    partisan_score = calculate_partisan_score(users[key])
    if abs(partisan_score) > 5:
        users_affiliation_data.append([key, partisan_score, is_rep_or_dem(partisan_score)])
        user_affiliation[key] = is_rep_or_dem(partisan_score)

users_affiliation_df = pd.DataFrame(users_affiliation_data, columns=['user', 'score', 'leaning'])

In [None]:
users_affiliation_df.to_csv(f'{EVENTS_DIR}/brexit_user_affiliation.csv')

with open(f"{DATA_DIR}/user_affiliation.json", "w") as outfile:
    json.dump(user_affiliation, outfile)

In [None]:
users_affiliation_df['score'].hist(bins=np.linspace(-250, 250, 500))
# plt.semilogx()
plt.xlim(-250,250)
plt.show()