In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from typing import List
# Config
NUM_CLIENTS = 10
NUM_EPOCHS = 5
LOCAL_EPOCHS = 3
BATCH_SIZE = 64
LR = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 200),
            nn.ReLU(),
            nn.Linear(200, 10)
        )

    def forward(self, x):
        return self.model(x)

# Dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root="./data", train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=False, transform=transform)

In [20]:

# Split into clients
def split_dataset(dataset, num_clients):
    data_len = len(dataset)
    indices = np.arange(data_len)
    np.random.shuffle(indices)
    split_size = data_len // num_clients
    return [Subset(dataset, indices[i*split_size:(i+1)*split_size]) for i in range(num_clients)]

client_datasets = split_dataset(train_dataset, NUM_CLIENTS)

# Local training
def train_local(model, train_data, epochs=1):
    model = model.to(DEVICE)
    model.train()
    loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# Model aggregation (FedAvg)
def average_models(state_dicts):
    avg_model = {}
    for key in state_dicts[0]:
        avg_model[key] = sum([sd[key] for sd in state_dicts]) / len(state_dicts)
    return avg_model


In [21]:


# Test
def test(model, test_data):
    model.eval()
    loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# Federated training
global_model = MLP().to(DEVICE)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    averaged_weights = average_models(local_weights)
    global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


Round 1, Test Accuracy: 0.7448
Round 2, Test Accuracy: 0.8265
Round 3, Test Accuracy: 0.8595
Round 4, Test Accuracy: 0.8759
Round 5, Test Accuracy: 0.8851


In [22]:
from AggregationStrategy import sync_aggregate,average_weights,sync_aggregate_norm,sync_aggregate_softmax, fedavgm_update

In [23]:
def get_weights(model):
    return [p.detach().cpu().numpy() for p in model.parameters()]


def set_weights(model, weights):
    params = list(model.parameters())
    if len(weights) != len(params):
        raise ValueError(f"Mismatch in weights ({len(weights)}) and model parameters ({len(params)})")

    for p, w in zip(params, weights):
        p.data = torch.tensor(w, dtype=p.dtype, device=p.device)

In [24]:
def train_local(model, train_data, epochs=1):
    model = model.to(DEVICE)
    model.train()
    loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
    return get_weights(model)

In [25]:


# Test
def test(model, test_data):
    model.eval()
    loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# Federated training
global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    sync_weights = sync_aggregate(global_weights,local_weights)
    set_weights(global_model,sync_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


SYNC Weights:[ -86.81517  -142.6893   -364.6838     -9.192195   -4.08542  -110.847046
   36.528458  202.88916   167.32198   311.63342 ] 
Round 1, Test Accuracy: 0.2338
SYNC Weights:[-1.3198408  -1.8052368   3.354141   -0.20724773 -1.8543218   1.8870451
  0.18543218 -0.583566   -0.19633995  0.5890199 ] 
Round 2, Test Accuracy: 0.1031
SYNC Weights:[-2.7790623  -0.0652362  -2.870393   -0.96549577  4.0968337  -4.122928
  2.8051567  -0.24789758  1.4743382   2.7268732 ] 
Round 3, Test Accuracy: 0.1839
SYNC Weights:[ -4.0026927  -8.267124  -11.656964    0.6575396   3.2110918 -12.601779
  14.804217    3.823944    5.3433075   8.7395315] 
Round 4, Test Accuracy: 0.2173
SYNC Weights:[ -1.200724   17.819838  -14.1357975  -1.8283752   5.212234  -44.017452
   6.221934    4.475426    1.5554835  25.952015 ] 
Round 5, Test Accuracy: 0.2118


In [26]:

global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    sync_weights = sync_aggregate_softmax(global_weights,local_weights)
    set_weights(global_model,sync_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


SYNC Weights:[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] 
Round 1, Test Accuracy: 0.7071
SYNC Weights:[5.324934164434305e-44, 0.0, 5.692147580394696e-27, 1.3877241161447416e-34, 0.0, 1.0, 0.0, 2.1498160520285614e-40, 0.0, 0.0] 
Round 2, Test Accuracy: 0.8275
SYNC Weights:[0.3264521062374115, 0.031513601541519165, 0.08028356730937958, 0.0024079345166683197, 0.20452912151813507, 0.001194120617583394, 0.000317467754939571, 0.0019059550249949098, 0.3264521062374115, 0.024943992495536804] 
Round 3, Test Accuracy: 0.8587
SYNC Weights:[0.4740835726261139, 0.0017221422167494893, 0.4740835726261139, 0.0001965281117008999, 0.007014804519712925, 2.895207580877468e-05, 4.846006504521938e-06, 0.0005457857041619718, 0.03688597306609154, 0.005433955695480108] 
Round 4, Test Accuracy: 0.8667
SYNC Weights:[0.0021372162736952305, 0.18804344534873962, 0.0009546735091134906, 0.1437452882528305, 0.020047184079885483, 0.42097073793411255, 0.1572105884552002, 0.05368135869503021, 0.007486575283110142,

In [27]:

global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    sync_weights = sync_aggregate_softmax(global_weights,local_weights)
    set_weights(global_model,sync_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


SYNC Weights:[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 
Round 1, Test Accuracy: 0.6989
SYNC Weights:[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] 
Round 2, Test Accuracy: 0.8251
SYNC Weights:[0.6339377164840698, 0.04061416909098625, 0.02072162553668022, 0.0012551635736599565, 0.16502106189727783, 0.06727717071771622, 0.00038659514393657446, 0.02072162553668022, 0.04061416909098625, 0.00945064052939415] 
Round 3, Test Accuracy: 0.8528
SYNC Weights:[4.014902515336871e-06, 0.0032507923897355795, 4.9464390031062067e-05, 0.004940362647175789, 0.00022949933190830052, 5.6869950640248135e-05, 0.9912301301956177, 4.9464390031062067e-05, 7.517306949011981e-05, 0.00011424360855016857] 
Round 4, Test Accuracy: 0.8543
SYNC Weights:[0.0, 0.5, 0.0, 5.701456694999843e-14, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0] 
Round 5, Test Accuracy: 0.8588


In [28]:

global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    sync_weights = sync_aggregate_norm(global_weights,local_weights)
    set_weights(global_model,sync_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


SYNC Weights:[np.float32(0.2745864), np.float32(0.57365245), np.float32(1.0), np.float32(0.611318), np.float32(0.094377644), np.float32(0.4010134), np.float32(0.33258045), np.float32(0.035223752), np.float32(0.3645687), np.float32(0.0)] 
Round 1, Test Accuracy: 0.7613
SYNC Weights:[np.float32(0.8306452), np.float32(0.8709678), np.float32(0.5806452), np.float32(0.7741936), np.float32(0.8225807), np.float32(0.9758065), np.float32(0.0), np.float32(1.0), np.float32(0.6048387), np.float32(0.7741936)] 
Round 2, Test Accuracy: 0.8621
SYNC Weights:[np.float32(0.58227843), np.float32(0.7341772), np.float32(0.59493667), np.float32(1.0), np.float32(0.6835443), np.float32(0.75949365), np.float32(0.0), np.float32(0.59493667), np.float32(0.58227843), np.float32(0.35443035)] 
Round 3, Test Accuracy: 0.8779
SYNC Weights:[np.float32(0.6724138), np.float32(0.6206897), np.float32(0.7413793), np.float32(0.87931037), np.float32(0.44827586), np.float32(0.8448276), np.float32(0.0), np.float32(1.0), np.float3

In [29]:

global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    for i in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        local_weights.append(train_local(local_model, client_datasets[i], epochs=LOCAL_EPOCHS))

    sync_weights = sync_aggregate_norm(global_weights,local_weights)
    set_weights(global_model,sync_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


SYNC Weights:[np.float32(0.43710375), np.float32(0.4477811), np.float32(0.90397066), np.float32(1.0), np.float32(0.35542205), np.float32(0.6184184), np.float32(0.63610274), np.float32(0.5550217), np.float32(0.17464131), np.float32(0.0)] 
Round 1, Test Accuracy: 0.7568
SYNC Weights:[np.float32(0.88235295), np.float32(0.14285715), np.float32(0.56302524), np.float32(0.45378152), np.float32(0.2857143), np.float32(0.62184876), np.float32(0.0), np.float32(0.5378151), np.float32(1.0), np.float32(0.29411766)] 
Round 2, Test Accuracy: 0.8597
SYNC Weights:[np.float32(0.9866666), np.float32(0.44), np.float32(0.9333333), np.float32(0.61333334), np.float32(0.013333332), np.float32(0.97333336), np.float32(0.0), np.float32(0.8666667), np.float32(1.0), np.float32(0.61333334)] 
Round 3, Test Accuracy: 0.8775
SYNC Weights:[np.float32(0.18279569), np.float32(0.46236557), np.float32(0.19354837), np.float32(0.32258064), np.float32(1.0), np.float32(0.0), np.float32(0.8172043), np.float32(0.13978493), np.flo

In [37]:
class TimeSeriesDifficultyWeight:
    def __init__(self, num_clients, accumulate_iters=20):
        self.num_clients = num_clients
        self.last_loss = torch.ones(num_clients).float().to(DEVICE)
        self.learn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.unlearn_score = torch.zeros(num_clients).float().to(DEVICE)
        self.ema_difficulty = torch.ones(num_clients).float().to(DEVICE)
        self.accumulate_iters = accumulate_iters

    def update(self, cid: int, loss_history: List[float]) -> float:
        """
        Update difficulty based on loss trend for a client.
        Expects a list of per-epoch losses.
        """
        current_loss = torch.tensor(loss_history[-1], dtype=torch.float32).to(DEVICE)
        previous_loss = self.last_loss[cid]
        delta = current_loss - previous_loss
        ratio = torch.log((current_loss + 1e-8) / (previous_loss + 1e-8))

        learn = torch.where(delta < 0, -delta * ratio, torch.tensor(0.0, device=current_loss.device))
        unlearn = torch.where(delta >= 0, delta * ratio, torch.tensor(0.0, device=current_loss.device))

        # EMA update
        momentum = (self.accumulate_iters - 1) / self.accumulate_iters
        self.learn_score[cid] = momentum * self.learn_score[cid] + (1 - momentum) * learn
        self.unlearn_score[cid] = momentum * self.unlearn_score[cid] + (1 - momentum) * unlearn

        # Difficulty score
        diff_ratio = (self.unlearn_score[cid] + 1e-8) / (self.learn_score[cid] + 1e-8)
        difficulty = diff_ratio #torch.pow(diff_ratio, 1 / 5)

        # Smooth difficulty over rounds
        self.ema_difficulty[cid] = momentum * self.ema_difficulty[cid] + (1 - momentum) * difficulty

        self.last_loss[cid] = current_loss
        return self.ema_difficulty[cid].item()

    def get_normalized_weights(self, client_ids: List[int]) -> List[float]:
        weights = [self.ema_difficulty[cid].item() for cid in client_ids]
        total = sum(weights)
        if total == 0:
            return [1.0 / len(client_ids)] * len(client_ids)
        return [w / total for w in weights]


In [38]:
def train_model(model, dataloader, device, learning_rate, loss_fn, optimizer_class, epochs):
    model.train()
    optimizer = optimizer_class(model.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss() if loss_fn is None else loss_fn

    loss_history = []

    for epoch in range(epochs):
        running_loss = 0.0
        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_history.append(running_loss / len(dataloader))
    
    return get_weights(model), loss_history

def train_local2(model, train_data, epochs=1):
    model = model.to(DEVICE)
    model.train()
    loss_history = []
    loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        running_loss = 0 
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_history.append(running_loss / len(loader))
    return get_weights(model), loss_history


In [39]:
# difficulty_tracker = TimeSeriesDifficultyWeight(num_clients=NUM_CLIENTS)

# for rnd in range(NUM_ROUNDS):
#     sampled_clients = random.sample(range(NUM_CLIENTS), int(CLIENT_FRAC * NUM_CLIENTS))
#     local_weights = []
#     difficulty_scores = []

#     for cid in tqdm(sampled_clients):
#         model = model_fn(model_name).to(DEVICE)
#         set_weights(model, global_weights)
#         train_loader, _ = load_energy_data_feather(cid, filepath=DATA_FILE)
#         updated_weights, loss_history = train_model(
#             model, train_loader,
#             device=DEVICE, learning_rate=LR,
#             loss_fn=None, optimizer_class=optim.Adam,
#             epochs=LOCAL_EPOCHS
#         )
#         local_weights.append(updated_weights)

#         # Update difficulty
#         difficulty = difficulty_tracker.update(cid, loss_history)
#         difficulty_scores.append(difficulty)

#     # Normalize difficulty scores
#     normalized_weights = difficulty_tracker.get_normalized_weights(sampled_clients)

#     # Difficulty-aware weighted aggregation
#     global_weights = average_weights(local_weights, client_weights=normalized_weights)
#     set_weights(global_model, global_weights)


In [40]:

difficulty_tracker = TimeSeriesDifficultyWeight(num_clients=NUM_CLIENTS)
global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    local_weights = []
    difficulty_scores = []
    for cid in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        updated_weights, loss_history = train_local2(local_model, client_datasets[cid], epochs=LOCAL_EPOCHS)
        local_weights.append(updated_weights)
        difficulty = difficulty_tracker.update(cid, loss_history)
        difficulty_scores.append(difficulty)

    # Normalize difficulty scores
    normalized_weights = difficulty_tracker.get_normalized_weights(range(NUM_CLIENTS))
    global_weights = average_weights(local_weights, client_weights=normalized_weights)

    # sync_weights = sync_aggregate_norm(global_weights,local_weights)
    set_weights(global_model,global_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


Round 1, Test Accuracy: 0.7628
Round 2, Test Accuracy: 0.8346
Round 3, Test Accuracy: 0.8635
Round 4, Test Accuracy: 0.8796
Round 5, Test Accuracy: 0.8874


In [42]:

difficulty_tracker = TimeSeriesDifficultyWeight(num_clients=NUM_CLIENTS)
global_model = MLP().to(DEVICE)
global_weights = get_weights(global_model)
for epoch in range(NUM_EPOCHS):
    local_weights = []
    local_weights = []
    difficulty_scores = []
    for cid in range(NUM_CLIENTS):
        local_model = MLP().to(DEVICE)
        local_model.load_state_dict(global_model.state_dict())
        updated_weights, loss_history = train_local2(local_model, client_datasets[cid], epochs=LOCAL_EPOCHS)
        local_weights.append(updated_weights)
        difficulty = difficulty_tracker.update(cid, loss_history)
        difficulty_scores.append(difficulty)

    # Normalize difficulty scores
    normalized_weights = difficulty_tracker.get_normalized_weights(range(NUM_CLIENTS))
    global_weights = average_weights(local_weights, client_weights=normalized_weights)
    print(f"normalized_weights {normalized_weights}")

    # sync_weights = sync_aggregate_norm(global_weights,local_weights)
    set_weights(global_model,global_weights)
    # global_model.load_state_dict(averaged_weights)

    acc = test(global_model, test_dataset)
    print(f"Round {epoch+1}, Test Accuracy: {acc:.4f}")


normalized_weights [0.09928567285029549, 0.10047606193106842, 0.10306820453686075, 0.1009480691539008, 0.09907237475317177, 0.10058047066814692, 0.10317079035889194, 0.09695239269951486, 0.09875751302937048, 0.09768845001877859]
Round 1, Test Accuracy: 0.7486
normalized_weights [0.09928567342350057, 0.10047606528495681, 0.10306820144287422, 0.10094806868653755, 0.09907237522979231, 0.10058047272431864, 0.10317078865635229, 0.09695239490616511, 0.09875751201842248, 0.09768844762708001]
Round 2, Test Accuracy: 0.8294
normalized_weights [0.09928567267514167, 0.10047606405340082, 0.10306820451792226, 0.10094806763763792, 0.09907237062346645, 0.10058047342087864, 0.1031707922430277, 0.09695239419872785, 0.09875750964989091, 0.09768845097990579]
Round 3, Test Accuracy: 0.8586
normalized_weights [0.09928567049219009, 0.10047606236244967, 0.10306821035629402, 0.1009480701159033, 0.09907236884912164, 0.10058047276661411, 0.10317079265936736, 0.0969523975093752, 0.09875750824217758, 0.0976884466