In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from typing import List, Tuple, Optional, Dict
import numpy as np
from Models import MoELSTM
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
import random
from Models import MoELSTM, LSTMModel, train_model
from Preprocess import (
    compute_metrics,
    convert_timeseries_to_numpy,
    create_dataloader,
    load_building_series,
    split_series_list,
)
import pandas as pd
from collections import defaultdict
import os
import torch
import torch.optim as optim
from tqdm import tqdm


from Models import model_fn
from tqdm import tqdm
from my_utils import train_model, load_energy_data_feather, get_weights, set_weights


In [9]:
from AggregationStrategy import sync_aggregate,average_weights,sync_aggregate_norm,sync_aggregate_softmax, fedavgm_update

In [10]:
# def average_weights(weights_list):
#     """Averages model weights provided as a list of get_weights outputs."""
#     avg_weights = []
#     num_models = len(weights_list)
#     for layer_weights in zip(*weights_list):
#         avg_layer = np.mean(np.array(layer_weights), axis=0)
#         avg_weights.append(avg_layer)
#     return avg_weights

In [11]:
df = pd.read_feather("train_final.feather")

In [12]:
df.head()

Unnamed: 0,building_id,meter,timestamp,meter_reading,primary_use,air_temperature
7593144,0,0,2016-05-21 01:00:00,72.221012,Education,25.6
7593145,1,0,2016-05-21 01:00:00,39.611586,Education,25.6
7593146,2,0,2016-05-21 01:00:00,1.920567,Education,25.6
7593147,3,0,2016-05-21 01:00:00,111.532464,Education,25.6
7593148,4,0,2016-05-21 01:00:00,456.734799,Education,25.6


In [13]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 11712248 entries, 7593144 to 20216099
Data columns (total 6 columns):
 #   Column           Dtype         
---  ------           -----         
 0   building_id      int64         
 1   meter            int64         
 2   timestamp        datetime64[ns]
 3   meter_reading    float64       
 4   primary_use      object        
 5   air_temperature  float64       
dtypes: datetime64[ns](1), float64(2), int64(2), object(1)
memory usage: 625.5+ MB


In [14]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru"]

# Config
NUM_CLIENTS = 1410
CLIENT_FRAC = 0.15
NUM_ROUNDS = 50
LOCAL_EPOCHS = 5
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_final.feather" # "meter_0_data_cleaned.feather"


### FedAvg

In [8]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = average_weights(local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvg.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/50


Training clients: 100%|██████████| 211/211 [03:25<00:00,  1.03it/s]


Saved global model to results/lstm/lstm_round_1_fedAvg.pt
Round 2/50


Training clients: 100%|██████████| 211/211 [03:51<00:00,  1.10s/it]


Saved global model to results/lstm/lstm_round_2_fedAvg.pt
Round 3/50


Training clients: 100%|██████████| 211/211 [03:53<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_3_fedAvg.pt
Round 4/50


Training clients: 100%|██████████| 211/211 [03:53<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_4_fedAvg.pt
Round 5/50


Training clients: 100%|██████████| 211/211 [03:56<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_5_fedAvg.pt
Round 6/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_6_fedAvg.pt
Round 7/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_7_fedAvg.pt
Round 8/50


Training clients: 100%|██████████| 211/211 [03:55<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_8_fedAvg.pt
Round 9/50


Training clients: 100%|██████████| 211/211 [03:56<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_9_fedAvg.pt
Round 10/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_10_fedAvg.pt
Round 11/50


Training clients: 100%|██████████| 211/211 [03:54<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_11_fedAvg.pt
Round 12/50


Training clients: 100%|██████████| 211/211 [03:49<00:00,  1.09s/it]


Saved global model to results/lstm/lstm_round_12_fedAvg.pt
Round 13/50


Training clients: 100%|██████████| 211/211 [03:54<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_13_fedAvg.pt
Round 14/50


Training clients: 100%|██████████| 211/211 [03:54<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_14_fedAvg.pt
Round 15/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_15_fedAvg.pt
Round 16/50


Training clients: 100%|██████████| 211/211 [03:55<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_16_fedAvg.pt
Round 17/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_17_fedAvg.pt
Round 18/50


Training clients: 100%|██████████| 211/211 [04:03<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_18_fedAvg.pt
Round 19/50


Training clients: 100%|██████████| 211/211 [03:59<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_19_fedAvg.pt
Round 20/50


Training clients: 100%|██████████| 211/211 [03:56<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_20_fedAvg.pt
Round 21/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_21_fedAvg.pt
Round 22/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_22_fedAvg.pt
Round 23/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_23_fedAvg.pt
Round 24/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_24_fedAvg.pt
Round 25/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_25_fedAvg.pt
Round 26/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_26_fedAvg.pt
Round 27/50


Training clients: 100%|██████████| 211/211 [03:59<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_27_fedAvg.pt
Round 28/50


Training clients: 100%|██████████| 211/211 [04:03<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_28_fedAvg.pt
Round 29/50


Training clients: 100%|██████████| 211/211 [04:02<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_29_fedAvg.pt
Round 30/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_30_fedAvg.pt
Round 31/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_31_fedAvg.pt
Round 32/50


Training clients: 100%|██████████| 211/211 [03:54<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_32_fedAvg.pt
Round 33/50


Training clients: 100%|██████████| 211/211 [04:03<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_33_fedAvg.pt
Round 34/50


Training clients: 100%|██████████| 211/211 [04:02<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_34_fedAvg.pt
Round 35/50


Training clients: 100%|██████████| 211/211 [03:58<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_35_fedAvg.pt
Round 36/50


Training clients: 100%|██████████| 211/211 [03:56<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_36_fedAvg.pt
Round 37/50


Training clients: 100%|██████████| 211/211 [04:02<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_37_fedAvg.pt
Round 38/50


Training clients: 100%|██████████| 211/211 [03:58<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_38_fedAvg.pt
Round 39/50


Training clients: 100%|██████████| 211/211 [04:00<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_39_fedAvg.pt
Round 40/50


Training clients: 100%|██████████| 211/211 [04:03<00:00,  1.15s/it]


Saved global model to results/lstm/lstm_round_40_fedAvg.pt
Round 41/50


Training clients: 100%|██████████| 211/211 [03:55<00:00,  1.11s/it]


Saved global model to results/lstm/lstm_round_41_fedAvg.pt
Round 42/50


Training clients: 100%|██████████| 211/211 [04:01<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_42_fedAvg.pt
Round 43/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_43_fedAvg.pt
Round 44/50


Training clients: 100%|██████████| 211/211 [03:56<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_44_fedAvg.pt
Round 45/50


Training clients: 100%|██████████| 211/211 [03:59<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_45_fedAvg.pt
Round 46/50


Training clients: 100%|██████████| 211/211 [03:52<00:00,  1.10s/it]


Saved global model to results/lstm/lstm_round_46_fedAvg.pt
Round 47/50


Training clients: 100%|██████████| 211/211 [03:58<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_47_fedAvg.pt
Round 48/50


Training clients: 100%|██████████| 211/211 [03:59<00:00,  1.14s/it]


Saved global model to results/lstm/lstm_round_48_fedAvg.pt
Round 49/50


Training clients: 100%|██████████| 211/211 [03:57<00:00,  1.12s/it]


Saved global model to results/lstm/lstm_round_49_fedAvg.pt
Round 50/50


Training clients: 100%|██████████| 211/211 [03:59<00:00,  1.13s/it]


Saved global model to results/lstm/lstm_round_50_fedAvg.pt
Starting experiment with model: gru
Round 1/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_1_fedAvg.pt
Round 2/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_2_fedAvg.pt
Round 3/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_3_fedAvg.pt
Round 4/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_4_fedAvg.pt
Round 5/50


Training clients: 100%|██████████| 211/211 [04:09<00:00,  1.18s/it]


Saved global model to results/gru/gru_round_5_fedAvg.pt
Round 6/50


Training clients: 100%|██████████| 211/211 [04:11<00:00,  1.19s/it]


Saved global model to results/gru/gru_round_6_fedAvg.pt
Round 7/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_7_fedAvg.pt
Round 8/50


Training clients: 100%|██████████| 211/211 [04:07<00:00,  1.17s/it]


Saved global model to results/gru/gru_round_8_fedAvg.pt
Round 9/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_9_fedAvg.pt
Round 10/50


Training clients: 100%|██████████| 211/211 [04:08<00:00,  1.18s/it]


Saved global model to results/gru/gru_round_10_fedAvg.pt
Round 11/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_11_fedAvg.pt
Round 12/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_12_fedAvg.pt
Round 13/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_13_fedAvg.pt
Round 14/50


Training clients: 100%|██████████| 211/211 [04:16<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_14_fedAvg.pt
Round 15/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_15_fedAvg.pt
Round 16/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_16_fedAvg.pt
Round 17/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_17_fedAvg.pt
Round 18/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_18_fedAvg.pt
Round 19/50


Training clients: 100%|██████████| 211/211 [04:11<00:00,  1.19s/it]


Saved global model to results/gru/gru_round_19_fedAvg.pt
Round 20/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_20_fedAvg.pt
Round 21/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_21_fedAvg.pt
Round 22/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_22_fedAvg.pt
Round 23/50


Training clients: 100%|██████████| 211/211 [04:19<00:00,  1.23s/it]


Saved global model to results/gru/gru_round_23_fedAvg.pt
Round 24/50


Training clients: 100%|██████████| 211/211 [04:16<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_24_fedAvg.pt
Round 25/50


Training clients: 100%|██████████| 211/211 [04:14<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_25_fedAvg.pt
Round 26/50


Training clients: 100%|██████████| 211/211 [04:10<00:00,  1.19s/it]


Saved global model to results/gru/gru_round_26_fedAvg.pt
Round 27/50


Training clients: 100%|██████████| 211/211 [04:16<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_27_fedAvg.pt
Round 28/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_28_fedAvg.pt
Round 29/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_29_fedAvg.pt
Round 30/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_30_fedAvg.pt
Round 31/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_31_fedAvg.pt
Round 32/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_32_fedAvg.pt
Round 33/50


Training clients: 100%|██████████| 211/211 [04:14<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_33_fedAvg.pt
Round 34/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_34_fedAvg.pt
Round 35/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_35_fedAvg.pt
Round 36/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_36_fedAvg.pt
Round 37/50


Training clients: 100%|██████████| 211/211 [04:16<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_37_fedAvg.pt
Round 38/50


Training clients: 100%|██████████| 211/211 [04:10<00:00,  1.19s/it]


Saved global model to results/gru/gru_round_38_fedAvg.pt
Round 39/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_39_fedAvg.pt
Round 40/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_40_fedAvg.pt
Round 41/50


Training clients: 100%|██████████| 211/211 [04:12<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_41_fedAvg.pt
Round 42/50


Training clients: 100%|██████████| 211/211 [04:18<00:00,  1.23s/it]


Saved global model to results/gru/gru_round_42_fedAvg.pt
Round 43/50


Training clients: 100%|██████████| 211/211 [04:17<00:00,  1.22s/it]


Saved global model to results/gru/gru_round_43_fedAvg.pt
Round 44/50


Training clients: 100%|██████████| 211/211 [04:14<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_44_fedAvg.pt
Round 45/50


Training clients: 100%|██████████| 211/211 [04:17<00:00,  1.22s/it]


Saved global model to results/gru/gru_round_45_fedAvg.pt
Round 46/50


Training clients: 100%|██████████| 211/211 [04:16<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_46_fedAvg.pt
Round 47/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_47_fedAvg.pt
Round 48/50


Training clients: 100%|██████████| 211/211 [04:13<00:00,  1.20s/it]


Saved global model to results/gru/gru_round_48_fedAvg.pt
Round 49/50


Training clients: 100%|██████████| 211/211 [04:15<00:00,  1.21s/it]


Saved global model to results/gru/gru_round_49_fedAvg.pt
Round 50/50


Training clients: 100%|██████████| 211/211 [04:17<00:00,  1.22s/it]

Saved global model to results/gru/gru_round_50_fedAvg.pt





### DiffAware FedAvg

### Diff-Aware Fed Avg

In [15]:
class TimeSeriesDifficultyWeight:
    def __init__(self, num_clients, accumulate_iters=20):
        self.num_clients = num_clients
        self.last_loss = torch.ones(num_clients).float().to(DEVICE)
        self.learn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.unlearn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.ema_difficulty = torch.ones(num_clients).float().to(DEVICE)
        self.accumulate_iters = accumulate_iters

    def update(self, cid: int, loss_history: List[float]) -> float:
        """
        Update difficulty based on loss trend for a client.
        Expects a list of per-epoch losses.
        """
        current_loss = torch.tensor(loss_history[-1], dtype=torch.float32).to(DEVICE)
        previous_loss = self.last_loss[cid]
        delta = current_loss - previous_loss
        ratio = torch.log((current_loss + 1e-8) / (previous_loss + 1e-8))

        learn = torch.where(delta < 0, -delta * ratio, torch.tensor(0.0, device=current_loss.device))
        unlearn = torch.where(delta >= 0, delta * ratio, torch.tensor(0.0, device=current_loss.device))

        # EMA update
        momentum = (self.accumulate_iters - 1) / self.accumulate_iters
        self.learn_score[cid] = momentum * self.learn_score[cid] + (1 - momentum) * learn
        self.unlearn_score[cid] = momentum * self.unlearn_score[cid] + (1 - momentum) * unlearn

        # Difficulty score
        diff_ratio = (self.unlearn_score[cid] + 1e-8) / (self.learn_score[cid] + 1e-8)
        difficulty = diff_ratio #torch.pow(diff_ratio, 1 / 5)

        # Smooth difficulty over rounds
        self.ema_difficulty[cid] = momentum * self.ema_difficulty[cid] + (1 - momentum) * difficulty

        self.last_loss[cid] = current_loss
        return self.ema_difficulty[cid].item()

    def get_normalized_weights(self, client_ids: List[int]) -> List[float]:
        weights = [self.ema_difficulty[cid].item() for cid in client_ids]
        total = sum(weights)
        if total == 0:
            return [1.0 / len(client_ids)] * len(client_ids)
        return [w / total for w in weights]
    
    def get_sampling_probabilities(self, min_prob=0.05):
        difficulty = self.ema_difficulty
        inv_difficulty = 1.0 / (difficulty + 1e-6)
        inv_difficulty = inv_difficulty / inv_difficulty.sum()
        probs = torch.clamp(inv_difficulty, min=min_prob)
        return (probs / probs.sum()).cpu().numpy()



## SCAFFOLD

In [16]:
def train_model_scaffold(
    local_model,
    train_loader,         # global_model, train_loader
    global_weights,    # x
    server_c,          # c
    client_ci,         # cᵢ
    device=DEVICE,
    learning_rate= LR,
    loss_fn=None,
    optimizer_class=optim.Adam,
    epochs= LOCAL_EPOCHS  # 50
):
    """Train client with SCAFFOLD correction. Return Δy, Δc, new cᵢ, final weights."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # print(f"[DEBUG] Training on: {device}")

    
    local_model.to(device)
    loss_fn = loss_fn or nn.MSELoss()
    optimizer = optimizer_class(local_model.parameters(), lr=learning_rate)
    loss_history = []

    local_model.train()
    total_steps = 0

    for epoch in range(epochs):
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            if y_batch.dim() == 3 and y_batch.shape[-1] == 1:
                y_batch = y_batch.squeeze(-1)

            optimizer.zero_grad()
            output = local_model(X_batch)
            loss = loss_fn(output, y_batch)
            loss.backward()
            optimizer.step()

            # ✅ SCAFFOLD correction: adjust each param after normal SGD step  # doubt at this step
            with torch.no_grad():
                for p, sc_np, ci_np in zip(local_model.parameters(), server_c, client_ci):
                    sc_tensor = torch.tensor(sc_np, dtype=p.dtype, device=p.device)
                    ci_tensor = torch.tensor(ci_np, dtype=p.dtype, device=p.device)
                    p -= learning_rate * (sc_tensor - ci_tensor)

            epoch_loss += loss.item()
            total_steps += 1
        
        loss_history.append(epoch_loss / len(train_loader))

    # Compute deltas
    local_weights = get_weights(local_model)
    delta_y = [lw - gw for lw, gw in zip(local_weights, global_weights)]

    # K = total_steps
    K = total_steps
    new_ci = []
    delta_c = []

    for gw, lw, ci, sc in zip(global_weights, local_weights, client_ci, server_c):
        ci_new = ci - sc + (gw - lw) / (K * learning_rate)
        new_ci.append(ci_new)  # doubt can overide new_ci
        delta_c.append(ci_new - ci)

    return delta_y, delta_c, new_ci, local_weights, loss_history

In [None]:


for model_name in MODEL_NAMES:
    difficulty_tracker = TimeSeriesDifficultyWeight(num_clients=NUM_CLIENTS)
    print(f"Starting experiment with model: {model_name}")

    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)
    print(f"Using device: {DEVICE}")

    server_c = [np.zeros_like(w) for w in global_weights]
    client_cs = {cid: [np.zeros_like(w) for w in global_weights] for cid in range(NUM_CLIENTS)}

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")

        # === Difficulty-aware sampling ===
        sampling_probs = difficulty_tracker.get_sampling_probabilities(min_prob=0.001)
        sampled_clients = np.random.choice(
            np.arange(NUM_CLIENTS),
            size=int(CLIENT_FRAC * NUM_CLIENTS),
            replace=False,
            p=sampling_probs
        )
        print(f"Sampled {len(sampled_clients)} clients")

        local_weight_deltas = []
        local_c_deltas = []

        for cid in tqdm(sampled_clients):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)

            train_loader, _ = load_energy_data_feather(cid, filepath=DATA_FILE)

            delta_y, delta_c, new_ci, local_weights, loss_history = train_model_scaffold(
                local_model, train_loader,
                global_weights=global_weights,
                server_c=server_c,
                client_ci=client_cs[cid],
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )

            # === Difficulty update ===
            difficulty_tracker.update(cid, loss_history)

            local_weight_deltas.append(delta_y)
            local_c_deltas.append(delta_c)
            client_cs[cid] = new_ci

        # === Aggregate and update global weights ===
        mean_delta_y = average_weights(local_weight_deltas)
        global_weights = [gw + mean_delta_y[i] for i, gw in enumerate(global_weights)]

        mean_delta_c = average_weights(local_c_deltas)
        frac = len(sampled_clients) / NUM_CLIENTS
        server_c = [sc + frac * mean_delta_c[i] for i, sc in enumerate(server_c)]

        set_weights(global_model, global_weights)

        ckpt_path = os.path.join("results", model_name, f"{model_name}_round_{rnd+1}_scaffold_diff001.pt")
        torch.save(global_model.state_dict(), ckpt_path)
        print(f"Saved: {ckpt_path}")


Starting experiment with model: lstm
Using device: cuda
Round 1/50
Sampled 211 clients


100%|██████████| 211/211 [08:16<00:00,  2.36s/it]


Saved: results/lstm/lstm_round_1_scaffold_diff001.pt
Round 2/50
Sampled 211 clients


100%|██████████| 211/211 [08:29<00:00,  2.42s/it]


Saved: results/lstm/lstm_round_2_scaffold_diff001.pt
Round 3/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.14s/it]


Saved: results/lstm/lstm_round_3_scaffold_diff001.pt
Round 4/50
Sampled 211 clients


100%|██████████| 211/211 [07:25<00:00,  2.11s/it]


Saved: results/lstm/lstm_round_4_scaffold_diff001.pt
Round 5/50
Sampled 211 clients


100%|██████████| 211/211 [07:25<00:00,  2.11s/it]


Saved: results/lstm/lstm_round_5_scaffold_diff001.pt
Round 6/50
Sampled 211 clients


100%|██████████| 211/211 [07:23<00:00,  2.10s/it]


Saved: results/lstm/lstm_round_6_scaffold_diff001.pt
Round 7/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/lstm/lstm_round_7_scaffold_diff001.pt
Round 8/50
Sampled 211 clients


100%|██████████| 211/211 [07:08<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_8_scaffold_diff001.pt
Round 9/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_9_scaffold_diff001.pt
Round 10/50
Sampled 211 clients


100%|██████████| 211/211 [06:59<00:00,  1.99s/it]


Saved: results/lstm/lstm_round_10_scaffold_diff001.pt
Round 11/50
Sampled 211 clients


100%|██████████| 211/211 [07:11<00:00,  2.05s/it]


Saved: results/lstm/lstm_round_11_scaffold_diff001.pt
Round 12/50
Sampled 211 clients


100%|██████████| 211/211 [07:00<00:00,  1.99s/it]


Saved: results/lstm/lstm_round_12_scaffold_diff001.pt
Round 13/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_13_scaffold_diff001.pt
Round 14/50
Sampled 211 clients


100%|██████████| 211/211 [07:03<00:00,  2.01s/it]


Saved: results/lstm/lstm_round_14_scaffold_diff001.pt
Round 15/50
Sampled 211 clients


100%|██████████| 211/211 [07:08<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_15_scaffold_diff001.pt
Round 16/50
Sampled 211 clients


100%|██████████| 211/211 [07:00<00:00,  1.99s/it]


Saved: results/lstm/lstm_round_16_scaffold_diff001.pt
Round 17/50
Sampled 211 clients


100%|██████████| 211/211 [07:08<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_17_scaffold_diff001.pt
Round 18/50
Sampled 211 clients


100%|██████████| 211/211 [07:04<00:00,  2.01s/it]


Saved: results/lstm/lstm_round_18_scaffold_diff001.pt
Round 19/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_19_scaffold_diff001.pt
Round 20/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_20_scaffold_diff001.pt
Round 21/50
Sampled 211 clients


100%|██████████| 211/211 [07:10<00:00,  2.04s/it]


Saved: results/lstm/lstm_round_21_scaffold_diff001.pt
Round 22/50
Sampled 211 clients


100%|██████████| 211/211 [07:01<00:00,  2.00s/it]


Saved: results/lstm/lstm_round_22_scaffold_diff001.pt
Round 23/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_23_scaffold_diff001.pt
Round 24/50
Sampled 211 clients


100%|██████████| 211/211 [07:04<00:00,  2.01s/it]


Saved: results/lstm/lstm_round_24_scaffold_diff001.pt
Round 25/50
Sampled 211 clients


100%|██████████| 211/211 [07:01<00:00,  2.00s/it]


Saved: results/lstm/lstm_round_25_scaffold_diff001.pt
Round 26/50
Sampled 211 clients


100%|██████████| 211/211 [07:08<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_26_scaffold_diff001.pt
Round 27/50
Sampled 211 clients


100%|██████████| 211/211 [07:08<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_27_scaffold_diff001.pt
Round 28/50
Sampled 211 clients


100%|██████████| 211/211 [07:02<00:00,  2.00s/it]


Saved: results/lstm/lstm_round_28_scaffold_diff001.pt
Round 29/50
Sampled 211 clients


100%|██████████| 211/211 [07:03<00:00,  2.01s/it]


Saved: results/lstm/lstm_round_29_scaffold_diff001.pt
Round 30/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_30_scaffold_diff001.pt
Round 31/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_31_scaffold_diff001.pt
Round 32/50
Sampled 211 clients


100%|██████████| 211/211 [07:17<00:00,  2.07s/it]


Saved: results/lstm/lstm_round_32_scaffold_diff001.pt
Round 33/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_33_scaffold_diff001.pt
Round 34/50
Sampled 211 clients


100%|██████████| 211/211 [07:10<00:00,  2.04s/it]


Saved: results/lstm/lstm_round_34_scaffold_diff001.pt
Round 35/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_35_scaffold_diff001.pt
Round 36/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_36_scaffold_diff001.pt
Round 37/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_37_scaffold_diff001.pt
Round 38/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_38_scaffold_diff001.pt
Round 39/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.01s/it]


Saved: results/lstm/lstm_round_39_scaffold_diff001.pt
Round 40/50
Sampled 211 clients


100%|██████████| 211/211 [06:58<00:00,  1.99s/it]


Saved: results/lstm/lstm_round_40_scaffold_diff001.pt
Round 41/50
Sampled 211 clients


100%|██████████| 211/211 [07:18<00:00,  2.08s/it]


Saved: results/lstm/lstm_round_41_scaffold_diff001.pt
Round 42/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_42_scaffold_diff001.pt
Round 43/50
Sampled 211 clients


100%|██████████| 211/211 [06:55<00:00,  1.97s/it]


Saved: results/lstm/lstm_round_43_scaffold_diff001.pt
Round 44/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_44_scaffold_diff001.pt
Round 45/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_45_scaffold_diff001.pt
Round 46/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.02s/it]


Saved: results/lstm/lstm_round_46_scaffold_diff001.pt
Round 47/50
Sampled 211 clients


100%|██████████| 211/211 [07:15<00:00,  2.07s/it]


Saved: results/lstm/lstm_round_47_scaffold_diff001.pt
Round 48/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_48_scaffold_diff001.pt
Round 49/50
Sampled 211 clients


100%|██████████| 211/211 [07:18<00:00,  2.08s/it]


Saved: results/lstm/lstm_round_49_scaffold_diff001.pt
Round 50/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/lstm/lstm_round_50_scaffold_diff001.pt
Starting experiment with model: gru
Using device: cuda
Round 1/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_1_scaffold_diff001.pt
Round 2/50
Sampled 211 clients


100%|██████████| 211/211 [07:33<00:00,  2.15s/it]


Saved: results/gru/gru_round_2_scaffold_diff001.pt
Round 3/50
Sampled 211 clients


100%|██████████| 211/211 [07:20<00:00,  2.09s/it]


Saved: results/gru/gru_round_3_scaffold_diff001.pt
Round 4/50
Sampled 211 clients


100%|██████████| 211/211 [07:23<00:00,  2.10s/it]


Saved: results/gru/gru_round_4_scaffold_diff001.pt
Round 5/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_5_scaffold_diff001.pt
Round 6/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/gru/gru_round_6_scaffold_diff001.pt
Round 7/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.15s/it]


Saved: results/gru/gru_round_7_scaffold_diff001.pt
Round 8/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.14s/it]


Saved: results/gru/gru_round_8_scaffold_diff001.pt
Round 9/50
Sampled 211 clients


100%|██████████| 211/211 [07:33<00:00,  2.15s/it]


Saved: results/gru/gru_round_9_scaffold_diff001.pt
Round 10/50
Sampled 211 clients


100%|██████████| 211/211 [07:30<00:00,  2.13s/it]


Saved: results/gru/gru_round_10_scaffold_diff001.pt
Round 11/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.15s/it]


Saved: results/gru/gru_round_11_scaffold_diff001.pt
Round 12/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/gru/gru_round_12_scaffold_diff001.pt
Round 13/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/gru/gru_round_13_scaffold_diff001.pt
Round 14/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_14_scaffold_diff001.pt
Round 15/50
Sampled 211 clients


100%|██████████| 211/211 [07:31<00:00,  2.14s/it]


Saved: results/gru/gru_round_15_scaffold_diff001.pt
Round 16/50
Sampled 211 clients


100%|██████████| 211/211 [07:27<00:00,  2.12s/it]


Saved: results/gru/gru_round_16_scaffold_diff001.pt
Round 17/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/gru/gru_round_17_scaffold_diff001.pt
Round 18/50
Sampled 211 clients


100%|██████████| 211/211 [07:27<00:00,  2.12s/it]


Saved: results/gru/gru_round_18_scaffold_diff001.pt
Round 19/50
Sampled 211 clients


100%|██████████| 211/211 [07:28<00:00,  2.12s/it]


Saved: results/gru/gru_round_19_scaffold_diff001.pt
Round 20/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_20_scaffold_diff001.pt
Round 21/50
Sampled 211 clients


100%|██████████| 211/211 [07:31<00:00,  2.14s/it]


Saved: results/gru/gru_round_21_scaffold_diff001.pt
Round 22/50
Sampled 211 clients


100%|██████████| 211/211 [07:28<00:00,  2.13s/it]


Saved: results/gru/gru_round_22_scaffold_diff001.pt
Round 23/50
Sampled 211 clients


100%|██████████| 211/211 [07:29<00:00,  2.13s/it]


Saved: results/gru/gru_round_23_scaffold_diff001.pt
Round 24/50
Sampled 211 clients


100%|██████████| 211/211 [07:36<00:00,  2.16s/it]


Saved: results/gru/gru_round_24_scaffold_diff001.pt
Round 25/50
Sampled 211 clients


100%|██████████| 211/211 [07:34<00:00,  2.16s/it]


Saved: results/gru/gru_round_25_scaffold_diff001.pt
Round 26/50
Sampled 211 clients


100%|██████████| 211/211 [07:31<00:00,  2.14s/it]


Saved: results/gru/gru_round_26_scaffold_diff001.pt
Round 27/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.14s/it]


Saved: results/gru/gru_round_27_scaffold_diff001.pt
Round 28/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.15s/it]


Saved: results/gru/gru_round_28_scaffold_diff001.pt
Round 29/50
Sampled 211 clients


100%|██████████| 211/211 [07:41<00:00,  2.19s/it]


Saved: results/gru/gru_round_29_scaffold_diff001.pt
Round 30/50
Sampled 211 clients


100%|██████████| 211/211 [07:33<00:00,  2.15s/it]


Saved: results/gru/gru_round_30_scaffold_diff001.pt
Round 31/50
Sampled 211 clients


100%|██████████| 211/211 [07:33<00:00,  2.15s/it]


Saved: results/gru/gru_round_31_scaffold_diff001.pt
Round 32/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_32_scaffold_diff001.pt
Round 33/50
Sampled 211 clients


100%|██████████| 211/211 [07:35<00:00,  2.16s/it]


Saved: results/gru/gru_round_33_scaffold_diff001.pt
Round 34/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.15s/it]


Saved: results/gru/gru_round_34_scaffold_diff001.pt
Round 35/50
Sampled 211 clients


100%|██████████| 211/211 [07:37<00:00,  2.17s/it]


Saved: results/gru/gru_round_35_scaffold_diff001.pt
Round 36/50
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.15s/it]


Saved: results/gru/gru_round_36_scaffold_diff001.pt
Round 37/50
Sampled 211 clients


100%|██████████| 211/211 [07:26<00:00,  2.11s/it]


Saved: results/gru/gru_round_37_scaffold_diff001.pt
Round 38/50
Sampled 211 clients


100%|██████████| 211/211 [07:00<00:00,  1.99s/it]


Saved: results/gru/gru_round_38_scaffold_diff001.pt
Round 39/50
Sampled 211 clients


100%|██████████| 211/211 [07:03<00:00,  2.01s/it]


Saved: results/gru/gru_round_39_scaffold_diff001.pt
Round 40/50
Sampled 211 clients


100%|██████████| 211/211 [06:57<00:00,  1.98s/it]


Saved: results/gru/gru_round_40_scaffold_diff001.pt
Round 41/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/gru/gru_round_41_scaffold_diff001.pt
Round 42/50
Sampled 211 clients


100%|██████████| 211/211 [07:03<00:00,  2.01s/it]


Saved: results/gru/gru_round_42_scaffold_diff001.pt
Round 43/50
Sampled 211 clients


100%|██████████| 211/211 [07:06<00:00,  2.02s/it]


Saved: results/gru/gru_round_43_scaffold_diff001.pt
Round 44/50
Sampled 211 clients


100%|██████████| 211/211 [07:00<00:00,  1.99s/it]


Saved: results/gru/gru_round_44_scaffold_diff001.pt
Round 45/50
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.02s/it]


Saved: results/gru/gru_round_45_scaffold_diff001.pt
Round 46/50
Sampled 211 clients


100%|██████████| 211/211 [07:05<00:00,  2.02s/it]


Saved: results/gru/gru_round_46_scaffold_diff001.pt
Round 47/50
Sampled 211 clients


100%|██████████| 211/211 [07:04<00:00,  2.01s/it]


Saved: results/gru/gru_round_47_scaffold_diff001.pt
Round 48/50
Sampled 211 clients


100%|██████████| 211/211 [07:04<00:00,  2.01s/it]


Saved: results/gru/gru_round_48_scaffold_diff001.pt
Round 49/50
Sampled 211 clients


100%|██████████| 211/211 [07:04<00:00,  2.01s/it]


Saved: results/gru/gru_round_49_scaffold_diff001.pt
Round 50/50
Sampled 211 clients


100%|██████████| 211/211 [07:01<00:00,  2.00s/it]

Saved: results/gru/gru_round_50_scaffold_diff001.pt





: 

## SCAFFOLD without Clustering


In [21]:
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # === SINGLE global model, weights, c, and all clients use it ===
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)
    # DEVICE Print
    print(f"Using device: {DEVICE}")

    

    server_c = [np.zeros_like(w) for w in global_weights]
    client_cs = {cid: [np.zeros_like(w) for w in global_weights] for cid in range(NUM_CLIENTS)}

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")

        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        print(f"Sampled {len(sampled_clients)} clients")

        local_weight_deltas = []
        local_c_deltas = []

        for cid in tqdm(sampled_clients):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)

            train_loader, _ = load_energy_data_feather(cid, filepath=DATA_FILE)

            delta_y, delta_c, new_ci, local_weights, fin_loss = train_model_scaffold(
                local_model, train_loader,
                global_weights = global_weights,
                server_c = server_c,
                client_ci = client_cs[cid],
                device = DEVICE,
                learning_rate = LR,
                loss_fn = None,
                optimizer_class = optim.Adam,
                epochs = LOCAL_EPOCHS
            )

            local_weight_deltas.append(delta_y)
            local_c_deltas.append(delta_c)
            client_cs[cid] = new_ci

        # === Aggregate and update global weights ===
        mean_delta_y = average_weights(local_weight_deltas)
        SERVER_LR = 1.0  # you can tune this if needed

        global_weights = [
            gw + SERVER_LR * dy for gw, dy in zip(global_weights, mean_delta_y)
        ]

        mean_delta_c = average_weights(local_c_deltas)
        frac = len(sampled_clients) / NUM_CLIENTS

        server_c = [
            sc + frac * dc for sc, dc in zip(server_c, mean_delta_c)
        ]

        set_weights(global_model, global_weights)

        ckpt_path = os.path.join("results", model_name, f"{model_name}_round_{rnd+1}_scaffold_sr.pt")
        torch.save(global_model.state_dict(), ckpt_path)
        print(f"Saved: {ckpt_path}")


Starting experiment with model: lstm
Using device: cuda
Round 1/5
Sampled 211 clients


100%|██████████| 211/211 [06:35<00:00,  1.87s/it]


Saved: results/lstm/lstm_round_1_scaffold_sr.pt
Round 2/5
Sampled 211 clients


100%|██████████| 211/211 [06:46<00:00,  1.93s/it]


Saved: results/lstm/lstm_round_2_scaffold_sr.pt
Round 3/5
Sampled 211 clients


100%|██████████| 211/211 [06:54<00:00,  1.97s/it]


Saved: results/lstm/lstm_round_3_scaffold_sr.pt
Round 4/5
Sampled 211 clients


100%|██████████| 211/211 [06:49<00:00,  1.94s/it]


Saved: results/lstm/lstm_round_4_scaffold_sr.pt
Round 5/5
Sampled 211 clients


100%|██████████| 211/211 [06:46<00:00,  1.93s/it]


Saved: results/lstm/lstm_round_5_scaffold_sr.pt
Starting experiment with model: gru
Using device: cuda
Round 1/5
Sampled 211 clients


100%|██████████| 211/211 [06:39<00:00,  1.89s/it]


Saved: results/gru/gru_round_1_scaffold_sr.pt
Round 2/5
Sampled 211 clients


100%|██████████| 211/211 [06:41<00:00,  1.90s/it]


Saved: results/gru/gru_round_2_scaffold_sr.pt
Round 3/5
Sampled 211 clients


100%|██████████| 211/211 [07:07<00:00,  2.03s/it]


Saved: results/gru/gru_round_3_scaffold_sr.pt
Round 4/5
Sampled 211 clients


100%|██████████| 211/211 [07:12<00:00,  2.05s/it]


Saved: results/gru/gru_round_4_scaffold_sr.pt
Round 5/5
Sampled 211 clients


100%|██████████| 211/211 [07:15<00:00,  2.07s/it]


Saved: results/gru/gru_round_5_scaffold_sr.pt
Starting experiment with model: moe_lstm
Using device: cuda
Round 1/5
Sampled 211 clients


100%|██████████| 211/211 [09:40<00:00,  2.75s/it]


Saved: results/moe_lstm/moe_lstm_round_1_scaffold_sr.pt
Round 2/5
Sampled 211 clients


100%|██████████| 211/211 [10:01<00:00,  2.85s/it]


Saved: results/moe_lstm/moe_lstm_round_2_scaffold_sr.pt
Round 3/5
Sampled 211 clients


100%|██████████| 211/211 [10:07<00:00,  2.88s/it]


Saved: results/moe_lstm/moe_lstm_round_3_scaffold_sr.pt
Round 4/5
Sampled 211 clients


100%|██████████| 211/211 [10:11<00:00,  2.90s/it]


Saved: results/moe_lstm/moe_lstm_round_4_scaffold_sr.pt
Round 5/5
Sampled 211 clients


100%|██████████| 211/211 [08:27<00:00,  2.40s/it]


Saved: results/moe_lstm/moe_lstm_round_5_scaffold_sr.pt
Starting experiment with model: moe_gru
Using device: cuda
Round 1/5
Sampled 211 clients


100%|██████████| 211/211 [08:37<00:00,  2.45s/it]


Saved: results/moe_gru/moe_gru_round_1_scaffold_sr.pt
Round 2/5
Sampled 211 clients


100%|██████████| 211/211 [08:03<00:00,  2.29s/it]


Saved: results/moe_gru/moe_gru_round_2_scaffold_sr.pt
Round 3/5
Sampled 211 clients


100%|██████████| 211/211 [07:26<00:00,  2.12s/it]


Saved: results/moe_gru/moe_gru_round_3_scaffold_sr.pt
Round 4/5
Sampled 211 clients


100%|██████████| 211/211 [07:33<00:00,  2.15s/it]


Saved: results/moe_gru/moe_gru_round_4_scaffold_sr.pt
Round 5/5
Sampled 211 clients


100%|██████████| 211/211 [07:32<00:00,  2.14s/it]

Saved: results/moe_gru/moe_gru_round_5_scaffold_sr.pt





In [34]:
print("torch.cuda.is_available():", torch.cuda.is_available())
print(f"Using device: {DEVICE}")


torch.cuda.is_available(): False
Using device: cpu


In [35]:
import sys
import torch
print("sys.executable:", sys.executable)
print("torch.__version__:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("torch.cuda.is_available():", torch.cuda.is_available())


sys.executable: /home/user/anaconda3/envs/Priyanka/bin/python
torch.__version__: 2.6.0.dev20241112
torch.version.cuda: None
torch.cuda.is_available(): False


In [27]:
def train_model_scaffold(
    local_model,
    train_loader,         # global_model, train_loader
    global_weights,    # x
    server_c,          # c
    client_ci,         # cᵢ
    device=DEVICE,
    learning_rate= LR,
    loss_fn=None,
    optimizer_class=optim.Adam,
    epochs= LOCAL_EPOCHS  # 50
):
    """Train client with SCAFFOLD correction. Return Δy, Δc, new cᵢ, final weights."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # print(f"[DEBUG] Training on: {device}")

    
    local_model.to(device)
    loss_fn = loss_fn or nn.MSELoss()
    optimizer = optimizer_class(local_model.parameters(), lr=learning_rate)
    loss_history = []

    local_model.train()
    total_steps = 0

    for epoch in range(epochs):
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            if y_batch.dim() == 3 and y_batch.shape[-1] == 1:
                y_batch = y_batch.squeeze(-1)

            optimizer.zero_grad()
            output = local_model(X_batch)
            loss = loss_fn(output, y_batch)
            loss.backward()
            optimizer.step()

            # ✅ SCAFFOLD correction: adjust each param after normal SGD step  # doubt at this step
            with torch.no_grad():
                for p, sc_np, ci_np in zip(local_model.parameters(), server_c, client_ci):
                    sc_tensor = torch.tensor(sc_np, dtype=p.dtype, device=p.device)
                    ci_tensor = torch.tensor(ci_np, dtype=p.dtype, device=p.device)
                    p -= learning_rate * (sc_tensor - ci_tensor)

            epoch_loss += loss.item()
            total_steps += 1
        
        loss_history.append(epoch_loss / len(train_loader))

    # Compute deltas
    local_weights = get_weights(local_model)
    delta_y = [lw - gw for lw, gw in zip(local_weights, global_weights)]

    # K = total_steps
    K = total_steps
    new_ci = []
    delta_c = []

    for gw, lw, ci, sc in zip(global_weights, local_weights, client_ci, server_c):
        ci_new = ci - sc + (gw - lw) / (K * learning_rate)
        new_ci.append(ci_new)  # doubt can overide new_ci
        delta_c.append(ci_new - ci)

    return delta_y, delta_c, new_ci, local_weights, loss_history

In [38]:
# print(f"[DEBUG] Training on: {device}")
import sys
import torch
print("sys.executable:", sys.executable)
print("torch.__version__:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("torch.cuda.is_available():", torch.cuda.is_available())


sys.executable: /home/user/anaconda3/envs/Priyanka/bin/python
torch.__version__: 2.6.0.dev20241112
torch.version.cuda: None
torch.cuda.is_available(): False


## Cluster Scaffold

In [None]:

# # Main experiment loop
# for model_name in MODEL_NAMES:
#     print(f"Starting experiments for model: {model_name}")

#     # Initialize per-cluster model weights
#     cluster_models = {}
#     cluster_weights = {}
#     global_c = {}
#     client_cs = {}

#     for cluster_name in CLUSTERS:
#         model = model_fn(model_name).to(DEVICE) # model_fn returns batch_size, output_size
#         # create one model per cluster and store its initial weights
#         # model = model_fn(model_name)
#         # print("[DEBUG] model_fn output:", model)
#         # print("[DEBUG] type(model_fn output):", type(model))
#         # model.to(DEVICE)
#         cluster_models[cluster_name] = model
#         cluster_weights[cluster_name] = get_weights(model)

#         # Initialize c_i for each cluster
#         global_c[cluster_name] = [np.zeros_like(p) for p in cluster_weights[cluster_name]]
#         # client_cs[cluster_name] = {cid: [torch.zeros_like(p) for p in cluster_weights[cluster_name]] for cid in CLUSTERS[cluster_name]}
#         for cid in CLUSTERS[cluster_name]:
#             client_cs[cid] = [np.zeros_like(p) for p in cluster_weights[cluster_name]]

#     for rnd in range(NUM_ROUNDS):
#         print(f"\Round {rnd+1}/{NUM_ROUNDS}")

#         for cluster_name, client_ids in CLUSTERS.items():  # howw random clients id ? selection
#             print(f" Processing {cluster_name} with {len(client_ids)} clients")

#             # Sample a fraction of clients from the cluster
#             sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
#             print(f"Sampling {len(sampled_clients)} Clients")
#             local_weight_deltas = []
#             local_c_deltas = [] # Store local c_i deltas for each client

#             for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
#                 local_model = model_fn(model_name).to(DEVICE) # batch_size, output_size not this its Pythorch object an instance of nn.Module
#                 set_weights(local_model, cluster_weights[cluster_name])

#                 train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
#                 # updated_weights, fin_loss = train_model_scaffold(
#                 #     local_model, train_loader,
#                 #     device=DEVICE,
#                 #     learning_rate=LR,
#                 #     loss_fn=None,
#                 #     optimizer_class=optim.Adam,
#                 #     epochs=LOCAL_EPOCHS
#                 # )
#                 #Local SCAFFOLD training
#                 delta_y, delta_c, new_ci, local_weights, fin_loss = train_model_scaffold(
#                     local_model, train_loader,
#                     global_weights = cluster_weights[cluster_name],  # 
#                     device=DEVICE,
#                     learning_rate=LR,
#                     loss_fn=None,
#                     optimizer_class=optim.Adam,
#                     epochs=LOCAL_EPOCHS,
#                     server_c =global_c[cluster_name],
#                     client_ci=client_cs[cid]
#                 )


#                 # local_weights.append(updated_weights)
#                 local_weight_deltas.append(delta_y)
#                 # Update client c_i
#                 client_cs[cid] = new_ci  # Update c_i for this client
#                 local_c_deltas.append(delta_c)  # Store local c_i deltas for each client)

#             # Aggregate and update cluster model
#             mean_delta_y = average_weights(local_weight_deltas)
#             SERVER_LR = 1.0  # or tune this!

#             updated_cluster_weights = [
#                 gw + SERVER_LR * dy for gw, dy in zip(cluster_weights[cluster_name], mean_delta_y)
#             ]

#             # set_weights(cluster_models[cluster_name], updated_cluster_weights)
#             # updated_cluster_weights = average_weights(local_weights)
#             set_weights(cluster_models[cluster_name], updated_cluster_weights)
#             cluster_weights[cluster_name] = updated_cluster_weights

#             mean_delta_c = average_weights(local_c_deltas, client_weights=None)

#             cluster_size = len(client_ids)
#             frac = len(sampled_clients) / cluster_size 

#             # Update global c for the cluster
#             global_c[cluster_name] = [
#                 gc + frac * dc for gc, dc in zip(global_c[cluster_name], mean_delta_c) 
#             ]


#             # Save checkpoint
#             ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}_scaffold.pt")
#             torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
#             # torch.save({
#             #     'model_state': cluster_models[cluster_name].state_dict(),
#             #     'global_c': global_c[cluster_name],
#             #     'client_cs': {cid: client_cs[cid] for cid in client_ids}  
#             # }, ckpt_path)
#             # print(f"Saved model: {ckpt_path}")


Starting experiments for model: lstm
\Round 1/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0:   0%|          | 0/80 [00:00<?, ?it/s]

Training cluster_0: 100%|██████████| 80/80 [03:44<00:00,  2.81s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [02:56<00:00,  2.76s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:52<00:00,  2.81s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:14<00:00,  2.85s/it]


\Round 2/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [03:51<00:00,  2.89s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [03:05<00:00,  2.90s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:57<00:00,  2.94s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:10<00:00,  2.72s/it]


\Round 3/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [03:51<00:00,  2.89s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [03:00<00:00,  2.83s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:54<00:00,  2.87s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:16<00:00,  2.94s/it]


\Round 4/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [03:52<00:00,  2.90s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [03:00<00:00,  2.82s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:58<00:00,  2.96s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:17<00:00,  2.99s/it]


\Round 5/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [03:50<00:00,  2.89s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [03:01<00:00,  2.84s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:57<00:00,  2.93s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:14<00:00,  2.85s/it]


\Round 6/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [03:51<00:00,  2.89s/it]


 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [03:02<00:00,  2.85s/it]


 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [01:54<00:00,  2.87s/it]


 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [01:15<00:00,  2.90s/it]


\Round 7/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0:  61%|██████▏   | 49/80 [02:26<01:33,  3.02s/it]

In [None]:
# #  Config
# # List of models to experiment with
# MODEL_NAMES = ["moe_gru"]

# # Config
# NUM_CLIENTS = 1410
# CLIENT_FRAC = 0.15
# NUM_ROUNDS = 10
# LOCAL_EPOCHS = 10
# LR = 0.001
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DATA_FILE ="train_final.feather" # "meter_0_data_cleaned.feather"

In [None]:

# # Main experiment loop
# for model_name in MODEL_NAMES:
#     print(f"Starting experiments for model: {model_name}")

#     # Initialize per-cluster model weights
#     cluster_models = {}
#     cluster_weights = {}
#     global_c = {}
#     client_cs = {}

#     for cluster_name in CLUSTERS:
#         model = model_fn(model_name).to(DEVICE) # model_fn returns batch_size, output_size
#         # create one model per cluster and store its initial weights
#         cluster_models[cluster_name] = model
#         cluster_weights[cluster_name] = get_weights(model)

#         # Initialize c_i for each cluster
#         global_c[cluster_name] = [np.zeros_like(p) for p in cluster_weights[cluster_name]]
#         # client_cs[cluster_name] = {cid: [torch.zeros_like(p) for p in cluster_weights[cluster_name]] for cid in CLUSTERS[cluster_name]}
#         for cid in CLUSTERS[cluster_name]:
#             client_cs[cid] = [np.zeros_like(p) for p in cluster_weights[cluster_name]]

#     for rnd in range(NUM_ROUNDS):
#         print(f"\Round {rnd+1}/{NUM_ROUNDS}")

#         for cluster_name, client_ids in CLUSTERS.items():  # howw random clients id ? selection
#             print(f" Processing {cluster_name} with {len(client_ids)} clients")

#             # Sample a fraction of clients from the cluster
#             sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
#             print(f"Sampling {len(sampled_clients)} Clients")
#             local_weight_deltas = []
#             local_c_deltas = [] # Store local c_i deltas for each client

#             for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
#                 local_model = model_fn(model_name).to(DEVICE) # batch_size, output_size not this its Pythorch object an instance of nn.Module
#                 set_weights(local_model, cluster_weights[cluster_name])

#                 train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
#                 # updated_weights, fin_loss = train_model_scaffold(
#                 #     local_model, train_loader,
#                 #     device=DEVICE,
#                 #     learning_rate=LR,
#                 #     loss_fn=None,
#                 #     optimizer_class=optim.Adam,
#                 #     epochs=LOCAL_EPOCHS
#                 # )
#                 #Local SCAFFOLD training
#                 delta_y, delta_c, new_ci, local_weights, fin_loss = train_model_scaffold(
#                     local_model, train_loader,
#                     global_weights = cluster_weights[cluster_name],  # 
#                     device=DEVICE,
#                     learning_rate=LR,
#                     loss_fn=None,
#                     optimizer_class=optim.Adam,
#                     epochs=LOCAL_EPOCHS,
#                     server_c =global_c[cluster_name],
#                     client_ci=client_cs[cid]
#                 )


#                 # local_weights.append(updated_weights)
#                 local_weight_deltas.append(delta_y)
#                 # Update client c_i
#                 client_cs[cid] = new_ci  # Update c_i for this client
#                 local_c_deltas.append(delta_c)  # Store local c_i deltas for each client)

#             # Aggregate and update cluster model
#             mean_delta_y = average_weights(local_weight_deltas)
#             SERVER_LR = 1.0  # or tune this!

#             updated_cluster_weights = [
#                 gw + SERVER_LR * dy for gw, dy in zip(cluster_weights[cluster_name], mean_delta_y)
#             ]

#             # set_weights(cluster_models[cluster_name], updated_cluster_weights)
#             # updated_cluster_weights = average_weights(local_weights)
#             set_weights(cluster_models[cluster_name], updated_cluster_weights)
#             cluster_weights[cluster_name] = updated_cluster_weights

#             mean_delta_c = average_weights(local_c_deltas, client_weights=None)

#             cluster_size = len(client_ids)
#             frac = len(sampled_clients) / cluster_size 

#             # Update global c for the cluster
#             global_c[cluster_name] = [
#                 gc + frac * dc for gc, dc in zip(global_c[cluster_name], mean_delta_c) 
#             ]


#             # Save checkpoint
#             ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}_scaff.pt")
#             # torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
#             torch.save({
#                 'model_state': cluster_models[cluster_name].state_dict(),
#                 'global_c': global_c[cluster_name],
#                 'client_cs': {cid: client_cs[cid] for cid in client_ids}  
#             }, ckpt_path)
#             print(f"Saved model: {ckpt_path}")


Starting experiments for model: moe_gru
\Round 1/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0:   0%|          | 0/80 [00:00<?, ?it/s]

Training cluster_0: 100%|██████████| 80/80 [00:45<00:00,  1.75it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_1_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:36<00:00,  1.75it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_1_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_1_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:14<00:00,  1.75it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_1_scaff.pt
\Round 2/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:46<00:00,  1.71it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_2_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.70it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_2_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.69it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_2_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.69it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_2_scaff.pt
\Round 3/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:47<00:00,  1.68it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_3_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:36<00:00,  1.73it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_3_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.71it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_3_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.64it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_3_scaff.pt
\Round 4/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:46<00:00,  1.71it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_4_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.69it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_4_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.71it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_4_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.73it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_4_scaff.pt
\Round 5/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:46<00:00,  1.72it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_5_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:36<00:00,  1.74it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_5_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.70it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_5_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.64it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_5_scaff.pt
\Round 6/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:47<00:00,  1.67it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_6_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.70it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_6_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:24<00:00,  1.66it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_6_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:16<00:00,  1.62it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_6_scaff.pt
\Round 7/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:48<00:00,  1.66it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_7_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.70it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_7_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.68it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_7_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.66it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_7_scaff.pt
\Round 8/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:48<00:00,  1.66it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_8_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.70it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_8_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.67it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_8_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.63it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_8_scaff.pt
\Round 9/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:47<00:00,  1.68it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_9_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:37<00:00,  1.69it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_9_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.68it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_9_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:16<00:00,  1.61it/s]


Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_9_scaff.pt
\Round 10/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:48<00:00,  1.65it/s]


Saved model: results/moe_gru/cluster_0/moe_gru_cluster_0_round_10_scaff.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:38<00:00,  1.67it/s]


Saved model: results/moe_gru/other/moe_gru_other_round_10_scaff.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:23<00:00,  1.70it/s]


Saved model: results/moe_gru/cluster_1/moe_gru_cluster_1_round_10_scaff.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:15<00:00,  1.65it/s]

Saved model: results/moe_gru/cluster_2/moe_gru_cluster_2_round_10_scaff.pt





In [None]:
# model_name = "lstm"  # or any other model you want to test