In [1]:
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
from flsim.utils.timing.training_duration_distribution import (
PerUserHalfNormalDurationDistribution,
PerUserHalfNormalDurationDistributionConfig,
)
from flsim.utils.timing.training_time_estimator import (
AsyncTrainingTimeEstimator)
from omegaconf import OmegaConf

def run(duration_min, rounds_async_fl, users_per_round_async_fl, num_users):
    duration_std = 1.25

    num_examples = None
    training_dist = PerUserHalfNormalDurationDistribution(
        **OmegaConf.structured(
            PerUserHalfNormalDurationDistributionConfig(
                training_duration_sd=duration_std,
                training_duration_min=duration_min,
            )
        )
    )

    epochs_async_fl = int(rounds_async_fl * users_per_round_async_fl / num_users)
    print(f"Epochs: {epochs_async_fl}")
    async_estimator = AsyncTrainingTimeEstimator(
        total_users=num_users,
        users_per_round=users_per_round_async_fl,
        epochs=epochs_async_fl,
        num_examples=num_examples,
        training_dist=training_dist,
    )
    async_time = async_estimator.training_time()
    print(f"Async {async_time}")

def bandiwth(num_clients, rounds_async_fl, dim_model):
        bytes_comm = dim_model * num_clients * rounds_async_fl
        # convert from bytes to GB
        bytes_comm = bytes_comm / 1024 / 1024 / 1024
        print(f"Bandwidth: {bytes_comm}")


In [2]:
torch.manual_seed(0)

print("REPLACE-BG")
run(duration_min=0.03, rounds_async_fl=1659, users_per_round_async_fl=16, num_users=180)
bandiwth(16, 1659, 260000)
print("CELEBA")
run(duration_min=0.003, rounds_async_fl=411, users_per_round_async_fl=128, num_users=2337)
bandiwth(128, 411, 31000)
print("SENT140")
run(duration_min=0.015, rounds_async_fl=387, users_per_round_async_fl=256, num_users=3488)
bandiwth(256, 387, 1e6)

REPLACE-BG
Epochs: 147
Async 1704.694091796875
Bandwidth: 6.427466869354248
CELEBA
Epochs: 22
Async 404.9036865234375
Bandwidth: 1.518845558166504
SENT140
Epochs: 28
Async 392.9781188964844
Bandwidth: 92.26799011230469


In [3]:
import pandas as pd

def run_sa(path, num_clients, rounds_async_fl):
    df = pd.read_csv(path)
    df = df[df["clients"] == num_clients]
    df = df.groupby("clients").mean()
    total_time = (num_clients * df["avg client computation time (ms)"]) * rounds_async_fl + rounds_async_fl * df["avg server computation time (ms)"] + rounds_async_fl * (df['decryptors'] * df['avg decryptors computation time (ms)'])
    # convert from ms to minutes
    total_time = total_time / 60000
    #convert to scalar
    total_time = total_time.sum()
    print(f"SA {total_time}")
    
def bandwith_sa(path, num_clients, rounds_async_fl):
    df = pd.read_csv(path)
    df = df[df["clients"] == num_clients]
    df = df.groupby("clients").mean()
    total_bandwidth = (num_clients * df["avg client bytes sent"]) * rounds_async_fl + rounds_async_fl * ((df['avg decryptors bytes sent'] + df['avg decryptors bytes received'] )* df['decryptors'])
    # convert from bytes to GB 
    total_bandwidth = total_bandwidth / 1e9
    #convert to scalar
    total_bandwidth = total_bandwidth.sum()
    print(f"SA bandiwtth {total_bandwidth}")


In [9]:
print("REPLACE-BG")
print("Buffalo")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 16,1659)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 16,1659)
print("DPSecAgg")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 16,1659)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 16,1659)

Buffalo
SA 169.04480857610702
SA bandiwtth 14.521054464
DPSecAgg
SA 4322.917563529015
SA bandiwtth 16.377064032


In [10]:
print("Buffalo")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 128,411)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 128,411)
print("DPSecAgg")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 128,411)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 128,411)

Buffalo
SA 347.64872247219085
SA bandiwtth 9.830541312
DPSecAgg
SA 8120.888476519584
SA bandiwtth 13.201359456


In [4]:
print("Buffalo")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 256,387)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_ss_f0.01.csv", 256,387)
print("DPSecAgg")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 256,387)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/stevens_f0.01.csv", 256,387)

Buffalo
SA 822.7216954350472
SA bandiwtth 389.616095232
DPSecAgg
SA 15392.618955688476
SA bandiwtth 395.938981728


In [5]:
print("Buffalo+")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_f0.01.csv", 16, 1659)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/ours_f0.01.csv", 16, 1659)
print("LightVeriFL")
run_sa("/home/taiello/projects/fl-med-devices/results_fl/lightveri-fl_f0.01.csv", 16, 1659)
bandwith_sa("/home/taiello/projects/fl-med-devices/results_fl/lightveri-fl_f0.01.csv", 16, 1659)

Buffalo+
SA 767.4119146633149
SA bandiwtth 0.1132738656
LightVeriFL
SA 119068.91036276102
SA bandiwtth 0.2106956544
