In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import scipy.stats as ss
import scikit_posthocs as sp
import util

In [None]:
def add_var_scatter_plot(fig, x, y, color, name=None, showlegend=True, **kwargs):
    """
    @param
    fig: go.Figure
    color: int, we prepare 10 colors, you can select the number 0 to 9.
    name: str, the name of plot
    """
    colors = [list(px.colors.hex_to_rgb(_hex)) for _hex in px.colors.qualitative.Plotly]
    rgb = 'rgb' + str(tuple(colors[color]))
    rgba = 'rgba' + str(tuple(colors[color] + [0.3]))  # opacity = 0.3
    mean, std = np.mean(y, axis=1), np.std(y, ddof=1, axis=1)
    fig.add_scatter(x=x, y=mean, name=name, mode='markers+lines', line=dict(color=rgb), showlegend=showlegend, **kwargs)
    fig.add_scatter(x=x, y=mean+std, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='none', **kwargs)
    fig.add_scatter(x=x, y=mean-std, mode='lines', fill="tonexty", line=dict(width=0),
                    showlegend=False, hoverinfo='none', fillcolor=rgba, **kwargs)
    return fig

# two-moon

In [None]:
# key: color, title
settings = {'gst': (4, 'GradualSelfTrain'),
            'gdamf': (3, 'GDAMF')}

fig = px.scatter()
for key in settings:
    color, name = settings[key]
    res = pd.read_pickle(f'two-moon_{key}.pkl')
    x = np.nanmean(res['wd'], axis=1)
    x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))  # scaling
    fig = add_var_scatter_plot(fig, x, res['acc'], color, name)

fig.update_layout(
    xaxis_title='Wasserstein distance between domains', yaxis_title='Accuracy',
    font=dict(family="PTSerif", size=14,),
    width=550, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)

fig.write_image('r1_two-moon.pdf')
fig.show()

# WD

In [None]:
# key: color, title
settings = {'cover': (2, 'Cover Type'),
            'portraits': (1, 'Portraits'),
            'mnist': (0, 'Rotating MNIST'),}

fig = px.scatter()

for key in settings:
    color, name = settings[key]
    res = pd.read_pickle(f'wd_{key}.pkl')
    x = res['num_inter']
    y = res['wd']
    y_scaled = (y - y.min(axis=0)) / (y.max(axis=0) - y.min(axis=0))
    fig = add_var_scatter_plot(fig, x, y_scaled, color, name)

fig.add_hline(y=0.5, line_width=1, line_dash="dash", line_color='gray')
fig.update_layout(
    xaxis_title='number of intermediate domains', yaxis_title='Wasserstein distance',
    font=dict(family="PTSerif", size=14,),
    width=550, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)

fig.write_image('r1_wd.pdf')
fig.show()

# Budgets

In [None]:
titles = ['Rotating MNIST', 'Portraits', 'Cover Type', 'Gas Sensor']
fig = make_subplots(rows=1, cols=len(titles), y_title='accuracy', shared_xaxes=True, shared_yaxes=True,
                    subplot_titles=titles, horizontal_spacing=0.03)

# key: position
data_dict = {'gas': 4,
             'cover': 3,
             'portraits': 2,
             'mnist': 1}
# key: color, legend
method_dict = {#'targetonly-al': (6, 'Target only w AL'),
               'targetonly': (5, 'Target only'),
               #'gdamf-ws': (4, 'GDAMF w/o warm-starting'),
               #'gdamf-abl': (3, 'GDAMF w/o AL, intermediate'),
               #'gdamf-direct': (2, 'GDAMF w/o intemediate'),
               #'gdamf-rnd': (1, 'GDAMF w/o AL'),
               'gdamf': (3, 'GDAMF')}

for i, d_key in enumerate(data_dict):
    showlegend = True if i == 0 else False
    pos = data_dict[d_key]
    for m_key in method_dict:
        color, legend = method_dict[m_key]
        res = pd.read_pickle(f'{d_key}_{m_key}.pkl')
        x = res['budgets']
        y = res['acc']

        fig = add_var_scatter_plot(fig, x, y, color, legend, showlegend, row=1, col=pos)

fig.update_layout(
    xaxis1_title='budgets', xaxis2_title='budgets', xaxis3_title='budgets', xaxis4_title='budgets',
    font=dict(family="PTSerif", size=14,),
    width=1400, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)

fig.write_image('r1_budget.pdf')
fig.show()

# Query

In [None]:
titles = ['Rotating MNIST', 'Portraits', 'Cover Type', 'Gas Sensor']
fig = make_subplots(rows=1, cols=len(titles), y_title='Number of queried samples', shared_xaxes=True, shared_yaxes=False,
                    subplot_titles=titles, horizontal_spacing=0.03)

# key: position
data_dict = {'gas': 4,
             'cover': 3,
             'portraits': 2,
             'mnist': 1}
# key: color, legend
method_dict = {'query': ('#636EFA', 'GDAMW'),
               'query-ws': ('#EF553B', 'GDAMW w/o warm-starting')}

for i, d in enumerate(data_dict):
    showlegend = True if i == 1 else False
    for m in method_dict:
        res = pd.read_pickle(f'{m}_{d}.pkl')
        mean = res['query'][-1].mean(axis=0)
        std = res['query'][-1].std(axis=0)
        if d == 'gas':
            mean, std = np.append(mean, [np.nan, np.nan]), np.append(std, [np.nan, np.nan])
        x = np.arange(mean.size) + 1
        fig.add_bar(x=x, y=mean, error_y=dict(type='data', array=std), showlegend=showlegend,
                    row=1, col=data_dict[d], marker_color=method_dict[m][0], name=method_dict[m][1])

xt = 'Query cost'
fig.update_layout(xaxis1_title=xt, xaxis2_title=xt, xaxis3_title=xt, xaxis4_title=xt,
                  font=dict(family="PTSerif", size=14,),
                  width=1400, height=400, margin=dict(t=30, b=30, r=30),
                  legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5))

fig.write_image('r1_query.pdf')
fig.show()

# table and statistical tests

In [None]:
col_dict = {'mnist': 'Rotating MNIST', 'portraits': 'Portraits', 'cover': 'Cover Type', 'gas': 'Gas Sensor'}
row_dict = {'gdamf': 'GDAMF',
            'gdamf-rnd': 'GDAMF w/o AL',
            'gdamf-direct': 'GDAMF w/o intermediate',
            'gdamf-abl': 'GDAMF w/o AL/intermediate',
            'gdamf-ws': 'GDAMF w/o warm-starting',
            'targetonly': 'Target Only',
            'gst': 'GradualSelfTrain',
            'dsaoda': 'DS-AODA',
            'gift': 'GIFT',
            'aux': 'AuxSelfTrain'}


data = {}
for c in col_dict:
    df = []
    for r in row_dict:
        res = pd.read_pickle(f'{c}_{r}.pkl')['acc']
        if 'gst' in r:
            val = pd.DataFrame(res[0]).T
            val.index = [row_dict[r]]
        elif ('gdamf' in r) | ('target' in r):
            val = pd.DataFrame(res[-1, :]).T
            val.index = [row_dict[r]]
        else:
            idx = [f'{row_dict[r]}-{i}' for i in ['low', 'mid', 'high']]
            val = pd.DataFrame(res, index=idx)
        df.append(val)
    data[c] = pd.concat(df)

## table

In [None]:
# make table
table = []
for c in data:
    val = [util.rounded_statistics(r) for i, r in data[c].iterrows()]
    s = pd.Series(data=val, name=col_dict[c], index=data[c].index)
    table.append(s)
table = pd.DataFrame(table).T
# add average
ave = pd.concat(list(data.values()), axis=1)
table['Average'] = [util.rounded_statistics(r) for i, r in ave.iterrows()]
# print for latex
for key, v in table.iterrows():
    txt = ' & '.join(v.tolist())
    print(f'{key} & {txt}')

#table.loc[idx].applymap(lambda s: s.split('Â±')[0]).astype(float).idxmax(axis=0)

## statistical tests

In [None]:
idx = ['GDAMF', 'Target Only', 'GradualSelfTrain',
       'DS-AODA-low', 'DS-AODA-mid', 'DS-AODA-high',
       'GIFT-low', 'GIFT-mid', 'GIFT-high',
       'AuxSelfTrain-low', 'AuxSelfTrain-mid', 'AuxSelfTrain-high']

st = []
for c in data:
    df = data[c].loc[idx].dropna().T.copy()
    stat, pvalue = ss.friedmanchisquare(*df.values.T)
    txt = 'Freedman Result ' + col_dict[c] + f' pvalue = {pvalue}'
    print(txt)
    st.append(sp.posthoc_nemenyi_friedman(df)['GDAMF'].iloc[1:])
st = pd.concat(st, axis=1)
st.columns = col_dict.values()
print()
print('Nemeny Result')
for key, v in st.round(3).astype(str).iterrows():
    txt = ' & '.join(v.tolist())
    print(f'{key} & {txt}')