In [None]:
# dependencies

import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import xfx.custom.binomial
import xfx.glm.binomial
import xfx.misc.plot

In [None]:
# helper functions

def package_samples(samples, factor_names, meta):

    rfx_samples, prec_samples = zip(*samples)
    rfx_samples = [np.array(samples_) for samples_ in zip(*rfx_samples)]
    prec_samples = np.array(prec_samples)
    return package_rfx_samples(rfx_samples, ['_const'] + factor_names, meta), package_prec_samples(prec_samples, factor_names, meta)

def package_rfx_samples(rfx_samples, factor_names, meta):

    dfs = []
    for samples_, factor_name in zip(rfx_samples, factor_names):
        df_ = pd.DataFrame(samples_.T)
        df_.index = df_.index.rename('level')
        df_.columns = df_.columns.rename('iter')
        df_['factor'] = factor_name
        for k, v in meta.items():
            df_[k] = v
        dfs.append(df_)
    df = pd.concat(dfs).reset_index().set_index(list(meta.keys()) + ['factor', 'level'])
    return df

def package_prec_samples(prec_samples, factor_names, meta):

    df = pd.DataFrame(prec_samples.T, index=factor_names)
    df['factor'] = factor_names
    for k, v in meta.items():
        df[k] = v
    df = df.set_index(list(meta.keys()) + ['factor'])
    df.columns = df.columns.rename('iter')
    return df

def est_acf(samples, n_lags):

    acf = samples.apply(lambda x: xfx.misc.plot.est_acf(x.values, n_lags), 1, False, 'expand')
    acf.columns = acf.columns.rename('lag')
    return acf

def est_ess(acfs, titer):
    
    df = pd.DataFrame(index=acfs.index)
    df['iat[iter]'] = acfs.apply(lambda x: xfx.misc.plot.est_int_autocor(x.values), 1, False, 'expand').rename('iat')
    df['iat[sec]'] = df['iat[iter]'] * titer
    df['rate[iter]'] = 1 / (2 * df['iat[iter]'])
    df['rate[sec]'] = df['rate[iter]'] / titer
    return df

In [None]:
# config

factor_names = ['province_id', 'activity', 'age', 'education', 'municipality_size', 'voting_recall', 'gender']
response_names = ['conservatives', 'social_democrats']
exclude = ['abstention', 'invalid']
seed = 0
n_samples = 10000

In [None]:
# construct inputs

cis = pd.read_csv('paper/data/cis.csv')
cis = cis.loc[(cis.study_id == '2019-11-10') & (~cis.voting_intention.isin(exclude)) & (~cis.voting_intention.isna())]
cis['response'] = np.where(cis.voting_intention.isin(response_names), cis.voting_intention, '_others')
cis['voting_recall'] = np.where(cis.voting_recall.isin(response_names), cis.voting_recall, '_others')
cis = cis[factor_names + ['response']].dropna()
cis = cis.groupby(factor_names + ['response']).agg(lambda x: len(x)).unstack('response').fillna(0)
codes = cis.index.to_frame().apply(lambda x: x.astype('category').cat.codes).astype(np.int64)

indices = codes.values
response = cis.social_democrats.values
trials = cis.sum(1).values
n_levels = np.max(indices, 0) + 1
rng = np.random.default_rng(seed)
prior_n_tau = np.repeat(len(response_names), len(n_levels))
gibbs_inputs = {'y': response, 'n': trials, 'j': n_levels, 'i': indices, 'prior_est_tau': None, 'prior_n_tau': prior_n_tau}

In [None]:
# sample vanilla

vanilla_sampler = xfx.custom.binomial.sample_posterior(**gibbs_inputs, x=None, collapse=False, ome=rng)
next(vanilla_sampler)
t0 = datetime.datetime.now()
vanilla_samples = [next(vanilla_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
vanilla_titer = (t1 - t0).total_seconds() / n_samples
vanilla_samples = [s[:1] + s[2:3] for s in vanilla_samples]

In [None]:
# sample loccent

loccent_sampler = xfx.glm.binomial.sample_posterior(**gibbs_inputs, ome=rng)
next(loccent_sampler)
t0 = datetime.datetime.now()
loccent_samples = [next(loccent_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
loccent_titer = (t1 - t0).total_seconds() / n_samples

In [None]:
# sample collapsed

collapsed_sampler = xfx.custom.binomial.sample_posterior(**gibbs_inputs, x=None, collapse=True, ome=rng)
next(collapsed_sampler)
t0 = datetime.datetime.now()
collapsed_samples = [next(collapsed_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
collapsed_titer = (t1 - t0).total_seconds() / n_samples
collapsed_samples = [s[:1] + s[2:3] for s in collapsed_samples]

In [None]:
# construct summaries

vanilla_rfx_samples, vanilla_prec_samples = package_samples(vanilla_samples, factor_names, {'algo': 'PG/Van-G'})
loccent_rfx_samples, loccent_prec_samples = package_samples(loccent_samples, factor_names, {'algo': 'LC-MwG'})
collapsed_rfx_samples, collapsed_prec_samples = package_samples(collapsed_samples, factor_names, {'algo': 'PG/Col-G'})

rfx_samples = pd.concat([vanilla_rfx_samples, loccent_rfx_samples, collapsed_rfx_samples])
prec_samples = pd.concat([vanilla_prec_samples, loccent_prec_samples, collapsed_prec_samples])

rfx_acf = est_acf(rfx_samples, 256)
prec_acf = est_acf(prec_samples, 256)

rfx_ess = est_ess(rfx_acf, 1)
prec_ess = est_ess(prec_acf, 1)

In [None]:
acf = pd.concat([rfx_acf, prec_acf]).reset_index().melt(id_vars=['algo', 'factor', 'level'], var_name='lag')
acf['level'] = acf.level.fillna(0)
acf['factor'] = acf.factor.astype('category').cat.codes
acf['group'] = (acf.factor.astype('str') + '-' + acf.level.astype('str')).astype('category').cat.codes
acf['time'] = acf.lag.astype(int) * np.select([acf.algo == 'PG/Van-G', acf.algo == 'LC-MwG'], [vanilla_titer, loccent_titer], collapsed_titer)
ess = pd.concat([rfx_ess, prec_ess]).reset_index()
ess['level'] = ess.level.fillna(0)

f, axes = plt.subplots(1, 2, figsize=(2 * (8/5 + 4/3), 2), gridspec_kw={'width_ratios': [8/5, 4/3]})
g = sns.lineplot(data=acf, x='time', y='value', hue='algo', style='group', hue_order=('LC-MwG', 'PG/Van-G', 'PG/Col-G'), dashes=False, markers=False, legend=False, ci=None, alpha=1/3, lw=1/3, ax=axes[0])
g.set(xlabel='wall time [sec]', ylabel='ACF', xlim=(-.25, 4.25))
g = sns.boxplot(data=ess, y='algo', x='rate[sec]', order=('LC-MwG', 'PG/Van-G', 'PG/Col-G'), linewidth=1, fliersize=1, sym='o', ax=axes[1])
g.set(ylabel='$\\quad$', xlabel='ESS/sec', xscale='log')