# Plot Pipeline
> Pipeline from raw survey data file up to creating the plot
> Built around the use of plot registry from plots.py

In [None]:
#| default_exp pp

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| exporti
import json, os
import itertools as it
from collections import defaultdict

import numpy as np
import pandas as pd
import polars as pl
import datetime as dt
import scipy.stats as sps

from typing import List, Tuple, Dict, Union, Optional

import altair as alt

from salk_toolkit.plots import stk_plot, stk_deregister, matching_plots, get_plot_fn, get_plot_meta
from salk_toolkit.utils import *
from salk_toolkit.io import load_parquet_with_metadata, extract_column_meta, group_columns_dict, list_aliases, read_annotated_data, read_json

In [None]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [None]:
# Simple args value for testing individual functions
args = {
    'res_col' : 'age_group',
    'factor_cols': ['EKRE'],
    'filter': { 'nationality': 'Estonian', 'EKRE': (-3,2), 'age_group': ('35-44', '75+'), 'party_preference': ['SDE','EKRE','Reformierakond'] },
    'plot': 'boxplots',
    'internal_facet': False
}

data_uri = '../../salk_internal_package/samples/bootstrap.parquet'

# Load a basic bootstrapped dataset
full_df, f_meta = load_parquet_with_metadata(data_uri)
data_meta = f_meta['data']

In [None]:
# Read metafile directly - allows faster iteration
data_metafile = '../data/master_meta.json'
if data_metafile:
    from salk_toolkit.utils import replace_constants
    data_meta = read_json(data_metafile)

In [None]:
#| exporti

# Augment each draw with bootstrap data from across whole population to make sure there are at least <threshold> samples
def augment_draws(data, factors=None, n_draws=None, threshold=50):
    if n_draws == None: n_draws = data.draw.max()+1
    
    if factors: # Run recursively on each factor separately and concatenate results
        if data[ ['draw']+factors ].value_counts().min() >= threshold: return data # This takes care of large datasets fast
        return data.groupby(factors,observed=False).apply(augment_draws,n_draws=n_draws,threshold=threshold).reset_index(drop=True) # Slow-ish, but only needed on small data now
    
    # Get count of values for each draw
    draw_counts = data['draw'].value_counts() # Get value counts of existing draws
    if len(draw_counts)<n_draws: # Fill in completely missing draws
        draw_counts = (draw_counts + pd.Series(0,index=range(n_draws))).fillna(0).astype(int)
        
    # If no new draws needed, just return original
    if draw_counts.min()>=threshold: return data
    
    # Generate an index for new draws
    new_draws = [ d for d,c in draw_counts[draw_counts<threshold].items() for _ in range(threshold-c) ]

    # Generate new draws
    new_rows = data.iloc[np.random.choice(len(data),len(new_draws)),:].copy()
    new_rows['draw'] = new_draws
    
    return pd.concat([data, new_rows])

In [None]:
#| exporti

# Get the categories that are in use
def get_cats(col, cats=None):
    if cats is None or len(set(col.dtype.categories)-set(cats))>0: cats = col.dtype.categories
    return [ c for c in cats if c in col.unique() ]

def transform_cont(data, transform):
    if not transform: return data
    elif transform == 'center': return data - data.mean(skipna=True)
    elif transform == 'zscore': return sps.zscore(data,nan_policy='omit')
    else: raise Exception(f"Unknown transform '{transform}'")

In [None]:
#| export

# Get all data required for a given graph
# Only return columns and rows that are needed
# This can handle either a pandas DataFrame or a polars LazyDataFrame (to allow for loading only needed data)
def get_filtered_data(full_df, data_meta, pp_desc, columns=[]):
    
    # Figure out which columns we actually need
    meta_cols = ['weight', 'training_subsample', '__index_level_0__'] + (['draw'] if vod(get_plot_meta(pp_desc['plot']),'draws') else []) + columns
    cols = [ pp_desc['res_col'] ]  + vod(pp_desc,'factor_cols',[]) + list(vod(pp_desc,'filter',{}).keys())
    cols += [ c for c in meta_cols if c in full_df.columns and c not in cols ]
    
    # If any aliases are used, cconvert them to column names according to the data_meta
    gc_dict = group_columns_dict(data_meta)
    c_meta = extract_column_meta(data_meta)
    
    cols = [ c for c in np.unique(list_aliases(cols,gc_dict)) if c in full_df.columns ]
    
    #print("C",cols)
    
    lazy = isinstance(full_df,pl.LazyFrame)
    if lazy: pl.enable_string_cache() # Needed for categories to be comparable to strings
    
    df = full_df.select(cols) if lazy else full_df[cols]
    
    # Filter using demographics dict. This is very clever but hard to read. See:
    filter_dict = vod(pp_desc,'filter',{})
    inds = True if lazy else np.full(len(df),True) 
    for k, v in filter_dict.items():
        
        # Handle continuous variables separately
        if isinstance(v,tuple) and (vod(c_meta[k],'continuous') or vod(c_meta[k],'datetime')): # Only special case where we actually need a range
            if lazy: inds = (((pl.col(k)>=v[0]) & (pl.col(k)<=v[1])) | pl.col(k).is_null()) & inds
            else: inds = (((df[k]>=v[0]) & (df[k]<=v[1])) | df[k].isna()) & inds
            continue # NB! this approach does not work for ordered categoricals with polars LazyDataFrame, hence handling that separately below
        
        # Filter by list of values:
        if isinstance(v,tuple):
            if vod(c_meta[k],'categories','infer')=='infer': raise Exception(f'Ordering unknown for column {k}')
            cats = list(c_meta[k]['categories'])
            if set(v) & set(cats) != set(v): raise Exception(f'Column {k} values {v} not found in {cats}')
            bi, ei = cats.index(v[0]), cats.index(v[1])
            flst = cats[bi:ei+1] # 
        elif isinstance(v,list): flst = v # List indicates a set of values
        elif 'groups' in c_meta[k] and v in c_meta[k]['groups']:
            flst = c_meta[k]['groups'][v]
        else: flst = [v] # Just filter on single value    
            
        inds =  (pl.col(k).is_in(flst) if lazy else df[k].isin(flst)) & inds
            
    filtered_df = df.filter(inds).collect().to_pandas() if lazy else df[inds].copy()
    if lazy and '__index_level_0__' in filtered_df.columns: # Fix index, if provided. This is a hack but seems to be needed as polars does not handle index properly by default
        filtered_df.index = filtered_df['__index_level_0__'] 
    
    # Replace draw with the draws used in modelling - NB! does not currenlty work for group questions
    if 'draw' in filtered_df.columns and pp_desc['res_col'] in vod(data_meta,'draws_data',{}):
        uid, ndraws = data_meta['draws_data'][pp_desc['res_col']]
        filtered_df = deterministic_draws(filtered_df, ndraws, uid, n_total = data_meta['total_size'] )
    
    # If not poststratisfied
    if not vod(pp_desc,'poststrat',True):
        filtered_df = filtered_df.assign(weight = 1.0) # Remove weighting
        if 'training_subsample' in filtered_df.columns:
            filtered_df = filtered_df[filtered_df['training_subsample']]
    
    n_datapoints = len(filtered_df)

    # Convert ordered categorical to continuous if we can
    res_meta = c_meta[pp_desc['res_col']]
    if vod(pp_desc,'convert_res') == 'continuous' and vod(res_meta,'ordered') and vod(res_meta,'categories','infer') != 'infer':
        cmap = dict(zip(res_meta['categories'],vod(res_meta,'num_values',range(len(res_meta['categories'])))))
        rc = gc_dict[pp_desc['res_col']] if pp_desc['res_col'] in gc_dict else [pp_desc['res_col']]
        for col in rc:
            filtered_df[col] = pd.to_numeric(filtered_df[col].replace(cmap))
    
    # If res_col is a group of questions
    # This might move to wrangle but currently easier to do here as we have gc_dict handy
    if pp_desc['res_col'] in gc_dict:
        value_vars = [ c for c in gc_dict[pp_desc['res_col']] if c in cols ]
        
        if filtered_df[value_vars[0]].dtype.name != 'category':
            filtered_df.loc[:,value_vars] = filtered_df.loc[:,value_vars].apply(transform_cont,axis=0,transform=vod(pp_desc,'cont_transform'))
        
        id_vars = [ c for c in cols if c not in value_vars ]
        filtered_df = filtered_df.melt(id_vars=id_vars, value_vars=value_vars, var_name='question', value_name=pp_desc['res_col'])
        filtered_df['question'] = pd.Categorical(filtered_df['question'],gc_dict[pp_desc['res_col']])
    elif filtered_df[pp_desc['res_col']].dtype.name != 'category':
        filtered_df[pp_desc['res_col']] = transform_cont(filtered_df[pp_desc['res_col']],transform=vod(pp_desc,'cont_transform'))
        
    # Filter out the unused categories so plots are cleaner
    for k in filtered_df.columns:
        if filtered_df[k].dtype.name == 'category':
            m_cats = c_meta[k]['categories'] if vod(c_meta[k],'categories','infer')!='infer' else None
            f_cats = get_cats(filtered_df[k],m_cats) if k != pp_desc['res_col'] or not vod(c_meta[k],'likert') else m_cats # Do not trim likert as plots need to be symmetric
            filtered_df.loc[:,k] = pd.Categorical(filtered_df[k],f_cats,ordered=vod(c_meta[k],'ordered',False))
    
    # Aggregate the data into right shape
    pparams = wrangle_data(filtered_df, data_meta, pp_desc)
    
    # How many datapoints the plot is based on. This is useful metainfo to display sometimes
    pparams['n_datapoints'] = n_datapoints
    
    if lazy: pl.disable_string_cache()
    
    return pparams

In [None]:
#| exporti

# Groupby if needed - this simplifies the wrangle considerably :)
def gb_in(df, gb_cols):
    return df.groupby(gb_cols,observed=False) if len(gb_cols)>0 else df

def discretize_continuous(col, col_meta={}):
    # NB! qcut might be a better default - see where testing leads us
    cut = pd.cut(col, bins = vod(col_meta,'bins',5), labels = vod(col_meta,'bin_labels',None) )
    cut = pd.Categorical(cut.astype(str), map(str,cut.dtype.categories), True) # Convert from intervals to strings for it to play nice with altair
    return cut

# Helper function that handles reformating data for create_plot
def wrangle_data(raw_df, data_meta, pp_desc):
    
    plot_meta = get_plot_meta(pp_desc['plot'])
    col_meta = extract_column_meta(data_meta)
    
    res_col, factor_cols = vod(pp_desc,'res_col'), vod(pp_desc,'factor_cols')
    
    draws, continuous, data_format = (vod(plot_meta, n, False) for n in ['draws','continuous','data_format'])
    
    gb_dims = (['draw'] if draws else []) + (factor_cols if factor_cols else []) + (['question'] if 'question' in raw_df.columns else [])
    
    if 'weight' not in raw_df.columns: raw_df = raw_df.assign(weight=1.0) # This also works for empty df-s
    else: raw_df.loc[:,'weight'] = raw_df['weight'].fillna(1.0)

    if draws and 'draw' in raw_df.columns and 'augment_to' in pp_desc: # Should we try to bootstrap the data to always have augment_to points. Note this is relatively slow
        raw_df = augment_draws(raw_df,gb_dims[1:],threshold=pp_desc['augment_to'])
        
    pparams = { 'value_col': 'value' }
    data = None
    
    if data_format=='raw':
        pparams['value_col'] = res_col
        if vod(plot_meta,'sample'):
            data = gb_in(raw_df[gb_dims+[res_col]],gb_dims).sample(plot_meta['sample'],replace=True)
        else: data = raw_df[gb_dims+[res_col]]

    elif False and data_format=='table': # TODO: Untested. Fix when first needed
        ddf = pd.get_dummies(raw_df[res_col])
        res_cols = list(ddf.columns)
        ddf.loc[:,gb_dims] = raw_df[gb_dims]
        data = gb_in(ddf,gb_dims)[res_cols].mean().reset_index()
        
    elif data_format=='longform':
        rc_meta = vod(col_meta,res_col,{})
        if raw_df[res_col].dtype == 'category':  #'categories' in rc_meta: # categorical
            pparams['cat_col'] = res_col 
            pparams['value_col'] = 'percent'
            
            # Aggregate the data
            data = raw_df.groupby(gb_dims+[res_col],observed=False)['weight'].sum()
            if vod(plot_meta,'agg_fn')!='sum': data /= gb_in(raw_df,gb_dims)['weight'].sum()
            data = data.rename(pparams['value_col']).dropna().reset_index()
            
        else: # Continuous
            agg_fn = vod(pp_desc,'agg_fn','mean') # We may want to try median vs mean or plot sd-s or whatever
            agg_fn = vod(plot_meta,'agg_fn',agg_fn) # Some plots mandate this value (election model for instance)
            data = getattr(gb_in(raw_df,gb_dims)[res_col],agg_fn)().dropna().reset_index() 
            pparams['value_col'] = res_col
            
        if vod(plot_meta,'group_sizes'):
            data = data.merge(gb_in(raw_df,gb_dims).size().rename('group_size').reset_index(),on=gb_dims,how='left')
    else:
        raise Exception("Unknown data_format")
        
    # Ensure all rv columns other than value are categorical
    for c in data.columns:
        if c in ['group_size']: continue # bypass some columns added above
        if data[c].dtype.name != 'category' and c!=pparams['value_col']:
            if vod(vod(col_meta,c,{}),'continuous'):
                data.loc[:,c] = discretize_continuous(data[c],vod(col_meta,c,{}))
            else: # Just assume it's categorical by any other name
                data.loc[:,c] = pd.Categorical(data[c])
            
    pparams['data'] = data
    return pparams

In [None]:
pparams = get_filtered_data(full_df, data_meta, args)
fdf = pparams['data']
fdf.sample(5)

Unnamed: 0,draw,EKRE,age_group,percent
799,26,"(-1.0, 0.0]",75+,0.291667
859,28,"(-1.0, 0.0]",75+,0.157895
933,31,"(-3.005, -2.0]",65-74,0.3
1181,39,"(-2.0, -1.0]",45-54,0.3
931,31,"(-3.005, -2.0]",45-54,0.3


In [None]:
#| exporti

# Create a color scale
ordered_gradient = ["#c30d24", "#f3a583", "#94c6da", "#1770ab"]
def meta_color_scale(cmeta,argname='colors',column=None, translate=None):
    scale = vod(cmeta,argname)
    cats = column.dtype.categories if column.dtype.name=='category' else None
    if scale is None and column is not None and column.dtype.name=='category' and column.dtype.ordered:
        scale = dict(zip(cats,gradient_to_discrete_color_scale(ordered_gradient, len(cats))))
    if translate and cats is not None:
        remap = dict(zip(cats,[ translate(c) for c in cats ]))
        scale = { (remap[k] if k in remap else k) : v for k,v in scale.items() } if scale else scale
        cats = [ remap[c] for c in cats ]
    return to_alt_scale(scale,cats)

In [None]:
#| export
def translate_df(df, translate):
    df.columns = [ translate(c) for c in df.columns ]
    for c in df.columns:
        if df[c].dtype.name == 'category':
            cats = df[c].dtype.categories
            remap = dict(zip(cats,[ translate(c) for c in cats ]))
            df[c] = df[c].replace(remap)
    return df

In [None]:
#| export

# Function that takes filtered raw data and plot information and outputs the plot
# Handles all of the data wrangling and parameter formatting
def create_plot(pparams, data_meta, pp_desc, alt_properties={}, alt_wrapper=None, dry_run=False, width=200, return_matrix_of_plots=False, translate=None):
    
    data = pparams['data']

    plot_meta = get_plot_meta(pp_desc['plot'])
    col_meta = extract_column_meta(data_meta)
    
    if 'plot_args' in pp_desc: pparams.update(pp_desc['plot_args'])
    pparams['color_scale'] = meta_color_scale(col_meta[pp_desc['res_col']],'colors',data[pp_desc['res_col']],translate=translate)
    if data[pp_desc['res_col']].dtype.name=='category':
        pparams['cat_order'] = list(data[pp_desc['res_col']].dtype.categories) 
        
    pparams['val_format'] = '.1%' if pparams['value_col'] == 'percent' else '.1f'

    # Handle factor columns 
    factor_cols = vod(pp_desc,'factor_cols',[])
    
    # If we have a question column not handled by the plot, add it to factors:
    pparams['question_col'] = 'question'
    if 'question' in data.columns and not vod(plot_meta,'question'):
        factor_cols = factor_cols + ['question']
    # If we don't have a question column but need it, just fill it with res_col name
    elif 'question' not in data.columns and vod(plot_meta,'question'):
        data.loc[:,'question'] = pd.Categorical([pp_desc['res_col']]*len(pparams['data']))
        
    if vod(plot_meta,'question'):
        pparams['question_color_scale'] = meta_color_scale(col_meta[pp_desc['res_col']],'question_colors',data['question'],translate=translate)
        pparams['question_order'] = list(data['question'].dtype.categories) 
    
    if vod(plot_meta,'continuous') and 'cat_col' in pparams:
        to_ind = 1 if len(factor_cols)>0 and vod(pp_desc,'internal_facet') else 0
        factor_cols = factor_cols.copy()
        factor_cols.insert(to_ind,pparams['cat_col'])
    
    if factor_cols:
        # See if we should use it as an internal facet?
        plot_args = vod(pp_desc,'plot_args',{})
        if vod(pp_desc,'internal_facet'):
            pparams['factor_col'] = factor_cols[0]
            if factor_cols[0] == 'question':
                pparams['factor_color_scale'] = meta_color_scale(col_meta[pp_desc['res_col']],'question_colors',data['question'],translate=translate)
            else:
                pparams['factor_color_scale'] = meta_color_scale(col_meta[factor_cols[0]],'colors',data[factor_cols[0]],translate=translate)
            pparams['factor_order'] = list(data[factor_cols[0]].dtype.categories) 
            factor_cols = factor_cols[1:] # Leave rest for external faceting
            if 'factor_meta' in plot_meta: 
                for kw in plot_meta['factor_meta']: pparams[kw] = vod(col_meta[pparams['factor_col']],kw)
    
    # Handle translations
    if translate:
        # Translate data - column names, categorical columns
        pparams['data'] = data = translate_df(data,translate)
        
        # Provide a list of translated params - translate either direct if string or elemwise if list
        translate_list = ['res_col','value_col','factor_col', 'question_col', 'cat_order', 'factor_order', 'question_order']
        for k, v in pparams.items():
            if k not in translate_list: continue
            if isinstance(v,str): pparams[k] = translate(v)
            else: pparams[k] = [ translate(c) for c in v ]
                
        # Translate facets too
        factor_cols = [ translate(c) for c in factor_cols ]
        
        # Pass translation function to plot in case it has new strings
        pparams['translate'] = translate
    else:
        pparams['translate'] = lambda s: s # Make sure function is given even if it does nothing
    
    # If we still have more than 1 factor - merge the rest
    if len(factor_cols)>1:
        n_facet_cols = len(data[factor_cols[-1]].dtype.categories)
        if not return_matrix_of_plots:
            factor_col = '+'.join(factor_cols)
            data.loc[:,factor_col] = data[factor_cols].agg(', '.join, axis=1)
            pparams['data'] = data
            n_facet_cols = len(data[factor_cols[-1]].dtype.categories)
            factor_cols = [factor_col]
    else:
        n_facet_cols = vod(plot_meta,'factor_columns',1)
    
    # Create the plot using it's function
    if dry_run: return pparams

    if factor_cols: n_facet_cols = vod(plot_args,'n_facet_cols',n_facet_cols) # Allow plot_args to override col nr
    dims = {'width': width//n_facet_cols if factor_cols else width}
    if 'aspect_ratio' in plot_meta:   dims['height'] = int(dims['width']/plot_meta['aspect_ratio'])        
    
    # Make plot properties available to plot function (mostly useful for as_is plots)
    pparams.update(dims); pparams['alt_properties'] = alt_properties; pparams['outer_factors'] = factor_cols
    
    # Trim down parameters list if needed
    plot_fn = get_plot_fn(pp_desc['plot'])
    pparams = clean_kwargs(plot_fn,pparams)
    
    
    if alt_wrapper is None: alt_wrapper = lambda p: p
    if vod(plot_meta,'as_is'): # if as_is set, just return the plot as-is
        return plot_fn(**pparams)
    elif factor_cols:
        if return_matrix_of_plots: 
            del pparams['data']
            combs = it.product( *[data[fc].dtype.categories for fc in factor_cols ])
            #print( [ data[(data[factor_cols]==c).all(axis=1)] for c in combs ] )
            #print(list(combs))
            return list(batch([
                alt_wrapper(plot_fn(data[(data[factor_cols]==c).all(axis=1)],**pparams).properties(title='-'.join(map(str,c)),**dims, **alt_properties))
                for c in combs
                ], n_facet_cols))
        else: # Use faceting:
            plot = alt_wrapper(plot_fn(**pparams).properties(**dims, **alt_properties).facet(f'{factor_cols[0]}:O',columns=n_facet_cols))
    else:
        plot = alt_wrapper(plot_fn(**pparams).properties(**dims, **alt_properties))
        if return_matrix_of_plots: plot = [[plot]]

    return plot

In [None]:
pp_desc = {
    'res_col' : 'thermometer',
    'factor_cols': ['party_preference'],
    'filter': { 'nationality': 'Estonian' },
    'plot': 'matrix-cont',
    'internal_facet': True
}

In [None]:

fdf = get_filtered_data(full_df, data_meta, pp_desc)
#wdf = wrangle_data(fdf, **args, **get_plot_meta(args['plot']))
#fdf['data'].sample(5)
create_plot(fdf,data_meta,pp_desc,width=800)

In [None]:
#| export

# A convenience function to draw a plot straight from a dataset
def e2e_plot(pp_desc, data_file=None, full_df=None, data_meta=None, width=800, check_match=True,lazy=True,**kwargs):
    if data_file is None and full_df is None:
        raise Exception('Data must be provided either as data_file or full_df')
    if data_file is None and data_meta is None:
        raise Exception('If data provided as full_df then data_meta must also be given')
        
    if full_df is None: 
        if data_file.endswith('.parquet'): # Try lazy loading as it only loads what it needs from disk
            full_df, full_meta = load_parquet_with_metadata(data_file,lazy=lazy)
            dm = full_meta['data']
        else: full_df, dm = read_annotated_data(data_file)
        if data_meta is None: data_meta = dm
    
    matches = matching_plots(pp_desc, full_df, data_meta, details=True, list_hidden=True)
    
    if pp_desc['plot'] not in matches: 
        raise Exception(f"Plot not registered: {pp_desc['plot']}")
    
    fit, imp = matches[pp_desc['plot']]
    if  fit<0:
        raise Exception(f"Plot {pp_desc['plot']} not applicable in this situation because of flags {imp}")
        
    pparams = get_filtered_data(full_df, data_meta, pp_desc)
    return create_plot(pparams, data_meta, pp_desc, width=width,**kwargs)

# Another convenience function to simplify testing new plots
def test_new_plot(fn, pp_desc, *args, plot_meta={}, **kwargs):
    stk_plot(**{**plot_meta,'plot_name':'test'})(fn) # Register the plot under name 'test'
    pp_desc = {**pp_desc, 'plot': 'test'}
    res = e2e_plot(pp_desc,*args,**kwargs)
    stk_deregister('test') # And de-register it again
    return res

In [None]:
data_file = '../samples/w25_bootstrap.parquet'
data_metafile = '../../salk_internal_package/data/master_meta.json'
if data_metafile:
    data_meta = read_json(data_metafile)

td = { 'unit': 'Üksus', 'Keskerakond':'Kekre', 'education': 'Haridus', 'Basic education':'Põhiharidus' }

def translate(s):
    return (td[s] if s in td else s)
    
e2e_plot({
    'res_col' : 'age_group',
    'factor_cols': ['party_preference'],
    'filter': {},
    'plot': 'boxplots',
    'internal_facet': True
}, data_file, data_meta=data_meta,width=800, translate=translate)

{'18-24': '#c30d24', '25-34': '#db5853', '35-44': '#f3a583', '45-54': '#c3b6af', '55-64': '#94c6da', '65-74': '#559ac2', '75+': '#1770ab'} ['18-24', '25-34', '35-44', '45-54', '55-64', '65-74', '75+'] {'18-24': '18-24', '25-34': '25-34', '35-44': '35-44', '45-54': '45-54', '55-64': '55-64', '65-74': '65-74', '75+': '75+'}
{'EKRE': '#8B4513', 'Eesti 200': '#31758A', 'Isamaa': '#009BDF', 'Kekre': '#007557', 'Reformierakond': '#FFE200', 'Rohelised': '#88AF47', 'SDE': '#E10600', 'Parempoolsed': 'orange', 'None of the parties': 'grey', 'No opinion': 'lightgrey', 'Other': 'lightgrey'} ['Kekre', 'EKRE', 'Reformierakond', 'Isamaa', 'SDE', 'Rohelised', 'Eesti 200', 'Parempoolsed', 'No party', 'Other', 'Hard to say'] {'Keskerakond': 'Kekre', 'EKRE': 'EKRE', 'Reformierakond': 'Reformierakond', 'Isamaa': 'Isamaa', 'SDE': 'SDE', 'Rohelised': 'Rohelised', 'Eesti 200': 'Eesti 200', 'Parempoolsed': 'Parempoolsed', 'No party': 'No party', 'Other': 'Other', 'Hard to say': 'Hard to say'}


In [None]:
# Test e2e_plot
#alt.data_transformers.disable_max_rows()
pp_desc = {
    'res_col' : 'income',
    #'factor_cols': ['gender'],
    'filter': { 'nationality': 'Estonian' },
    'plot': 'boxplots',
    'internal_facet': True
}
e2e_plot(pp_desc,data_uri)

In [None]:
import altair as alt

# Test test_new_plot
def smooth(data, cat_col, value_col='value', color_scale=alt.Undefined, factor_col=None):
    options_cols = list(data[cat_col].dtype.categories)
    ldict = dict(zip(options_cols, range(len(options_cols))))
    data.loc[:,'order'] = data[cat_col].replace(ldict)
    plot=alt.Chart(data
        ).mark_area(interpolate='natural').encode(
            x=alt.X(f'{factor_col}:O', title=None),
            y=alt.Y(f'{value_col}:Q', title=None, stack='normalize',
                 scale=alt.Scale(domain=[0, 1]), axis=alt.Axis(format='%')
                 ),
            order="order:O",
            color=alt.Color(cat_col, legend=alt.Legend(orient='top', title=None),
                sort=alt.SortField("order", "descending"), scale=color_scale
                )
        )
    return plot

test_new_plot(smooth, {
    'res_col' : 'party_preference',
    'factor_cols': ['age_group','gender'],  'filter': {},
    'plot': 'area_smooth',
    'internal_facet': True
}, full_df=full_df, data_meta=data_meta, plot_meta={})

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

# Testing

In [None]:
alt.data_transformers.disable_max_rows()

In [None]:
data_file = '../samples/w25_bootstrap.parquet'
data_metafile = '../../salk_internal_package/data/master_meta.json'
if data_metafile:
    data_meta = read_json(data_metafile)
    
e2e_plot({
    'res_col' : 'party_preference',
    'factor_cols': ['unit'],
    'filter': {},
    'plot': 'geoplot',
    'internal_facet': True
}, data_file, data_meta=data_meta,width=800, return_matrix_of_plots=True)