In [None]:
# dependencies

import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import xfx.mvglm.cmult
import xfx.mvglm.fmult
import xfx.custom.symmetric_multinomial
import xfx.misc.plot

In [None]:
# helper functions

def package_samples(samples, factor_names, response_names, meta):

    rfx_samples, prec_samples = zip(*samples)
    rfx_samples = [np.array(samples_) for samples_ in zip(*rfx_samples)]
    prec_samples = np.trace(prec_samples, axis1=2, axis2=3).T
    return package_rfx_samples(rfx_samples, ['_const'] + factor_names, response_names, meta), package_prec_samples(prec_samples, factor_names, meta)

def package_rfx_samples(rfx_samples, factor_names, response_names, meta):

    dfs = []
    for i, (samples_, factor_name) in enumerate(zip(rfx_samples, factor_names)):
        for j in range(samples_.shape[1]):
            df_ = pd.DataFrame(samples_[:, j].T, index=response_names)
            df_.index = df_.index.rename('response')
            df_.columns = df_.columns.rename('iter')
            for k, v in meta.items():
                df_[k] = v
            df_['factor'] = factor_name
            df_['level'] = j
            dfs.append(df_)
    df = pd.concat(dfs).reset_index().set_index(['algo', 'factor', 'level', 'response'])
    return df

def package_prec_samples(prec_samples, factor_names, meta):

    df = pd.DataFrame(prec_samples, index=factor_names)
    df.index = df.index.rename('factor')
    df.columns = df.columns.rename('iter')
    for k, v in meta.items():
        df[k] = v
    df = df.reset_index().set_index(['algo', 'factor'])
    return df

def est_acf(samples, n_lags):

    acf = samples.apply(lambda x: xfx.misc.plot.est_acf(x.values, n_lags), 1, False, 'expand')
    acf.columns = acf.columns.rename('lag')
    return acf

def est_ess(acfs, titer):
    
    df = pd.DataFrame(index=acfs.index)
    df['iat[iter]'] = acfs.apply(lambda x: xfx.misc.plot.est_int_autocor(x.values), 1, False, 'expand').rename('iat')
    df['iat[sec]'] = df['iat[iter]'] * titer
    df['rate[iter]'] = 1 / (2 * df['iat[iter]'])
    df['rate[sec]'] = df['rate[iter]'] / titer
    return df

In [None]:
# config

factor_names = ['province_id', 'activity', 'age', 'education', 'municipality_size', 'voting_recall', 'gender']
response_names = ['conservatives', 'social_democrats']
exclude = ['abstention', 'invalid']
seed = 0
n_samples = 10000

In [None]:
# construct inputs

cis = pd.read_csv('paper/data/cis.csv')
cis = cis.loc[(cis.study_id == '2019-11-10') & (~cis.voting_intention.isin(exclude)) & (~cis.voting_intention.isna())]
cis['response'] = np.where(cis.voting_intention.isin(response_names), cis.voting_intention, '_others')
cis['voting_recall'] = np.where(cis.voting_recall.isin(response_names), cis.voting_recall, '_others')
cis = cis[factor_names + ['response']].dropna()
cis = cis.groupby(factor_names + ['response']).agg(lambda x: len(x)).unstack('response').fillna(0)
codes = cis.index.to_frame().apply(lambda x: x.astype('category').cat.codes).astype(np.int64)

indices = codes.values
response = cis.values
trials = cis.sum(1).values
n_levels = np.max(indices, 0) + 1
rng = np.random.default_rng(seed)

inputs = {'y': response, 'j': n_levels, 'i': indices}

In [None]:
# sample constrained

constlc_sampler = xfx.mvglm.cmult.sample_posterior(**inputs, ome=rng)
next(constlc_sampler)
t0 = datetime.datetime.now()
constlc_samples = [next(constlc_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
constlc_titer = (t1 - t0).total_seconds() / n_samples

In [None]:
# sample constrained

freelc_sampler = xfx.mvglm.fmult.sample_posterior(**inputs, ome=rng)
next(freelc_sampler)
t0 = datetime.datetime.now()
freelc_samples = [next(freelc_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
freelc_titer = (t1 - t0).total_seconds() / n_samples

In [None]:
# sample unconstrained

freeplc_sampler = xfx.custom.symmetric_multinomial.sample_posterior(**inputs, ome=rng)
next(freeplc_sampler)
t0 = datetime.datetime.now()
freeplc_samples = [next(freeplc_sampler) for _ in range(2 * n_samples)][n_samples:]
t1 = datetime.datetime.now()
freeplc_titer = (t1 - t0).total_seconds() / n_samples

In [None]:
# construct summaries

freelc_rfx_samples, freelc_prec_samples = package_samples(freelc_samples, factor_names, response_names + ['others'], {'algo': 'free/LC-MwG'})
constlc_rfx_samples, constlc_prec_samples = package_samples(constlc_samples, factor_names, response_names, {'algo': 'const/LC-MwG'})
freeplc_rfx_samples, freeplc_prec_samples = package_samples(freeplc_samples, factor_names, response_names + ['others'], {'algo': 'free/PLC-MwG'})

rfx_samples = pd.concat([freelc_rfx_samples, constlc_rfx_samples, freeplc_rfx_samples])
prec_samples = pd.concat([freelc_prec_samples, constlc_prec_samples, freeplc_prec_samples])

rfx_acf = est_acf(rfx_samples, 256)
prec_acf = est_acf(prec_samples, 256)

rfx_ess = est_ess(rfx_acf, 1)
prec_ess = est_ess(prec_acf, 1)

rfx_ess['iat[sec]'] = rfx_ess['iat[iter]'] * np.select([rfx_ess.reset_index().algo == 'free/LC-MwG', rfx_ess.reset_index().algo == 'const/LC-MwG'], [freelc_titer, constlc_titer], freeplc_titer)
rfx_ess['rate[sec]'] = rfx_ess['rate[iter]'] / np.select([rfx_ess.reset_index().algo == 'free/LC-MwG', rfx_ess.reset_index().algo == 'const/LC-MwG'], [freelc_titer, constlc_titer], freeplc_titer)
prec_ess['iat[sec]'] = prec_ess['iat[iter]'] * np.select([prec_ess.reset_index().algo == 'free/LC-MwG', prec_ess.reset_index().algo == 'const/LC-MwG'], [freelc_titer, constlc_titer], freeplc_titer)
prec_ess['rate[sec]'] = prec_ess['rate[iter]'] / np.select([prec_ess.reset_index().algo == 'free/LC-MwG', prec_ess.reset_index().algo == 'const/LC-MwG'], [freelc_titer, constlc_titer], freeplc_titer)

In [None]:
acf = pd.concat([rfx_acf, prec_acf]).reset_index().melt(id_vars=['algo', 'factor', 'level', 'response'], var_name='lag')
acf['level'] = acf.level.fillna(0)
acf['factor'] = acf.factor.astype('category').cat.codes
acf['group'] = (acf.factor.astype('str') + '-' + acf.level.astype('str')).astype('category').cat.codes
acf['time'] = acf.lag.astype(int) * np.select([acf.algo == 'free/LC-MwG', acf.algo == 'const/LC-MwG'], [freelc_titer, constlc_titer], freeplc_titer)
ess = pd.concat([rfx_ess, prec_ess]).reset_index()
ess['level'] = ess.level.fillna(0)

f, axes = plt.subplots(1, 2, figsize=(2 * (8/5 + 4/3), 2), gridspec_kw={'width_ratios': [8/5, 4/3]})
g = sns.lineplot(data=acf, x='time', y='value', hue='algo', style='group', hue_order=('const/LC-MwG', 'free/LC-MwG', 'free/PLC-MwG'), dashes=False, markers=False, legend=False, alpha=1/3, lw=1/3, ax=axes[0])
g.set(xlabel='wall time [sec]', ylabel='ACF', xlim=(-.25, 4.25))
g = sns.boxplot(data=ess, y='algo', x='rate[sec]', order=('const/LC-MwG', 'free/LC-MwG', 'free/PLC-MwG'), linewidth=1, fliersize=1, sym='o', ax=axes[1])
g.set(ylabel='$\\quad$', xlabel='ESS/sec', xscale='log')