In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from Models import MoELSTM
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from typing import List, Tuple, Optional, Dict
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
import random
from Models import MoELSTM, LSTMModel, train_model
from Preprocess import (
    compute_metrics,
    convert_timeseries_to_numpy,
    create_dataloader,
    load_building_series,
    split_series_list,
)
from Models import model_fn
from tqdm import tqdm
from my_utils import train_model, load_energy_data_feather, get_weights, set_weights


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from AggregationStrategy import sync_aggregate,average_weights,sync_aggregate_norm,sync_aggregate_softmax, fedavgm_update

In [3]:
# def average_weights(weights_list):
#     """Averages model weights provided as a list of get_weights outputs."""
#     avg_weights = []
#     num_models = len(weights_list)
#     for layer_weights in zip(*weights_list):
#         avg_layer = np.mean(np.array(layer_weights), axis=0)
#         avg_weights.append(avg_layer)
#     return avg_weights

In [4]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [4]:
DEVICE

'cuda'

### Clustered

In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm
import random
import numpy as np

# Define your clusters manually (or randomly for now)
CLUSTERS = {
    "cluster_0": list(range(0, 30)),    # Clients 0‚Äì29
    "cluster_1": list(range(30, 60)),   # Clients 30‚Äì59
    "cluster_2": list(range(60, 90)),   # Clients 60‚Äì89
}

# Create results directory
for model_name in MODEL_NAMES:
    for cluster_name in CLUSTERS:
        os.makedirs(os.path.join("results", model_name, cluster_name), exist_ok=True)


In [None]:

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = average_weights(local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")



üåê Starting experiments for model: lstm

üåÄ Round 1/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.31it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_1.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.28it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_1.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.02it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_1.pt

üåÄ Round 2/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.46it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_2.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.07it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_2.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.74it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_2.pt

üåÄ Round 3/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.48it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_3.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.58it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_3.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.10it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_3.pt

üåÄ Round 4/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.99it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_4.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.53it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_4.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.88it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_4.pt

üåÄ Round 5/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.83it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_5.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.47it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_5.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.25it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_5.pt

üåÄ Round 6/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.84it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_6.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.87it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_6.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.85it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_6.pt

üåÄ Round 7/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.95it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_7.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.51it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_7.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.39it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_7.pt

üåÄ Round 8/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.92it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_8.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.98it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_8.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.11it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_8.pt

üåÄ Round 9/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.02it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_9.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.42it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_9.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.97it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_9.pt

üåÄ Round 10/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.46it/s]


‚úÖ Saved model: results/lstm/cluster_0/lstm_cluster_0_round_10.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.09it/s]


‚úÖ Saved model: results/lstm/cluster_1/lstm_cluster_1_round_10.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.03it/s]


‚úÖ Saved model: results/lstm/cluster_2/lstm_cluster_2_round_10.pt

üåê Starting experiments for model: gru

üåÄ Round 1/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.47it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_1.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.49it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_1.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.93it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_1.pt

üåÄ Round 2/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.11it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_2.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.04it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_2.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  4.00it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_2.pt

üåÄ Round 3/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.40it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_3.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.34it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_3.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.05it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_3.pt

üåÄ Round 4/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.12it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_4.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.18it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_4.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.20it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_4.pt

üåÄ Round 5/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.68it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_5.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.17it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_5.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.53it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_5.pt

üåÄ Round 6/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.02it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_6.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.63it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_6.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.45it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_6.pt

üåÄ Round 7/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.51it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_7.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.01it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_7.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.04it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_7.pt

üåÄ Round 8/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.88it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_8.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.95it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_8.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.16it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_8.pt

üåÄ Round 9/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.42it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_9.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.04it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_9.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.36it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_9.pt

üåÄ Round 10/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.67it/s]


‚úÖ Saved model: results/gru/cluster_0/gru_cluster_0_round_10.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.08it/s]


‚úÖ Saved model: results/gru/cluster_1/gru_cluster_1_round_10.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00,  4.62it/s]


‚úÖ Saved model: results/gru/cluster_2/gru_cluster_2_round_10.pt

üåê Starting experiments for model: moe_lstm

üåÄ Round 1/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.22it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_1.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.23it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_1.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  2.99it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_1.pt

üåÄ Round 2/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.60it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_2.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.36it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_2.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.25it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_2.pt

üåÄ Round 3/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.19it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_3.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.21it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_3.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.23it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_3.pt

üåÄ Round 4/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.42it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_4.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.12it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_4.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.06it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_4.pt

üåÄ Round 5/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.42it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_5.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.07it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_5.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.06it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_5.pt

üåÄ Round 6/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.43it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_6.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.19it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_6.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.35it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_6.pt

üåÄ Round 7/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.06it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_7.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.27it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_7.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  2.99it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_7.pt

üåÄ Round 8/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.23it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_8.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.41it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_8.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.43it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_8.pt

üåÄ Round 9/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.24it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_9.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.22it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_9.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.41it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_9.pt

üåÄ Round 10/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.47it/s]


‚úÖ Saved model: results/moe_lstm/cluster_0/moe_lstm_cluster_0_round_10.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.24it/s]


‚úÖ Saved model: results/moe_lstm/cluster_1/moe_lstm_cluster_1_round_10.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.27it/s]


‚úÖ Saved model: results/moe_lstm/cluster_2/moe_lstm_cluster_2_round_10.pt

üåê Starting experiments for model: moe_gru

üåÄ Round 1/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.27it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_1.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.20it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_1.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.30it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_1.pt

üåÄ Round 2/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.49it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_2.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.58it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_2.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.26it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_2.pt

üåÄ Round 3/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.08it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_3.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.22it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_3.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.20it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_3.pt

üåÄ Round 4/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.47it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_4.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.10it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_4.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.35it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_4.pt

üåÄ Round 5/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.27it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_5.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.09it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_5.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.28it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_5.pt

üåÄ Round 6/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.29it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_6.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.62it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_6.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.26it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_6.pt

üåÄ Round 7/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.33it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_7.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.51it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_7.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.09it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_7.pt

üåÄ Round 8/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.43it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_8.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.37it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_8.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.27it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_8.pt

üåÄ Round 9/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.24it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_9.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.13it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_9.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.26it/s]


‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_9.pt

üåÄ Round 10/10
üîπ Processing cluster_0 with 30 clients


Training cluster_0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.43it/s]


‚úÖ Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_10.pt
üîπ Processing cluster_1 with 30 clients


Training cluster_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.12it/s]


‚úÖ Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_10.pt
üîπ Processing cluster_2 with 30 clients


Training cluster_2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  3.53it/s]

‚úÖ Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_10.pt





### FedAVG

In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = average_weights(local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvg.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients:   0%|          | 0/150 [00:00<?, ?it/s]

Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [00:48<00:00,  3.07it/s]


Saved global model to results/lstm/lstm_round_1_fedAvg.pt
Round 2/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [00:47<00:00,  3.14it/s]


Saved global model to results/lstm/lstm_round_2_fedAvg.pt
Round 3/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [00:48<00:00,  3.11it/s]


Saved global model to results/lstm/lstm_round_3_fedAvg.pt
Round 4/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [00:48<00:00,  3.12it/s]


Saved global model to results/lstm/lstm_round_4_fedAvg.pt
Round 5/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [00:47<00:00,  3.14it/s]


Saved global model to results/lstm/lstm_round_5_fedAvg.pt
Round 6/10


Training clients:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 145/150 [00:46<00:01,  2.96it/s]

### FedAvgM

In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    velocity = [np.zeros_like(w) for w in global_weights]

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights, velocity = fedavgm_update(global_weights,local_weights,velocity)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvgM.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [00:42<00:00,  2.83s/it]


Saved global model to results\lstm\lstm_round_1_kr_norm.pt
Round 2/10


Training clients:  40%|‚ñà‚ñà‚ñà‚ñà      | 6/15 [00:20<00:30,  3.39s/it]


KeyboardInterrupt: 

### Kuramoto FedAvg

In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_norm(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_norm.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_softmax(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_sft.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [01:16<00:00,  7.61s/it]


SYNC Weights:[1.1280927881409405e-24, 3.455221414490252e-24, 2.7815943992592906e-10, 1.5342428063772783e-20, 8.519412258989334e-11, 1.0, 4.7079616409486216e-23, 0.0, 0.0, 1.4587423713408043e-08] 
Saved global model to results\lstm\lstm_round_1_kr.pt
Round 2/10


Training clients: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [01:16<00:00,  7.66s/it]


SYNC Weights:[0.051666800892013795, 2.0060822225145682e-07, 0.0038842655760772567, 0.33163288880904007, 0.05313184572821639, 0.17136356681197534, 4.1494413146093236e-12, 2.694967957872787e-05, 0.3882710659981862, 2.2412009829798056e-05] 
Saved global model to results\lstm\lstm_round_2_kr.pt
Round 3/10


Training clients:  30%|‚ñà‚ñà‚ñà       | 3/10 [00:19<00:46,  6.64s/it]


KeyboardInterrupt: 

### Diff-Aware Fed Avg

### Diff-Sync FedAvg