In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
#import util

In [None]:
def add_var_scatter_plot(fig, x, y, color, name=None, showlegend=True, **kwargs):
    """
    @param
    fig: go.Figure
    color: int, we prepare 10 colors, you can select the number 0 to 9.
    name: str, the name of plot
    """
    colors = [list(px.colors.hex_to_rgb(_hex)) for _hex in px.colors.qualitative.Plotly]
    rgb = 'rgb' + str(tuple(colors[color]))
    rgba = 'rgba' + str(tuple(colors[color] + [0.3]))  # opacity = 0.3
    mean, std = np.mean(y, axis=1), np.std(y, ddof=1, axis=1)
    fig.add_scatter(x=x, y=mean, name=name, mode='markers+lines', line=dict(color=rgb), showlegend=showlegend, **kwargs)
    fig.add_scatter(x=x, y=mean+std, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='none', **kwargs)
    fig.add_scatter(x=x, y=mean-std, mode='lines', fill="tonexty", line=dict(width=0),
                    showlegend=False, hoverinfo='none', fillcolor=rgba, **kwargs)
    return fig


def add_box_plot(fig, y, color, **kwargs):
    """
    @param
    fig: go.Figure
    color: int, we prepare 10 colors, you can select the number 0 to 9.
    name: str, the name of plot
    """
    colors = [list(px.colors.hex_to_rgb(_hex)) for _hex in px.colors.qualitative.Plotly + px.colors.qualitative.D3]
    rgb = 'rgb' + str(tuple(colors[color]))
    black = 'rgb(0,0,0)'
    y = np.array(y).squeeze()
    fig.add_box(y=y, fillcolor=rgb, line=dict(color=black),  **kwargs)
    return fig


def print_mean_std(array, rnd=3):
    mean = round(np.mean(array), rnd)
    std = round(np.std(array), rnd)
    txt = f'{mean}±{std}'
    return txt

# two-moon

In [None]:
# key: color, title
settings = {'gst': (4, 'GradualSelfTrain'),
            'gdamf': (3, 'GDAMF')}

fig = px.scatter()
for key in settings:
    color, name = settings[key]
    res = pd.read_pickle(f'./result/two-moon_{key}.pkl')
    x = np.nanmean(res['wd'], axis=1)
    x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))  # scaling
    fig = add_var_scatter_plot(fig, x, res['acc'], color, name)

fig.update_layout(
    xaxis_title='Wasserstein distance between domains', yaxis_title='Accuracy',
    font=dict(family="PTSerif", size=14,),
    width=550, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)
fig.update_annotations(font=dict(size=22))

fig.write_image('r2_two-moon.pdf')
fig.show()

# WD

In [None]:
# key: color, title
settings = {'cover': (2, 'Cover Type'),
            'portraits': (1, 'Portraits'),
            'mnist': (0, 'Rotating MNIST'),}

fig = px.scatter()

for key in settings:
    color, name = settings[key]
    res = pd.read_pickle(f'./result/wd_{key}.pkl')
    x = res['num_inter']
    y = res['wd']
    y_scaled = (y - y.min(axis=0)) / (y.max(axis=0) - y.min(axis=0))
    th_idx = np.argmin(np.abs(y_scaled.mean(axis=1) - 0.5))  # threshold = 0.5 
    print(f'{name},  num inter domain: {th_idx+1}')
    fig = add_var_scatter_plot(fig, x, y_scaled, color, name)

fig.add_hline(y=0.5, line_width=1, line_dash="dash", line_color='gray')
fig.update_layout(
    xaxis_title='number of intermediate domains', yaxis_title='Wasserstein distance',
    font=dict(family="PTSerif", size=14,),
    width=550, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)
fig.update_annotations(font=dict(size=22))

fig.write_image('r2_wd.pdf')
fig.show()

# Budgets

In [None]:
titles = ['Two Moon', 'Rotating MNIST', 'Portraits', 'Cover Type', 'Gas Sensor']
fig = make_subplots(rows=1, cols=len(titles), y_title='Accuracy', shared_xaxes=True, shared_yaxes=True,
                    subplot_titles=titles, horizontal_spacing=0.03)

# key: position
data_dict = {'gas': 5,
             'cover': 4,
             'portraits': 3,
             'mnist': 2,
             'rotmoon': 1}
# key: color, legend
method_dict = {#'targetonly-al': (6, 'Target only w AL'),
               'targetonly': (5, 'Target only'),
               #'gdamf-ws': (4, 'GDAMF w/o warm-starting'),
               #'gdamf-abl': (3, 'GDAMF w/o AL, intermediate'),
               #'gdamf-direct': (2, 'GDAMF w/o intemediate'),
               #'gdamf-rnd': (1, 'GDAMF w/o AL'),
               'gdamf': (3, 'GDAMF')}

for i, d_key in enumerate(data_dict):
    showlegend = True if i == 0 else False
    pos = data_dict[d_key]
    for m_key in method_dict:
        color, legend = method_dict[m_key]
        res = pd.read_pickle(f'./result/{d_key}_{m_key}.pkl')
        x = res['budgets']
        y = res['acc']

        fig = add_var_scatter_plot(fig, x, y, color, legend, showlegend, row=1, col=pos)

fig.update_layout(
    xaxis1_title='budgets', xaxis2_title='budgets', xaxis3_title='budgets', xaxis4_title='budgets',
    font=dict(family="PTSerif", size=14,),
    width=1800, height=400, margin=dict(t=30, b=30, r=30),
    legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5),
)
fig.update_annotations(font=dict(size=22))
fig.write_image('r3_budget.pdf')
fig.show()

# Query

In [None]:
titles = ['Two Moon','Rotating MNIST', 'Portraits', 'Cover Type', 'Gas Sensor']
fig = make_subplots(rows=1, cols=len(titles), y_title='Number of queried samples', shared_xaxes=True, shared_yaxes=False,
                    subplot_titles=titles, horizontal_spacing=0.03)

# key: position
data_dict = {'gas': 5,
             'cover': 4,
             'portraits': 3,
             'mnist': 2,
             'rotmoon': 1}
# key: color, legend
method_dict = {'query': ('#636EFA', 'GDAMF'),
               'query-ws': ('#EF553B', 'GDAMF w/o warm-starting')}

for i, d in enumerate(data_dict):
    showlegend = True if i == 1 else False
    for m in method_dict:
        res = pd.read_pickle(f'./result/{m}_{d}.pkl')
        mean = res['query'][-1].mean(axis=0)
        std = res['query'][-1].std(axis=0)
        if (d == 'gas') or (d == 'rotmoon'):
            mean, std = np.append(mean, [np.nan, np.nan]), np.append(std, [np.nan, np.nan])
        x = np.arange(mean.size) + 1
        fig.add_bar(x=x, y=mean, error_y=dict(type='data', array=std), showlegend=showlegend,
                    row=1, col=data_dict[d], marker_color=method_dict[m][0], name=method_dict[m][1])

xaxis_titles = {f'xaxis{i+1}_title':'Query cost' for i in range(len(data_dict))}
fig.update_layout(**xaxis_titles,
                  font=dict(family="PTSerif", size=14,),
                  width=1800, height=400, margin=dict(t=30, b=30, r=30),
                  legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5))
fig.update_annotations(font=dict(size=22))
fig.write_image('r3_query.pdf')
fig.show()

# Box plot

## ablation study

In [None]:
data_dict = {'rotmoon':'Two Moon', 'mnist': 'Rotating MNIST', 'portraits': 'Portraits', 'cover': 'Cover Type', 'gas': 'Gas Sensor'}
method_dict = {'gdamf': 'GDAMF',
               'gdamf-rnd': 'GDAMF w/o AL',
               'gdamf-direct': 'GDAMF w/o intermediate',
               'gdamf-abl': 'GDAMF w/o AL/intermediate',
               'gdamf-ws': 'GDAMF w/o warm-starting'}

fig = make_subplots(rows=1, cols=len(data_dict), y_title='Accuracy',
                    shared_xaxes=True, subplot_titles=list(data_dict.values()), horizontal_spacing=0.04)

for col, data in enumerate(data_dict):
    showlegend = True if col == 0 else False
    for color, method in enumerate(method_dict):
        name = method_dict[method]
        acc = pd.read_pickle(f'./result/{data}_{method}.pkl')['acc'][-1]
        fig = add_box_plot(fig, acc, color=color, row=1, col=col+1, name=name, showlegend=showlegend)

xaxis_visible = {f'xaxis{i+1}_visible': False for i in range(len(data_dict))}
fig.update_layout(**xaxis_visible, width=1800, height=400, margin=dict(t=30, b=30, r=30), font=dict(family="PTSerif", size=18),
                  legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.2, xanchor="center", x=0.5))
fig.update_annotations(font=dict(size=22))
fig.write_image('r3_ablation_study.pdf')
fig.show()


# print table info
for result, n in zip(np.array_split(fig.data, len(data_dict)), data_dict):
    for r in result:
        name, array = r['name'], r['y']
        txt = print_mean_std(array)
        print(f'{n} {name} {txt}')

## compare with baseline methods

In [None]:
data_dict = {'rotmoon':'Two Moon', 'mnist': 'Rotating MNIST', 'portraits': 'Portraits', 'cover': 'Cover Type', 'gas': 'Gas Sensor'}
method_dict = {'gdamf': 'GDAMF',
               'targetonly': 'Target Only',
               'gst': 'GradualSelfTrain',
               'dsaoda': 'DS-AODA',
               'gift': 'GIFT',
               'aux': 'AuxSelfTrain'}

fig = make_subplots(rows=1, cols=len(data_dict), y_title='Accuracy',
                    shared_xaxes=True, subplot_titles=list(data_dict.values()), horizontal_spacing=0.04)


for col, data in enumerate(data_dict):
    showlegend = True if col == 0 else False
    color = 0
    for method in method_dict:
        name = method_dict[method]
        res = pd.read_pickle(f'./result/{data}_{method}.pkl')['acc']
        if ('gst' in method) | ('gdamf' in method) | ('target' in method):
            acc = res[-1]
            fig = add_box_plot(fig, acc, color=color, row=1, col=col+1, name=name, showlegend=showlegend)
            color += 1
        else:
            for acc, level in zip(res, ['-low', '-mid', '-high']):
                name_with_level = name + level
                fig = add_box_plot(fig, acc, color=color, row=1, col=col+1, name=name_with_level, showlegend=showlegend)
                color += 1

xaxis_visible = {f'xaxis{i+1}_visible': False for i in range(len(data_dict))}
fig.update_layout(**xaxis_visible, width=1800, height=400, margin=dict(t=30, b=30, r=30), font=dict(family="PTSerif", size=18),
                  legend=dict(orientation="h", bordercolor="Black", borderwidth=0.3, yanchor="bottom", y=-0.4, xanchor="center", x=0.5))
fig.update_annotations(font=dict(size=22))
fig.write_image('r3_baseline.pdf')
fig.show()

# print table info
for result, n in zip(np.array_split(fig.data, len(data_dict)), data_dict):
    for r in result:
        name, array = r['name'], r['y']
        txt = print_mean_std(array)
        print(f'{n} {name} {txt}')

## Source Only

In [None]:
data_dict = {'rotmoon':'Two Moon', 'mnist': 'Rotating MNIST', 'portraits': 'Portraits', 'cover': 'Cover Type', 'gas': 'Gas Sensor'}
method_dict = {'sourceonly': 'Source Only'}


for method in method_dict:
    for data in data_dict:
        name = data_dict[data]
        res = pd.read_pickle(f'./result/{data}_{method}.pkl')['acc']
        txt = print_mean_std(res)
        print(f'{name}: {txt}')


## Query Number of DSAODA

In [None]:
data_dict = {'rotmoon':'Two Moon', 'mnist': 'Rotating MNIST', 'portraits': 'Portraits', 'cover': 'Cover Type', 'gas': 'Gas Sensor'}
method_dict = {'dsaoda': 'DSAODA'}

for data in data_dict:
    name = data_dict[data]
    res = pd.read_pickle(f'./result/query_dsaoda_{data}.pkl')
    acc, qnum = res.values()
    acc_mean = [round(np.mean(i, axis=0), 3) for i in acc]
    qnum_mean = [(np.mean(i, axis=0)).astype(int) for i in qnum]
    print(f'{name}: Accuracy {acc_mean}, Query {qnum_mean}\n')
    #txt = print_mean_std(res)
    #rint(f'{name}: {txt}')

