In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import pandas as pd
pd.options.display.max_columns = None
from pandas_ply import install_ply, X, sym_call
install_ply(pd)
import copy


from support.mexlet.jupyter import print_df


def _print(_a):
    print_df(_a)
    return _a

In [2]:
from support.mexlet.records.aggregator import MongoAggregator, PdAggregator
from support.database.mongo import get_table as orig_get_table

port = 27017
recording_set = 'icml2020'
table_name = 'icml2020'
dbname = 'causal_da'
record_pickle_base = 'pickle/records/'


def get_table(table_name, dbname, port):
    host = 'localhost'
    user = 'me'
    password = 'pass'
    return orig_get_table(table_name, host, port, user, password, dbname)

## Read baseline results from MongoDB

In [3]:
BASELINE_METHODS = [
    'naive',
    'gdm',
    'copula',
    'iw-base',
    'tr-adaboost',
]

# _table = get_table(table_name, dbname=dbname, port=port)
# baseline_records = MongoAggregator(_table, query={
#     'recording_set': recording_set,
#     'data': 'gasoline',
#     'method': {'$in': BASELINE_METHODS}
# }).get_results_pd()

# baseline_records.to_pickle(record_pickle_base + 'icml2020_baseline_records.pkl')

baseline_records = pd.read_pickle(record_pickle_base + 'icml2020_baseline_records.pkl')

## Read results from MongoDB

In [4]:
### Get updates
# # Results
# _table = get_table(table_name, dbname=dbname, port=port)
# proposed_records = MongoAggregator(_table, query={
#     'recording_set': recording_set,
#     'data': 'gasoline',
#     'method': 'causal-da'
# }).get_results_pd()
# _run_id = proposed_records['sacred_run_id'].tolist()

# # Sacred metrics
# _table = get_table('metrics', dbname='sacred', port=port)
# raw_metrics = MongoAggregator(_table, query={'run_id': {'$in': _run_id}}).get_results_pd(index=None)

# proposed_records.to_pickle(record_pickle_base + 'icml2020_proposed_records.pkl')
# raw_metrics.to_pickle(record_pickle_base + 'icml2020_proposed_raw_metrics.pkl')

### Load
proposed_records = pd.read_pickle(record_pickle_base + 'icml2020_proposed_records.pkl')
raw_metrics = pd.read_pickle(record_pickle_base + 'icml2020_proposed_raw_metrics.pkl')

In [5]:
all_records = proposed_records.append(baseline_records)

of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.


  sort=sort)


In [6]:
data = all_records
cv_group = ['data', 'data_run_id']

In [7]:
from support.mexlet.views.sacred import cleanse_nans_from_metrics

metrics_all = cleanse_nans_from_metrics(raw_metrics, 'epoch_model AugKRR_MSE', 'epoch_model target_cv_KRR_MSE')

In [8]:
# Virtually early-stop and cross-validate
from support.mexlet.records import VirtualValidation
from support.mexlet.views import pd_add_column
from support.mexlet.views.sacred import select_by_argmin, select_by_quantile

def _get_proposed(method_names, selector, rename=None):
    _metrics = metrics_all.ply_where(X.run_id.isin(data.ply_where(
            lambda x: np.isin(x['method'], method_names)
    )['sacred_run_id']))
    metrics = _metrics[['run_id', 'name', 'values', 'steps']]

    test_metric = 'epoch_model AugKRR_MSE'

    # This part performs early stopping from the Sacred metrics log.
    selected_values = select_by_argmin(metrics, test_metric, selector, lambda _: 'test_metric')
    selector_values = select_by_argmin(metrics, selector, selector, lambda _: 'min_selector')
    _added = pd_add_column(data, selected_values, 'sacred_run_id', 'run_id')
    _added = pd_add_column(_added, selector_values, 'sacred_run_id', 'run_id')
    if rename is not None:
        _added['method'] = rename

    return _added


## Select proposed method
def get_proposed_method(selector):
    _added = _get_proposed(['causal-da'], selector)
    _cv = cv_group + ['method']
    return VirtualValidation(_added).fit(_cv, [
       ('min_selector', {'larger_is_better': False}),
    ])[_cv + ['target_c'] + ['test_metric'] + ['sacred_run_id']]


def _fetch_naive_baseline(method, basename):
    _a = (data
          .ply_where(X.method == basename)
          .rename(columns={f'KRR_MSE_{method}': 'test_metric'})
          .groupby(["data", "data_run_id"]).first().reset_index()
         )
    _a['method'] = method
    return _a


def get_naive_baseline(method):
    return _fetch_naive_baseline(method, 'naive')[cv_group + ['target_c'] + ['method'] + ['test_metric']]


def get_iw():
    _a = (data
         .ply_where(X.method == 'iw-base')
         .ply_select('*', test_metric = X.MSE))
    _cv = cv_group + ['method']
    _result = pd.DataFrame(columns=_a.columns)
    for alpha in _a.alpha.unique():
        _aa = _a.ply_where(X.alpha == alpha)
        _aa['method'] = _aa['method'] + _aa['alpha'].apply(lambda alpha: f'(alpha={alpha})')
        _result = _result.append(_aa)
    _result = VirtualValidation(_result).fit(_cv, [('loocv_score', {'larger_is_better': False})])
    return _result[cv_group + ['target_c', 'method', 'test_metric']]


def get_copula():
    _a = (data
         .ply_where(X.method == 'copula')
         .ply_select('*', test_metric = X.MSE))
    return _a[cv_group + ['target_c', 'method', 'test_metric']]


def get_gdm():
    _a = (data
         .ply_where(X.method == 'gdm')
         .ply_select('*', test_metric = X.MSE))
    _a = VirtualValidation(_a).fit(
        cv_group + ['method'], [('valid_error', {'larger_is_better': False})])
    return _a[cv_group + ['target_c', 'method', 'test_metric']]


def get_tradaboost():
    """Retrieve the results of TrAdaBoost.

    Note
    ----
    In this method, 'max_threads' is used as a dummy for the CV error value.
    This is because the CV for TrAdaBoost is performed internally in the method
    and the CV is unncecessary here but at the same time we need to format the dataframe
    into a shape that is output by VirtualValidation to make things simpler.
    """
    _a = (data
         .ply_where(X.method == 'tr-adaboost')
         .ply_select('*', test_metric = X.MSE,
                     cv_error = X.max_threads))
    _a = VirtualValidation(_a).fit(
        cv_group + ['method'], [(
            'cv_error',
            {'larger_is_better': False})])
    return _a[cv_group + ['target_c', 'method', 'test_metric']]


_result_table = (
    get_proposed_method('epoch_model target_cv_KRR_MSE')
    .append(get_naive_baseline('TargetOnly'))
    .append(get_naive_baseline('SourceOnly'))
    .append(get_naive_baseline('SourceAndTargetTargetValidate'))
    .append(get_naive_baseline('LOO'))
    .append(get_iw())
    .append(get_copula())
    .append(get_gdm())
    .append(get_tradaboost())
    .sort_values(['data', 'data_run_id'])
)

of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.


  sort=sort)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


In [9]:
import copy

def normalize(table, normalizing_method):
    table = pd_add_column(
        table.ply_select('*', tag = X.data + X.data_run_id),
        table.ply_where(X.method == normalizing_method)
        .ply_select('*', tag = X.data + X.data_run_id, normalizer_val=X.test_metric)
        [['tag', 'normalizer_val']],
        'tag', 'tag').drop(columns='tag')
    table['relative_test_metric'] = table['test_metric'] / table['normalizer_val']
    return table

In [10]:
# Aggregation
import copy


def reformat(_result_table):
    result_table = copy.copy(_result_table)
    # Flatten the target label
    result_table['target_c'] = result_table['target_c'].apply(lambda x: x[0])
    # Concatenate LOO score for each group as a column
    result_table = normalize(result_table, 'LOO')
    result_table['grouping'] = result_table['target_c']
    result_table_mean = result_table.groupby(['grouping', 'method']).mean().reset_index().ply_select('*', _merge_key=X.grouping + X.method).rename(columns={'relative_test_metric': 'rel_test_mean'})
    result_table_stderr = result_table.groupby(['grouping', 'method']).sem().reset_index().ply_select('*', _merge_key=X.grouping + X.method).rename(columns={'relative_test_metric': 'rel_test_stderr', 'grouping': 'grouping_stderr', 'method': 'method_stderr'})
    _res = pd_add_column(result_table_mean, result_table_stderr, '_merge_key')

    #######
    ## Mark the best score
    ###
    _min_vals = (_res.ply_where(X.method != 'LOO')
             .groupby('grouping')[['grouping', 'rel_test_mean']]
             .transform('min')
             .drop_duplicates('grouping')
             .rename(columns={'rel_test_mean': 'rel_test_mean_min_in_group'})
            )
    _res = (pd_add_column(_res, _min_vals, 'grouping').ply_select('*', is_min=X.rel_test_mean_min_in_group == X.rel_test_mean))
    def _fstr(x):
        if x['method'] == 'LOO':
            return f"{x['rel_test_mean']:.0f}"
        ret = f"{x['rel_test_mean']:.2f} ({x['rel_test_stderr']:.2f})"
        if x['is_min']:
            return "\\textbf{" + ret + "}"
        else:
            return ret
    #######
    _res['relative_test_metric'] = _res.apply(_fstr, axis=1)

    def _fcount(x):
        if x['method'] == 'LOO':
            return None
        ret = f"{x['rel_test_mean']:.2f} ({x['rel_test_stderr']:.2f})"
        if x['is_min']:
            return "\\textbf{" + ret + "}"
        else:
            return ret

    best_score_counts = _res.groupby('method')['is_min'].sum().apply(int).to_dict()

    # Reformat the table
    _res = _res.pivot(index='grouping', columns='method', values='relative_test_metric')
    return _res, best_score_counts


exclude = []
proposed_name = 'causal-da'
result_table, best_score_counts = reformat(_result_table.ply_where(lambda x: np.logical_not(np.isin(x['method'], exclude))))

_print(result_table)
best_score_counts['proposed'] = best_score_counts[proposed_name]

method,LOO,SourceAndTargetTargetValidate,SourceOnly,TargetOnly,causal-da,copula,gdm,iw-base(alpha=0.0),iw-base(alpha=0.5),iw-base(alpha=0.95),tr-adaboost
grouping,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
AUSTRIA,1,9.84 (0.62),9.67 (0.57),5.88 (1.60),\textbf{5.39 (1.86)},27.33 (0.77),31.56 (1.39),39.72 (0.74),39.45 (0.72),39.18 (0.76),5.78 (2.15)
BELGIUM,1,9.48 (0.91),8.19 (0.68),10.70 (7.50),\textbf{7.94 (2.19)},119.86 (2.64),89.10 (4.12),105.15 (2.96),105.28 (2.95),104.30 (2.95),8.10 (1.88)
CANADA,1,156.65 (10.69),157.74 (8.83),5.16 (1.36),\textbf{3.84 (0.98)},406.91 (1.59),516.90 (4.45),592.21 (1.87),591.21 (1.84),589.87 (1.91),51.94 (30.06)
DENMARK,1,28.12 (1.67),30.79 (0.93),3.26 (0.61),\textbf{3.23 (0.63)},14.46 (0.79),16.84 (0.85),22.15 (1.10),22.11 (1.10),21.72 (1.07),25.60 (13.11)
FRANCE,1,3.05 (0.11),4.67 (0.41),2.79 (1.10),\textbf{1.92 (0.66)},156.29 (1.96),91.69 (1.34),116.32 (1.27),116.54 (1.25),115.29 (1.28),52.65 (25.83)
GERMANY,1,210.59 (14.99),229.65 (9.13),16.99 (8.04),\textbf{6.71 (1.23)},929.03 (4.85),739.29 (11.81),817.50 (4.60),818.13 (4.55),812.60 (4.57),341.03 (157.80)
GREECE,1,5.75 (0.68),5.30 (0.90),3.80 (2.21),\textbf{3.55 (1.79)},23.05 (0.53),26.90 (1.89),47.07 (1.92),45.50 (1.82),45.72 (2.00),11.78 (2.36)
IRELAND,1,12.34 (0.58),135.57 (5.64),\textbf{3.05 (0.34)},4.35 (1.25),26.60 (0.59),3.84 (0.22),6.38 (0.13),6.31 (0.14),6.16 (0.13),23.40 (17.50)
ITALY,1,39.27 (2.52),35.29 (1.83),\textbf{13.00 (4.15)},14.05 (4.81),343.10 (10.04),226.95 (11.14),244.25 (8.50),244.84 (8.58),242.60 (8.46),87.34 (24.05)
JAPAN,1,8.38 (1.07),\textbf{8.10 (1.05)},10.55 (4.67),12.32 (4.95),71.02 (5.08),95.58 (7.89),135.24 (13.57),134.89 (13.50),134.16 (13.43),18.81 (4.59)


In [11]:
import jinja2
import os
from support.mexlet.latex import tex_escape
abbrev = lambda x: {"AUSTRIA": "AUT",
                    "BELGIUM": "BEL",
                    "CANADA": "CAN",
                    "DENMARK": "DNK",
                    "FRANCE": "FRA",
                    "GERMANY": "DEU",
                    "GREECE": "GRC",
                    "IRELAND": "IRL",
                    "ITALY": "ITA",
                    "JAPAN": "JPN",
                    "NETHERLA": "NLD",
                    "NORWAY": "NOR",
                    "SPAIN": "ESP",
                    "SWEDEN": "SWE",
                    "SWITZERL": "CHE",
                    "TURKEY": "TUR",
                    "U.K.": "GBR",
                    "U.S.A.": "USA"
                    }[x]

TEMPLATES_DIR = '.'
def output_tex(_a, file=None):
    """https://miller-blog.com/latex-with-jinja2/"""
    latex_jinja_env = jinja2.Environment(
        block_start_string='\BLOCK{',
        block_end_string='}',
        variable_start_string='{{',
        variable_end_string='}}',
        comment_start_string='\#{',
        comment_end_string='}',
        line_statement_prefix='%%',
        line_comment_prefix='%#',
        trim_blocks=True,
        autoescape=False,
        loader=jinja2.FileSystemLoader(TEMPLATES_DIR))
    path = 'results_icml2020.tex.j2'
    template = latex_jinja_env.get_template(path)
    template.globals['tex_escape'] = tex_escape
    template.globals['abbrev'] = abbrev
    template.globals['count_best'] = best_score_counts
    rendered_tex = template.render(rows=_a)
    if file is not None:
        with open(file, 'w') as _f:
            _f.write(rendered_tex)
    print(rendered_tex)

from pathlib import Path
Path('output').mkdir(exist_ok=True)
output_tex(result_table, 'output/results_icml2020.tex')

\begin{tabular}[t]{|*1{p{8mm}|}|*1{p{8mm}}||*4{p{10mm}|}*1{p{10mm}|}*5{p{10mm}|}}
\hline

\hhline{|=||=||=|=|=|=|=|=|=|=|=|=|}
AUT & 1 & 5.88 (1.60) & \textbf{5.39 (1.86)} & 9.67 (0.57) & 9.84 (0.62) & 5.78 (2.15) & 31.56 (1.39) & 27.33 (0.77) & 39.72 (0.74) & 39.45 (0.72) & 39.18 (0.76)\\
\hline
BEL & 1 & 10.70 (7.50) & \textbf{7.94 (2.19)} & 8.19 (0.68) & 9.48 (0.91) & 8.10 (1.88) & 89.10 (4.12) & 119.86 (2.64) & 105.15 (2.96) & 105.28 (2.95) & 104.30 (2.95)\\
\hline
CAN & 1 & 5.16 (1.36) & \textbf{3.84 (0.98)} & 157.74 (8.83) & 156.65 (10.69) & 51.94 (30.06) & 516.90 (4.45) & 406.91 (1.59) & 592.21 (1.87) & 591.21 (1.84) & 589.87 (1.91)\\
\hline
DNK & 1 & 3.26 (0.61) & \textbf{3.23 (0.63)} & 30.79 (0.93) & 28.12 (1.67) & 25.60 (13.11) & 16.84 (0.85) & 14.46 (0.79) & 22.15 (1.10) & 22.11 (1.10) & 21.72 (1.07)\\
\hline
FRA & 1 & 2.79 (1.10) & \textbf{1.92 (0.66)} & 4.67 (0.41) & 3.05 (0.11) & 52.65 (25.83) & 91.69 (1.34) & 156.29 (1.96) & 116.32 (1.27) & 116.54 (1.25) & 115.29 (1.28)\