In [None]:
import os
import sys

# Adjust import path to import turnout models
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import arviz as az
import pandas as pd
import numpy as np
import pathlib
import plotnine as p9
import salk_turnout_models as tm

## Preprocess data

### Census data

In [None]:
def filter_census_data(df, name, cond):
    new_df = df[cond]
    print(f'{name}: Removed {df['n'].sum() - new_df['n'].sum()} persons')
    return new_df

col_name_map = {
    'Vanuserühm': 'VANUS',
    'Sugu': 'SUGU',
    'Rahvus': 'RAHVUS_3',
    'Maakonna/Tallinna kood': 'CODE',
    'Maakond': 'MK_NIMI',
    'Piirkond': 'Piirkonna_NIMI',
    'Haridustase': 'HARIDUS_4',
    'Arv': 'n',
}

estonia_census_df = pd.read_excel('../data/census_2022.xlsx', sheet_name='Eesti kodanikud', skiprows=5, usecols=col_name_map.keys()).rename(columns=col_name_map)

# Rename categories in 'MK_NIMI'
mk_map = {
    'Tallinn': 'Harju maakond',
}

estonia_census_df['MK_NIMI'] = estonia_census_df['MK_NIMI'].map(lambda x: mk_map[x] if x in mk_map else x)

# Filter out 0-15 year olds and adjust 16-24 year olds to 18-24
vanus_map = { '16-24': '18-24' }
estonia_census_df = filter_census_data(estonia_census_df, 'VANUS != 0-15', estonia_census_df['VANUS'] != '0-15')
estonia_census_df.loc[estonia_census_df['VANUS'] == '16-24', 'n'] = estonia_census_df.loc[estonia_census_df['VANUS'] == '16-24', 'n'].apply(lambda x: x * (9-2)/9).round(0).astype(int)
estonia_census_df['VANUS'] = estonia_census_df['VANUS'].map(lambda x: vanus_map[x] if x in vanus_map else x)

# Filter people with unknown education
estonia_census_df = filter_census_data(estonia_census_df, 'HARIDUS_4 != Haridustase teadmata', estonia_census_df['HARIDUS_4'] != 'Haridustase teadmata')

# Filter people with unknown nationality
estonia_census_df = filter_census_data(estonia_census_df, 'RAHVUS_3 != Rahvus teadmata', estonia_census_df['RAHVUS_3'] != 'Rahvus teadmata')

estonia_census_df.to_csv('../data/census_base.csv', index=False)
print(estonia_census_df['n'].sum())
estonia_census_df.head()

In [None]:
census_base_df = pd.read_csv('../data/census_base.csv')

census_col_names = {
    'age_group': 'VANUS',
    'education': 'HARIDUS_4',
    'gender': 'SUGU',
    'nationality': 'RAHVUS_3',
    'electoral_district': 'Piirkonna_NIMI',
    'unit': 'Piirkonna_NIMI',
    'N': 'n',
}

census_dtype = {
    'age_group': 'category',
    'education': 'category',
    'gender': 'category',
    'nationality': 'category',
    'electoral_district': 'category',
    'unit': 'category',
    'voting_intent': 'category',
    'N': 'int64',
}

census_cats = {
    'age_group': ['18-24', '25-34', '35-44', '45-54', '55-64', '65-74', '75+'],
    'education': ['Basic education', 'Secondary education', 'Higher education'],
    'gender': ['Male', 'Female'],
    'nationality': ['Estonian', 'Other'],
    'electoral_district': ['Haabersti, Põhja-Tallinn ja Kristiine', 'Harju- ja Raplamaa', 'Hiiu-, Lääne- ja Saaremaa',
                           'Ida-Virumaa', 'Järva- ja Viljandimaa', 'Jõgeva- ja Tartumaa', 'Kesklinn, Lasnamäe ja Pirita',
                           'Lääne-Virumaa', 'Mustamäe ja Nõmme', 'Pärnumaa', 'Tartu linn', 'Võru-, Valga- ja Põlvamaa'],
    'unit': ['Haabersti', 'Harjumaa', 'Hiiumaa', 'Ida-Virumaa', 'Järvamaa', 'Jõgevamaa', 'Kesklinn', 'Kristiine',
             'Lasnamäe', 'Lääne-Virumaa', 'Läänemaa', 'Mustamäe', 'Nõmme', 'Pirita', 'Pärnumaa', 'Põhja-Tallinn',
             'Põlvamaa', 'Raplamaa', 'Saaremaa', 'Tartu linn', 'Tartumaa', 'Valgamaa', 'Viljandimaa', 'Võrumaa'],
}

census_map = {
    'education': {
        'Keskharidus või kutseharidus/keskeriharidus keskhariduse baasil': 'Secondary education',
        'Kõrgharidus': 'Higher education',
        'Põhiharidus või madalam': 'Basic education',
    },
    'gender': {
        'Mehed': 'Male',
        'Naised': 'Female',
    },
    'nationality': {
        'Eestlased': 'Estonian',
        'Muud rahvused': 'Other',
    },
    'electoral_district': {
        'Haabersti linnaosa': 'Haabersti, Põhja-Tallinn ja Kristiine',
        'Kesklinna linnaosa': 'Kesklinn, Lasnamäe ja Pirita',
        'Kristiine linnaosa': 'Haabersti, Põhja-Tallinn ja Kristiine',
        'Lasnamäe linnaosa': 'Kesklinn, Lasnamäe ja Pirita',
        'Mustamäe linnaosa': 'Mustamäe ja Nõmme',
        'Nõmme linnaosa': 'Mustamäe ja Nõmme',
        'Pirita linnaosa': 'Kesklinn, Lasnamäe ja Pirita',
        'Tartu linn': 'Tartu linn',
        'Põhja-Tallinna linnaosa': 'Haabersti, Põhja-Tallinn ja Kristiine',
        'Harju maakond': 'Harju- ja Raplamaa',
        'Lääne-Viru maakond': 'Lääne-Virumaa',
        'Ida-Viru maakond': 'Ida-Virumaa',
        'Pärnu maakond': 'Pärnumaa',
        'Lääne maakond': 'Hiiu-, Lääne- ja Saaremaa',
        'Hiiu maakond': 'Hiiu-, Lääne- ja Saaremaa',
        'Saare maakond': 'Hiiu-, Lääne- ja Saaremaa',
        'Jõgeva maakond': 'Jõgeva- ja Tartumaa',
        'Järva maakond': 'Järva- ja Viljandimaa',
        'Viljandi maakond': 'Järva- ja Viljandimaa',
        'Rapla maakond': 'Harju- ja Raplamaa',
        'Tartu maakond': 'Jõgeva- ja Tartumaa',
        'Võru maakond': 'Võru-, Valga- ja Põlvamaa',
        'Valga maakond': 'Võru-, Valga- ja Põlvamaa',
        'Põlva maakond': 'Võru-, Valga- ja Põlvamaa',
    },
    'unit': {
        'Haabersti linnaosa': 'Haabersti',
        'Kesklinna linnaosa': 'Kesklinn',
        'Kristiine linnaosa': 'Kristiine',
        'Lasnamäe linnaosa': 'Lasnamäe',
        'Mustamäe linnaosa': 'Mustamäe',
        'Nõmme linnaosa': 'Nõmme',
        'Pirita linnaosa': 'Pirita',
        'Tartu linn': 'Tartu linn',
        'Põhja-Tallinna linnaosa': 'Põhja-Tallinn',
        'Harju maakond': 'Harjumaa',
        'Harju maakond, v.a Tallinn': 'Harjumaa',
        'Lääne-Viru maakond': 'Lääne-Virumaa',
        'Ida-Viru maakond': 'Ida-Virumaa',
        'Pärnu maakond': 'Pärnumaa',
        'Lääne maakond': 'Läänemaa',
        'Hiiu maakond': 'Hiiumaa',
        'Saare maakond': 'Saaremaa',
        'Jõgeva maakond': 'Jõgevamaa',
        'Järva maakond': 'Järvamaa',
        'Viljandi maakond': 'Viljandimaa',
        'Rapla maakond': 'Raplamaa',
        'Tartu maakond': 'Tartumaa',
        'Tartu maakond, v.a Tartu linn': 'Tartumaa',
        'Võru maakond': 'Võrumaa',
        'Valga maakond': 'Valgamaa',
        'Põlva maakond': 'Põlvamaa',
    },
}

census_data = {}
for k, v in census_col_names.items():
    data = census_base_df[v].copy()

    if k in census_map:
        data = data.map(census_map[k])

    if k in census_dtype:
        if census_dtype[k] == 'category' and k in census_cats:
            data = pd.Categorical(data, categories=census_cats[k])
        else:
            data = data.astype(census_dtype[k])

    census_data[k] = data

estonia_census_df = pd.DataFrame(census_data)
estonia_census_df.to_csv('../data/census.csv', index=False)
print(estonia_census_df.shape, estonia_census_df['N'].sum())
assert estonia_census_df.duplicated(subset=['age_group', 'gender', 'nationality', 'education', 'electoral_district', 'unit'], keep=False).sum() == 0
estonia_census_df.head()

In [None]:
estonia_census_df.groupby('age_group')['N'].sum()

In [None]:
estonia_census_df.groupby('gender')['N'].sum()

In [None]:
estonia_census_df.groupby('unit')['N'].sum()

In [None]:
tmp_df = estonia_census_df.groupby(['age_group', 'gender', 'nationality', 'education', 'unit'])['N'].sum().reset_index()
tmp_df = tmp_df[tmp_df.N > 0]
tmp_df.N.agg(['min', 'max'])

### Survey data

In [None]:
rk2023_survey_dtype = {
    'methods': 'category',
    'wave': 'int64',
    'age_group': 'category',
    'education': 'category',
    'gender': 'category',
    'nationality': 'category',
    'electoral_district': 'category',
    'unit': 'category',
    'voting_intent': 'category',
}
rk2023_survey_base_df = pd.read_csv('../data/estonia_turnout/survey_base.csv', dtype=rk2023_survey_dtype, parse_dates=['date'])
rk2023_survey_base_df.head()

In [None]:
rk2023_survey_base_df.date.min(), rk2023_survey_base_df.date.max()

In [None]:
rk2023_survey_base_df.wave.value_counts()

In [None]:
rk2023_survey_base_df.methods.value_counts()

In [None]:
rk2023_survey_base_df.voting_intent.value_counts()/len(rk2023_survey_base_df)

In [None]:
survey_cols = ['age_group', 'education', 'gender', 'nationality', 'electoral_district', 'unit', 'voting_intent']
rk2023_survey_base_df[survey_cols].to_csv('../data/estonia_turnout/survey.csv', index=False)

In [None]:
target_column = 'age_group'



In [None]:
rk2023_survey_df = pd.read_csv('../data/estonia_turnout/survey.csv', dtype=rk2023_survey_dtype)
rk2023_survey_df.head()

### Margins data

In [None]:
def turnout_to_margins(infile, outfile, census_df, category_col, category_map=None):
    turnout_df = pd.read_csv(infile)

    if category_map is not None: turnout_df[category_col] = turnout_df[category_col].map(category_map)

    turnout_categories = turnout_df[category_col].unique().tolist()
    census_categories = census_df[category_col].unique().tolist()
    missing_categories = set(census_categories) - set(turnout_categories)
    if len(missing_categories) > 0:
        print(f'Warning: {len(missing_categories)} categories in census not present in turnout data: {missing_categories}')
        print('Turnout categories:', turnout_categories)
        print('Census categories:', census_categories)

    turnout_df = pd.merge(turnout_df, census_df.groupby(category_col, observed=False)['N'].sum().rename('total'), on=category_col, how='left')
    turnout_df['Yes'] = (turnout_df['turnout'] * turnout_df['total']).round(0).astype(int)
    turnout_df['No'] = ((1 - turnout_df['turnout']) * turnout_df['total']).round(0).astype(int)
    margin_df = pd.melt(turnout_df, id_vars=category_col, value_vars=['Yes', 'No'], var_name='voting_intent', value_name='N').sort_values(by=[category_col] + ['voting_intent'])
    margin_df.to_csv(outfile, index=False)

    assert margin_df['N'].sum() == census_df['N'].sum(), f'{margin_df["N"].sum()} != {census_df["N"].sum()}'

    return margin_df

In [None]:
unit_map = {
    'Haabersti linnaosa': 'Haabersti',
    'Kristiine linnaosa': 'Kristiine',
    'Põhja-Tallinna linnaosa': 'Põhja-Tallinn',
    'Kesklinna linnaosa': 'Kesklinn',
    'Lasnamäe linnaosa': 'Lasnamäe',
    'Pirita linnaosa': 'Pirita',
    'Mustamäe linnaosa': 'Mustamäe',
    'Nõmme linnaosa': 'Nõmme',
    'Harju maakond': 'Harjumaa',
    'Rapla maakond': 'Raplamaa',
    'Hiiu maakond': 'Hiiumaa',
    'Lääne maakond': 'Läänemaa',
    'Saare maakond': 'Saaremaa',
    'Lääne-Viru maakond': 'Lääne-Virumaa',
    'Ida-Viru maakond': 'Ida-Virumaa',
    'Järva maakond': 'Järvamaa',
    'Viljandi maakond': 'Viljandimaa',
    'Jõgeva maakond': 'Jõgevamaa',
    'Tartu maakond': 'Tartumaa',
    'Tartu linn': 'Tartu linn',
    'Põlva maakond': 'Põlvamaa',
    'Valga maakond': 'Valgamaa',
    'Võru maakond': 'Võrumaa',
    'Pärnu maakond': 'Pärnumaa',
}

In [None]:
rk2023_voters_df = pd.read_csv('../data/estonia_turnout/rk2023.csv.gz', dtype={'value': int})
rk2023_unit_voters_df = rk2023_voters_df.groupby('county')['value'].sum().reset_index().rename(columns={'county': 'unit', 'value': 'voters'})
#rk2023_unit_voters_df

In [None]:
unit_turnout_df = turnout_to_margins('../data/estonia_turnout/rk2023_unit_turnout.csv', '../data/estonia_turnout/rk2023_unit_margins.csv', estonia_census_df, 'unit', unit_map)
#unit_turnout_df

In [None]:
unit_turnout_df[unit_turnout_df['voting_intent'] == 'Yes']['N'].sum()/unit_turnout_df['N'].sum()

In [None]:
turnout_to_margins('../data/estonia_turnout/rk2023_age_group_turnout.csv', '../data/estonia_turnout/rk2023_age_group_margins.csv', estonia_census_df, 'age_group')

In [None]:
turnout_to_margins('../data/estonia_turnout/rk2023_gender_turnout.csv', '../data/estonia_turnout/rk2023_gender_margins.csv', estonia_census_df, 'gender')

## Run turnout models

In [None]:
model_path_prefix = '../tmp/estonia_turnout'
sample_kwargs = {'chains': 4, 'tune': 1000, 'draws': 500}

input_vars = ['age_group', 'education', 'gender', 'nationality', 'unit']
interactions = True

In [None]:
def run_model(config, sample_kwargs, model_path_prefix):
    model_path = pathlib.Path(model_path_prefix) / config['name']
    draws_path = model_path / 'draws.parquet'

    if draws_path.exists():
        model_df = pd.read_parquet(draws_path)
    else:
        result = tm.run_model(config, sample_kwargs=sample_kwargs, save_path=model_path)
        model_df = result['draws']

        summary_df = az.summary(result['idata'])
        mean_rhat = summary_df.r_hat.mean().item()
        print('Mean R-hat:', mean_rhat)

    return model_df

In [None]:
bp_config = {
    'name': '1_bp',
    'model_type': 'BP',
    'outcome_col': 'voting_intent',
    'input_cols': input_vars,
    'interactions': interactions,
    'population': '../data/census.csv',
    'survey': '../data/estonia_turnout/survey.csv',
}

bp_model_df = run_model(bp_config, sample_kwargs, model_path_prefix)

In [None]:
ei_config = {
    'name': '2_ei',
    'model_type': 'EI',
    'outcome_col': 'voting_intent',
    'input_cols': input_vars,
    'interactions': interactions,
    'population': '../data/census.csv',
    'margin': '../data/estonia_turnout/rk2023_unit_margins.csv',
}

ei_model_df = run_model(ei_config, sample_kwargs, model_path_prefix)

In [None]:
gg_config = {
    'name': '3_gg',
    'model_type': 'GG',
    'outcome_col': 'voting_intent',
    'input_cols': input_vars,
    'interactions': interactions,
    'population': '../data/census.csv',
    'margin': '../data/estonia_turnout/rk2023_unit_margins.csv',
    'survey': '../data/estonia_turnout/survey.csv',
}

gg_model_df = run_model(gg_config, sample_kwargs, model_path_prefix)

In [None]:
pm_config = {
    'name': '4_pm',
    'model_type': 'PM',
    'outcome_col': 'voting_intent',
    'input_cols': input_vars,
    'interactions': interactions,
    'population': '../data/census.csv',
    'margin': '../data/estonia_turnout/rk2023_unit_margins.csv',
    'survey': '../data/estonia_turnout/survey.csv',
}

pm_model_df = run_model(pm_config, sample_kwargs, model_path_prefix)

In [None]:
fs_config = {
    'name': '5_fs',
    'model_type': 'FS',
    'outcome_col': 'voting_intent',
    'input_cols': input_vars,
    'interactions': interactions,
    'population': '../data/census.csv',
    'margin': '../data/estonia_turnout/rk2023_unit_margins.csv',
    'survey': '../data/estonia_turnout/survey.csv',
    #'imr': True, # Use inverse Mill's ratio approximation to improve sampling,
}

fs_model_df = run_model(fs_config, sample_kwargs, model_path_prefix)

## Results

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_gender_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_age_group_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_unit_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
tmp_figures_prefix = '../tmp/figures'
pathlib.Path(tmp_figures_prefix).mkdir(parents=True, exist_ok=True)

In [None]:
def df_margins(pdf, columns, outcome):
    if columns:
        return (pdf.groupby(columns, observed=False)[outcome].value_counts() / len(pdf)).rename('proportion')
    else:
        return (pdf[outcome].value_counts() / len(pdf)).rename('proportion')

def df_turnout(pdf, columns, outcome):
    if columns:
        groups = pdf.groupby(columns, observed=True)
        return (groups[outcome].value_counts() / groups[outcome].size()).rename('proportion')
    else:
        return (pdf[outcome].value_counts() / len(pdf)).rename('proportion')

def cell_margins(df, cell_cols, margin_cols, outcome):
    avg_df = df.groupby(cell_cols)[['N', 'N_census']].mean().reset_index()

    if margin_cols:
        grouped_df = avg_df.groupby(margin_cols)[['N', 'N_census']].sum().reset_index()
    else:
        grouped_df = pd.DataFrame({'N': avg_df['N'].sum(), 'N_census': avg_df['N_census'].sum()}, index=[0])

    yes_df = grouped_df.copy()
    yes_df['proportion'] = yes_df['N'] / yes_df['N_census'].sum()
    yes_df[outcome] = 'Yes'

    no_df = grouped_df.copy()
    no_df['proportion'] = (no_df['N_census'] - no_df['N']) / no_df['N_census'].sum()
    no_df[outcome] = 'No'

    return pd.concat([yes_df, no_df])[margin_cols + [outcome, 'proportion']].set_index(margin_cols + [outcome])

def cell_turnout(df, cell_cols, margin_cols, outcome):
    avg_df = df.groupby(cell_cols)[['N', 'N_census']].mean().reset_index()

    if margin_cols:
        grouped_df = avg_df.groupby(margin_cols)[['N', 'N_census']].sum().reset_index()
    else:
        grouped_df = pd.DataFrame({'N': avg_df['N'].sum(), 'N_census': avg_df['N_census'].sum()}, index=[0])

    yes_df = grouped_df.copy()
    yes_df['proportion'] = yes_df['N'] / yes_df['N_census']
    yes_df[outcome] = 'Yes'

    no_df = grouped_df.copy()
    no_df['proportion'] = (no_df['N_census'] - no_df['N']) / no_df['N_census']
    no_df[outcome] = 'No'

    return pd.concat([yes_df, no_df])[margin_cols + [outcome, 'proportion']].set_index(margin_cols + [outcome])

def kl_divergence(margins_df, epsilon=1e-10):
    # Clip values to avoid zero division and log(0)
    p = np.clip(margins_df['proportion_pop'].values.flatten(), epsilon, 1)
    q = np.clip(margins_df['proportion_mod'].values.flatten(), epsilon, 1)
    return np.sum(p * np.log(p / q)).item()

def em_distance(margins_df, epsilon=1e-10):
    # Clip values to avoid zero division and log(0)
    p = np.clip(margins_df['proportion_pop'].values.flatten(), epsilon, 1)
    q = np.clip(margins_df['proportion_mod'].values.flatten(), epsilon, 1)
    return np.abs(p - q).sum().item() / 2

def model_margin_stats(model_name, mod_df, cell_cols, outcome_col, pop_margins_map):
    results = []
    
    for margin_col, pop_margins in pop_margins_map.items():
        mod_margins = cell_margins(mod_df, cell_cols, [margin_col], 'voting_intent').reset_index()
        margins_df = pd.merge(pop_margins, mod_margins, on=[margin_col, outcome_col], how='outer', suffixes=('_pop', '_mod')).fillna(0)
        results.append(pd.DataFrame({
            'kld': kl_divergence(margins_df),
            'emd': em_distance(margins_df),
        }, index=pd.MultiIndex.from_arrays([[model_name], [margin_col]], names=['model_name', 'margin_name'])))

    return pd.concat(results)

def model_turnout_stats(model_name, mod_df, cell_cols, turnout_cols, outcome_col):
    results = []

    for turnout_col in turnout_cols:
        mod_turnout = cell_turnout(mod_df, cell_cols, [turnout_col], outcome_col).reset_index()
        mod_turnout = mod_turnout[mod_turnout[outcome_col] == 'Yes'].drop(columns=[outcome_col]).rename(columns={turnout_col: 'category', 'proportion': 'turnout'})
        mod_turnout['margin_name'] = turnout_col
        mod_turnout['model_name'] = model_name
        results.append(mod_turnout)

    return pd.concat(results)

def map_var(df, col, var_map):
    df = df[df[col].isin(var_map.keys())].copy()
    df[col] = pd.Categorical(df[col].map(var_map), var_map.values())
    return df

In [None]:
age_group_margins = pd.read_csv('../data/estonia_turnout/rk2023_age_group_margins.csv').set_index(['age_group', 'voting_intent'])
age_group_margins['proportion'] = age_group_margins['N'] / age_group_margins['N'].sum()
age_group_margins['turnout'] = age_group_margins['N'] / age_group_margins.groupby('age_group')['N'].sum()
age_group_margins.drop(columns=['N'], inplace=True)
gender_margins = pd.read_csv('../data/estonia_turnout/rk2023_gender_margins.csv').set_index(['gender', 'voting_intent'])
gender_margins['proportion'] = gender_margins['N'] / gender_margins['N'].sum()
gender_margins['turnout'] = gender_margins['N'] / gender_margins.groupby('gender')['N'].sum()
gender_margins.drop(columns=['N'], inplace=True)

In [None]:
pop_margins_map = {
    'gender': gender_margins,
    'age_group': age_group_margins,
}

model_stats_df = pd.concat([
    model_margin_stats('bp_model', bp_model_df, input_vars, 'voting_intent', pop_margins_map),
    model_margin_stats('ei_model', ei_model_df, input_vars, 'voting_intent', pop_margins_map),
    model_margin_stats('gg_model', gg_model_df, input_vars, 'voting_intent', pop_margins_map),
    model_margin_stats('pm_model', pm_model_df, input_vars, 'voting_intent', pop_margins_map),
    model_margin_stats('fs_model', fs_model_df, input_vars, 'voting_intent', pop_margins_map), 
]).reset_index()

model_stats_long_df = pd.melt(model_stats_df, id_vars=['model_name', 'margin_name'], value_vars=['kld', 'emd'], var_name='variable', value_name='value')

model_name_map = {
    'bp_model': 'BP model',
    'ei_model': 'EI model',
    'gg_model': 'GG model',
    'pm_model': 'PM model',
    'fs_model': 'FS model',
}

margin_name_map = {
    'age_group': 'Age',
    'gender': 'Gender',
}

variable_map = {
    'kld': 'Margin $D_{KL}$',
    'emd': 'Margin $D_{EM}$',
}

model_stats_long_df = map_var(model_stats_long_df, 'model_name', model_name_map)
model_stats_long_df = map_var(model_stats_long_df, 'margin_name', margin_name_map)
model_stats_long_df = map_var(model_stats_long_df, 'variable', variable_map)

In [None]:
model_color_map = {
    'BP model': '#e41a1c',
    'EI model': '#377eb8',
    'GG model': '#4daf4a',
    'PM model': '#984ea3',
    'FS model': '#ff7f00',
}

model_stats_long_df = model_stats_long_df[model_stats_long_df['variable'] != 'Margin $p_B$']

p = (
    p9.ggplot(model_stats_long_df, p9.aes(x='margin_name', y='value', fill='model_name')) +
    p9.geom_bar(stat='identity', position='dodge') +
    p9.facet_wrap('~variable', scales='free_y') +
    p9.scale_fill_manual(breaks=list(model_color_map.keys()), values=list(model_color_map.values())) +
    p9.theme_minimal() +
    p9.labs(x='Margin group', y='Distance metric', fill='Model') +
    p9.theme(
        axis_text_x=p9.element_text(angle=90),
        figure_size=(5, 5.0*2.0/3),
        dpi=300,
        legend_position='bottom',
        panel_background=p9.element_rect(fill='white',color='white'),
        plot_background=p9.element_rect(fill='white',color='white')
    ) +
    p9.guides(fill=p9.guide_legend(nrow=2))
)

p.save(f'{tmp_figures_prefix}/estonia-turnout-metrics.png')
p

In [None]:
model_stats_long_wo_bp_df = model_stats_long_df[model_stats_long_df['model_name'] != 'BP model']
model_color_wo_bp_map = {k: v for k, v in model_color_map.items() if k != 'BP model'}

p = (
    p9.ggplot(model_stats_long_wo_bp_df, p9.aes(x='margin_name', y='value', fill='model_name')) +
    p9.geom_bar(stat='identity', position='dodge') +
    p9.facet_wrap('~variable', scales='free_y') +
    p9.scale_fill_manual(breaks=list(model_color_wo_bp_map.keys()), values=list(model_color_wo_bp_map.values())) +
    p9.theme_minimal() +
    p9.labs(x='Margin group', y='Distance metric', fill='Model') +
    p9.theme(
        axis_text_x=p9.element_text(angle=90),
        figure_size=(5, 5.0*2.0/3),
        dpi=300,
        legend_position='bottom',
        panel_background=p9.element_rect(fill='white',color='white'),
        plot_background=p9.element_rect(fill='white',color='white')
    ) +
    p9.guides(fill=p9.guide_legend(nrow=1))
)

p.save(f'{tmp_figures_prefix}/estonia-turnout-metrics-no-bp.png')
p

In [None]:
pop_age_group_turnout = pd.read_csv('../data/estonia_turnout/rk2023_age_group_turnout.csv').rename(columns={'age_group': 'category'})
pop_age_group_turnout['margin_name'] = 'age_group'
pop_age_group_turnout['model_name'] = 'population'
pop_gender_turnout = pd.read_csv('../data/estonia_turnout/rk2023_gender_turnout.csv').rename(columns={'gender': 'category'})
pop_gender_turnout['margin_name'] = 'gender'
pop_gender_turnout['model_name'] = 'population'
pop_unit_turnout = pd.read_csv('../data/estonia_turnout/rk2023_unit_turnout.csv').rename(columns={'unit': 'category'}).drop(columns=['total_voters_rk2023', 'total_voters_census', 'voters_rk2023',	'voters_adj'])
pop_unit_turnout['category'] = pop_unit_turnout['category'].map(unit_map)
pop_unit_turnout['margin_name'] = 'unit'
pop_unit_turnout['model_name'] = 'population'
pop_turnout = pd.concat([pop_age_group_turnout, pop_gender_turnout, pop_unit_turnout]).reset_index(drop=True)

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_unit_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_age_group_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
tmp = pd.read_csv('../data/estonia_turnout/rk2023_gender_margins.csv')
tmp.groupby('voting_intent')['N'].sum()

In [None]:
turnout_cols = ['age_group', 'gender', 'unit']

turnout_df = pd.concat([
    pop_turnout,
    model_turnout_stats('bp_model', bp_model_df, input_vars, turnout_cols, 'voting_intent'),
    model_turnout_stats('ei_model', ei_model_df, input_vars, turnout_cols, 'voting_intent'),
    model_turnout_stats('gg_model', gg_model_df, input_vars, turnout_cols, 'voting_intent'),
    model_turnout_stats('pm_model', pm_model_df, input_vars, turnout_cols, 'voting_intent'),
    model_turnout_stats('fs_model', fs_model_df, input_vars, turnout_cols, 'voting_intent'),
]).reset_index(drop=True)

model_map = {
    'bp_model': 'BP model',
    'ei_model': 'EI model',
    'gg_model': 'GG model',
    'pm_model': 'PM model',
    'fs_model': 'FS model',
    'population': 'True value',
}

turnout_df = map_var(turnout_df, 'model_name', model_map)
turnout_df = map_var(turnout_df, 'margin_name', margin_name_map)

cat_codes = {'18-24': 0, '25-34': 1, '35-44': 2, '45-54': 3, '55-64': 4, '65-74': 5, '75+': 6, 'Female': 0, 'Male': 1}
turnout_df['cat_nr'] = turnout_df['category'].map(cat_codes)

In [None]:
cat_model_color_map = model_color_map# | {'True value': 'gray'}

true_value_df = turnout_df[turnout_df['model_name'] == 'True value'].copy()
turnout_wo_true_value_df = turnout_df[turnout_df['model_name'] != 'True value'].copy()

p = (
    p9.ggplot(turnout_wo_true_value_df, p9.aes(x='category', y='turnout', fill='model_name')) +
    p9.geom_bar(stat='identity', position='dodge') +
    p9.facet_wrap('~margin_name', scales='free_x') +
    p9.geom_segment(p9.aes(x='cat_nr+0.55', xend='cat_nr+1.45', y='turnout', yend='turnout', linetype='model_name'), color='black', size=0.5, data=true_value_df) +
    p9.scale_fill_manual(breaks=list(cat_model_color_map.keys()), values=list(cat_model_color_map.values())) +
    p9.scale_linetype_manual(breaks=['True value'], values=['dotted']) +
    p9.theme_minimal() +
    p9.labs(x='Category', y='Within-category turnout proportion', fill='Model', linetype='') +
    p9.theme(
        axis_text_x=p9.element_text(angle=90),
        figure_size=(7, 7.0*2.1/3),
        dpi=300,
        legend_position='bottom',
        panel_background=p9.element_rect(fill='white',color='white'),
        plot_background=p9.element_rect(fill='white',color='white')
    ) +
    p9.guides(fill=p9.guide_legend(nrow=2))
)

p.save(f'{tmp_figures_prefix}/estonia-turnout-categories.png')
p