In [None]:
from utils_general import *
from utils_methods import *
import json
import logging
# from utils_methods_FedDC import train_FedDC

In [None]:
class config(object):
    def __init__(self):
        self.dataset="cifar10"
        self.device = "cuda"
        self.comm_round=2
        self.lr = 0.01
        self.batch_size = 64
        self.epochs = 50
        self.n_parties = 10
        self.seed = 42
        self.alg="fedavg"
        self.rootdir="./result/bench2/"
        self.datadir="./Folder/"
        self.beta=0.1
        self.model_name = 'cifar10_LeNet' # Model type


args = config()

In [None]:
def get_logger(logger_path):
    logging.basicConfig(
        filename=logger_path,
        # filename='/home/qinbin/test.log',
        format='[%(asctime)s] %(levelname)s: %(message)s',
        datefmt='%m-%d %H:%M', 
        level=logging.DEBUG, 
        filemode='w'
    )

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    logger.addHandler(ch)

    return logger 

In [None]:
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

In [None]:
# Dataset initialization
data_path = args.datadir
n_client = args.n_parties
# Dirichlet (0.3)

data_obj = DatasetObject(dataset=args.dataset.upper(), n_client=n_client, seed=args.seed, unbalanced_sgm=0, rule='Drichlet', rule_arg=args.beta, data_path=data_path)


In [None]:
###
# Common hyperparameters


args.weight_decay = 1e-3
args.act_prob = 1.0

# Model function
model_func = lambda : client_model(args.model_name)
init_model = model_func()


# Initalise the model for all methods with a random seed or load it from a saved initial model
# torch.manual_seed(37)
# init_model = model_func()


# if not os.path.exists('%sModel/%s/%s_init_mdl.pt' %(data_path, data_obj.name, model_name)):
#     if not os.path.exists('%sModel/%s/' %(data_path, data_obj.name)):
#         print("Create a new directory")
#         os.mkdir('%sModel/%s/' %(data_path, data_obj.name))
#     torch.save(init_model.state_dict(), '%sModel/%s/%s_init_mdl.pt' %(data_path, data_obj.name, model_name))
# else:
#     # Load model
#     init_model.load_state_dict(torch.load('%sModel/%s/%s_init_mdl.pt' %(data_path, data_obj.name, model_name)))    

tag = f"{args.dataset}-{args.model_name}-{args.alg}-N{args.n_parties}-beta{args.beta}-ep{args.epochs}-lr{args.lr}-round{args.comm_round}"

args.exp_dir = os.path.join(args.rootdir, tag) 
args.logdir = os.path.join(args.exp_dir, f"seed{args.seed}")
os.makedirs(args.logdir)
# os.makedirs(args.modeldir)

args.argument_path='experiment_arguments-%s.json' % datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S")
with open(os.path.join(args.logdir, args.argument_path), 'w') as f:
    json.dump(str(args), f)
device = torch.device(args.device)


args.log_file_name = 'experiment_log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
log_path=args.log_file_name+'.log'
logger = get_logger(logger_path=os.path.join(args.logdir, log_path))


# logging.basicConfig(filename='test.log', level=logger.info, filemode='w')
# logging.info("test")

# for handler in logging.root.handlers[:]:
#     logging.root.removeHandler(handler)



print(f'trying {args.alg}')



[fed_mdls_sel, trn_perf_sel, tst_perf_sel, fed_mdls_all, trn_perf_all, tst_perf_all] = train_FedAvg(
    data_obj=data_obj, 
    act_prob=act_prob ,
    learning_rate=learning_rate, 
    batch_size=batch_size, 
    epoch=epoch, 
    com_amount=com_amount, 
    weight_decay=weight_decay, 
    model_func=model_func, 
    init_model=init_model,
    suffix=suffix, 
    trial=False, 
    data_path=data_path,
    rand_seed=seed,
    overwrite=True,
    exp_dir,
    log_dir 
    )
        

