In [None]:
import os
import sys

# Adjust import path to import turnout models
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import copy
import datetime
import hashlib
import itertools as it
import json
import logging
import pathlib
import shutil
import traceback
import arviz as az
import numpy as np
import pandas as pd
import statsmodels.api as sm
import salk_turnout_models as tm
from time import perf_counter
from tqdm.notebook import tqdm

In [None]:
import warnings

warnings.filterwarnings('ignore', module='arviz', message='invalid value encountered in scalar divide')

In [None]:
LOGGER = logging.getLogger('synth-data-models')
LOGGER.setLevel(logging.DEBUG)

## Models

In [None]:
def df_margins(pdf, columns, outcome):
    if columns:
        return (pdf.groupby(columns, observed=False)[outcome].value_counts() / len(pdf)).rename('proportion')
    else:
        return (pdf[outcome].value_counts() / len(pdf)).rename('proportion')

def df_turnout(pdf, columns, outcome):
    if columns:
        groups = pdf.groupby(columns, observed=True)
        return (groups[outcome].value_counts() / groups[outcome].size()).rename('proportion')
    else:
        return (pdf[outcome].value_counts() / len(pdf)).rename('proportion')

def cell_margins(df, cell_cols, margin_cols, outcome):
    avg_df = df.groupby(cell_cols)[['N', 'N_census']].mean().reset_index()

    if margin_cols:
        grouped_df = avg_df.groupby(margin_cols)[['N', 'N_census']].sum().reset_index()
    else:
        grouped_df = pd.DataFrame({'N': avg_df['N'].sum(), 'N_census': avg_df['N_census'].sum()}, index=[0])

    yes_df = grouped_df.copy()
    yes_df['proportion'] = yes_df['N'] / yes_df['N_census'].sum()
    yes_df[outcome] = 'Yes'

    no_df = grouped_df.copy()
    no_df['proportion'] = (no_df['N_census'] - no_df['N']) / no_df['N_census'].sum()
    no_df[outcome] = 'No'

    return pd.concat([yes_df, no_df])[margin_cols + [outcome, 'proportion']].set_index(margin_cols + [outcome])

def cell_turnout(df, cell_cols, margin_cols, outcome):
    avg_df = df.groupby(cell_cols)[['N', 'N_census']].mean().reset_index()

    if margin_cols:
        grouped_df = avg_df.groupby(margin_cols)[['N', 'N_census']].sum().reset_index()
    else:
        grouped_df = pd.DataFrame({'N': avg_df['N'].sum(), 'N_census': avg_df['N_census'].sum()}, index=[0])

    yes_df = grouped_df.copy()
    yes_df['proportion'] = yes_df['N'] / yes_df['N_census']
    yes_df[outcome] = 'Yes'

    no_df = grouped_df.copy()
    no_df['proportion'] = (no_df['N_census'] - no_df['N']) / no_df['N_census']
    no_df[outcome] = 'No'

    return pd.concat([yes_df, no_df])[margin_cols + [outcome, 'proportion']].set_index(margin_cols + [outcome])

def kl_divergence(margins_df, epsilon=1e-10):
    # Clip values to avoid zero division and log(0)
    p = np.clip(margins_df['proportion_pop'].values.flatten(), epsilon, 1)
    q = np.clip(margins_df['proportion_mod'].values.flatten(), epsilon, 1)
    return np.sum(p * np.log(p / q)).item()

def em_distance(margins_df, epsilon=1e-10):
    # Clip values to avoid zero division and log(0)
    p = np.clip(margins_df['proportion_pop'].values.flatten(), epsilon, 1)
    q = np.clip(margins_df['proportion_mod'].values.flatten(), epsilon, 1)
    return np.abs(p - q).sum().item() / 2

def get_distances(pop_df, mod_df, columns, outcome):
    pop_margins = df_margins(pop_df, columns, 'voting_intent').reset_index()
    mod_margins = cell_margins(mod_df, columns, columns, 'voting_intent').reset_index()

    margins_df = pd.merge(pop_margins, mod_margins, on=columns + [outcome], how='outer', suffixes=('_pop', '_mod')).fillna(0)

    emd = em_distance(margins_df)
    kld = kl_divergence(margins_df)

    emd_1d = np.array([em_distance(margins_df.groupby([col, outcome])[['proportion_pop', 'proportion_mod']].sum()) for col in columns])
    kld_1d = np.array([kl_divergence(margins_df.groupby([col, outcome])[['proportion_pop', 'proportion_mod']].sum()) for col in columns])

    emd_2d = np.array([em_distance(margins_df.groupby([c1, c2, outcome])[['proportion_pop', 'proportion_mod']].sum()) for c1, c2 in it.combinations(columns, 2)])
    kld_2d = np.array([kl_divergence(margins_df.groupby([c1, c2, outcome])[['proportion_pop', 'proportion_mod']].sum()) for c1, c2 in it.combinations(columns, 2)])

    topline_margins = {
        'topline_yes': mod_margins[mod_margins.voting_intent == 'Yes']['proportion'].sum().item(),
        'pop_topline_yes': pop_margins[pop_margins.voting_intent == 'Yes']['proportion'].sum().item(),
    }

    return {
        'kld': kld,
        'kld_1d': kld_1d.mean().item(),
        'kld_2d': kld_2d.mean().item(),
        'emd': emd,
        'emd_1d': emd_1d.mean().item(),
        'emd_2d': emd_2d.mean().item(),
    } | topline_margins

def get_coefs(model_path, posterior=None):
    if posterior is None:
        posterior = az.from_netcdf(model_path / 'idata.nc').posterior
    heckman_coefs = json.load(open(model_path / 'heckman_coefs.json'))

    var_names = {
        'selection': 'selection',
        'outcome': 'outcome',
    }

    ignore_cols = ['selection_latent']

    dfs = []

    for p in ['selection', 'outcome']:
        vname = f'{var_names[p]}_%s_effect'
        hcoefs = heckman_coefs[p].get('beta', {})

        mvals = pd.Series()
        rvals = pd.Series()

        for c in hcoefs:
            cname = c

            if ':' in c:
                c1, c2 = c.split(':')

                if c1 in ignore_cols or c2 in ignore_cols:
                    continue

                index = []

                if vname % f'{c1},{c2}' in posterior:
                    c1, c2 = c1, c2

                    for cat in hcoefs[c]:
                        cat1, cat2 = cat.split(':')
                        index.append((cat1, cat2))
                elif vname % f'{c2},{c1}' in posterior:
                    c1, c2 = c2, c1

                    for cat in hcoefs[c]:
                        cat2, cat1 = cat.split(':')
                        index.append((cat1, cat2))
                else:
                    continue

                cname = f'{c1}:{c2}'
                rvals = pd.Series(hcoefs[c].values(), index=pd.MultiIndex.from_tuples(index, names=[c1, c2])).rename('real')
                mvals = posterior[vname % f'{c1},{c2}'].mean(dim=['chain','draw']).to_series().rename('model')
            else:
                rvals = pd.Series(hcoefs[c]).rename('real')
                if vname % c not in posterior: continue
                mvals = posterior[vname % c].mean(dim=['chain','draw']).to_series().rename('model')

            rvals -= rvals.mean()
            mvals -= mvals.mean()

            df = pd.concat([rvals, mvals], axis=1)
            df['index'] = list(df.index)
            df['process'] = p
            df['var'] = cname
            df.reset_index(drop=True, inplace=True)

            dfs.append(df)

    if len(dfs) == 0:
        return pd.DataFrame({'process': [], 'var': [], 'index': [], 'real': [], 'model': []})

    coefs = pd.concat(dfs)

    return coefs[['process', 'var', 'index', 'real', 'model']]

In [None]:
class StreamToLogger:
    """
    Fake file-like stream object that redirects writes to a logger instance.
    https://stackoverflow.com/questions/19425736/how-to-redirect-stdout-and-stderr-to-logger-in-python
    """
    def __init__(self, stream, logger, level):
        self.stream = stream
        self.logger = logger
        self.level = level
        self.linebuf = ''

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.level, line.rstrip())
        self.stream.write(buf)

    def flush(self):
        self.stream.flush()

class CaptureStdStreams:
    def __init__(self, logger):
        self.logger = logger

    def __enter__(self):
        self.sys_stdout = sys.stdout
        self.sys_stderr = sys.stderr
        sys.stdout = self.stdout = StreamToLogger(sys.stdout, self.logger, logging.INFO)
        sys.stderr = self.stderr = StreamToLogger(sys.stderr, self.logger, logging.ERROR)

    def __exit__(self, exc_type, exc_value, traceback):
        sys.stdout = self.sys_stdout
        sys.stderr = self.sys_stderr

class ModelLogger:
    def __init__(self, root_logger, mname, mpath):
        self.root_logger = root_logger
        self.logger = root_logger.getChild(mname)

        self.formatter = logging.Formatter(
            fmt='[{name}] {asctime} {levelname}: {message}',
            datefmt='%m/%d/%Y %H:%M:%S',
            style='{'
        )

        self.fh = logging.FileHandler(mpath / 'log.txt', mode='w')
        self.fh.setFormatter(self.formatter)
        self.logger.addHandler(self.fh)

    def __enter__(self):
        return self.logger

    def __exit__(self, type, value, traceback):
        pass

class Timer:
    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start

def dict_apply(obj, key, func):
    if isinstance(obj, dict):
        for k, v in obj.items():
            if k == key:
                obj[k] = func(v)
            else:
                obj[k] = dict_apply(v, key, func)
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            obj[i] = dict_apply(v, key, func)
    return obj

def get_file_hash(filename, root=None):
    if root:
        file_path = pathlib.Path(root) / filename
    else:
        file_path = pathlib.Path(filename)

    if file_path.suffix == '.json':
        meta = json.load(open(file_path, 'r'))
        meta = dict_apply(copy.deepcopy(meta), 'file', lambda fn: get_file_hash(fn, file_path.parent))
        return hashlib.md5(json.dumps(meta, sort_keys=True).encode('utf-8')).hexdigest()
    else:
        return hashlib.sha256(open(file_path, 'rb').read()).hexdigest()

def get_model_id(model_desc):
    desc = copy.deepcopy(model_desc)
    # Ignore model name in ID calculation
    desc.pop('name')
    # Replace file paths with the hashes of the file contents
    desc = dict_apply(desc, 'file', get_file_hash)
    return hashlib.md5(json.dumps(desc, sort_keys=True).encode('utf-8')).hexdigest()

def run_model(model_desc, population_file, model_data, path_prefix='./tmp', root_logger=LOGGER, progress=None, progress_postfix=None):
    def aggregate_draws_to_chain_means(draws_df: pd.DataFrame) -> pd.DataFrame:
        # Collapse posterior draws to per-chain means (drops `draw`), keeping cell dims.
        if 'draw' not in draws_df.columns or 'chain' not in draws_df.columns or 'N' not in draws_df.columns:
            return draws_df

        group_cols = [c for c in draws_df.columns if c not in ['N', 'draw']]
        if 'N_census' in draws_df.columns:
            group_cols_wo_nc = [c for c in group_cols if c != 'N_census']
            agg_df = (
                draws_df
                .groupby(group_cols_wo_nc, observed=False, sort=False)
                .agg(N=('N', 'mean'), N_census=('N_census', 'first'))
                .reset_index()
            )
        else:
            agg_df = (
                draws_df
                .groupby(group_cols, observed=False, sort=False)['N']
                .mean()
                .reset_index()
            )

        return agg_df

    model_id = get_model_id(model_desc)
    model_name = f'{model_desc["name"]}-{model_id}'
    model_path = pathlib.Path(path_prefix) / model_id
    model_path.mkdir(parents=True, exist_ok=True)

    link_path = pathlib.Path(path_prefix) / model_name
    if not link_path.is_symlink():
        os.symlink(model_path.name, link_path)

    if progress:
        if progress_postfix:
            progress.set_postfix({**progress_postfix, 'model': model_name})
        else:
            progress.set_postfix({'model': model_name})
    else:
        print('Model:', model_name)

    model_path.mkdir(parents=True, exist_ok=True)

    summary_path = model_path / 'summary.json'

    if summary_path.exists():
        summary_data = json.load(open(summary_path, 'r'))
        # Cleanup legacy idata.nc if present
        try:
            os.unlink(model_path / 'idata.nc')
        except FileNotFoundError:
            pass
        return {
            'model_path': str(model_path),
            'model_id': model_id,
            'model_name': model_desc['name'],
        } | summary_data

    # Load or fit model
    if (model_path / 'draws.parquet').exists():
        mod_df = pd.read_parquet(model_path / 'draws.parquet')
        fit_time = float(open(model_path / 'time.txt').read())
    else:
        with ModelLogger(root_logger, model_name, model_path) as logger, CaptureStdStreams(logger):
            try:
                with open(model_path / 'model_desc.json', 'w') as f:
                    json.dump(model_desc, f)

                with Timer() as timer:
                    run_result = tm.run_model(model_desc, sample_kwargs=model_desc.get('sample_kwargs'), save_path=str(model_path))
                    mod_df = run_result['draws']

                root_logger.info(f'Model {model_name} run in {timer.time:.3f} seconds')
                fit_time = timer.time

                with open(model_path / 'time.txt', 'w') as f:
                    f.write(f'{timer.time:.3f}')

                with open(model_path / 'model_id.txt', 'w') as f:
                    f.write(model_id)

                if not (model_path / 'heckman_coefs.json').exists():
                    shutil.copyfile(model_data['heckman_coefs.json'], model_path / 'heckman_coefs.json')
            except KeyboardInterrupt as e:
                raise e
            except Exception as e:
                root_logger.error(f'Error running model {model_name}: {e}')
                traceback.print_exc(file=sys.stderr)
                return {
                    'model_path': str(model_path),
                    'model_id': model_id,
                    'model_name': model_desc['name'],
                }

    model_identifiers = {
        'model_path': str(model_path),
        'model_id': model_id,
        'model_name': model_desc['name'],
    }

    # Population
    pop_cols = ['age_group', 'education', 'gender', 'nationality', 'electoral_district', 'unit', 'voting_intent']
    pop_dtype = {col: 'category' for col in pop_cols}
    if str(population_file).endswith('.parquet'):
        pop_df = pd.read_parquet(population_file)[pop_cols]
    else:
        pop_df = pd.read_csv(population_file, dtype=pop_dtype)[pop_cols]

    # Distances (includes BPV if draw-level is available)
    margin_cols = model_desc['input_cols']
    distances = get_distances(pop_df, mod_df, margin_cols, 'voting_intent')

    # Diagnostics from idata.nc
    idata = az.from_netcdf(model_path / 'idata.nc')
    sdf = az.summary(idata)
    mean_rhat = sdf.r_hat.mean().item()
    divergences = idata.sample_stats.diverging.sum().item()

    model_fit = {
        'fit_time': fit_time,
        'mean_rhat': mean_rhat,
        'divergences': divergences,
    }

    posterior = idata.posterior

    coefs = get_coefs(model_path, posterior=posterior)
    ocoefs = coefs[coefs['process'] == 'outcome']

    # Fit robust regression using Huber's T norm
    X_with_const = sm.add_constant(ocoefs['real'].values)
    y = ocoefs['model'].values
    rlm_model = sm.RLM(y, X_with_const, M=sm.robust.norms.HuberT())
    rlm_fit = rlm_model.fit()
    y_pred = rlm_fit.predict(X_with_const)
    mae = np.mean(np.abs(y - y_pred))

    model_coefs = {}
    if 'rho' in posterior:
        model_coefs['sp_slope_vi_mean'] = posterior['rho'].mean(dim=['chain', 'draw']).item()
        model_coefs['sp_slope_vi_sd'] = posterior['rho'].std(dim=['chain', 'draw']).item()

    rlm_stats = {
        'rlm_slope': rlm_fit.params[1],
        'rlm_mae': mae,
    }

    # Persist aggregates for future runs
    summary_data = distances | rlm_stats | model_fit | model_coefs
    with open(summary_path, 'w', encoding='utf-8') as f:
        json.dump(summary_data, f, ensure_ascii=False, indent=2)

    # Shrink draws.parquet: per-chain means (drop draw)
    if 'draw' in mod_df.columns:
        agg_df = aggregate_draws_to_chain_means(mod_df)
        agg_df.to_parquet(model_path / 'draws.parquet', index=False)

    # Remove idata.nc after extracting summary stats
    try:
        os.unlink(model_path / 'idata.nc')
    except FileNotFoundError:
        pass

    return model_identifiers | summary_data

In [None]:
MODEL_DESCRIPTION_TEMPLATE = {
    'outcome_col': 'voting_intent',
    'population': 'census_data.csv',
}

def replace_value(obj, orig_value, new_value):
    if isinstance(obj, dict):
        for k, v in obj.items():
            obj[k] = replace_value(v, orig_value, new_value)
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            obj[i] = replace_value(v, orig_value, new_value)
    elif obj == orig_value:
        obj = new_value
    return obj

def get_model_description(name, config):
    model_desc = copy.deepcopy(MODEL_DESCRIPTION_TEMPLATE)
    model_desc['name'] = name
    model_desc['model_type'] = config['model_type']
    model_desc['input_cols'] = sorted(config.get('input_cols', ['age_group', 'gender', 'education', 'unit', 'nationality']))

    if interactions := config.get('interactions') is not None:
        model_desc['interactions'] = interactions

    if model_desc['model_type'] in ['BP', 'GG', 'PM', 'FS']:
        model_desc['survey'] = 'survey_data.csv'

    if model_desc['model_type'] in ['EI', 'GG', 'PM', 'FS']:
        margin_cols = sorted(config.get('margin_cols', ['unit']))
        margin_file_name = f'{"_".join(margin_cols)}_margins_data.csv' if len(margin_cols) > 0 else 'margins_data.csv'
        model_desc['margin'] = margin_file_name

    if config.get('priors_scale_sigma'):
        model_desc['priors'] = config.get('priors', {})
        model_desc['priors']['scale_sigma'] = config['priors_scale_sigma']

    if config.get('sample_kwargs') is not None:
        model_desc['sample_kwargs'] = config['sample_kwargs']

    if config.get('multilevel') is not None:
        model_desc['multilevel'] = config['multilevel']

    if config.get('imr') is not None:
        model_desc['imr'] = config['imr']

    if config.get('centered') is not None:
        model_desc['centered'] = config['centered']

    if config.get('margin_dist') is not None:
        model_desc['margin_dist'] = config['margin_dist']

    return model_desc

def get_model_data(data_name, margin_cols=[['unit']], tmp_data_prefix='../tmp/data'):
    return {
        'census_data.csv': '../data/census.csv',
        'population.csv': f'{tmp_data_prefix}/{data_name}/population.parquet',
        'survey_data.csv': f'{tmp_data_prefix}/{data_name}/estonia_selection.csv',
        'margins_data.csv': f'{tmp_data_prefix}/{data_name}/estonia_margins.csv',
        'heckman_coefs.json': f'{tmp_data_prefix}/{data_name}/heckman_coefs.json',
    } | {
        f'{"_".join(sorted(cols))}_margins_data.csv': f'{tmp_data_prefix}/{data_name}/estonia_{"_".join(sorted(cols))}_margins.csv' for cols in margin_cols
    }
    
def create_symlink(target_path, link_path):
    if isinstance(link_path, str):
        link_path = pathlib.Path(link_path)

    if link_path.is_symlink():
        os.unlink(link_path)
    os.symlink(target_path, link_path)

def get_data_list(file_path, n_seeds_limit=1000):
    data_list = json.load(open(file_path, 'r'))
    seeds = {data_name: np.unique([data_config['seed'] for _, data_config in data_descs]) for data_name, data_descs in data_list.items()}
    return {data_name: [[data_id, data_config] for data_id, data_config in data_descs if data_config['seed'] in seeds[data_name][:n_seeds_limit]] for data_name, data_descs in data_list.items()}

In [None]:
tmp_data_prefix = '../tmp/data'
data_list = get_data_list(f'{tmp_data_prefix}/data_list.json', 10)

demography_cols = ['age_group', 'gender', 'education', 'unit', 'nationality', 'electoral_district']
input_cols = [col for col in demography_cols if col != 'electoral_district']
all_margin_cols = [[col] for col in demography_cols] + [[c1, c2] for c1, c2 in it.combinations(demography_cols, 2)]

bp_config = {'model_type': 'BP', 'input_cols': input_cols, 'centered': False }
ei_config = {'model_type': 'EI', 'input_cols': input_cols, 'centered': True  }
gg_config = {'model_type': 'GG', 'input_cols': input_cols, 'centered': False }
pm_config = {'model_type': 'PM', 'input_cols': input_cols, 'centered': True  }
fs_config = {'model_type': 'FS', 'input_cols': input_cols, 'centered': True  }

common_models = [
    get_model_description('1_bp', bp_config),
    get_model_description('2_ei', ei_config),
    get_model_description('3_gg', gg_config),
    get_model_description('4_pm', pm_config),
    get_model_description('5_fs', fs_config),
]

int_models = [
    get_model_description('1_int_bp', bp_config | {'interactions': True}),
    get_model_description('2_int_ei', ei_config | {'interactions': True}),
    get_model_description('3_int_gg', gg_config | {'interactions': True}),
    get_model_description('4_int_pm', pm_config | {'interactions': True}),
    get_model_description('5_int_fs', fs_config | {'interactions': True}),
]

def get_margin_model_name(margin_cols):
    return 'tpl' if len(margin_cols) == 0 else '_'.join(margin_cols)

margin_cols_list = [[], ['unit'], ['electoral_district']]
ei_margin_models = [get_model_description(f'2_margin_{get_margin_model_name(margin_cols)}_ei', ei_config | {'input_cols': demography_cols, 'margin_cols': margin_cols}) for margin_cols in margin_cols_list]
gg_margin_models = [get_model_description(f'3_margin_{get_margin_model_name(margin_cols)}_gg', gg_config | {'input_cols': demography_cols, 'margin_cols': margin_cols}) for margin_cols in margin_cols_list]
pm_margin_models = [get_model_description(f'4_margin_{get_margin_model_name(margin_cols)}_pm', pm_config | {'input_cols': demography_cols, 'margin_cols': margin_cols}) for margin_cols in margin_cols_list]
fs_margin_models = [get_model_description(f'5_margin_{get_margin_model_name(margin_cols)}_fs', fs_config | {'input_cols': demography_cols, 'margin_cols': margin_cols}) for margin_cols in margin_cols_list]
margin_models = ei_margin_models + gg_margin_models + pm_margin_models + fs_margin_models #+ afs_margin_models + oppm_margin_models + opafs_margin_models + pmdei_margin_models

scale_models = [
    get_model_description('1_scale_bp', bp_config | {'priors_scale_sigma': 0.5}),
    get_model_description('2_scale_ei', ei_config | {'priors_scale_sigma': 0.5}),
    get_model_description('3_scale_gg', gg_config | {'priors_scale_sigma': 0.5}),
    get_model_description('4_scale_pm', pm_config | {'priors_scale_sigma': 0.5}),
    get_model_description('5_scale_fs', fs_config | {'priors_scale_sigma': 0.5}),
]

experiments = {
    'est-default': common_models #+ int_models + scale_models,
    'est-electoral-district': margin_models,
    'est-no-selection': common_models,
    'est-non-response': common_models,
    'est-overreport-const': common_models,
    'est-heck-cor': common_models,
    'est-hcoef-cor': common_models,
    'est-hcoef-sigma': common_models,
    'est-non-normal-error': common_models,
    'est-agg-bias': common_models,
    'est-noise': common_models,
    'est-sample-size': common_models,
    'est-int': common_models + int_models,
}

model_path_prefix = '../tmp/models'

model_results = []

model_counts = {exp_name: len(experiments[exp_name]) * len(data_list[exp_name]) for exp_name in experiments}
print(model_counts)

progress = tqdm(total=sum(model_counts.values()))

for experiment_name, model_descriptions in experiments.items():
    for data_name, data_desc in data_list[experiment_name]:
        model_data = get_model_data(data_name, all_margin_cols)

        for desc in model_descriptions:
            # Apply data configuration to the model description
            model_desc = copy.deepcopy(desc)

            for key, value in model_data.items():
                model_desc = replace_value(model_desc, key, value)

            progress_postfix = {'experiment': experiment_name, 'dataset': data_name, 'description': data_desc}
            model_result = run_model(model_desc, model_data['population.csv'], model_data, path_prefix=model_path_prefix, progress=progress, progress_postfix=progress_postfix)
            model_result['data_name'] = '-'.join(data_name.split('-')[:-1])
            model_result['data_id'] = data_name
            model_result['desc'] = ','.join([f'{k.split('/')[-1]}={v}' for k, v in data_desc.items()])
            is_success = model_result.get('fit_time') is not None

            model_results.append(model_result)

            if is_success:
                create_symlink(pathlib.Path(model_result['model_path']).name, pathlib.Path(model_path_prefix) / model_desc['name'])

            progress.update(1)

progress.close()

model_results_df = pd.concat([pd.DataFrame(data=result, index=[result['model_path']]) for result in model_results])

ignored_cols = ['model_path']
ordered_cols = ['model_id', 'data_name', 'data_id', 'desc', 'model_name']
other_cols = [col for col in model_results_df.columns if col not in ordered_cols and col not in ignored_cols]

model_results_df = model_results_df[ordered_cols + other_cols]
model_results_path = pathlib.Path(f'../tmp/models/{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}_model_results.csv')
model_results_df.to_csv(model_results_path)
create_symlink(model_results_path.name, '../tmp/models/model_results.csv')