From a2ce5cd3c557856da05866789d3956d4dd28c0db Mon Sep 17 00:00:00 2001 From: Sunitha Basodi Date: Wed, 4 Sep 2024 10:45:17 +0000 Subject: [PATCH] Updating intial weights based on local betas and fixing bug computing mean_y_global --- scripts/local.py | 2 ++ scripts/remote.py | 24 +++++++++++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/scripts/local.py b/scripts/local.py index e639d76..17a8023 100644 --- a/scripts/local.py +++ b/scripts/local.py @@ -80,6 +80,7 @@ def local_1(args): output_dict = { "beta_vector_local": beta_vector, "beta_vec_size": beta_vec_size, + "augmented_X_labels": list(augmented_X.columns), "number_of_regressions": len(y_labels), "computation_phase": "local_1" } @@ -275,3 +276,4 @@ def start(PARAM_DICT): else: raise ValueError("Error occurred at Local") + diff --git a/scripts/remote.py b/scripts/remote.py index dae5481..b40a8c4 100644 --- a/scripts/remote.py +++ b/scripts/remote.py @@ -61,9 +61,13 @@ def remote_1(args): #log(args, args['state']) """Need this function for performing multi-shot regression""" input_list = args["input"] - first_user_id = list(input_list.keys())[0] + first_user_id = list(sorted(input_list.keys()))[0] beta_vec_size = input_list[first_user_id]["beta_vec_size"] number_of_regressions = input_list[first_user_id]["number_of_regressions"] + mean_local_betas = np.array( + [np.array(input_list[site]["beta_vector_local"]) for site in input_list]).mean( + axis=0) + # Initial setup beta1 = 0.9 @@ -78,6 +82,15 @@ def remote_1(args): np.zeros((number_of_regressions, beta_vec_size), dtype=float) for _ in range(4) ] + + #Update initial weights based on the local beta's + local_X_labels=['const'] + local_X_labels.extend(args['cache']['X_labels']) + augmented_X_labels=input_list[first_user_id]['augmented_X_labels'] + for label_idx, curr_label in enumerate(local_X_labels): + idx = augmented_X_labels.index(curr_label) + wp[:, idx] = mean_local_betas[:, label_idx] + prev_cost = [None] * number_of_regressions iter_flag = 1 @@ -253,11 +266,11 @@ def remote_3(args): site: input_list[site]["local_stats_list"] for site in sorted_site_ids }] - mean_y_local = [input_list[site]["mean_y_local"] for site in input_list] - count_y_local = [np.array(input_list[site]["count_local"]) for site in input_list] - mean_y_global = np.array(mean_y_local) * np.array(count_y_local) + mean_y_local = np.array([input_list[site]["mean_y_local"] for site in input_list]) + count_y_local = np.array([input_list[site]["count_local"] for site in input_list]) + mean_y_global = mean_y_local * count_y_local #mean_y_global = np.average(mean_y_global, axis=0) - mean_y_global = mean_y_global.sum(axis=0) / np.sum(count_y_local) + mean_y_global = mean_y_global.sum(axis=0) / count_y_local.sum(axis=0) dof_global = sum(count_y_local) - avg_beta_vector.shape[1] @@ -417,3 +430,4 @@ def start(PARAM_DICT): else: raise ValueError("Error occurred at Remote") +