In [None]:
from matplotlib import pyplot as plt
from collections import defaultdict
import seaborn as sns
import pandas as pd
import psycopg2
import datetime
import pickle
import os

%load_ext nb_black

In [None]:
def connect():
    conn = psycopg2.connect(
        user="postgres",
        password=os.environ.get("POSTGRES_PASS", ""),
        host="localhost",
        port=5432,
        database="venmo",
    )
    return conn


def location_to_state(location):
    name = location.raw["display_name"]
    if "United States of America" not in name:
        return None
    a, b, c = (["na"] + name.split(", "))[-3:]
    if b.replace("-", "").replace(":", "").isdigit():
        return a
    return b

In [None]:
with open("user_id_to_loc.pkl", "rb") as f:
    user_id_to_loc_saved = pickle.load(f)
with open("geo_cache.pkl", "rb") as f:
    geo_cache = pickle.load(f)

In [None]:
transactions_by_state = defaultdict(list)

conn = connect()
cur = conn.cursor()

for user_id, (lat, lng, loc) in user_id_to_loc_saved.items():
    state = location_to_state(geo_cache[loc])
    if state is None:
        continue
    cur.execute(
        """
    SELECT 'from', id, message, type, created, actor_user_id, recipient_id FROM transactions 
      WHERE actor_user_id=%s AND created > '2020-01-01'
    UNION ALL
    SELECT 'to', id, message, type, created, actor_user_id, recipient_id FROM transactions 
      WHERE recipient_id=%s AND created > '2020-01-01'
    """,
        (user_id, user_id),
    )
    transactions = cur.fetchall()
    from_user = [
        t[1:]
        for t in transactions
        if t[0] == "from"
    ]
    to_user = [
        t[1:]
        for t in transactions
        if t[0] == "to"
    ]
    transactions_by_state[state].extend(from_user + to_user)

print('Saving...')
with open("transactions_by_state.pkl", "wb") as f:
    pickle.dump(transactions_by_state, f)

conn.close()

In [None]:
COVID_WORDS = [
    "diagnosed",
    "pneumonia",
    "coronavirus",
    "fever",
    "covid",
    "isolating",
    "quarantine",
    "cough",
    "sick",
    "social distancing",
    "self isolat",
    "self-isolat",
]

with open("transactions_by_state.pkl", "rb") as f:
    transactions_by_state_saved = pickle.load(f)

if "13413:13501" in transactions_by_state_saved:
    del transactions_by_state_saved["13413:13501"]

df_by_state_data = {"State": [], "Date": []}
for state, transactions in transactions_by_state_saved.items():
    for id_, msg, type_, created, from_, to in transactions:
        df_by_state_data["State"].append(state)
        df_by_state_data["Date"].append(created.timestamp())
df_by_state = pd.DataFrame(df_by_state_data)

In [None]:
# As expected: big states, more transactions
df_by_state.groupby("State").count()

In [None]:
fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
fig.tight_layout(pad=6.0)

cases_df = pd.read_csv("United_States_COVID-19_Cases_and_Deaths_by_State_over_Time.csv")
cases_df["State"] = cases_df.state
cases_df["Date"] = cases_df.submission_date.apply(
    lambda date: datetime.datetime.strptime(date, "%m/%d/%Y").timestamp()
)
cases_df["Cases"] = cases_df.new_case.rolling(7).mean()
cases_df = cases_df[cases_df.Date > 1.580450e09]


def plot_state(state, state_abbr, use_ax, copy_ax=None):
    sns.histplot(
        df_by_state[df_by_state["State"] == state].dropna(),
        x="Date",
        ax=use_ax,
        fill=False,
        bins=20,
    ).set_title("Venmo transactions in " + state + " (State cases in red)")
    if copy_ax is None:
        use_ax.set_xticklabels(
            [
                datetime.datetime.fromtimestamp(ts).isoformat()[:10]
                for ts in use_ax.get_xticks()
            ]
        )
    else:
        use_ax.set_xticks(copy_ax.get_xticks())
        use_ax.set_xticklabels(copy_ax.get_xticklabels())
    sns.lineplot(
        data=cases_df[cases_df.state == state_abbr],
        x="Date",
        y="Cases",
        ax=use_ax.twinx(),
        color="red",
    )

    # total_cases = cases_df.groupby("Date")[["Cases"]].agg("sum")
    # sns.lineplot(
    #     data=total_cases, x="Date", y="Cases", ax=use_ax.twinx(), color="green"
    # )


plot_state("Ohio", "OH", ax)
plot_state("California", "CA", ax2, copy_ax=ax)
plot_state("Texas", "TX", ax3, copy_ax=ax)
plot_state("New York", "NY", ax4, copy_ax=ax)

# total_cases = cases_df.groupby("Date")[["Cases"]].agg("sum")
# sns.lineplot(data=total_cases, x="Date", y="Cases", ax=ax.twinx(), color="green")