# Plots
> Common plots used in the dashboards

In [1]:
#| default_exp plots

In [2]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [3]:
#| exporti
import json, os, math
import itertools as it
from collections import defaultdict

import numpy as np
import pandas as pd
import datetime as dt

from typing import List, Tuple, Dict, Union, Optional

import altair as alt
import scipy as sp
import scipy.stats as sps
from scipy.cluster import hierarchy
from KDEpy import FFTKDE
import arviz as az

from salk_toolkit.utils import *
from salk_toolkit.io import extract_column_meta, read_json
from salk_toolkit.pp import registry, registry_meta, e2e_plot, stk_plot

from matplotlib import font_manager
from PIL import ImageFont

In [4]:
# For tests, set up a data file and import from pp
data_file = '../samples/w25_bootstrap.parquet'

In [5]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

## Utility function to wrap horizontal legends properly

In [6]:
#| exporti

# Find a sensible approximation to the font used in vega/altair
font = font_manager.FontProperties(family='sans-serif', weight='regular')
font_file = font_manager.findfont(font)
legend_font = ImageFont.truetype(font_file,10)

In [7]:
#| export

# Legends are not wrapped, nor is there a good way of doing accurately it in vega/altair
# This attempts to estimate a reasonable value for columns which induces wrapping
def estimate_legend_columns_horiz_naive(cats, width):
    max_str_len = max(map(len,cats))
    n_cols = max(1,width//(15+5*max_str_len))
    # distribute them roughly equally to avoid last row being awkwardly shorter
    n_rows = int(math.ceil(len(cats)/n_cols))
    return int(math.ceil(len(cats)/n_rows))

# More sophisticated version that looks at lengths of individual strings across multiple rows
# ToDo: it should max over each column separately not just look at max(sum(row)). This is close enough though.
def estimate_legend_columns_horiz(cats, width):
    max_cols, restart = len(cats), True
    lens = list(map(lambda s: 15+legend_font.getlength(s),cats))
    while restart:
        restart, rl, cc = False, 0, 0
        for l in lens:
            if cc >= max_cols: # Start a new row
                rl, cc = l, 1
            elif rl + l > width: # Exceed width - restart
                max_cols = cc
                # Start from beginning every thime columns number changes
                # This is because what ends up in second+ rows depends on length of first
                restart = True
            else: # Just append to existing row
                rl += l
                cc += 1
                
    # For very long labels just accept we can't do anything
    max_cols = max(max_cols,1)

    # distribute them roughly equally to avoid last row being awkwardly shorter
    n_rows = int(math.ceil(len(cats)/max_cols))
    return int(math.ceil(len(cats)/n_rows))

In [8]:
assert estimate_legend_columns_horiz(["Keskerakond", "EKRE", "Reformierakond", "Isamaa", "SDE", "Rohelised", "Eesti 200", "Parempoolsed", "Other", "None of the parties", "No opinion"],800) == 11

# Plots

## Data options:
 - data_format: either 'longform' (default) or 'raw' for raw data
 - n_facets: tuple of how many internal facets the plot needs: (minimum, recommended)
 - draws: if it requires draws to be present (such as boxplots)
 - no_question_facet: for this plot question does not make sense as an internal facet (f.e. for density plots)
 - requires: list of dicts describing what we expect from each facet
    - 'likert': True - requires category to be symmetric with the neutral/na value in middle for odd number of levels
    - 'ordered': True - requires category to be ordered
    - \<any kw\>: 'pass' - pass that value from column meta on to the function. useful for geoplot or electoral systems
 - args: dict of kw:type specifying what additional parameters the plot accepts via plot_args
 - priority: number that determines how likely this plot is to be picked as a default
 - group_size: requrests pp to add a column to data with size of each group. Needed for some plots that also represent group size
 - agg_fn: locks the aggregation function for continuous inputs (usually to sum, f.e. election modelling)
 - nonnegative: specifies that the value_col is expected to be non_negative for the plot to work properly
 

In [9]:
#| export
@stk_plot('boxplots', data_format='longform', draws=True, n_facets=(1,2), priority=50)
def boxplots(data, value_col='value', facets=[], val_format='%', width=800, tooltip=[]):
    f0, f1 = facets[0], facets[1] if len(facets)>1 else None

    if val_format[-1] == '%': # Boxplots being a compound plot, this workaround is needed for axis & tooltips to be proper
        data[value_col]*=100
        val_format = val_format[:-1]+'f'
    
    shared = {
        'y': alt.Y(f'{f0["col"]}:N', title=None, sort=f0['order']),

        **({
            'color': alt.Color(f'{f0["col"]}:N', scale=f0['colors'], legend=None)    
            } if not f1 else {
                'yOffset':alt.YOffset(f'{f1["col"]}:N', title=None, sort=f1['order']), 
                'color': alt.Color(f'{f1["col"]}:N', scale=f1['colors'], 
                                   legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1['order'],width)))
            })
    }
    
    base = alt.Chart(round(data, 2))
    
    # This plot is here because boxplot does not draw if variance is very low, so this is the backup
    tick_plot = base.mark_tick(thickness=3).encode(
        x=alt.X(f'mean({value_col}):Q'),
        tooltip=[alt.Tooltip(f'mean({value_col}):Q')] + tooltip[1:],
        **shared
    )
    
    box_plot = base.mark_boxplot(
        clip=True,
        #extent='min-max',
        outliers=False
    ).encode(
        x=alt.X(
            f'{value_col}:Q',
            title=value_col,
            axis=alt.Axis(format=val_format)
            ),
        tooltip=tooltip[1:],
        **shared,
    )
    return tick_plot + box_plot

In [10]:
e2e_plot({
    'res_col' : 'party_preference',
    'factor_cols': ['age_group'],
    'filter': {},
    'plot': 'boxplots',
    'internal_facet': True
},data_file)

In [11]:
#| export
@stk_plot('columns', data_format='longform', draws=False, n_facets=(1,2))
def columns(data, value_col='value', facets=[], val_format='%', width=800, tooltip=[]):
    f0, f1 = facets[0], facets[1] if len(facets)>1 else None
    plot = alt.Chart(round(data, 3), width = 'container' \
    ).mark_bar().encode(
        x=alt.X(
            f'{value_col}:Q',
            title=value_col,
            axis=alt.Axis(format=val_format),
        ),
        y=alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
        tooltip = tooltip,
        **({
                'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=None)    
            } if not f1 else {
                'yOffset':alt.YOffset(f'{f1["col"]}:N', title=None, sort=f1["order"]),
                'color': alt.Color(f'{f1["col"]}:N', scale=f1["colors"],
                                    legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width)))
            }),
    )
    return plot

In [12]:
e2e_plot({
    'res_col' : 'e-valimised',
    'factor_cols': ['nationality'],
    'filter': {},
    'plot': 'columns',
    'internal_facet': True
},data_file)

In [13]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['nationality'],
    'filter': {},
    'plot': 'columns',
    'internal_facet': True
},data_file)

In [14]:
#| export
@stk_plot('stacked_columns', data_format='longform', draws=False, nonnegative=True, n_facets=(2,2), agg_fn='sum', args={'normalized':'bool'})
def stacked_columns(data, value_col='value', facets=[], n_datapoints=1, val_format='%', width=800, normalized=False, tooltip=[]):
    f0, f1 = facets[0], facets[1]
    
    data[value_col] = data[value_col]/n_datapoints
    
    ldict = dict(zip(f1["order"], range(len(f1["order"]))))
    data['f_order'] = data[f1["col"]].astype('object').replace(ldict).astype('int')
    
    plot = alt.Chart(round(data, 3), width = 'container' \
    ).mark_bar().encode(
        x=alt.X(
            f'{value_col}:Q',
            title=value_col,
            axis=alt.Axis(format=val_format),
            **({'stack':'normalize'} if normalized else {})
            #scale=alt.Scale(domain=[0,30]) #see lõikab mõnedes jaotustes parema ääre ära
        ),
        y=alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
        tooltip = tooltip,
        **({
                'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=None)    
            } if len(facets)<=1 else {
                'order': alt.Order('f_order:O'),
                'color': alt.Color(f'{f1["col"]}:N', scale=f1["colors"],
                                    legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width)))
            }),
    )
    return plot

In [15]:
e2e_plot({
  "res_col": "party_preference",
  "factor_cols": ["EKRE"],
  "internal_facet": True,
  "plot": "stacked_columns",
  "plot_args": {
    "normalized": False
  },
  "filter": {}
},data_file)

In [16]:
#| export
@stk_plot('diff_columns', data_format='longform', draws=False, n_facets=(2,2), args={'sort_descending':'bool'})
def diff_columns(data, value_col='value', facets=[], val_format='%', sort_descending=False, tooltip=[]):
    f0, f1 = facets[0], facets[1]
    
    ind_cols = list(set(data.columns)-{value_col,f1["col"]})
    factors = data[f1["col"]].unique() # use unique instead of categories to allow filters to select the two that remain
    
    idf = data.set_index(ind_cols)
    diff = (idf[idf[f1["col"]]==factors[1]][value_col]-idf[idf[f1["col"]]==factors[0]][value_col]).reset_index()
    
    if sort_descending: f0["order"] = list(diff.sort_values(value_col,ascending=False)[f0["col"]])
    
    plot = alt.Chart(round(diff, 3), width = 'container' \
    ).mark_bar().encode(
        y=alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
        x=alt.X(
            f'{value_col}:Q',
            title=f"{factors[1]} - {factors[0]}",
            axis=alt.Axis(format=val_format, title=f"{factors[0]} <> {factors[1]}"),
            #scale=alt.Scale(domain=[0,30]) #see lõikab mõnedes jaotustes parema ääre ära
            ),
        
        tooltip=[
            alt.Tooltip(f'{f0["col"]}:N'),
            alt.Tooltip(f'{value_col}:Q',format=val_format, title=f'{value_col} difference')
            ],
        color=alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=None)    
    )
    return plot

In [17]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['age_group'],
    'filter': { 'age_group': ('25-34','35-44')},
    'plot': 'diff_columns',
    'internal_facet': True,
    'plot_args': { 'sort_descending': True }
},data_file)

In [18]:
#| export

# The idea was to also visualize the size of each cluster. Currently not very useful, may need to be rethought

@stk_plot('massplot', data_format='longform', draws=False, group_sizes=True, n_facets=(1,2), hidden=True)
def massplot(data, value_col='value', facets=[], n_datapoints=1, val_format='%', width=800, tooltip=[]):
    f0, f1 = facets[0], facets[1] if len(facets)>1 else None

    data['group_size']=(data['group_size']/n_datapoints)#.round(2)

    plot = alt.Chart(round(data, 3), width = 'container' \
    ).mark_circle().encode(
        y=alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
        x=alt.X(
            f'{value_col}:Q',
            title=value_col,
            axis=alt.Axis(format=val_format),
            #scale=alt.Scale(domain=[0,30]) #see lõikab mõnedes jaotustes parema ääre ära
            ),
        size=alt.Size('group_size:Q', legend=None, scale=alt.Scale(range=[100, 500])),
        opacity=alt.value(1.0),
        stroke=alt.value('#777'),
        tooltip = tooltip + [ alt.Tooltip('group_size:N',format='.1%',title='Group size') ],
        #tooltip=[
        #    'response:N',
            #alt.Tooltip('mean(support):Q',format='.1%')
        #    ],
        **({
                'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=None)    
            } if not f1 else {
                'yOffset':alt.YOffset(f'{f1["col"]}:N', title=None, sort=f1["order"]), 
                'color': alt.Color(f'{f1["col"]}:N', scale=f1["colors"],
                                legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width)))
            }),
    )
    return plot

In [19]:
e2e_plot({
    'res_col' : 'trust',
    'factor_cols': ['party_preference'],
    'filter': {},
    'plot': 'massplot',
    'internal_facet': True,
    'convert_res': 'continuous'
},data_file)#,dry_run=True)['data']

In [20]:
#| export

# Make the likert bar pieces
def make_start_end(x,value_col,cat_col,cat_order):
    #print("######################")
    #print(x)
    if len(x) != len(cat_order):
        # Fill in missing rows with value zero so they would just be skipped
        x = pd.merge(pd.DataFrame({cat_col:cat_order}),x,on=cat_col,how='left').fillna({value_col:0})
    mid = len(x)//2
        
    if len(x)%2==1: # odd:
        scale_start=1.0
        x_mid = x.iloc[[mid],:]
        x_mid.loc[:,['start','end']] = -scale_start
        x_mid.loc[:,'end'] = -scale_start+x_mid[value_col]
        nonmid = [ i for i in range(len(x)) if i!=mid ]
        x_mid = [x_mid]
    else: # even - no separate mid
        nonmid = np.arange(len(x))
        x_mid = []
    
    x_other = x.iloc[nonmid,:].copy()
    x_other.loc[:,'end'] = x_other[value_col].cumsum() - x_other[:mid][value_col].sum()
    x_other.loc[:,'start'] = (x_other[value_col][::-1].cumsum()[::-1] - x_other[mid:][value_col].sum())*-1
    res = pd.concat([x_other] + x_mid).dropna() # drop any na rows added in the filling in step
    #print(res)
    return res

@stk_plot('likert_bars', data_format='longform', draws=False, requires=[{'likert':True}], n_facets=(1,3), priority=50)
def likert_bars(data, value_col='value', facets=[],  tooltip=[], outer_factors=[]):
    # First facet is likert, second is labeled question, third is offset. Second is better for question which usually goes last, hence reorder
    if len(facets)==1: # Create a dummy second facet
        facets.append({ 'col': 'question', 'order': [facets[0]['col']], 'colors': alt.Undefined })
        data['question'] = facets[0]['col']
    if len(facets)>=3: f0, f1, f2 = facets[0], facets[2], facets[1]
    elif len(facets)==2: f0, f1, f2 = facets[0], facets[1], None

    gb_cols = outer_factors+[f["col"] for f in facets[1:]] # There can be other extra cols (like labels) that should be ignored
    options_cols = list(data[f0["col"]].dtype.categories) # Get likert scale names
    bar_data = data.groupby(gb_cols, group_keys=False, observed=False)[data.columns].apply(make_start_end, value_col=value_col,cat_col=f0["col"],cat_order=f0["order"],include_groups=False)
    
    plot = alt.Chart(bar_data).mark_bar() \
        .encode(
            x=alt.X('start:Q', axis=alt.Axis(title=None, format = '%')),
            x2=alt.X2('end:Q'),
            y=alt.Y(f'{f1["col"]}:N', axis=alt.Axis(title=None, offset=5, ticks=False, minExtent=60, domain=False), sort=f1["order"]),
            tooltip=tooltip,
            color=alt.Color(
                f'{f0["col"]}:N',
                legend=alt.Legend(
                    title='Response',
                    orient='bottom',
                    ),
                scale=f0["colors"],
            ),
            **({ 'yOffset':alt.YOffset(f'{f2["col"]}:N', title=None, sort=f2["order"]),
                 #'strokeWidth': alt.value(3)
               } if f2 else {})
        )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'valitsus',
    'factor_cols': [],  'filter': {},
    'plot': 'likert_bars',
    'internal_facet': True
},data_file)

In [None]:
e2e_plot({
    'res_col' : 'trust',
    'factor_cols': ['party_preference'],  'filter': {},
    'plot': 'likert_bars',
    'internal_facet': True
},data_file)

In [None]:
#| export

# Calculate KDE ourselves using a fast libary. This gets around having to do sampling which is unstable
def kde_1d(vc, value_col, ls, scale=False):
    y =  FFTKDE(kernel='gaussian').fit(vc.to_numpy()).evaluate(ls)
    if scale: y*=len(vc)
    return pd.DataFrame({'density': y, value_col: ls})

@stk_plot('density', data_format='raw', factor_columns=3, aspect_ratio=(1.0/1.0), n_facets=(0,1), args={'stacked':'bool'}, no_question_facet=True)
def density(data, value_col='value', facets=[], tooltip=[], outer_factors=[], stacked=False, width=800):
    f0 = facets[0] if len(facets)>0 else None
    gb_cols = [ c for c in outer_factors+[f['col'] for f in facets] if c is not None ] # There can be other extra cols (like labels) that should be ignored
    
    ls = np.linspace(data[value_col].min()-1e-10,data[value_col].max()+1e-10,200)
    ndata = gb_in_apply(data,gb_cols,cols=[value_col],fn=kde_1d,value_col=value_col,ls=ls,scale=stacked).reset_index()

    if stacked:
        
        if f0:
            ldict = dict(zip(f0["order"], reversed(range(len(f0["order"])))))
            ndata.loc[:,'order'] = ndata[f0["col"]].astype('object').replace(ldict).astype('int')
        
        ndata['density'] /= len(data)
        plot=alt.Chart(ndata).mark_area(interpolate='natural').encode(
                x=alt.X(f"{value_col}:Q"),
                y=alt.Y('density:Q',axis=alt.Axis(title=None, format = '%'),stack='zero'),
                tooltip = tooltip[1:],
                **({'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f0["order"],width))), 'order': alt.Order('order:O')} if f0 else {})
            )
    else:
        plot = alt.Chart(
                ndata
            ).mark_line().encode(
                x=alt.X(f"{value_col}:Q"),
                y=alt.Y('density:Q',axis=alt.Axis(title=None, format = '%')),
                tooltip = tooltip[1:],
                **({'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f0["order"],width)))} if f0 else {})
            )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['party_preference'],  'filter': {},
    'plot': 'density',
    'plot_args': {'stacked': True},
    'internal_facet': True
},data_file,width=500)

In [None]:
#| export

@stk_plot('violin', data_format='raw', n_facets=(1,2), as_is=True)
def violin(data, value_col='value', facets=[], tooltip=[], outer_factors=[],width=800):
    f0, f1 = facets[0], facets[1] if len(facets)>1 else None
    gb_cols = outer_factors + [ f['col'] for f in facets ] # There can be other extra cols (like labels) that should be ignored
    
    ls = np.linspace(data[value_col].min()-1e-10,data[value_col].max()+1e-10,200)
    ndata = gb_in_apply(data,gb_cols,cols=[value_col],fn=kde_1d,value_col=value_col,ls=ls,scale=True).reset_index()
    
    if f1:
        ldict = dict(zip(f1["order"], reversed(range(len(f1["order"])))))
        ndata.loc[:,'order'] = ndata[f1["col"]].astype('object').replace(ldict).astype('int')

    ndata['density'] /= len(data)
    plot=alt.Chart(ndata).mark_area(interpolate='natural').encode(
            x=alt.X(f"{value_col}:Q"),
            y=alt.Y('density:Q',axis=alt.Axis(title=None, labels=False, values=[0], grid=False),stack='center'),
            row=alt.Row(f'{f0["col"]}:N',header=alt.Header(orient='top',title=None),spacing=5,sort=f0["order"]),
            tooltip = tooltip[1:],
            #color=alt.Color(f'{question_col}:N'),
            **({'color': alt.Color(f'{f1["col"]}:N', scale=f1["colors"], legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width))), 'order': alt.Order('order:O')} if f1 else 
               {'color': alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=None)})
        ).properties(width=width,height=70)

    return plot

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['party_preference'],  'filter': {},
    'plot': 'violin',
    'internal_facet': True
},data_file, width=650)

In [None]:
#| export

# Cluster-based reordering
def cluster_based_reorder(X):
    pd = sp.spatial.distance.pdist(X)#,metric='cosine')
    return hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(hierarchy.ward(pd), pd))

@stk_plot('matrix', data_format='longform', aspect_ratio=(1/0.8), n_facets=(2,2), args={'reorder':'bool'})
def matrix(data, value_col='value', facets=[], val_format='%', reorder=False, row_normalize=False, tooltip=[]):
    f0, f1 = facets[0], facets[1]
    
    fcols = [c for c in data.columns if c not in [value_col,f0["col"]]]
    if len(fcols)==1 and reorder: # Reordering only works if no external facets
        X = data.pivot(columns=f1["col"],index=f0["col"]).to_numpy()
        f0["order"] = np.array(f0["order"])[cluster_based_reorder(X)]
        f1["order"] = np.array(f1["order"])[cluster_based_reorder(X.T)]
    
    # Find max absolute value to keep color scale symmetric
    dmax = max(-data[value_col].min(),data[value_col].max())
    
    # Draw colored boxes
    base = alt.Chart(data).mark_rect().encode(
            x=alt.X(f'{f1["col"]}:N', title=None, sort=f1["order"]),
            y=alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
            color=alt.Color(f'{value_col}:Q', scale=alt.Scale(scheme='redyellowgreen', domainMid=0, domainMin=-dmax, domainMax=dmax),
                legend=alt.Legend(title=None)),
            tooltip=tooltip,
        )
    
    # Add in numerical values
    text = base.mark_text().encode(
        text=alt.Text(f'{value_col}:Q', format=val_format),
        color=alt.condition(
            alt.datum[f'{value_col}:Q']**2 > 1.5,
            alt.value('white'),
            alt.value('black')
        ),
        tooltip=tooltip
    )
    
    return base+text

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': [ 'party_preference'],  'filter': {},
    'plot': 'matrix',
    'internal_facet': True
},data_file)

In [None]:
e2e_plot({
    'res_col' : 'party_preference',
    'factor_cols': [ 'age_group'],  'filter': {},
    'plot': 'matrix',
    'internal_facet': True
},data_file)

In [None]:
e2e_plot({
    'res_col' : 'wedge',
    'factor_cols': [ 'party_preference'],  'filter': {},
    'plot': 'matrix',
    'plot_args': { 'reorder':True },
    'internal_facet': True,
    'convert_res': 'continuous',
    'cont_transform': 'center'
},data_file)

In [None]:
#| export
@stk_plot('lines',data_format='longform', draws=False, requires=[{},{'ordered':True}], n_facets=(2,2), args={'smooth':'bool'})
def lines(data, value_col='value', facets=[], smooth=False, width=800, tooltip=[], val_format='.2f',):
    f0, f1 = facets[0], facets[1]
    if smooth:
        smoothing = 'basis'
        points = 'transparent'
    else:
        smoothing = 'natural'
        points = True
    plot = alt.Chart(data).mark_line(point=points, interpolate=smoothing).encode(
        alt.X(f'{f1["col"]}:N', title=None, sort=f1["order"]),
        alt.Y(f'{value_col}:Q', title=None, axis=alt.Axis(format=val_format)),
        tooltip=tooltip,
        color=alt.Color(f'{f0["col"]}:N', scale=f0["colors"], sort=f0["order"],
                        legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f0["order"],width)))
    )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['education'],  'filter': {},
    'plot': 'lines',
    'internal_facet': True,
    'plot_args': { 'smooth':True }
},data_file)

In [None]:
#| export

def draws_to_hdis(data,vc,hdi_vals):
    gbc = [ c for c in data.columns if c not in [vc,'draw'] ]
    ldfs = []
    for hdiv in hdi_vals:
        ldf_v = data.groupby(gbc,observed=False)[vc].apply(lambda s: pd.Series(list(az.hdi(s.to_numpy(),hdi_prob=hdiv)),index=['lo','hi'])).reset_index()
        ldf_v['hdi']=hdiv
        ldfs.append(ldf_v)
    ldf = pd.concat(ldfs).reset_index(drop=True)
    df = ldf.pivot(index=gbc+['hdi'], columns=ldf.columns[-3],values=vc).reset_index()
    return df

@stk_plot('lines_hdi',data_format='longform', draws=True, requires=[{},{'ordered':True}], n_facets=(2,2), args={'hdi1':'float','hdi2':'float'})
def lines_hdi(data, value_col='value', facets=[], width=800, tooltip=[], val_format='.2f', hdi1=0.94, hdi2=0.5):
    f0, f1 = facets[0], facets[1]
    
    hdf = draws_to_hdis(data,value_col,[hdi1,hdi2])
    # Draw them in reverse order so the things that are first (i.e. most important) are drawn last (i.e. on top of others)
    # Also draw wider hdi before the narrower
    hdf.sort_values([f0["col"],'hdi'],ascending=[False,False],inplace=True)

    plot = alt.Chart(hdf).mark_area(interpolate='basis').encode(
        alt.X(f'{f1["col"]}:O', title=None, sort=f1["order"]),
        y=alt.Y('lo:Q',
            axis=alt.Axis(format=val_format),
            title=value_col
            ),
        y2=alt.Y2('hi:Q'),
        color=alt.Color(
            f'{f0["col"]}:N',
            sort=f0["order"],
            scale=f0["colors"]
            ),
        opacity=alt.Opacity('hdi:N',legend=None,scale=to_alt_scale({0.5:0.75,0.94:0.25})),
        tooltip=[
            alt.Tooltip('hdi:N', title='HDI', format='.0%'),
            alt.Tooltip('lo:Q', title='HDI lower', format=val_format),
            alt.Tooltip('hi:Q', title='HDI upper', format=val_format),] + tooltip[1:]
        )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'party_preference',
    'factor_cols': ['age_group','nationality'],  'filter': {},
    'plot': 'lines_hdi',
    'internal_facet': True,
},data_file)

In [None]:
#| export
@stk_plot('area_smooth',data_format='longform', draws=False, nonnegative=True, requires=[{},{'ordered':True}], n_facets=(2,2))
def area_smooth(data, value_col='value', facets=[], width=800, tooltip=[]):
    f0, f1 = facets[0], facets[1]
    ldict = dict(zip(f0["order"], range(len(f0["order"]))))
    data.loc[:,'order'] = data[f0["col"]].astype('object').replace(ldict).astype('int')
    plot=alt.Chart(data
        ).mark_area(interpolate='natural').encode(
            x=alt.X(f'{f1["col"]}:O', title=None, sort=f1["order"]),
            y=alt.Y(f'{value_col}:Q', title=None, stack='normalize',
                 scale=alt.Scale(domain=[0, 1]), axis=alt.Axis(format='%')
                 ),
            order=alt.Order('order:O'),
            color=alt.Color(f'{f0["col"]}:N',
                legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f0["order"],width)),
                sort=f0["order"], scale=f0["colors"]
                ),
            tooltip=tooltip
        )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'party_preference',
    'factor_cols': ['education','gender'],  'filter': {},
    'plot': 'area_smooth',
    'internal_facet': True
},data_file)

In [None]:
#| export
def likert_aggregate(x, cat_col, cat_order, value_col):
    
    cc, vc = x[cat_col], x[value_col]
    cats = cat_order
    
    mid, odd = len(cats)//2, len(cats)%2
    
    nonmid_sum = vc[cc !=  cats[mid]].sum() if odd else vc.sum()
    
    #print(len(x),x.columns,x.head())
    pol = ( np.minimum(
                vc[cc.isin(cats[:mid])].sum(),
                vc[cc.isin(cats[mid+odd:])].sum()
            ) / nonmid_sum )

    rad = ( vc[cc.isin([cats[0],cats[-1]])].sum() /
            nonmid_sum )

    rel = 1.0-nonmid_sum/vc.sum()
    
    return pd.Series({ 'polarisation': pol, 'radicalisation':rad, 'relevance':rel})

@stk_plot('likert_rad_pol',data_format='longform', requires=[{'likert':True}], args={'normalized':'bool'}, n_facets=(1,2))
def likert_rad_pol(data, value_col='value', facets=[], normalized=True, width=800, outer_factors=[], tooltip=[]):
    f0, f1 = facets[0], facets[1] if len(facets)>1 else None
    #gb_cols = list(set(data.columns)-{ f0["col"], value_col }) # Assume all other cols still in data will be used for factoring
    gb_cols = outer_factors + [f['col'] for f in facets[1:]] # There can be other extra cols (like labels) that should be ignored
    likert_indices = gb_in_apply(data,gb_cols,likert_aggregate,cat_col=f0["col"],cat_order=f0["order"],value_col=value_col).reset_index()
    
    if normalized and len(likert_indices)>1: likert_indices.loc[:,['polarisation','radicalisation']] = likert_indices[['polarisation','radicalisation']].apply(sps.zscore)
    
    plot = alt.Chart(likert_indices).mark_point().encode(
        x=alt.X('polarisation:Q'),
        y=alt.Y('radicalisation:Q'),
        size=alt.Size('relevance:Q', legend=None, scale=alt.Scale(range=[100, 500])),
        opacity=alt.value(1.0),
        #stroke=alt.value('#777'),
        tooltip=[
            alt.Tooltip('radicalisation:Q', format='.2'),
            alt.Tooltip('polarisation:Q', format='.2'),
            alt.Tooltip('relevance:Q', format='.2')
        ] + tooltip[2:],
        **({'color': alt.Color(f'{f1["col"]}:N', scale=f1["colors"], 
                               legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width)))
            } if f1 else {})
        )
    return plot

In [None]:
e2e_plot({
  "res_col": "referendum",
  "factor_cols": [
    "electoral_district"
  ],
  "internal_facet": True,
  "plot": "likert_rad_pol",
  "filter": {}
},'../samples/master_bootstrap.parquet',width=600)

In [None]:
#| export
@stk_plot('barbell', data_format='longform', draws=False, n_facets=(2,2))
def barbell(data, value_col='value', facets=[], n_datapoints=1, val_format='%', width=800, tooltip=[]):
    f0, f1 = facets[0], facets[1]
    
    chart_base = alt.Chart(data).encode(
        alt.X(f'{value_col}:Q', title=None, axis=alt.Axis(format=val_format)),
        alt.Y(f'{f0["col"]}:N', title=None, sort=f0["order"]),
        tooltip=tooltip
    )

    chart = chart_base.mark_line(color='lightgrey', size=1, opacity=1.0).encode(
        detail=f'{f0["col"]}:N'
    )
    selection = alt.selection_point(fields=[f1["col"]], bind='legend')

    chart += chart_base.mark_point(
        size=50,
        opacity=1,
        filled=True
    ).encode(
        color=alt.Color(f'{f1["col"]}:N',
            #legend=alt.Legend(orient='right', title=None),
            legend=alt.Legend(orient='top',columns=estimate_legend_columns_horiz(f1["order"],width)),
            scale=f1["colors"],
            sort=f1["order"]
        ),
        opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
    ).add_params(
        selection
    )#.interactive()
    
    return chart

In [None]:
e2e_plot({
    'res_col' : 'trust',
    'factor_cols': ['party_preference'],
    'filter': {},
    'plot': 'barbell',
    'internal_facet': True,
    'convert_res': 'continuous'
},data_file)

In [None]:
# This does not work because of a Vega lite bug: https://github.com/altair-viz/altair/issues/2369 -> https://github.com/vega/vega-lite/issues/3729
#@stk_plot('geoplot', data_format='longform', factor_meta=['topo_feature'])
def geoplot_ideal(data, value_col='value', facets=[], val_format='%'):
    f0, f1 = facets[0], facets[1]
    
    # TODO: replace with data_format='table' one of these days
    pdf = data.pivot(index=[f1["col"]], columns=f0["col"])
    pdf.columns = pdf.columns.get_level_values(1)
    pdf = pdf.reset_index()
    #print(data)
    
    source = alt.topo_feature(tjson_url, "data")
    plot = alt.Chart(source).transform_lookup(
        lookup = f"properties.{geo_fname}",
        from_ = alt.LookupData(
            data=pdf,
            key=f1["col"],
            fields=f0["order"]
        ),
    ).transform_fold(f0["order"],as_=[f0["col"],value_col]).mark_geoshape(stroke='white', strokeWidth=0.1).encode(
        tooltip=[alt.Tooltip(f'properties.{ geo_fname}:N', title='maakond'),
                alt.Tooltip(f'{value_col}:Q', title='osakaal', format=val_format)],
        color=alt.Color(
            f'{value_col}:Q',
            scale=alt.Scale(scheme="reds"),
            legend=alt.Legend(format=".0%", title=None, orient='top-left'),
        )
    )
    '''plot = alt.Chart(source).mark_geoshape(stroke='white', strokeWidth=0.1).transform_lookup(
        lookup = f"properties.{geo_fname}",
        from_ = alt.LookupData(
            data=pdf,
            key=f1["col"],
            fields=f0["order"]
        ),
    ).transform_fold(f0["order"],as_=[f0["col"],value_col]).encode(
        #tooltip=[alt.Tooltip(f'properties.{ geo_fname}:N', title='maakond'),
        #        alt.Tooltip(f'{value_col}:Q', title='osakaal', format='.0%')],
        color=alt.Color(
            f'{value_col}:Q',
            scale=alt.Scale(scheme="reds"),
            legend=alt.Legend(format=".0%", title=None, orient='top-left'),
        ),
        facet=alt.Facet(f'{f0["col"]}:N')
    )#.facet(f'{f0["col"]}:N')'''
    

In [None]:
#| export

@stk_plot('geoplot', data_format='longform', n_facets=(1,1), requires=[{'topo_feature':'pass'}], aspect_ratio=(4.0/3.0), no_question_facet=True)
def geoplot(data, topo_feature, value_col='value', facets=[], val_format='.2f',tooltip=[]):
    f0 = facets[0]

    tjson_url, tjson_meta, tjson_col = topo_feature
    source = alt.topo_feature(tjson_url, tjson_meta)

    plot = alt.Chart(source).mark_geoshape(stroke='white', strokeWidth=0.1).transform_lookup(
        lookup = f"properties.{tjson_col}",
        from_ = alt.LookupData(
            data=data,
            key=f0["col"],
            fields=list(data.columns)
        ),
    ).encode(
        tooltip=tooltip, #[alt.Tooltip(f'properties.{tjson_col}:N', title=f1["col"]),
                #alt.Tooltip(f'{value_col}:Q', title=value_col, format=val_format)],
        color=alt.Color(
            f'{value_col}:Q',
            scale=alt.Scale(scheme="reds"), # To use color scale, consider switching to opacity for value
            legend=alt.Legend(format=val_format, title=None, orient='top-left',gradientThickness=6),
        )
    ).project('mercator')
    return plot

In [None]:
data_metafile = '../data/master_meta.json'
data_meta = read_json(data_metafile)
    
e2e_plot({
    'res_col' : 'EKRE',
    'factor_cols': ['electoral_district'],
    'filter': {},
    'plot': 'geoplot',
    'internal_facet': True
}, data_file, data_meta=data_meta,width=400)

In [None]:
data_meta['structure']

In [None]:
#| exporti

# Assuming ns is ordered by unique row values, find the split points
def split_ordered(cvs):
    if len(cvs.shape)==1: cvs = cvs[:,None]
    unique_idxs = np.empty(len(cvs), dtype=np.bool_)
    unique_idxs[:1] = False
    unique_idxs[1:] = np.any(cvs[:-1, :] != cvs[1:, :], axis=-1)
    return np.arange(len(unique_idxs))[unique_idxs]

# Split a series of weights into groups of roughly equal sum
# This algorithm is greedy and does not split values but it is fast and should be good enough for most uses
def split_even_weight(ws, n):
    cws = np.cumsum(ws)
    cws = (cws/(cws[-1]/n)).astype('int')
    return (split_ordered(cws)+1)[:-1]

In [None]:
# Test
assert (split_even_weight([1,0,1,
                           2,
                           1.5,1,
                           1.5],4) == np.array([3,4,6])).all()

In [None]:
#| export

def fd_mangle(vc, value_col, factor_col, n_points=10): 
    
    vc = vc.sort_values(value_col)
    
    ws = np.ones(len(vc))
    splits = split_even_weight(ws, n_points)
    
    ccodes, cats = vc[factor_col].factorize()
    
    ofreqs = np.stack([ np.bincount(g, weights=gw, minlength=len(cats))/gw.sum()
                        for g,gw in zip(np.split(ccodes,splits),np.split(ws,splits)) ],axis=0)
    
    df = pd.DataFrame(ofreqs, columns=cats)
    df['percentile'] = np.linspace(0,1,n_points)
    return df.melt(id_vars='percentile',value_vars=cats,var_name=factor_col,value_name='density')

@stk_plot('facet_dist', data_format='raw', factor_columns=3,aspect_ratio=(1.0/1.0), n_facets=(1,1), no_question_facet=True)
def facet_dist(data, value_col='value',facets=[], tooltip=[], outer_factors=[]):
    f0 = facets[0]
    gb_cols = [ c for c in outer_factors if c is not None ] # There can be other extra cols (like labels) that should be ignored
    ndata = gb_in_apply(data,gb_cols,cols=[value_col,f0["col"]],fn=fd_mangle,value_col=value_col,factor_col=f0["col"]).reset_index()
    plot=alt.Chart(ndata).mark_area(interpolate='natural').encode(
            x=alt.X(f"percentile:Q",axis=alt.Axis(format='%')),
            y=alt.Y('density:Q',axis=alt.Axis(title=None, format = '%'),stack='normalize'),
            tooltip = tooltip[1:],
            color=alt.Color(f'{f0["col"]}:N', scale=f0["colors"], legend=alt.Legend(orient='top')),
            #order=alt.Order('order:O')
    )

    return plot

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['party_preference'],  'filter': {},
    'plot': 'facet_dist',
    'internal_facet': True
},data_file)

In [None]:
#| exporti

# Vectorized multinomial sampling. Should be slightly faster
def vectorized_mn(prob_matrix):
    s = prob_matrix.cumsum(axis=1)
    s = s/s[:,-1][:,None]
    r = np.random.rand(prob_matrix.shape[0])[:,None]
    return (s < r).sum(axis=1)

def linevals(vals, value_col, n_points, dim, cats, ccodes=None, ocols=None, boost_signal=True,gc=False, weights=None):
    ws = weights if weights is not None else np.ones(len(vals))
    
    order = np.lexsort((vals,ccodes)) if dim and gc else np.argsort(vals)
    splits = split_even_weight(ws[order], n_points)
    aer = np.array([ g.mean() for g in np.split(vals[order],splits) ])
    pdf = pd.DataFrame(aer, columns=[value_col])
    
    if dim:
        # Find the frequency of each category in ccodes
        osignal = np.stack([ np.bincount(g, weights=gw, minlength=len(cats))/gw.sum()
                            for g,gw in zip(np.split(ccodes[order],splits),np.split(ws[order],splits)) ],axis=0)
        
        ref_p = osignal.mean(axis=0)+1e-10
        np.random.seed(0) # So they don't change every re-render

        signal = osignal + 1e-10 # So even if values are zero, log and vectorized_mn would work
        if boost_signal: # Boost with K-L, leaving only categories that grew in probability boosted by how much they did
            klv = signal*(np.log(signal/ref_p[None,:]))
            signal = np.maximum(1e-10,klv)
            pdf['kld'] = np.sum(klv,axis=1)

        #pdf[dim] = cats[signal.apply(lambda r: np.random.multinomial(1,r/r.sum()).argmax() if r.sum()>0.0 else 0,axis=1)]
        cat_inds = vectorized_mn(signal)
        pdf[dim] = np.array(cats)[cat_inds]
        pdf['probability'] = osignal[np.arange(len(cat_inds)),cat_inds]

        #pdf[dim] = pdf[cats].idxmax(axis=1)
        #pdf['weight'] = np.minimum(pdf[cats].max(axis=1),pdf['matches'])

    pdf['pos'] = np.arange(0,1,1.0/len(pdf))
    
    if ocols is not None:
        for iv in ocols.index:
            pdf[iv] = ocols[iv]

    return pdf

In [None]:
#| export

@stk_plot('ordered_population', data_format='raw', factor_columns=3, aspect_ratio=(1.0/1.0), plot_args={'group_categories':'bool'}, n_facets=(0,1), no_question_facet=True)
def ordered_population(data, value_col='value', facets=[], tooltip=[], outer_factors=[], group_categories=False):
    f0 = facets[0] if len(facets)>0 else None
    
    n_points, maxn = 200, 1000000
    
     # TODO: use weight if available. linevals is ready for it, just needs to be fed in. 
    
    # Sample down to maxn points if exceeding that
    if len(data)>maxn: data = data.sample(maxn,replace=False)
    
    data = data.sort_values(outer_factors)
    vals = data[value_col].to_numpy()

    if len(facets)>=1:
        fcol = f0["col"]
        cat_idx, cats = pd.factorize(data[f0["col"]])
        cats = list(cats)
    else:
        fcol = None
        cat_idx, cats = None, []
        
    if outer_factors:
        
        # This is optimized to not use pandas.groupby as it makes it about 2x faster - which is 2+ seconds with big datasets
        
        # Assume data is sorted by outer_factors, split vals into groups by them
        ofids = np.stack([ data[f].cat.codes.values for f in outer_factors ],axis=1)
        splits = split_ordered(ofids)        
        groups = np.split(vals,splits)
        cgroups = np.split(cat_idx,splits) if len(facets)>=1 else groups
        
        # Perform the equivalent of groupby
        ocols = data.iloc[[0]+list(splits)][outer_factors]
        tdf = pd.concat([linevals(g,value_col=value_col,dim=fcol, ccodes=gc, cats=cats, n_points=n_points,ocols=ocols.iloc[i,:],gc=group_categories) 
                         for i,(g,gc) in enumerate(zip(groups,cgroups))])

        #tdf = data.groupby(outer_factors,observed=True).apply(linevals,value_col=value_col,dim=fcol,cats=cats,n_points=n_points,gc=group_categories,include_groups=False).reset_index()
    else:
        tdf = linevals(vals,value_col=value_col,dim=fcol,ccodes=cat_idx,cats=cats,n_points=n_points, gc=group_categories)
        #tdf = linevals(data,value_col=value_col,cats=cats,dim=fcol,n_points=n_points,gc=group_categories)
        
    #if boost_signal:
    #    tdf['matches'] = np.minimum(tdf['matches'],tdf['kld']/tdf['kld'].quantile(0.75))

    base = alt.Chart(tdf).encode(
        x=alt.X('pos:Q',
            title="",
            axis=alt.Axis(
                labels=False,
                ticks=False,
                #grid=False
            )
        )
    )
    #selection = alt.selection_multi(fields=[dim], bind='legend')
    line = base.mark_circle(size=10).encode(
        y=alt.Y(f"{value_col}:Q",impute={'value':None}, title='', axis=alt.Axis(grid=True)),
        #opacity=alt.condition(selection, alt.Opacity("matches:Q",scale=None), alt.value(0.1)),
        color=alt.Color(
            f'{f0["col"]}',
            sort=f0["order"],
            scale=f0["colors"]
            ) if len(facets)>=1 else alt.value('red'),
        tooltip=tooltip+([alt.Tooltip('probability:Q',format='.1%',title='category prob.')] if len(facets)>=1 else [])
    )#.add_selection(selection)


    rule = alt.Chart().mark_rule(color='red', strokeDash=[2, 3]).encode(
        y=alt.Y('mv:Q')
    ).transform_joinaggregate(
        mv = f'mean({value_col}):Q',
        groupby=outer_factors
    )

    plot = alt.layer(
        rule,
        line,
        data=tdf,
    )
    return plot

In [None]:
e2e_plot({
    'res_col' : 'thermometer',
    'factor_cols': ['party_preference'],  'filter': {},
    'plot': 'ordered_population',
    #'plot_args': {'group_categories':True},
    'internal_facet': True
},data_file,width=800)

In [None]:
#| export
@stk_plot('marimekko', data_format='longform', draws=False, group_sizes=True, args={'separate':'bool'}, n_facets=(2,2))
def marimekko(data, value_col='value', facets=[], val_format='%', width=800, tooltip=[], outer_factors=[], separate=False):
    f0, f1 = facets[0], facets[1]

    #xcol, ycol, ycol_scale = f1["col"], f0["col"], f0["colors"]
    xcol, ycol, ycol_scale = f0["col"], f1["col"], f1["colors"]
     
    data['w'] = data['group_size']*data[value_col]
    data.sort_values([xcol,ycol],ascending=[True,False],inplace=True)

    if separate: # Split and center each ycol group so dynamics can be better tracked for all of them
        ndata = data.groupby(outer_factors+[xcol],observed=False)[[ycol,value_col,'w']].apply(lambda df: pd.DataFrame({ ycol: df[ycol], 'yv': df['w']/df['w'].sum(), 'w': df['w']})).reset_index()
        ndata = ndata.merge(ndata.groupby(outer_factors + [ycol],observed=True)['yv'].max().rename('ym').reset_index(),on=outer_factors + [ycol]).fillna({'ym':0.0})
        ndata = ndata.groupby(outer_factors+[xcol],observed=False)[[ycol,'w','yv','ym']].apply(lambda df: pd.DataFrame({ ycol: df[ycol], 'yv': df['yv'], 'w': df['w'].sum(),
                                                                                                                        'y1': (df['ym'].cumsum()- df['ym']/2 - df['yv']/2)/df['ym'].sum(),
                                                                                                                        'y2': (df['ym'].cumsum()- df['ym']/2 + df['yv']/2)/df['ym'].sum(), })).reset_index()
    else: # Regular marimekko
        ndata = data.groupby(outer_factors+[xcol],observed=False)[[ycol,value_col,'w']].apply(lambda df: pd.DataFrame({ ycol: df[ycol], 'w': df['w'].sum(),
                                                                                                                       'yv': df['w']/df['w'].sum(), 
                                                                                                                       'y2': df['w'].cumsum()/df['w'].sum()})).reset_index()
        ndata['y1'] = ndata['y2']-ndata['yv']
    
    ndata = ndata.groupby(outer_factors+[ycol],observed=False)[[xcol,'yv','y1','y2','w']].apply(lambda df: pd.DataFrame({ xcol: df[xcol], 'xv': df['w']/df['w'].sum(), 'x2': df['w'].cumsum()/df['w'].sum(), 'yv':df['yv'], 'y1':df['y1'], 'y2':df['y2']})).reset_index()
    ndata['x1'] = ndata['x2']-ndata['xv']
    

    #selection = alt.selection_point(fields=[yvar], bind="legend")
    STROKE = 0.25
    plot = alt.Chart(ndata).mark_rect(
            strokeWidth=STROKE,
            stroke="white",
            xOffset=STROKE / 2,
            x2Offset=STROKE / 2,
            yOffset=STROKE / 2,
            y2Offset=STROKE / 2,
        ).encode(
            x=alt.X(
                "x1:Q",
                axis=alt.Axis(
                    zindex=1, format="%", title=[f"{xcol} (% of total)", " "], grid=False
                ),
                scale=alt.Scale(domain=[0, 1]),
            ),
            x2="x2:Q",
            y=alt.Y(
                "y1:Q",
                axis=alt.Axis(
                    zindex=1, format="%", title=f"{ycol} (% of total)", grid=False, labels=not separate
                    ),
                scale=alt.Scale(domain=[0, 1])
            ),
            y2="y2:Q",
            color=alt.Color(
                f"{ycol}:N",
                legend=alt.Legend(title=None, symbolStrokeWidth=0), #title=f"{yvar}"),
                scale=ycol_scale,
            ),
            tooltip=[
                alt.Tooltip("yv:Q", title=f'{ycol} proportion', format='.1%' ),
                alt.Tooltip("xv:Q", title=f'{xcol} proportion', format='.1%' ),
            ]+tooltip[1:],
            #opacity=alt.condition(selection, alt.value(1), alt.value(0.3)),
        )
        #.add_params(selection)
    
    return plot

In [None]:
e2e_plot({
    'res_col' : 'trust',
    'factor_cols': ['party_preference','question'],
    'filter': {},
    #'plot_args': {'separate':True},
    'plot': 'marimekko',
    'internal_facet': True,
},data_file)

In [None]:
# Test convert_res
e2e_plot({
    'res_col' : 'age_group',
    'factor_cols': ['gender'],
    'filter': {},
    'plot': 'boxplots',
    'internal_facet': True,
    'convert_res': 'continuous'
},data_file)

In [None]:
# Test replace_question and num_values
e2e_plot({
    'res_col' : 'valitsus',
    'factor_cols': ['party_preference','age_group'],
    'filter': {},
    'convert_res': 'continuous',
    'num_values': [0,0,0,1,1],
    'value_name': 'Pr[valitsus==Agree]',
    'value_format': '.0%',
    'plot': 'lines_hdi',
    'internal_facet': True,
}, data_file, width=800,lazy=False)

In [None]:
# Test labels
data_file = '../samples/m_bootstrap.parquet'
data_metafile = '../../salk_internal_package/data/master_meta.json'
if data_metafile:
    data_meta = read_json(data_metafile)
e2e_plot({
    'res_col' : 'trust',
    'factor_cols': ['gender','age_group'],
    'filter': {},
    'plot': 'likert_bars',
    'internal_facet': True
}, data_file, width=800, data_meta=data_meta)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()