In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

from utils.trainer_utils import TrainConfig
from flearn.trainer.fedavg import FedAvg
from flearn.trainer.fesem import FeSEM
from flearn.trainer.fedgroup import FedGroup
from flearn.trainer.ifca import IFCA

def main(dataset, model, trainer, glr, dynamic, swap_p, shift_type, RAC, RCC):
    config = TrainConfig(dataset, model, trainer)
    config.trainer_config['group_agg_lr'] = glr
    config.trainer_config['RAC'] = RAC
    config.trainer_config['RCC'] = RCC
    config.trainer_config['dynamic'] = dynamic
    config.trainer_config['swap_p'] = swap_p
    config.trainer_config['shift_type'] = shift_type
    
    if trainer == 'fedavg':
        trainer = FedAvg(config)
    if trainer == 'fesem':
        trainer = FeSEM(config)
    if trainer == 'ifca':
        trainer = IFCA(config)
    if trainer == 'fedgroup':
        trainer = FedGroup(config)
    trainer.train()

""" No Distribution Shift """
task_list = [
            ('femnist', 'mlp', 'fedavg', 0, False, 0, None, False, False), # FedAvg
            ('femnist', 'mlp', 'ifca', 0, False, 0, None, False, False), # IFCA
            ('femnist', 'mlp', 'fesem', 0, False, 0, None, False, False), # FeSEM
            ('femnist', 'mlp', 'fedgroup', 0, False, 0, None, False, False), # FlexCFL
            ('femnist', 'mlp', 'fedgroup', 0.1, False, 0, None, False, False), # FlexCFL with group aggregation
            ('femnist', 'mlp', 'fedgroup', 0, False, 0, None, True, False), # FlexCFL-RAC
            ('femnist', 'mlp', 'fedgroup', 0, False, 0, None, False, True), # FlexCFL-RCC
            ]

""" Swap all
task_list = [
            ('femnist', 'mlp', 'fedavg', 0, False, 0.05, "all", False, False), # FedAvg
            ('femnist', 'mlp', 'ifca', 0, False, 0.05, "all", False, False), # IFCA
            ('femnist', 'mlp', 'fesem', 0, False, 0.05, "all", False, False), # FeSEM
            ('femnist', 'mlp', 'fedgroup', 0, False, 0.05, "all", False, False), # FedGroup
            ('femnist', 'mlp', 'fedgroup', 0, True, 0.05, "all", False, False), # FlexCFL
            ('femnist', 'mlp', 'fedgroup', 5.0, True, 0.05, "all", False, False), # FlexCFL with group aggregation
            ('femnist', 'mlp', 'ifca', 0, True, 0.05, "all", False, False), # IFCA with client migration
            ('femnist', 'mlp', 'fesem', 0, True, 0.05, "all", False, False), # FeSEM with client migration
            ]
"""

""" Swap part
task_list = [
            ('femnist', 'mlp', 'fedavg', 0, False, 0.05, "part", False, False), # FedAvg
            ('femnist', 'mlp', 'ifca', 0, False, 0.05, "part", False, False), # IFCA
            ('femnist', 'mlp', 'fesem', 0, False, 0.05, "part", False, False), # FeSEM
            ('femnist', 'mlp', 'fedgroup', 0, False, 0.05, "part", False, False), # FedGroup
            ('femnist', 'mlp', 'fedgroup', 0, True, 0.05, "part", False, False), # FlexCFL
            ('femnist', 'mlp', 'fedgroup', 5.0, True, 0.05, "part", False, False), # FlexCFL with group aggregation
            ('femnist', 'mlp', 'ifca', 0, True, 0.05, "part", False, False), # IFCA with client migration
            ('femnist', 'mlp', 'fesem', 0, True, 0.05, "part", False, False), # FeSEM with client migration
            ]
"""

""" Incremental
task_list = [
            ('femnist', 'mlp', 'fedavg', 0, False, 0, "increment", False, False), # FedAvg
            ('femnist', 'mlp', 'ifca', 0, False, 0, "increment", False, False), # IFCA
            ('femnist', 'mlp', 'fesem', 0, False, 0, "increment", False, False), # FeSEM
            ('femnist', 'mlp', 'fedgroup', 0, False, 0, "increment", False, False), # FedGroup
            ('femnist', 'mlp', 'fedgroup', 0, True, 0, "increment", False, False), # FlexCFL
            ('femnist', 'mlp', 'fedgroup', 5.0, True, 0, "increment", False, False), # FlexCFL with group aggregation
            ('femnist', 'mlp', 'ifca', 0, True, 0, "increment", False, False), # IFCA with client migration
            ('femnist', 'mlp', 'fesem', 0, True, 0, "increment", False, False), # FeSEM with client migration
            ]
"""

for idx, task in enumerate(task_list):
    print(f'Runing {idx+1} task!!')
    main(*task)