In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from typing import List, Tuple, Optional, Dict
import numpy as np
from Models import MoELSTM
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
import random
from Models import MoELSTM, LSTMModel, train_model
from Preprocess import (
    compute_metrics,
    convert_timeseries_to_numpy,
    create_dataloader,
    load_building_series,
    split_series_list,
)
import pandas as pd
from collections import defaultdict
import os
import torch
import torch.optim as optim
from tqdm import tqdm


from Models import model_fn
from tqdm import tqdm
from my_utils import train_model_transformer, load_energy_data_feather_transformer, get_weights, set_weights


In [None]:
from AggregationStrategy import sync_aggregate,average_weights,sync_aggregate_norm,sync_aggregate_softmax, fedavgm_update

In [None]:
df = pd.read_feather("train_final.feather")

In [None]:
df.head()

In [None]:
df.info()

In [None]:


# Config
# List of models to experiment with
MODEL_NAMES = ["transformer"]

# Config
NUM_CLIENTS = 1410
CLIENT_FRAC = 0.15
NUM_ROUNDS = 50
LOCAL_EPOCHS = 5
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_FILE ="train_final.feather" # "meter_0_data_cleaned.feather"


### FedAvg

In [None]:

def train_model_transformer(model, train_loader, device=None, learning_rate=0.001, loss_fn=None, optimizer_class=optim.Adam, epochs=50):
    """Train the model and return the average loss."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    loss_fn = loss_fn or nn.MSELoss()
    optimizer = optimizer_class(model.parameters(), lr=learning_rate)
    loss_history = []

    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # if y_batch.dim() == 3 and y_batch.shape[-1] == 1:
            #     y_batch = y_batch.squeeze(-1)

            optimizer.zero_grad()
            output = model(X_batch)
            # loss = loss_fn(output, y_batch)
            loss = loss_fn(output.squeeze(-1), y_batch)  # (batch_size, 24)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        loss_history.append(epoch_loss/len(train_loader))

    # fin_loss = epoch_loss / len(train_loader)

    return get_weights(model), loss_history

In [None]:

for model_name in MODEL_NAMES:
    print(f"Starting experiment with model: {model_name}")

    # Directory to save checkpoints
    model_dir = os.path.join("results", model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Init model and weights
    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")
        sampled_clients = random.sample(range(NUM_CLIENTS), k=int(CLIENT_FRAC * NUM_CLIENTS))
        local_weights = []

        for cid in tqdm(sampled_clients, desc="Training clients"):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)
            train_loader, test_loader = load_energy_data_feather_transformer(cid, filepath=DATA_FILE)

            updated_weights, fin_loss = train_model_transformer(
                local_model, train_loader,
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )
            local_weights.append(updated_weights)

        # Federated averaging
        global_weights = average_weights(local_weights)
        set_weights(global_model, global_weights)

        # Save model checkpoint
        checkpoint_path = os.path.join(model_dir, f"{model_name}_round_{rnd+1}_fedAvg.pt")
        torch.save(global_model.state_dict(), checkpoint_path)
        print(f"Saved global model to {checkpoint_path}")


### Diff-Aware Fed Avg

In [None]:
class TimeSeriesDifficultyWeight:
    def __init__(self, num_clients, accumulate_iters=20):
        self.num_clients = num_clients
        self.last_loss = torch.ones(num_clients).float().to(DEVICE)
        self.learn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.unlearn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.ema_difficulty = torch.ones(num_clients).float().to(DEVICE)
        self.accumulate_iters = accumulate_iters

    def update(self, cid: int, loss_history: List[float]) -> float:
        """
        Update difficulty based on loss trend for a client.
        Expects a list of per-epoch losses.
        """
        current_loss = torch.tensor(loss_history[-1], dtype=torch.float32).to(DEVICE)
        previous_loss = self.last_loss[cid]
        delta = current_loss - previous_loss
        ratio = torch.log((current_loss + 1e-8) / (previous_loss + 1e-8))

        learn = torch.where(delta < 0, -delta * ratio, torch.tensor(0.0, device=current_loss.device))
        unlearn = torch.where(delta >= 0, delta * ratio, torch.tensor(0.0, device=current_loss.device))

        # EMA update
        momentum = (self.accumulate_iters - 1) / self.accumulate_iters
        self.learn_score[cid] = momentum * self.learn_score[cid] + (1 - momentum) * learn
        self.unlearn_score[cid] = momentum * self.unlearn_score[cid] + (1 - momentum) * unlearn

        # Difficulty score
        diff_ratio = (self.unlearn_score[cid] + 1e-8) / (self.learn_score[cid] + 1e-8)
        difficulty = diff_ratio #torch.pow(diff_ratio, 1 / 5)

        # Smooth difficulty over rounds
        self.ema_difficulty[cid] = momentum * self.ema_difficulty[cid] + (1 - momentum) * difficulty

        self.last_loss[cid] = current_loss
        return self.ema_difficulty[cid].item()

    def get_normalized_weights(self, client_ids: List[int]) -> List[float]:
        weights = [self.ema_difficulty[cid].item() for cid in client_ids]
        total = sum(weights)
        if total == 0:
            return [1.0 / len(client_ids)] * len(client_ids)
        return [w / total for w in weights]
    
    def get_sampling_probabilities(self, min_prob=0.05):
        difficulty = self.ema_difficulty
        inv_difficulty = 1.0 / (difficulty + 1e-6)
        inv_difficulty = inv_difficulty / inv_difficulty.sum()
        probs = torch.clamp(inv_difficulty, min=min_prob)
        return (probs / probs.sum()).cpu().numpy()



## SCAFFOLD

In [None]:
def train_model_scaffold(
    local_model,
    train_loader,         # global_model, train_loader
    global_weights,    # x
    server_c,          # c
    client_ci,         # cᵢ
    device=DEVICE,
    learning_rate= LR,
    loss_fn=None,
    optimizer_class=optim.Adam,
    epochs= LOCAL_EPOCHS  # 50
):
    """Train client with SCAFFOLD correction. Return Δy, Δc, new cᵢ, final weights."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # print(f"[DEBUG] Training on: {device}")

    
    local_model.to(device)
    loss_fn = loss_fn or nn.MSELoss()
    optimizer = optimizer_class(local_model.parameters(), lr=learning_rate)
    loss_history = []

    local_model.train()
    total_steps = 0

    for epoch in range(epochs):
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # if y_batch.dim() == 3 and y_batch.shape[-1] == 1:
            #     y_batch = y_batch.squeeze(-1)

            optimizer.zero_grad()
            output = local_model(X_batch)
            loss = loss_fn(output.squeeze(-1), y_batch)
            loss.backward()
            optimizer.step()

            # ✅ SCAFFOLD correction: adjust each param after normal SGD step  # doubt at this step
            with torch.no_grad():
                for p, sc_np, ci_np in zip(local_model.parameters(), server_c, client_ci):
                    sc_tensor = torch.tensor(sc_np, dtype=p.dtype, device=p.device)
                    ci_tensor = torch.tensor(ci_np, dtype=p.dtype, device=p.device)
                    p -= learning_rate * (sc_tensor - ci_tensor)

            epoch_loss += loss.item()
            total_steps += 1
        
        loss_history.append(epoch_loss / len(train_loader))

    # Compute deltas
    local_weights = get_weights(local_model)
    delta_y = [lw - gw for lw, gw in zip(local_weights, global_weights)]

    # K = total_steps
    K = total_steps
    new_ci = []
    delta_c = []

    for gw, lw, ci, sc in zip(global_weights, local_weights, client_ci, server_c):
        ci_new = ci - sc + (gw - lw) / (K * learning_rate)
        new_ci.append(ci_new)  # doubt can overide new_ci
        delta_c.append(ci_new - ci)

    return delta_y, delta_c, new_ci, local_weights, loss_history

In [None]:


for model_name in MODEL_NAMES:
    difficulty_tracker = TimeSeriesDifficultyWeight(num_clients=NUM_CLIENTS)
    print(f"Starting experiment with model: {model_name}")

    global_model = model_fn(model_name).to(DEVICE)
    global_weights = get_weights(global_model)
    print(f"Using device: {DEVICE}")

    server_c = [np.zeros_like(w) for w in global_weights]
    client_cs = {cid: [np.zeros_like(w) for w in global_weights] for cid in range(NUM_CLIENTS)}

    for rnd in range(NUM_ROUNDS):
        print(f"Round {rnd+1}/{NUM_ROUNDS}")

        # === Difficulty-aware sampling ===
        sampling_probs = difficulty_tracker.get_sampling_probabilities(min_prob=0.05)
        sampled_clients = np.random.choice(
            np.arange(NUM_CLIENTS),
            size=int(CLIENT_FRAC * NUM_CLIENTS),
            replace=False,
            p=sampling_probs
        )
        print(f"Sampled {len(sampled_clients)} clients")

        local_weight_deltas = []
        local_c_deltas = []

        for cid in tqdm(sampled_clients):
            local_model = model_fn(model_name).to(DEVICE)
            set_weights(local_model, global_weights)

            train_loader, _ = load_energy_data_feather_transformer(cid, filepath=DATA_FILE)

            delta_y, delta_c, new_ci, local_weights, loss_history = train_model_scaffold(
                local_model, train_loader,
                global_weights=global_weights,
                server_c=server_c,
                client_ci=client_cs[cid],
                device=DEVICE,
                learning_rate=LR,
                loss_fn=None,
                optimizer_class=optim.Adam,
                epochs=LOCAL_EPOCHS
            )

            # === Difficulty update ===
            difficulty_tracker.update(cid, loss_history)

            local_weight_deltas.append(delta_y)
            local_c_deltas.append(delta_c)
            client_cs[cid] = new_ci

        # === Aggregate and update global weights ===
        mean_delta_y = average_weights(local_weight_deltas)
        global_weights = [gw + mean_delta_y[i] for i, gw in enumerate(global_weights)]

        mean_delta_c = average_weights(local_c_deltas)
        frac = len(sampled_clients) / NUM_CLIENTS
        server_c = [sc + frac * mean_delta_c[i] for i, sc in enumerate(server_c)]

        set_weights(global_model, global_weights)

        ckpt_path = os.path.join("results", model_name, f"{model_name}_round_{rnd+1}_scaffold_diff.pt")
        torch.save(global_model.state_dict(), ckpt_path)
        print(f"Saved: {ckpt_path}")
