In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from Models import MoELSTM
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from typing import List, Tuple, Optional, Dict
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
import random
from Models import MoELSTM, LSTMModel, train_model
from Preprocess import (
    compute_metrics,
    convert_timeseries_to_numpy,
    create_dataloader,
    load_building_series,
    split_series_list,
)
import pandas as pd
from collections import defaultdict
import os
import torch
import torch.optim as optim
from tqdm import tqdm


from Models import model_fn
from tqdm import tqdm
from my_utils import train_model, load_energy_data_feather, get_weights, set_weights


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from AggregationStrategy import sync_aggregate,average_weights,sync_aggregate_norm,sync_aggregate_softmax, fedavgm_update

In [4]:
# def average_weights(weights_list):
#     """Averages model weights provided as a list of get_weights outputs."""
#     avg_weights = []
#     num_models = len(weights_list)
#     for layer_weights in zip(*weights_list):
#         avg_layer = np.mean(np.array(layer_weights), axis=0)
#         avg_weights.append(avg_layer)
#     return avg_weights

In [5]:
df = pd.read_feather("train_final.feather")

In [6]:
df.head()

Unnamed: 0,building_id,meter,timestamp,meter_reading,primary_use,air_temperature
7593144,0,0,2016-05-21 01:00:00,72.221012,Education,25.6
7593145,1,0,2016-05-21 01:00:00,39.611586,Education,25.6
7593146,2,0,2016-05-21 01:00:00,1.920567,Education,25.6
7593147,3,0,2016-05-21 01:00:00,111.532464,Education,25.6
7593148,4,0,2016-05-21 01:00:00,456.734799,Education,25.6


In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 11712248 entries, 7593144 to 20216099
Data columns (total 6 columns):
 #   Column           Dtype         
---  ------           -----         
 0   building_id      int64         
 1   meter            int64         
 2   timestamp        datetime64[ns]
 3   meter_reading    float64       
 4   primary_use      object        
 5   air_temperature  float64       
dtypes: datetime64[ns](1), float64(2), int64(2), object(1)
memory usage: 625.5+ MB


In [8]:

def cluster_buildings_top3_primary_use(df: pd.DataFrame) -> dict:
    """
    Clusters building_ids by top 3 primary_use categories.
    Remaining categories are grouped into 'Other'.

    Args:
        df (pd.DataFrame): Input DataFrame with 'building_id' and 'primary_use'.

    Returns:
        dict: {'cluster_0': [...], 'cluster_1': [...], 'cluster_2': [...], 'other': [...]}
    """
    # Step 1: Get top 3 most common primary_use values
    top3_uses = df['primary_use'].value_counts().nlargest(3).index.tolist()

    # Step 2: Create a mapping of cluster name -> list of building_ids
    clusters = defaultdict(list)

    # Drop duplicate building_id-primary_use pairs to avoid counting duplicates
    unique_buildings = df[['building_id', 'primary_use']].drop_duplicates()

    for _, row in unique_buildings.iterrows():
        bldg_id = row['building_id']
        use = row['primary_use']
        if use == top3_uses[0]:
            clusters['cluster_0'].append(bldg_id)
        elif use == top3_uses[1]:
            clusters['cluster_1'].append(bldg_id)
        elif use == top3_uses[2]:
            clusters['cluster_2'].append(bldg_id)
        else:
            clusters['other'].append(bldg_id)

    return dict(clusters)


In [9]:
top3_uses = df['primary_use'].value_counts().nlargest(3).index.tolist()

In [10]:
top3_uses

['Education', 'Office', 'Entertainment/public assembly']

In [11]:
clusters = cluster_buildings_top3_primary_use(df)

for name, ids in clusters.items():
    print(f"{name}: {len(ids)} buildings {ids[:5]}...")


cluster_0: 537 buildings [0, 1, 2, 3, 4]...
other: 428 buildings [6, 12, 27, 33, 34]...
cluster_1: 269 buildings [9, 15, 17, 19, 21]...
cluster_2: 179 buildings [10, 59, 87, 88, 40]...


In [12]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1410
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_final.feather" # "meter_0_data_cleaned.feather"


### Naive Model :
Returns values by lag-24 hr (Acts as Lower Bound)

### Global Models
All data, no Fedeerated Aggregation

In [15]:
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)
    global_model = model_fn(model_name).to(DEVICE)

    sampled_clients = list(range(NUM_CLIENTS))
    

    for cid in tqdm(sampled_clients, desc="Training clients"):
        
        train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

        updated_weights, fin_loss = train_model(
            global_model, train_loader,
            device=DEVICE,
            learning_rate=LR,
            loss_fn=None,
            optimizer_class=optim.Adam,
            epochs=LOCAL_EPOCHS
        )
    checkpoint_path = os.path.join(model_dir, f"{model_name}_global_model.pt")
    torch.save(global_model.state_dict(), checkpoint_path)
    print(f"Saved global model to {checkpoint_path}")

Starting experiment with model: lstm


Training clients: 100%|██████████| 1410/1410 [10:34<00:00,  2.22it/s]


Saved global model to results/lstm/lstm_global_model.pt
Starting experiment with model: gru


Training clients: 100%|██████████| 1410/1410 [10:28<00:00,  2.24it/s]


Saved global model to results/gru/gru_global_model.pt
Starting experiment with model: moe_lstm


Training clients: 100%|██████████| 1410/1410 [12:42<00:00,  1.85it/s]


Saved global model to results/moe_lstm/moe_lstm_global_model.pt
Starting experiment with model: moe_gru


Training clients: 100%|██████████| 1410/1410 [12:38<00:00,  1.86it/s]

Saved global model to results/moe_gru/moe_gru_global_model.pt





## Federated Learning Without Clustering

### FedAvg

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = average_weights(local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvg.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


### FedAvgM 

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    velocity = [np.zeros_like(w) for w in global_weights]

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights, velocity = fedavgm_update(global_weights,local_weights,velocity)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvgM.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients: 100%|██████████| 211/211 [01:35<00:00,  2.21it/s]


Saved global model to results/lstm/lstm_round_1_fedAvgM.pt
Round 2/10


Training clients: 100%|██████████| 211/211 [01:36<00:00,  2.19it/s]


Saved global model to results/lstm/lstm_round_2_fedAvgM.pt
Round 3/10


Training clients: 100%|██████████| 211/211 [01:36<00:00,  2.20it/s]


Saved global model to results/lstm/lstm_round_3_fedAvgM.pt
Round 4/10


Training clients: 100%|██████████| 211/211 [01:35<00:00,  2.21it/s]


Saved global model to results/lstm/lstm_round_4_fedAvgM.pt
Round 5/10


Training clients: 100%|██████████| 211/211 [01:36<00:00,  2.18it/s]


Saved global model to results/lstm/lstm_round_5_fedAvgM.pt
Round 6/10


Training clients: 100%|██████████| 211/211 [01:36<00:00,  2.19it/s]


Saved global model to results/lstm/lstm_round_6_fedAvgM.pt
Round 7/10


Training clients: 100%|██████████| 211/211 [01:35<00:00,  2.20it/s]


Saved global model to results/lstm/lstm_round_7_fedAvgM.pt
Round 8/10


Training clients: 100%|██████████| 211/211 [01:35<00:00,  2.20it/s]


Saved global model to results/lstm/lstm_round_8_fedAvgM.pt
Round 9/10


Training clients: 100%|██████████| 211/211 [01:35<00:00,  2.22it/s]


Saved global model to results/lstm/lstm_round_9_fedAvgM.pt
Round 10/10


Training clients: 100%|██████████| 211/211 [01:36<00:00,  2.19it/s]


Saved global model to results/lstm/lstm_round_10_fedAvgM.pt
Starting experiment with model: gru
Round 1/10


Training clients: 100%|██████████| 211/211 [01:34<00:00,  2.22it/s]


Saved global model to results/gru/gru_round_1_fedAvgM.pt
Round 2/10


Training clients: 100%|██████████| 211/211 [01:34<00:00,  2.23it/s]


Saved global model to results/gru/gru_round_2_fedAvgM.pt
Round 3/10


Training clients: 100%|██████████| 211/211 [01:37<00:00,  2.16it/s]


Saved global model to results/gru/gru_round_3_fedAvgM.pt
Round 4/10


Training clients: 100%|██████████| 211/211 [01:34<00:00,  2.24it/s]


Saved global model to results/gru/gru_round_4_fedAvgM.pt
Round 5/10


Training clients: 100%|██████████| 211/211 [01:34<00:00,  2.24it/s]


Saved global model to results/gru/gru_round_5_fedAvgM.pt
Round 6/10


Training clients: 100%|██████████| 211/211 [01:41<00:00,  2.08it/s]


Saved global model to results/gru/gru_round_6_fedAvgM.pt
Round 7/10


Training clients: 100%|██████████| 211/211 [01:46<00:00,  1.99it/s]


Saved global model to results/gru/gru_round_7_fedAvgM.pt
Round 8/10


Training clients: 100%|██████████| 211/211 [01:46<00:00,  1.98it/s]


Saved global model to results/gru/gru_round_8_fedAvgM.pt
Round 9/10


Training clients: 100%|██████████| 211/211 [01:47<00:00,  1.95it/s]


Saved global model to results/gru/gru_round_9_fedAvgM.pt
Round 10/10


Training clients: 100%|██████████| 211/211 [01:49<00:00,  1.93it/s]


Saved global model to results/gru/gru_round_10_fedAvgM.pt
Starting experiment with model: moe_lstm
Round 1/10


Training clients: 100%|██████████| 211/211 [02:21<00:00,  1.49it/s]


Saved global model to results/moe_lstm/moe_lstm_round_1_fedAvgM.pt
Round 2/10


Training clients: 100%|██████████| 211/211 [02:22<00:00,  1.48it/s]


Saved global model to results/moe_lstm/moe_lstm_round_2_fedAvgM.pt
Round 3/10


Training clients: 100%|██████████| 211/211 [02:24<00:00,  1.46it/s]


Saved global model to results/moe_lstm/moe_lstm_round_3_fedAvgM.pt
Round 4/10


Training clients: 100%|██████████| 211/211 [02:25<00:00,  1.45it/s]


Saved global model to results/moe_lstm/moe_lstm_round_4_fedAvgM.pt
Round 5/10


Training clients: 100%|██████████| 211/211 [02:24<00:00,  1.46it/s]


Saved global model to results/moe_lstm/moe_lstm_round_5_fedAvgM.pt
Round 6/10


Training clients: 100%|██████████| 211/211 [02:23<00:00,  1.47it/s]


Saved global model to results/moe_lstm/moe_lstm_round_6_fedAvgM.pt
Round 7/10


Training clients: 100%|██████████| 211/211 [02:25<00:00,  1.45it/s]


Saved global model to results/moe_lstm/moe_lstm_round_7_fedAvgM.pt
Round 8/10


Training clients: 100%|██████████| 211/211 [02:23<00:00,  1.47it/s]


Saved global model to results/moe_lstm/moe_lstm_round_8_fedAvgM.pt
Round 9/10


Training clients: 100%|██████████| 211/211 [02:26<00:00,  1.44it/s]


Saved global model to results/moe_lstm/moe_lstm_round_9_fedAvgM.pt
Round 10/10


Training clients: 100%|██████████| 211/211 [02:24<00:00,  1.46it/s]


Saved global model to results/moe_lstm/moe_lstm_round_10_fedAvgM.pt
Starting experiment with model: moe_gru
Round 1/10


Training clients: 100%|██████████| 211/211 [02:26<00:00,  1.44it/s]


Saved global model to results/moe_gru/moe_gru_round_1_fedAvgM.pt
Round 2/10


Training clients: 100%|██████████| 211/211 [02:25<00:00,  1.45it/s]


Saved global model to results/moe_gru/moe_gru_round_2_fedAvgM.pt
Round 3/10


Training clients:  10%|▉         | 21/211 [00:14<02:00,  1.57it/s]

### FedAdam

### Kuramoto FedAvg

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


### Kuramoto-Norm FedAvg

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_norm(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_norm.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


### Kuramoto-Softmax

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_softmax(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_sft.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


### DiffAware FedAvg

## Federated Learning With Clustering

In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm
import random
import numpy as np

CLUSTERS = cluster_buildings_top3_primary_use(df)

for model_name in MODEL_NAMES:
    for cluster_name in CLUSTERS:
        os.makedirs(os.path.join("results", model_name, cluster_name), exist_ok=True)


### FedAvg

In [None]:

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = average_weights(local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


### FedAvgM

In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights and velocity
    cluster_models = {}
    cluster_weights = {}
    cluster_velocities = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        weights = get_weights(model)
        velocity = [np.zeros_like(w) for w in weights]

        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = weights
        cluster_velocities[cluster_name] = velocity

    for rnd in range(NUM_ROUNDS):
        print(f"\nRound {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # ---- FedAvgM Aggregation ----
            new_weights, new_velocity = fedavgm_update(
                cluster_weights[cluster_name],
                local_weights,
                cluster_velocities[cluster_name]
            )

            # Update model, weights, and velocity
            set_weights(cluster_models[cluster_name], new_weights)
            cluster_weights[cluster_name] = new_weights
            cluster_velocities[cluster_name] = new_velocity

            # Save checkpoint
            ckpt_dir = os.path.join("results", model_name, cluster_name)
            os.makedirs(ckpt_dir, exist_ok=True)
            ckpt_path = os.path.join(ckpt_dir, f"{model_name}_{cluster_name}_round_{rnd+1}_fedAvgM.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


### FedAdam

In [None]:

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = average_weights(local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


### Kuramoto FedAvg

In [None]:

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = sync_aggregate(updated_cluster_weights,local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}_kr.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


### Kuramoto Softmax FedAvg

In [None]:

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = sync_aggregate_softmax(updated_cluster_weights,local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}_kr_softmax.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


### Kuramoto Norm FedAvg

In [None]:
# DEVICE
# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiments for model: {model_name}")

    # Initialize per-cluster model weights
    cluster_models = {}
    cluster_weights = {}

    for cluster_name in CLUSTERS:
        model = model_fn(model_name).to(DEVICE)
        cluster_models[cluster_name] = model
        cluster_weights[cluster_name] = get_weights(model)

    for rnd in range(NUM_ROUNDS):
        print(f"\Round {rnd+1}/{NUM_ROUNDS}")

        for cluster_name, client_ids in CLUSTERS.items():
            print(f" Processing {cluster_name} with {len(client_ids)} clients")

            # Sample a fraction of clients from the cluster
            sampled_clients = random.sample(client_ids, k=int(CLIENT_FRAC * len(client_ids)))
            print(f"Sampling {len(sampled_clients)} Clients")
            local_weights = []

            for cid in tqdm(sampled_clients, desc=f"Training {cluster_name}"):
                local_model = model_fn(model_name).to(DEVICE)
                set_weights(local_model, cluster_weights[cluster_name])

                train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)
                updated_weights, fin_loss = train_model(
                    local_model, train_loader,
                    device=DEVICE,
                    learning_rate=LR,
                    loss_fn=None,
                    optimizer_class=optim.Adam,
                    epochs=LOCAL_EPOCHS
                )
                local_weights.append(updated_weights)

            # Aggregate and update cluster model
            updated_cluster_weights = sync_aggregate_softmax(updated_cluster_weights,local_weights)
            set_weights(cluster_models[cluster_name], updated_cluster_weights)
            cluster_weights[cluster_name] = updated_cluster_weights

            # Save checkpoint
            ckpt_path = os.path.join("results", model_name, cluster_name, f"{model_name}_{cluster_name}_round_{rnd+1}_kr_norm.pt")
            torch.save(cluster_models[cluster_name].state_dict(), ckpt_path)
            print(f"Saved model: {ckpt_path}")


'cuda'

### Clustered

Starting experiments for model: lstm
\Round 1/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:36<00:00,  2.19it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_1.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:28<00:00,  2.24it/s]


Saved model: results/lstm/other/lstm_other_round_1.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:17<00:00,  2.22it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_1.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.17it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_1.pt
\Round 2/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:35<00:00,  2.25it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_2.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:28<00:00,  2.27it/s]


Saved model: results/lstm/other/lstm_other_round_2.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.21it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_2.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:12<00:00,  2.14it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_2.pt
\Round 3/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:35<00:00,  2.24it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_3.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:28<00:00,  2.22it/s]


Saved model: results/lstm/other/lstm_other_round_3.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.15it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_3.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.23it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_3.pt
\Round 4/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:36<00:00,  2.19it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_4.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:29<00:00,  2.16it/s]


Saved model: results/lstm/other/lstm_other_round_4.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.20it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_4.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:12<00:00,  2.12it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_4.pt
\Round 5/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:37<00:00,  2.15it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_5.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:29<00:00,  2.19it/s]


Saved model: results/lstm/other/lstm_other_round_5.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.21it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_5.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.19it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_5.pt
\Round 6/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:36<00:00,  2.17it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_6.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:29<00:00,  2.20it/s]


Saved model: results/lstm/other/lstm_other_round_6.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.17it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_6.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.17it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_6.pt
\Round 7/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:37<00:00,  2.15it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_7.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:28<00:00,  2.25it/s]


Saved model: results/lstm/other/lstm_other_round_7.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.20it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_7.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.17it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_7.pt
\Round 8/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:36<00:00,  2.22it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_8.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:29<00:00,  2.20it/s]


Saved model: results/lstm/other/lstm_other_round_8.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.19it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_8.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.25it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_8.pt
\Round 9/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:35<00:00,  2.22it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_9.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:28<00:00,  2.23it/s]


Saved model: results/lstm/other/lstm_other_round_9.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.18it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_9.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2: 100%|██████████| 26/26 [00:11<00:00,  2.21it/s]


Saved model: results/lstm/cluster_2/lstm_cluster_2_round_9.pt
\Round 10/10
 Processing cluster_0 with 537 clients
Sampling 80 Clients


Training cluster_0: 100%|██████████| 80/80 [00:36<00:00,  2.17it/s]


Saved model: results/lstm/cluster_0/lstm_cluster_0_round_10.pt
 Processing other with 428 clients
Sampling 64 Clients


Training other: 100%|██████████| 64/64 [00:29<00:00,  2.20it/s]


Saved model: results/lstm/other/lstm_other_round_10.pt
 Processing cluster_1 with 269 clients
Sampling 40 Clients


Training cluster_1: 100%|██████████| 40/40 [00:18<00:00,  2.20it/s]


Saved model: results/lstm/cluster_1/lstm_cluster_1_round_10.pt
 Processing cluster_2 with 179 clients
Sampling 26 Clients


Training cluster_2:  19%|█▉        | 5/26 [00:02<00:10,  1.94it/s]

### FedAVG

In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = average_weights(local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvg.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients:   0%|          | 0/150 [00:00<?, ?it/s]

Training clients: 100%|██████████| 150/150 [00:48<00:00,  3.07it/s]


Saved global model to results/lstm/lstm_round_1_fedAvg.pt
Round 2/10


Training clients: 100%|██████████| 150/150 [00:47<00:00,  3.14it/s]


Saved global model to results/lstm/lstm_round_2_fedAvg.pt
Round 3/10


Training clients: 100%|██████████| 150/150 [00:48<00:00,  3.11it/s]


Saved global model to results/lstm/lstm_round_3_fedAvg.pt
Round 4/10


Training clients: 100%|██████████| 150/150 [00:48<00:00,  3.12it/s]


Saved global model to results/lstm/lstm_round_4_fedAvg.pt
Round 5/10


Training clients: 100%|██████████| 150/150 [00:47<00:00,  3.14it/s]


Saved global model to results/lstm/lstm_round_5_fedAvg.pt
Round 6/10


Training clients:  97%|█████████▋| 145/150 [00:46<00:01,  2.96it/s]

### FedAvgM

In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    velocity = [np.zeros_like(w) for w in global_weights]

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights, velocity = fedavgm_update(global_weights,local_weights,velocity)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvgM.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients: 100%|██████████| 15/15 [00:42<00:00,  2.83s/it]


Saved global model to results\lstm\lstm_round_1_kr_norm.pt
Round 2/10


Training clients:  40%|████      | 6/15 [00:20<00:30,  3.39s/it]


KeyboardInterrupt: 

### Kuramoto FedAvg

In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_norm(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_norm.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["lstm", "gru", "moe_lstm", "moe_gru"]

# Config
NUM_CLIENTS = 1000
CLIENT_FRAC = 0.15
NUM_ROUNDS = 10
LOCAL_EPOCHS = 10
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_cleaned_reindex.feather" # "meter_0_data_cleaned.feather"


In [None]:
import os
import torch
import torch.optim as optim
from tqdm import tqdm


# # Weighted averaging function (uniform weights for now)
# def average_weights(weights_list):
#     return [np.mean(np.stack(layer_weights), axis=0) for layer_weights in zip(*weights_list)]

# Make sure your model_fn is already defined as in your message

# Main experiment loop
for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = sync_aggregate_softmax(global_weights,local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_kr_sft.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


Starting experiment with model: lstm
Round 1/10


Training clients: 100%|██████████| 10/10 [01:16<00:00,  7.61s/it]


SYNC Weights:[1.1280927881409405e-24, 3.455221414490252e-24, 2.7815943992592906e-10, 1.5342428063772783e-20, 8.519412258989334e-11, 1.0, 4.7079616409486216e-23, 0.0, 0.0, 1.4587423713408043e-08] 
Saved global model to results\lstm\lstm_round_1_kr.pt
Round 2/10


Training clients: 100%|██████████| 10/10 [01:16<00:00,  7.66s/it]


SYNC Weights:[0.051666800892013795, 2.0060822225145682e-07, 0.0038842655760772567, 0.33163288880904007, 0.05313184572821639, 0.17136356681197534, 4.1494413146093236e-12, 2.694967957872787e-05, 0.3882710659981862, 2.2412009829798056e-05] 
Saved global model to results\lstm\lstm_round_2_kr.pt
Round 3/10


Training clients:  30%|███       | 3/10 [00:19<00:46,  6.64s/it]


KeyboardInterrupt: 

### Diff-Aware Fed Avg

### Diff-Sync FedAvg