# Plotting comparison plots Homo-PSI

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from data_cleaner import load_full_df, create_individual_dfs

### Config style

In [None]:
sns.set_style()
plt.rc('text', usetex=True)
plt.rc('font', family='serif', size=15)
plt.rc('figure', figsize=(5.5,4))
plt.rc('text.latex', preamble=r'\usepackage{mathptmx}')

FIG_DIR = ''

### Load data

In [None]:
create_individual_dfs('../data/raw/', '../data/agg/')

In [None]:
data = load_full_df('../data/')
data = data.query('SetNum > 4')

# Add artificial network delay to our X-MS and CA-MS data
adj_indx = data.query('name in ["CA-MS", "X-MS"]').index
data.loc[adj_indx,'Latency'] += 0.1 # add 1 rtt 
data.loc[adj_indx,'Latency'] += data.loc[adj_indx,'com_MiB'] / (100/8)# add transfer cost

# data

## Plot base

In [None]:
main_plot_config = {
    'CA-MS':{
        'label':"CA-Agg $(P_{8k})$",
        'color':'#33a02c',
        'fmt':'*-'
    },
    'X-MS':{
        'label':"X-Agg $(P_{32k})$",
        'color':'#b2df8a',
        'fmt':'o:'
    },
    'EMP-CA':{
        'label':"CA-EMP",
        'color':'#1f78b4',
        'fmt':'*-'
    },
    'EMP-X':{
        'label':"X-EMP",
        'color':'#a6cee3',
        'fmt':'o:'
    },
    'Circuit-PSI':{
        'label':"Circuit-PSI",
        'color':'#e31a1c',
        'fmt':'x-.'
    },
}
spot_plot_config = {
    'CA-MS':{
        'label':"CA-Agg $(P_{8k})$",
        'color':'#33a02c',
        'fmt':'*-'
    },
    'X-MS':{
        'label':"X-Agg $(P_{32k})$",
        'color':'#b2df8a',
        'fmt':'o:'
    },
    'SpOT':{
        'label':"SpOT",
        'color':'#e31a1c',
        'fmt':'x-.'
    },
    # 
}

In [None]:
def base_plot(
    ax:plt.Axes,
    df: pd.DataFrame,
    target_y:str,
    y_label:str,
    config: dict[str:str],
    ):
    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.set_xlabel('\#Documents')
    ax.set_ylabel(y_label)
    major_ticks = [8, 32, 128, 512, 2048, 8192]
    ax.set_xticks(major_ticks, major_ticks)
    minor_ticks = [16, 64, 256, 1024, 4096]
    ax.xaxis.set_minor_locator(plt.FixedLocator(minor_ticks))
    ax.xaxis.set_minor_formatter(plt.NullFormatter())

    df = df.groupby(['SetNum', 'name']).agg(['mean', 'sem'])
    df = df.reset_index()

    for name in config:
        # print(plot_config[name])
        df_fil = df[df['name'] == name]
        x , y, yerr = df_fil["SetNum"], df_fil[target_y,'mean'], df_fil[target_y,'sem']
        ax.errorbar(x, y, yerr=yerr, linewidth=2, **config[name])

    


### Latency

In [None]:
fig, ax = plt.subplots()
base_plot(ax, data, 'Latency', 'Latency (s)', main_plot_config)

plt.savefig(FIG_DIR+"doc_latency.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

### Client computation

In [None]:
fig, ax = plt.subplots()
ylabel = "Client's computation cost (s)"
base_plot(ax, data, 'client_comp', ylabel, main_plot_config)

plt.savefig(FIG_DIR+"doc_client.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

### Communication

In [None]:
fig, ax = plt.subplots()
ylabel = "Transfer cost (MiB)"
base_plot(ax, data, 'com_MiB', ylabel, main_plot_config)

plt.savefig(FIG_DIR+"doc_com.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

### Server cost

In [None]:
fig, ax = plt.subplots()
ylabel = "Server's computation cost (s)"
base_plot(ax, data, 'server_comp', ylabel, main_plot_config)
ax.legend(
    bbox_to_anchor=(0.002, 0.92, 0.9, .102), 
    loc='lower left',
    ncol=3, 
    mode="expand", 
    borderaxespad=0., 
    bbox_transform=fig.transFigure,
)
plt.savefig(FIG_DIR+"doc_server.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

## triple plot 

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18,3.9))
axes = axes.flat

base_plot(axes[0], data, 'Latency', 'Latency (s)', main_plot_config)
base_plot(axes[2], data, 'com_MiB', "Transfer cost (MiB)", main_plot_config)
base_plot(axes[1], data, 'client_comp', "Client's computation cost (s)", main_plot_config)

handles, labels = axes[0].get_legend_handles_labels()
order = [0,2,1,3,4]
new_handles, new_labels = [handles[idx] for idx in order],[labels[idx] for idx in order]

axes[0].legend(
               # handles=new_handles, labels=new_labels,
               bbox_to_anchor=(0.2, 0.92, 0.60, .102), 
               loc='lower left',
               ncol=5, 
               mode="expand", 
               borderaxespad=0., 
               bbox_transform=fig.transFigure,
              )

plt.savefig(FIG_DIR+"doc_search.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

## SpOT plot

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

base_plot(axes[0][0], data, 'Latency', 'Latency (s)', spot_plot_config)
base_plot(axes[0][1], data, 'client_comp', "Client's computation cost (s)", spot_plot_config)
base_plot(axes[1][1], data, 'server_comp', "Server's computation cost (s)", spot_plot_config)
base_plot(axes[1][0], data, 'com_MiB', "Transfer cost (MiB)", spot_plot_config)

axes[0][0].legend(
               bbox_to_anchor=(0.25, 0.90, 0.50, .102), 
               loc='lower left',
               ncol=3, 
               mode="expand", 
               borderaxespad=0., 
               bbox_transform=fig.transFigure,
              )

plt.savefig(FIG_DIR+"doc_spot.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show() 

## getting improvement factors

In [None]:
data.query('SetNum == 1024').groupby(['SetNum', 'name']).agg(['mean', 'sem'])