In [1]:
import pandas as pd

In [2]:
n_agents = 10
n_classes = 10
# read in the data
no_route_data = pd.read_csv(f'mnist_ll_no_routing_n{n_agents}.csv')
oracle_data = pd.read_csv(f'mnist_ll_oracle_n{n_agents}.csv')

In [3]:
oracle_data

Unnamed: 0,agent_id,ll_time,local,train_acc,val_acc,test_acc,past_task_test_acc,cls_0,cls_1,cls_2,cls_3,cls_4,cls_5,cls_6,cls_7,cls_8,cls_9
0,0,0,True,1.000000,0.984375,0.988542,0.988542,0,0,0,0,0,128,0,128,0,0
1,1,0,True,0.992188,0.981771,0.981610,0.981610,128,0,128,0,0,0,0,0,0,0
2,2,0,True,1.000000,0.988281,0.992071,0.992071,0,128,0,0,0,0,0,0,0,128
3,3,0,True,0.996094,0.973958,0.979137,0.979137,0,128,0,0,0,0,0,0,128,0
4,4,0,True,1.000000,0.997396,0.998109,0.998109,128,128,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,5,4,False,0.992188,0.930990,0.928125,0.954000,696,385,399,776,702,250,664,266,1134,1079
96,6,4,False,0.996094,0.953125,0.948310,0.948600,235,771,248,555,716,883,441,993,383,756
97,7,4,False,0.980469,0.875000,0.907380,0.945800,356,746,690,226,691,414,703,1077,1108,229
98,8,4,False,0.984375,0.936198,0.947712,0.951400,230,393,534,755,849,592,965,901,383,252


In [4]:
# compute the average number of data points per agent per class
# after sharing at round ll_time
def avg_data_per_agent(df, ll_time):
    # the columns cls_i is the number of data points for class i
    ret = 0
    for agent in range(n_agents):
        sub_df = df[(df['agent_id'] == agent) & (df['ll_time'] == ll_time) & (df["local"] == False)]
        n_points = sum([sub_df[f'cls_{i}'].values[0] for i in range(n_classes)])
        ret += n_points 
    return ret / (n_agents * n_classes)


In [5]:
def avg_metric_per_agent(df, ll_time, metric_key):
    """
    After assimilation
    """
    ret = 0
    for agent in range(n_agents):
        sub_df = df[(df['agent_id'] == agent) & (df['ll_time'] == ll_time) & (df["local"] == False)]
        ret += sub_df[metric_key].values[0]
    return ret / n_agents 

In [6]:
for ll_time in range(5):
    baseline = avg_data_per_agent(no_route_data, ll_time)
    our = avg_data_per_agent(oracle_data, ll_time)
    improv = (our - baseline) / baseline
    print("no_route:", baseline)
    print("oracle:", our)
    print(f"improvement: {improv :.3%}")
    print()


no_route: 25.6
oracle: 60.2
improvement: 135.156%

no_route: 51.2
oracle: 188.08
improvement: 267.344%

no_route: 76.8
oracle: 328.88
improvement: 328.229%

no_route: 102.4
oracle: 469.68
improvement: 358.672%

no_route: 128.0
oracle: 610.48
improvement: 376.938%



In [7]:
for ll_time in range(5):
    baseline = avg_metric_per_agent(no_route_data, ll_time, "test_acc")
    our = avg_metric_per_agent(oracle_data, ll_time, "test_acc")
    improv = (our - baseline) / baseline
    print(f"no_route test task: {baseline:.3%}")
    print(f"oracle test task: {our:.3%}")
    print(f"improvement: {improv :.3%}")
    print()

no_route test task: 98.622%
oracle test task: 98.936%
improvement: 0.319%

no_route test task: 95.442%
oracle test task: 96.912%
improvement: 1.539%

no_route test task: 94.691%
oracle test task: 96.510%
improvement: 1.921%

no_route test task: 93.559%
oracle test task: 94.528%
improvement: 1.035%

no_route test task: 92.508%
oracle test task: 93.531%
improvement: 1.106%



The performance of agents = 10, data points averaging 610 is not inline with previous analysis. Theory: validation set for offline buffer integration needs to be changed. We might be stopping too soon, i.e., validation size is still too small!

In [8]:
for ll_time in range(5):
    baseline = avg_metric_per_agent(no_route_data, ll_time, "past_task_test_acc")
    our = avg_metric_per_agent(oracle_data, ll_time, "past_task_test_acc")
    improv = (our - baseline) / baseline
    print(f"no_route past_task_test_acc: {baseline:.3%}")
    print(f"oracle past_task_test_acc: {our:.3%}")
    print(f"improvement: {improv :.3%}")
    print()

no_route past_task_test_acc: 98.622%
oracle past_task_test_acc: 98.936%
improvement: 0.319%

no_route past_task_test_acc: 95.704%
oracle past_task_test_acc: 97.028%
improvement: 1.384%

no_route past_task_test_acc: 94.587%
oracle past_task_test_acc: 96.227%
improvement: 1.734%

no_route past_task_test_acc: 93.657%
oracle past_task_test_acc: 95.430%
improvement: 1.893%

no_route past_task_test_acc: 93.262%
oracle past_task_test_acc: 94.999%
improvement: 1.862%

