# Import lib(s)

In [None]:
!pip install ecos


In [None]:
from abc import abstractmethod
import os
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

import copy
import cvxpy as cp

import argparse

import time

import glob

from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import matplotlib.pyplot as plt

import ecos

import math
from scipy.linalg import sqrtm


# Model

In [None]:
class Model(nn.Module):
    def __init__(self, name, args=True):
        super(Model, self).__init__()
        self.name = name

        if self.name == "linear":
            [self.n_dim, self.n_out] = args
            self.fc = nn.Linear(self.n_dim, self.n_out)

        elif self.name == "mnist":
            self.n_cls = 10
            self.fc1 = nn.Linear(1 * 28 * 28, 200)
            self.fc2 = nn.Linear(200, 200)
            self.fc3 = nn.Linear(200, self.n_cls)

        elif self.name == "emnist":
            self.n_cls = 10
            self.fc1 = nn.Linear(1 * 28 * 28, 100)
            self.fc2 = nn.Linear(100, 100)
            self.fc3 = nn.Linear(100, self.n_cls)

        elif self.name == "cifar10":
            self.n_cls = 10
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
            self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.fc1 = nn.Linear(64 * 5 * 5, 384)
            self.fc2 = nn.Linear(384, 192)
            self.fc3 = nn.Linear(192, self.n_cls)

        elif self.name == "cifar100":
            self.n_cls = 100
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
            self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.fc1 = nn.Linear(64 * 5 * 5, 384)
            self.fc2 = nn.Linear(384, 192)
            self.fc3 = nn.Linear(192, self.n_cls)

        elif self.name == "resnet18":
            resnet18 = models.resnet18()
            resnet18.fc = nn.Linear(512, 10)

            # Change BN to GN
            resnet18.bn1 = nn.GroupNorm(num_groups=2, num_channels=64)

            resnet18.layer1[0].bn1 = nn.GroupNorm(num_groups=2, num_channels=64)
            resnet18.layer1[0].bn2 = nn.GroupNorm(num_groups=2, num_channels=64)
            resnet18.layer1[1].bn1 = nn.GroupNorm(num_groups=2, num_channels=64)
            resnet18.layer1[1].bn2 = nn.GroupNorm(num_groups=2, num_channels=64)

            resnet18.layer2[0].bn1 = nn.GroupNorm(num_groups=2, num_channels=128)
            resnet18.layer2[0].bn2 = nn.GroupNorm(num_groups=2, num_channels=128)
            resnet18.layer2[0].downsample[1] = nn.GroupNorm(
                num_groups=2, num_channels=128
            )
            resnet18.layer2[1].bn1 = nn.GroupNorm(num_groups=2, num_channels=128)
            resnet18.layer2[1].bn2 = nn.GroupNorm(num_groups=2, num_channels=128)

            resnet18.layer3[0].bn1 = nn.GroupNorm(num_groups=2, num_channels=256)
            resnet18.layer3[0].bn2 = nn.GroupNorm(num_groups=2, num_channels=256)
            resnet18.layer3[0].downsample[1] = nn.GroupNorm(
                num_groups=2, num_channels=256
            )
            resnet18.layer3[1].bn1 = nn.GroupNorm(num_groups=2, num_channels=256)
            resnet18.layer3[1].bn2 = nn.GroupNorm(num_groups=2, num_channels=256)

            resnet18.layer4[0].bn1 = nn.GroupNorm(num_groups=2, num_channels=512)
            resnet18.layer4[0].bn2 = nn.GroupNorm(num_groups=2, num_channels=512)
            resnet18.layer4[0].downsample[1] = nn.GroupNorm(
                num_groups=2, num_channels=512
            )
            resnet18.layer4[1].bn1 = nn.GroupNorm(num_groups=2, num_channels=512)
            resnet18.layer4[1].bn2 = nn.GroupNorm(num_groups=2, num_channels=512)

            assert len(dict(resnet18.named_parameters()).keys()) == len(
                resnet18.state_dict().keys()
            ), "More BN layers are there..."
            self.model = resnet18

    def forward(self, x):
        if self.name == "linear":
            x = self.fc(x)

        elif self.name == "mnist":
            x = x.view(-1, 1 * 28 * 28)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

        elif self.name == "emnist":
            x = x.view(-1, 1 * 28 * 28)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

        elif self.name == "cifar10":
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 64 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

        elif self.name == "cifar100":
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 64 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

        elif self.name == "resnet18":
            x = self.model(x)

        return x


# Args

In [None]:
def args_parser():
    parser = argparse.ArgumentParser()

    # FedDyn setup
    parser.add_argument(
        "--algorithm_name", type=str, default="FedDyn", help="algorithm name"
    )

    parser.add_argument("--n_clients", type=int, default=30, help="number of clients")

    parser.add_argument(
        "--comm_rounds", type=int, default=50, help="number of communication rounds"
    )

    parser.add_argument("--dataset_name", type=str, default="cifar10")


    parser.add_argument("--lr", type=float, default=0.03, help="learning rate")

    parser.add_argument(
        "--act_prob", type=float, default=0.9, help="probability of active clients"
    )

    parser.add_argument(
        "--lr_decay_per_round",
        type=float,
        default=0.99,
        help="learning rate decay per round",
    )

    parser.add_argument("--batch_size", type=int, default=50, help="batch size")

    parser.add_argument(
        "--epoch", type=int, default=5, help="local epoch for client training"
    )

    parser.add_argument("--weight_decay", type=float, default=1e-2, help="weight decay")

    parser.add_argument(
        "--max_norm", type=float, default=10, help="max norm for gradient clipping"
    )

    parser.add_argument("--model_name", type=str, default="cifar10", help="model name")

    parser.add_argument(
        "--rule", type=str, default="iid", help="the rule of data partitioning"
    )

    parser.add_argument("--rand_seed", type=int, default=1, help="random seed")

    parser.add_argument("--save_period", type=int, default=1, help="save period")

    parser.add_argument("--print_per", type=int, default=5, help="print period")

    # RIS FL setup
    parser.add_argument(
        "--n_RIS_ele", type=int, default=40, help="number of RIS elements"
    )

    parser.add_argument(
        "--n_receive_ant", type=int, default=5, help="number of receive antennas"
    )

    parser.add_argument(
        "--alpha_direct", type=float, default=3.76, help="path loss component"
    )

    parser.add_argument(
        "--SNR", type=float, default=90.0, help="noise variance/0.1W in dB"
    )

    parser.add_argument(
        "--location_range",
        type=int,
        default=30,
        help="location range between clients and RIS",
    )

    parser.add_argument(
        "--Jmax", type=int, default=50, help="number of maximum Gibbs Outer loops"
    )

    parser.add_argument(
        "--tau", type=float, default=0.03, help="tau, the SCA regularization term"
    )

    parser.add_argument(
        "--nit", type=int, default=100, help="I_max, number of maximum SCA loops"
    )

    parser.add_argument(
        "--threshold",
        type=float,
        default=1e-2,
        help="epsilon, SCA early stopping criteria",
    )

    parser.add_argument(
        "--transmit_power", type=float, default=0.003, help="transmit power"
    )

    parser.add_argument(
        "--noiseless", type=bool, default=False, help="whether the channel is noiseless"
    )

    args = parser.parse_args(args=[])

    return args


# Utils

In [None]:
def get_model_params(model_list, n_par=None):
    # count the number of parameters of a given model
    if n_par == None:
        exp_mdl = model_list[0]
        n_par = 0
        for name, param in exp_mdl.named_parameters():
            n_par += len(param.data.reshape(-1))

    # extract the parameters of a given model
    param_mat = np.zeros((len(model_list), n_par)).astype("float32")
    for i, mdl in enumerate(model_list):
        idx = 0
        for name, param in mdl.named_parameters():
            temp = param.data.cpu().numpy().reshape(-1)
            param_mat[i, idx : idx + len(temp)] = temp
            idx += len(temp)
    return np.copy(
        param_mat
    )  # param_mat =  [[ 0.09114207 -0.10681842  0.10701807 ...  0.07207876  0.00579278   -0.0345436 ]]


def set_model(model, params, device):
    dict_param = copy.deepcopy(dict(model.named_parameters()))
    idx = 0
    for name, param in model.named_parameters():
        weights = param.data
        length = len(weights.reshape(-1))
        dict_param[name].data.copy_(
            torch.tensor(params[idx : idx + length].reshape(weights.shape)).to(device)
        )
        idx += length

    model.load_state_dict(dict_param)
    return model


def get_acc_loss(
    data_x, data_y, model, dataset_name, device, w_decay=None, batch_size=50
):
    acc_overall = 0
    loss_overall = 0
    loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
    # batch_size = min(6000, data_x.shape[0])
    n_tst = data_x.shape[0]
    tst_gen = data.DataLoader(
        Dataset(data_x, data_y, dataset_name=dataset_name),
        batch_size=batch_size,
        shuffle=False,
    )
    model.eval()
    model = model.to(device)
    with torch.no_grad():
        tst_gen_iter = tst_gen.__iter__()
        for _ in range(int(np.ceil(n_tst / batch_size))):
            batch_x, batch_y = tst_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            y_pred = model(batch_x)

            loss = loss_fn(y_pred, batch_y.reshape(-1).long())
            loss_overall += loss.item()
            # Accuracy calculation
            y_pred = y_pred.cpu().numpy()
            y_pred = np.argmax(y_pred, axis=1).reshape(-1)
            batch_y = batch_y.cpu().numpy().reshape(-1).astype(np.int32)
            batch_correct = np.sum(y_pred == batch_y)
            acc_overall += batch_correct

    loss_overall /= n_tst
    if w_decay != None:
        # Add L2 loss
        params = get_model_params([model], n_par=None)
        loss_overall += w_decay / 2 * np.sum(params * params)

    model.train()
    return loss_overall, acc_overall / n_tst


def save_performance(
    communication_rounds,
    tst_perf_all,
    algorithm_name,
    data_obj_name,
    model_name,
    n_clients,
    noiseless,
    iid_str,
):
    plt.figure(figsize=(6, 5))
    plt.plot(
        np.arange(communication_rounds) + 1,
        tst_perf_all[:, 1],
        label=algorithm_name,
        linewidth=2.5,
        color="red",
    )
    plt.ylabel("Test Accuracy", fontsize=16)
    plt.xlabel("Communication Rounds", fontsize=16)
    plt.legend(fontsize=16, loc="lower right", bbox_to_anchor=(1.015, -0.02))
    plt.grid()
    plt.xlim([0, communication_rounds + 1])
    plt.title(data_obj_name, fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(
        "Output/{}/{}_{}cln_{}comm_{}_{}.pdf".format(
            data_obj_name,
            algorithm_name,
            n_clients,
            communication_rounds,
            "noiseless" if noiseless else "noisy",
            model_name,
        ),
        dpi=1000,
        bbox_inches="tight",
    )
    np.save(
        "Output/{}/{}_{}cln_{}comm_{}_{}_{}_tst_perf_all.npy".format(
            data_obj_name,
            algorithm_name,
            n_clients,
            communication_rounds,
            "noiseless" if noiseless else "noisy",
            iid_str.lower(),
            model_name,
        ),
        tst_perf_all,
    )


def evaluate_performance(
    cent_x,
    cent_y,
    tst_x,
    tst_y,
    dataset_name,
    avg_model,
    all_model,
    device,
    tst_perf_sel,
    trn_perf_sel,
    tst_perf_all,
    trn_perf_all,
    t,
):
    loss_tst, acc_tst = get_acc_loss(tst_x, tst_y, avg_model, dataset_name, device)
    tst_perf_sel[t] = [loss_tst, acc_tst]
    print(
        "\n**** Communication sel %3d, Test Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(cent_x, cent_y, avg_model, dataset_name, device)
    trn_perf_sel[t] = [loss_tst, acc_tst]
    print(
        "**** Communication sel %3d, Cent Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(tst_x, tst_y, all_model, dataset_name, device)
    tst_perf_all[t] = [loss_tst, acc_tst]
    print(
        "**** Communication all %3d, Test Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(cent_x, cent_y, all_model, dataset_name, device)
    trn_perf_all[t] = [loss_tst, acc_tst]
    print(
        "**** Communication all %3d, Cent Accuracy: %.4f, Loss: %.4f\n"
        % (t + 1, acc_tst, loss_tst)
    )


# Server

In [None]:
def get_model_params(model_list, n_par=None):
    # count the number of parameters of a given model
    if n_par == None:
        exp_mdl = model_list[0]
        n_par = 0
        for name, param in exp_mdl.named_parameters():
            n_par += len(param.data.reshape(-1))

    # extract the parameters of a given model
    param_mat = np.zeros((len(model_list), n_par)).astype("float32")
    for i, mdl in enumerate(model_list):
        idx = 0
        for name, param in mdl.named_parameters():
            temp = param.data.cpu().numpy().reshape(-1)
            param_mat[i, idx : idx + len(temp)] = temp
            idx += len(temp)
    return np.copy(
        param_mat
    )  # param_mat =  [[ 0.09114207 -0.10681842  0.10701807 ...  0.07207876  0.00579278   -0.0345436 ]]


def set_model(model, params, device):
    dict_param = copy.deepcopy(dict(model.named_parameters()))
    idx = 0
    for name, param in model.named_parameters():
        weights = param.data
        length = len(weights.reshape(-1))
        dict_param[name].data.copy_(
            torch.tensor(params[idx : idx + length].reshape(weights.shape)).to(device)
        )
        idx += length

    model.load_state_dict(dict_param)
    return model


def get_acc_loss(
    data_x, data_y, model, dataset_name, device, w_decay=None, batch_size=50
):
    acc_overall = 0
    loss_overall = 0
    loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
    # batch_size = min(6000, data_x.shape[0])
    n_tst = data_x.shape[0]
    tst_gen = data.DataLoader(
        Dataset(data_x, data_y, dataset_name=dataset_name),
        batch_size=batch_size,
        shuffle=False,
    )
    model.eval()
    model = model.to(device)
    with torch.no_grad():
        tst_gen_iter = tst_gen.__iter__()
        for _ in range(int(np.ceil(n_tst / batch_size))):
            batch_x, batch_y = tst_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            y_pred = model(batch_x)

            loss = loss_fn(y_pred, batch_y.reshape(-1).long())
            loss_overall += loss.item()
            # Accuracy calculation
            y_pred = y_pred.cpu().numpy()
            y_pred = np.argmax(y_pred, axis=1).reshape(-1)
            batch_y = batch_y.cpu().numpy().reshape(-1).astype(np.int32)
            batch_correct = np.sum(y_pred == batch_y)
            acc_overall += batch_correct

    loss_overall /= n_tst
    if w_decay != None:
        # Add L2 loss
        params = get_model_params([model], n_par=None)
        loss_overall += w_decay / 2 * np.sum(params * params)

    model.train()
    return loss_overall, acc_overall / n_tst


def save_performance(
    communication_rounds,
    tst_perf_all,
    algorithm_name,
    data_obj_name,
    model_name,
    n_clients,
    noiseless,
    iid_str,
):
    plt.figure(figsize=(6, 5))
    plt.plot(
        np.arange(communication_rounds) + 1,
        tst_perf_all[:, 1],
        label=algorithm_name,
        linewidth=2.5,
        color="red",
    )
    plt.ylabel("Test Accuracy", fontsize=16)
    plt.xlabel("Communication Rounds", fontsize=16)
    plt.legend(fontsize=16, loc="lower right", bbox_to_anchor=(1.015, -0.02))
    plt.grid()
    plt.xlim([0, communication_rounds + 1])
    plt.title(data_obj_name, fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(
        "Output/{}/{}_{}cln_{}comm_{}_{}.pdf".format(
            data_obj_name,
            algorithm_name,
            n_clients,
            communication_rounds,
            "noiseless" if noiseless else "noisy",
            model_name,
        ),
        dpi=1000,
        bbox_inches="tight",
    )
    np.save(
        "Output/{}/{}_{}cln_{}comm_{}_{}_{}_tst_perf_all.npy".format(
            data_obj_name,
            algorithm_name,
            n_clients,
            communication_rounds,
            "noiseless" if noiseless else "noisy",
            iid_str.lower(),
            model_name,
        ),
        tst_perf_all,
    )


def evaluate_performance(
    cent_x,
    cent_y,
    tst_x,
    tst_y,
    dataset_name,
    avg_model,
    all_model,
    device,
    tst_perf_sel,
    trn_perf_sel,
    tst_perf_all,
    trn_perf_all,
    t,
):
    loss_tst, acc_tst = get_acc_loss(tst_x, tst_y, avg_model, dataset_name, device)
    tst_perf_sel[t] = [loss_tst, acc_tst]
    print(
        "\n**** Communication sel %3d, Test Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(cent_x, cent_y, avg_model, dataset_name, device)
    trn_perf_sel[t] = [loss_tst, acc_tst]
    print(
        "**** Communication sel %3d, Cent Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(tst_x, tst_y, all_model, dataset_name, device)
    tst_perf_all[t] = [loss_tst, acc_tst]
    print(
        "**** Communication all %3d, Test Accuracy: %.4f, Loss: %.4f"
        % (t + 1, acc_tst, loss_tst)
    )

    loss_tst, acc_tst = get_acc_loss(cent_x, cent_y, all_model, dataset_name, device)
    trn_perf_all[t] = [loss_tst, acc_tst]
    print(
        "**** Communication all %3d, Cent Accuracy: %.4f, Loss: %.4f\n"
        % (t + 1, acc_tst, loss_tst)
    )


# Communication 

In [None]:
# air communication
class AirComp(object):
    def __init__(self, n_receive_ant, weight_list, transmit_power):
        self.n_receive_ant = n_receive_ant
        self.weight_list = weight_list
        self.transmit_power = transmit_power

        self.need_air_comp = True

    def transmit(self, d, signal, x, f, h, sigma):
        index = x == 1
        N = self.n_receive_ant
        K = self.weight_list[index] # K_m
        K2 = K**2 # (K_m)^2

        inner = f.conj() @ h[:, index]
        inner2 = np.abs(inner) ** 2

        g = signal

        # mean and variance
        mean = np.mean(g, axis = 1)
        g_bar = K @ mean

        var = np.var(g, axis = 1)
        var_sqrt = var ** 0.5

        eta = np.min(self.transmit_power * inner2 / K2 / var) # from 17a
        eta_sqrt = eta ** 0.5
        b = K * eta_sqrt * var_sqrt * inner.conj() / inner2 # from 17b

        noise_power = sigma * self.transmit_power

        n = (
            (np.random.randn(N, d) + 1j * np.random.randn(N, d))
            / (2) ** 0.5
            * noise_power ** 0.5
        )

        x_signal = np.tile(b / var_sqrt, (d, 1)).T * (g - np.tile(mean, (d, 1)).T)
        y = h[:, index] @ x_signal + n
        w = np.real((f.conj() @ y / eta_sqrt + g_bar)) / np.sum(K) # from 11

        return w if self.need_air_comp else y

# xây dựng kênh truyền
class Channel(object):
    def __init__(
            self,
            SNR, # tỉ số tín hiệu trên nhiễu
            n_clients, # số lượng edge devices
            location_range, # phạm vi đặt tọa độ của client
            fc, # tần số sóng mang (free-space path loss)
            alpha_direct, # hệ số suy hao trên đường truyền (PL trong bài báo)
            n_RIS_ele, # số lượng phần tử RIS (L trong bài báo)
            n_receive_ant, # số anten ở PS
            User_Gain, # antenna gain at user (G_D)
            x0, # tọa độ ban đầu của thiết bị
            BS, # vị trí của PS
            BS_Gain, # antenna gain at PS (G_PS)
            RIS, # vị trí của RIS
            RIS_Gain, # antenna gain at RIS (G_RIS)
            dimen_RIS, # kích thước phần tử RIS
    ):
        self.SNR = SNR
        self.n_clients = n_clients
        self.location_range = location_range
        self.fc = fc
        self.alpha_direct = alpha_direct
        self.n_RIS_ele = n_RIS_ele
        self.n_receive_ant = n_receive_ant
        self.User_Gain = User_Gain
        self.x = x0

        self.BS = BS
        self.BS_Gain = BS_Gain

        self.RIS = RIS
        self.RIS_Gain = RIS_Gain
        self.dimen_RIS = dimen_RIS

    def generate(self):
        ref = (1e-10) ** 0.5
        sigma_n = np.power(10, -self.SNR/10)
        sigma = sigma_n / ref**2

        # setting 2
        dx2 = (
            np.random.rand(int(self.n_clients - np.round(self.n_clients / 2)))
            * self.location_range + 200
        ) # phần ở xa
        dx1 = (
            np.random.rand(int(np.round(self.n_clients /2 ))) * self.location_range - self.location_range
        ) # [-location_range, 0]

        dx = np.concatenate((dx1, dx2)) # mảng chứa các tọa độ dx của thiết bị
        np.random.shuffle(dx)

        dy = np.random.rand(self.n_clients) * 20 - 10

        d_UR = (
            (dx - self.RIS[0]) ** 2 + (dy - self.RIS[1]) ** 2 + self.RIS[2] ** 2
        ) ** 0.5 # khoảng cách từ RIS đến User

        d_RB = np.linalg.norm(self.BS - self.RIS) # khoảng cách tử RIS đến PS
        d_direct = (
            (dx - self.BS[0]) ** 2 + (dy - self.BS[1]) ** 2 + self.BS[2] ** 2
        ) ** 0.5 # khoảng cách từ PS đến User

        PL_direct = (
            self.BS_Gain * self.User_Gain
            * (3 * 10 ** 8 / self.fc / 4 / np.pi / d_direct) ** self.alpha_direct
        ) # free-space path loss, device-PS direct channels
        PL_RIS = (
            self.BS_Gain * self.User_Gain * self.RIS_Gain
            * self.n_RIS_ele ** 2
            * self.dimen_RIS ** 2
            / 4
            / np.pi
            * (3 * 10**8 / self.fc / 4 / np.pi / d_UR) ** 2
            * (3 * 10**8 / self.fc / 4 / np.pi / d_RB) ** 2
        )

        h_d = (
            np.random.randn(self.n_receive_ant, self.n_clients)
            + 1j * np.random.randn(self.n_receive_ant, self.n_clients)
        ) / 2 ** 0.5 # small-scale fading coefficients
        h_d = h_d @ np.diag(PL_direct**0.5) / ref # channel coefficients

        H_RB = (
            np.random.randn(self.n_receive_ant, self.n_RIS_ele)
            + 1j * np.random.randn(self.n_receive_ant, self.n_RIS_ele)
        ) / 2**0.5 # RIS to PS

        h_UR = (
            np.random.randn(self.n_RIS_ele, self.n_clients)
            + 1j * np.random.randn(self.n_RIS_ele, self.n_clients)
        ) / 2**0.5
        h_UR = h_UR @ np.diag(PL_RIS**0.5) / ref # User to PS

        G = np.zeros(
            [self.n_receive_ant, self.n_RIS_ele, self.n_clients], dtype = complex
        ) # các kênh truyền user -> 1 phần tử trên RIS -> 1 anten của PS
        for j in range(self.n_clients): # với mỗi user
            G[:, :, j] = H_RB @ np.diag(h_UR[:, j]) # RIS to PS multiply User(j) to RIS

        return h_d, G, self.x, sigma # kênh trực tiếp, kênh gián tiếp, tọa độ user, noise

# optimize
class Gibbs(object):
    def __init__(
            self,
            n_clients, # số lượng user
            n_receive_ant, # số anten tại PS
            n_RIS_ele, # số phần tử của RIS
            Jmax, # số vòng Gibbs sampling
            weight_list, # trọng số của thiết bị (K_m)
            tau,
            nit,
            threshold,
    ):
        self.n_clients = n_clients
        self.n_receive_ant = n_receive_ant
        self.n_RIS_ele = n_RIS_ele
        self.Jmax = Jmax
        self.weight_list = weight_list

        # SCA_based optimization
        self.tau = tau
        self.nit = nit
        self.threshold = threshold

    def optimize(self, h_d, G, x0, sigma):
        x_store, obj_new, f_store, theta_store = self.sampling(h_d, G, x0, sigma) # tối ưu luân phiên từng biến

        x_optim = x_store[self.Jmax]
        f_optim = f_store[:, self.Jmax]
        theta_optim = theta_store[:, self.Jmax]
        h_optim = np.zeros([self.n_receive_ant, self.n_clients], dtype = complex)

        for i in range(self.n_clients):
            h_optim[:, i] = h_d[:, i] + G[:, :, i] @ theta_optim

        return x_optim, f_optim, h_optim

    # algorithm 2
    def sampling(self, h_d, G, x0, sigma):
        N = self.n_receive_ant
        L = self.n_RIS_ele
        M = self.n_clients

        K = self.weight_list/ np.mean(self.weight_list) # chuẩn hóa K
        K2 = K ** 2
        Ksum2 = sum(K)**2 # K^2
        x = x0

        #
        obj_new = np.zeros(self.Jmax + 1)
        f_store = np.zeros([N, self.Jmax + 1], dtype = complex)
        theta_store = np.zeros([L, self.Jmax + 1], dtype = complex)
        x_store = np.zeros([self.Jmax + 1, M], dtype = int)

        # first loop
        ind = 0
        [obj_new[ind], x_store[ind, :], f, theta] = self.find_obj_inner(
            x, K, K2, Ksum2, h_d, G, None, None, sigma
        )

        theta_store[:, ind] = copy.deepcopy(theta)
        f_store[:, ind] = copy.deepcopy(f)
        beta = min(1, obj_new[ind])
        alpha = 0.9

        f_loop = np.tile(f, (M+1, 1))
        theta_loop = np.tile(theta, (M+1, 1))

        for j in range(self.Jmax):

            # store possible transition solution and their obj
            X_sample = np.zeros([M+1, M], dtype = int)
            Temp = np.zeros(M+1)

            # first transition -> no change
            X_sample[0, :] = copy.deepcopy(x)
            Temp[0] = copy.deepcopy(obj_new[ind])
            f_loop[0] = copy.deepcopy(f)
            theta_loop[0] = copy.deepcopy(theta)

            # 2 - M+1 transition, change only 1 position
            for m in range(M):

                # flip the m-th position
                x_sam = copy.deepcopy(x)
                x_sam[m] = copy.deepcopy((x_sam[m] + 1) % 2)
                X_sample[m+1, :] = copy.deepcopy(x_sam)
                Temp[m+1], _, f_loop[m+1], theta_loop[m+1] = self.find_obj_inner(
                    x_sam,
                    K,
                    K2,
                    Ksum2,
                    h_d,
                    G,
                    f_loop[m+1],
                    theta_loop[m+1],
                    sigma,
                )
            temp2 = Temp

            Lambda = np.exp(-1 * temp2/beta)
            Lambda = Lambda / sum(Lambda)
            while np.isnan(Lambda).any():
                beta = beta / alpha
                Lambda = np.exp(-1.0 * temp2 / beta)
                Lambda = Lambda/sum(Lambda)

            kk_prime = np.random.choice(M+1, p = Lambda)
            x = copy.deepcopy(X_sample[kk_prime, :])
            f = copy.deepcopy(f_loop[kk_prime])
            theta = copy.deepcopy(theta_loop[kk_prime])
            ind += 1
            obj_new[ind] = copy.deepcopy(Temp[kk_prime])
            x_store[ind, :] = copy.deepcopy(x)
            theta_store[:, ind] = copy.deepcopy(theta)
            f_store[:, ind] = copy.deepcopy(f)

            beta = max(alpha * beta, 1e-4)

        return x_store, obj_new, f_store, theta_store

    # đánh giá thiết bị được chọn
    def find_obj_inner(self, x, K, K2, Ksum2, h_d, G, f0, theta0, sigma):
        N = self.n_receive_ant
        L = self.n_RIS_ele
        M = self.n_clients

        if sum(x) == 0:
            obj = np.inf

            theta = np.ones([L], dtype=complex)
            f = h_d[:, 0] / np.linalg.norm(h_d[:, 0])
        else:
            index = x == 1

            f, theta, _ = self.sca_fmincon(
                h_d[:, index], G[:, :, index], f0, theta0, x, K2[index]
            )

            h = np.zeros([N, M], dtype=complex)
            for i in range(M):
                h[:, i] = h_d[:, i] + G[:, :, i] @ theta
            gain = K2 / (np.abs(np.conjugate(f) @ h) ** 2) * sigma
            obj = (
                np.max(gain[index]) / (sum(K[index])) ** 2
                + 4 / Ksum2 * (sum(K[~index])) ** 2
            )
        return obj, x, f, theta

    # Algorithm 1
    def sca_fmincon(self, h_d, G, f, theta, x, K2):  # (25)
        N = self.n_receive_ant
        L = self.n_RIS_ele
        I = sum(x)

        if theta is None:
            theta = np.ones([L], dtype=complex)
        result = np.zeros(self.nit)
        h = np.zeros([N, I], dtype=complex)
        for i in range(I):
            h[:, i] = h_d[:, i] + G[:, :, i] @ theta

        if f is None:
            f = h[:, 0] / np.linalg.norm(h[:, 0])

        obj = min(np.abs(np.conjugate(f) @ h) ** 2 / K2)

        for it in range(self.nit):
            obj_pre = copy.deepcopy(obj)
            a = np.zeros([N, I], dtype=complex)
            b = np.zeros([L, I], dtype=complex)
            c = np.zeros([1, I], dtype=complex)
            F_cro = np.outer(f, np.conjugate(f))
            for i in range(I):  # (26)
                a[:, i] = (
                    self.tau * K2[i] * f + np.outer(h[:, i], np.conjugate(h[:, i])) @ f
                )

                b[:, i] = (
                    self.tau * K2[i] * theta + G[:, :, i].conj().T @ F_cro @ h[:, i]
                )

                c[:, i] = (
                    np.abs(np.conjugate(f) @ h[:, i]) ** 2
                    + 2 * self.tau * K2[i] * (L + 1)
                    + 2
                    * np.real(
                        (theta.conj().T) @ (G[:, :, i].conj().T) @ F_cro @ h[:, i]
                    )
                )

            # convex optimization
            mu = cp.Variable(I, nonneg=True)
            obj = cp.Minimize(
                cp.real(2 * cp.norm(a @ mu) + 2 * cp.norm(b @ mu, 1) - c @ mu)
            )
            prob = cp.Problem(obj, [K2 @ mu == 1])
            mu.value = 1 / K2
            prob.solve(solver='ECOS')

            fn = a @ mu.value
            thetan = b @ mu.value
            fn = fn / np.linalg.norm(fn)
            f = fn

            thetan = thetan / np.abs(thetan)
            theta = thetan

            for i in range(I):
                h[:, i] = h_d[:, i] + G[:, :, i] @ theta
            obj = min(np.abs(np.conjugate(f) @ h) ** 2 / K2)  # (24)
            result[it] = copy.deepcopy(obj)

            if np.abs(obj - obj_pre) / min(1, abs(obj)) <= self.threshold:
                break

        result = result[0:it]
        return f, theta, result

# Client

In [None]:
class Client:
    def __init__(
        self,
        algorithm,  # local training algorithm
        device,
        weight,
        train_data_X,
        train_data_Y,
        model,
        client_param,
        malicious=False,          # NEW: whether this client is adversary
        attack_params=None,       # NEW: dict with keys {'poison_fraction','epsilon','poison_labels'}
        client_id=None,
    ):
        self.algorithm = algorithm
        self.device = device
        self.weight = weight
        self.train_data_X = train_data_X
        self.train_data_Y = train_data_Y
        self.model = model
        self.client_param = client_param
        self.malicious = malicious
        self.attack_params = attack_params or {}
        self.client_id = client_id

    def local_train(self, inputs: dict):
        # Pass inputs through; algorithm.local_train will inspect inputs.get('attack_params')
        self.algorithm.local_train(self, inputs)

    def aggregate(self, algorithm):
        self.algorithm = algorithm

# Adversary FGSM

In [None]:
def fgsm_image_attack(model, x_batch, y_batch, epsilon=2/255.0, device='cpu', targeted=False, target_labels=None):
    """
    Single-step FGSM on image batch.
    - model: torch.nn.Module (in eval/training mode ok)
    - x_batch: torch.Tensor shape (B, C, H, W), values in [0,1] or normalized as your transforms
    - y_batch: torch.Tensor shape (B,) (long)
    - epsilon: float (perturbation magnitude)
    - targeted: if True, uses target_labels (tensor) to reduce loss on target (less common)
    Returns: torch.Tensor (adversarial images) same device as x_batch
    """
    model.eval()  # we compute gradients w.r.t. inputs only
    x = x_batch.clone().detach().to(device)
    x.requires_grad = True
    y = y_batch.clone().detach().to(device)
    logits = model(x)
    if targeted and target_labels is not None:
        # targeted: minimize loss for target -> use negative grad
        loss = F.cross_entropy(logits, target_labels)
        factor = -1.0
    else:
        loss = F.cross_entropy(logits, y)
        factor = 1.0
    model.zero_grad()
    loss.backward()
    grad = x.grad.data
    x_adv = x + factor * epsilon * torch.sign(grad)
    # clamp according to underlying preprocessing.
    # If your data is in [0,1] (ToTensor) then clamp 0..1; if normalized, adjust accordingly.
    x_adv = torch.clamp(x_adv, 0.0, 1.0)
    x_adv = x_adv.detach()
    return x_adv

# Algorithm

In [None]:
# algorithm_base :
class Algorithm:
    # constructor
    def __init__(
        self,
        name, # tên thuật toán
        lr, # learning rate
        lr_decay_per_round, # tỷ lệ giảm learning rate sau mỗi round
        batch_size,
        epoch, # số epoch training local
        weight_decay, # hệ số L2 penalty - chống overfitting
        model_func, # hàm khởi tạo model
        n_param, # số lượng tham số
        max_norm, # max gradient (nếu ||g||_2 > max_norm, g = g * (max_norm/||g||_2) )
        noiseless, # cờ bật/tắt nhiễu
        dataset_name,
        save_period, # chu kì lưu model
        print_per, # sau mỗi n round sẽ in
  ):
        self.name = name
        self.lr = lr
        self.lr_decay_per_round = lr_decay_per_round
        self.batch_size = batch_size
        self.epoch = epoch
        self.weight_decay = weight_decay
        self.model_func = model_func
        self.n_param = n_param
        self.max_norm = max_norm
        self.noiseless = noiseless
        self.dataset_name = dataset_name
        self.save_period = save_period
        self.print_per = print_per

    @abstractmethod
    def local_train(self):
        raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method for client local training.")

    abstractmethod
    def aggregate(self):
        raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method for server aggregation. ")

class AlgorithmFactory:
    def __init__(self, args):
        self.args = args

    def create_algorithm(self, algorithm_name) -> Algorithm:
        algorithm = None

        if algorithm_name == "FedDyn":
            alpha_coef = 1e-2

            algorithm = FedDyn(
                self.args.lr,
                self.args.lr_decay_per_round,
                self.args.batch_size,
                self.args.epoch,
                self.args.weight_decay,
                self.args.model_func,
                self.args.n_param,
                self.args.max_norm,
                self.args.noiseless,
                self.args.data_obj.dataset,
                self.args.save_period,
                self.args.print_per,
                alpha_coef,
            )

        elif algorithm_name == "FedProx":
            mu = 1e-4

            algorithm = FedProx(
                self.args.lr,
                self.args.lr_decay_per_round,
                self.args.batch_size,
                self.args.epoch,
                self.args.weight_decay,
                self.args.model_func,
                self.args.n_param,
                self.args.max_norm,
                self.args.noiseless,
                self.args.data_obj.dataset,
                self.args.save_period,
                self.args.print_per,
                mu,
            )

        elif algorithm_name == "SCAFFOLD":
            n_data_per_client = (
                np.concatenate(self.args.data_obj.clnt_x, axis=0).shape[0]
                / self.args.n_clients
            )
            n_iter_per_epoch = np.ceil(n_data_per_client / self.args.batch_size)
            n_minibatch = (self.args.epoch * n_iter_per_epoch).astype(np.int64)
            self.args.print_per = self.args.print_per * n_iter_per_epoch
            global_learning_rate = 1

            algorithm = SCAFFOLD(
                self.args.lr,
                self.args.lr_decay_per_round,
                self.args.batch_size,
                self.args.epoch,
                self.args.weight_decay,
                self.args.model_func,
                self.args.n_param,
                self.args.max_norm,
                self.args.noiseless,
                self.args.data_obj.dataset,
                self.args.save_period,
                self.args.print_per,
                n_minibatch,
                global_learning_rate,
            )

        elif algorithm_name == "FedAvg":
            algorithm = FedAvg(
                self.args.lr,
                self.args.lr_decay_per_round,
                self.args.batch_size,
                self.args.epoch,
                self.args.weight_decay,
                self.args.model_func,
                self.args.n_param,
                self.args.max_norm,
                self.args.noiseless,
                self.args.data_obj.dataset,
                self.args.save_period,
                self.args.print_per,
            )

        else:
            raise ValueError(f"Unknown algorithm name: {self.algorithm_name}")

        return algorithm

# FedAvg
class FedAvg(Algorithm):
    def __init__(
        self,
        lr,
        lr_decay_per_round,
        batch_size,
        epoch,
        weight_decay,
        model_func,
        n_param,
        max_norm,
        noiseless,
        dataset_name,
        save_period,
        print_per,
    ):
        super().__init__(
            "FedAvg",
            lr,
            lr_decay_per_round,
            batch_size,
            epoch,
            weight_decay,
            model_func,
            n_param,
            max_norm,
            noiseless,
            dataset_name,
            save_period,
            print_per,
        )

    def local_train(self, client: Client, inputs: dict):
        self.device = client.device

        client.model = self.model_func().to(self.device)
        client.model.load_state_dict(
            copy.deepcopy(dict(inputs["avg_model"].named_parameters()))
        )

        for params in client.model.parameters():
            parems.requires_grad = True

        print("client model parameters: ", get_model_params([client.model], self.n_param)[0],)

        client.model = self.__train_model(
            client.model,
            client.train_data_X,
            client.train_data_Y,
            inputs["curr_round"]
        )
        updated_param = get_model_params([client.model], self.n_param)[0]
        print("updated model parameters: ", updated_param)

        client.client_param = updated_param

    def __train_model(self, model, trn_x, trn_y, curr_round):
        decayed_lr = self.lr * (self.lr_decay_per_round ** curr_round)

        n_trn = trn_x.shape[0]
        trn_gen = data.DataLoader(
            Dataset(trn_x, trn_y, train = True, dataset_name = self.dataset_name),
            batch_size = self.batch_size,
            shuffle = True,
        )
        loss_fn = torch.nn.CrossEntropyLoss(reduction = "sum")

        optimizer = torch.optim.SGD(
            model.parameters(), lr = decayed_lr, weight_decay = self.weight_decay
        )
        model.train()
        model = model.to(self.device)

        for e in range(self.epoch):
            # training
            trn_gen_iter = trn_gen.__iter__()
            for _ in range (int(np.ceil(n_trn/self.batch_size))):
                batch_x, batch_y = trn_gen_iter.__next__()
                batch_x = batch_x.to(self.device)
                batch_y = batch_y.to(self.device)

                y_pred = model(batch_x)
                loss = loss_fn(y_pred, batch_y.reshape(-1).long())
                loss = loss / list(batch_y.size())[0]

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    parameters = model.parameters(), max_norm = self.max_norm
                )
                optimizer.step()

            if (e + 1) % self.print_per == 0:
                loss_trn, acc_trn = get_acc_loss(
                    trn_x, trn_y, model, self.dataset_name, self.weight_decay
                )
                print(
                    "epoch %3d, training accuracy: %.4f, loss: %.4f" % (e+1, acc_trn, loss_trn)
                )
                model.train()

        for params in model.parameters():
            params.requires_grad = False
        model.eval()

        return model

    def aggreate(self, server: Server, inputs: dict):
        clients_list = inputs["clients_list"]
        selected_clients_idx = inputs["selected_clients_idx"]
        weight_list = inputs["weight_list"]

        clients_param_list = np.array([client.client_param for client in clinets_list])
        weight_list = weight_list.reshape((-1, 1))
        print("weight_list.shape = ", weight_list.shape)

        avg_model_param = (
            inputs["avg_model_param"]
            if not self.noiseless
            else np.sum(
                clients_param_list[selected_clients_idx]
                * weight_list[selected_clients_idx]
                / np.sum(weight_list[selected_clients_idx]),
                axis = 0,
            )
        )

        print("avg_model_param = ", avg_model_param)

        server.avg_model = set_model(self.model_func(), avg_model_param, server.device)
        server.all_model = set_model(
            self.model_func(),
            np.sum(clients_param_list * weight_list / np.sum(weight_list), axis = 0),
            server.device,
        )

# FedDyn
class FedDyn(Algorithm):
    def __init__(
        self,
        lr,
        lr_decay_per_round,
        batch_size,
        epoch,
        weight_decay,
        model_func,
        n_param,
        max_norm,
        noiseless,
        dataset_name,
        save_period,
        print_per,
        alpha_coef,
    ):
        super().__init__(
            "FedDyn",
            lr,
            lr_decay_per_round,
            batch_size,
            epoch,
            weight_decay,
            model_func,
            n_param,
            max_norm,
            noiseless,
            dataset_name,
            save_period,
            print_per,
        )

        self.alpha_coef = alpha_coef

    # override
    def local_train(self, client: Client, inputs: dict):
        self.device = client.device

        client.model = self.model_func().to(self.device)
        model = client.model
        # Warm start from current avg model
        model.load_state_dict(
            copy.deepcopy(dict(inputs["cloud_model"].named_parameters()))
        )
        for params in model.parameters():
            params.requires_grad = True

        # Scale down
        alpha_coef_adpt = self.alpha_coef / client.weight  # adaptive alpha coef
        local_param_list_curr = torch.tensor(
            inputs["local_param"], dtype=torch.float32, device=self.device
        )  # = local_grad_vector
        print("local_param_list_curr = ", local_param_list_curr)
        print("cloud_model_param_tensor = ", inputs["cloud_model_param_tensor"])
        client.model = self.__train_model(
            model,
            alpha_coef_adpt,
            inputs["cloud_model_param_tensor"],
            local_param_list_curr,
            client.train_data_X,
            client.train_data_Y,
            inputs["curr_round"],
        )
        curr_model_par = get_model_params([client.model], self.n_param)[
            0
        ]  # get the model parameter after running FedDyn
        print("curr_model_par = ", curr_model_par)

        # No need to scale up hist terms. They are -\nabla/alpha and alpha is already scaled.
        inputs["local_param"] += (
            curr_model_par - inputs["cloud_model_param"]
        )  # after training, dynamically update the weight with the cloud model parameters

        client.client_param = curr_model_par

    def __train_model(
        self,
        model,
        alpha_coef_adpt,
        avg_mdl_param,
        local_grad_vector,
        trn_x,
        trn_y,
        curr_round,
    ):
        decayed_lr = self.lr * (self.lr_decay_per_round**curr_round)

        n_trn = trn_x.shape[0]
        trn_gen = data.DataLoader(
            Dataset(trn_x, trn_y, train=True, dataset_name=self.dataset_name),
            batch_size=self.batch_size,
            shuffle=True,
        )
        loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=decayed_lr,
            weight_decay=alpha_coef_adpt + self.weight_decay,
        )
        model.train()
        model = model.to(self.device)

        for e in range(self.epoch):
            # Training
            epoch_loss = 0
            trn_gen_iter = trn_gen.__iter__()
            for _ in range(int(np.ceil(n_trn / self.batch_size))):
                batch_x, batch_y = trn_gen_iter.__next__()
                batch_x = batch_x.to(self.device)
                batch_y = batch_y.to(self.device)

                y_pred = model(batch_x)

                ## Get f_i estimate
                loss_f_i = loss_fn(y_pred, batch_y.reshape(-1).long())
                loss_f_i = loss_f_i / list(batch_y.size())[0]

                # Get linear penalty on the current parameter estimates
                local_par_list = None
                for param in model.parameters():
                    if not isinstance(local_par_list, torch.Tensor):
                        # Initially nothing to concatenate
                        local_par_list = param.reshape(-1)
                    else:
                        local_par_list = torch.cat(
                            (local_par_list, param.reshape(-1)), 0
                        )

                loss_algo = alpha_coef_adpt * torch.sum(
                    local_par_list * (-avg_mdl_param + local_grad_vector)
                )
                loss = loss_f_i + loss_algo

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(), max_norm=self.max_norm
                )  # Clip gradients
                optimizer.step()
                epoch_loss += loss.item() * list(batch_y.size())[0]

            if (e + 1) % self.print_per == 0:
                epoch_loss /= n_trn
                if self.weight_decay != None:
                    # Add L2 loss to complete f_i
                    params = get_model_params([model], self.n_param)
                    epoch_loss += (
                        (alpha_coef_adpt + self.weight_decay)
                        / 2
                        * np.sum(params * params)
                    )
                print("Epoch %3d, Training Loss: %.4f" % (e + 1, epoch_loss))
                model.train()

        # Freeze model
        for params in model.parameters():
            params.requires_grad = False
        model.eval()

        return model

    # override
    def aggregate(self, server: Server, inputs: dict):
        clients_list = inputs["clients_list"]
        selected_clnts_idx = inputs["selected_clnts_idx"]

        clients_param_list = np.array([client.client_param for client in clients_list])

        avg_mdl_param = (
            inputs["avg_mdl_param"]
            if not self.noiseless
            else np.mean(clients_param_list[selected_clnts_idx], axis=0)
        )

        print("avg_mdl_param = ", avg_mdl_param)
        # print("n_param = ", self.n_param)
        # print("avg_mdl_param.shape = ", avg_mdl_param.shape)

        inputs["cloud_model_param"] = avg_mdl_param + np.mean(
            inputs["local_param_list"], axis=0
        )

        server.avg_model = set_model(self.model_func(), avg_mdl_param, server.device)
        server.all_model = set_model(
            self.model_func(), np.mean(clients_param_list, axis=0), server.device
        )
        inputs["cloud_model"] = set_model(
            self.model_func().to(server.device),
            inputs["cloud_model_param"],
            server.device,
        )


# Dataset

In [None]:
import os

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms


class DatasetObject:
    def __init__(self, dataset, n_client, rule, unbalanced_sgm=0, rule_arg=""):
        self.dataset = dataset
        self.n_client = n_client
        self.rule = rule
        self.rule_arg = rule_arg
        rule_arg_str = rule_arg if isinstance(rule_arg, str) else "%.3f" % rule_arg
        self.name = "%s_%d_%s_%s" % (
            self.dataset,
            self.n_client,
            self.rule,
            rule_arg_str,
        )
        self.name += "_%f" % unbalanced_sgm if unbalanced_sgm != 0 else ""
        self.unbalanced_sgm = unbalanced_sgm
        self.data_path = "Data"
        self.set_data()

    def set_data(self):
        # Prepare data if not ready
        if not os.path.exists("%s/%s" % (self.data_path, self.name)):
            # Get Raw data
            if self.dataset == "mnist":
                transform = transforms.Compose(
                    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
                )
                trnset = torchvision.datasets.MNIST(
                    root="%s/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                tstset = torchvision.datasets.MNIST(
                    root="%s/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )

                trn_load = torch.utils.data.DataLoader(
                    trnset, batch_size=len(trnset), shuffle=False, num_workers=1
                )
                tst_load = torch.utils.data.DataLoader(
                    tstset, batch_size=len(tstset), shuffle=False, num_workers=1
                )
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10

            elif self.dataset == "cifar10":
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]
                        ),
                    ]
                )

                trnset = torchvision.datasets.CIFAR10(
                    root="%s/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                tstset = torchvision.datasets.CIFAR10(
                    root="%s/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )

                trn_load = torch.utils.data.DataLoader(
                    trnset, batch_size=len(trnset), shuffle=False, num_workers=1
                )
                tst_load = torch.utils.data.DataLoader(
                    tstset, batch_size=len(tstset), shuffle=False, num_workers=1
                )
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 10

            elif self.dataset == "cifar100":
                print(self.dataset)
                # mean and std are validated here: https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
                        ),
                    ]
                )
                trnset = torchvision.datasets.CIFAR100(
                    root="%s/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                tstset = torchvision.datasets.CIFAR100(
                    root="%s/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )
                trn_load = torch.utils.data.DataLoader(
                    trnset, batch_size=len(trnset), shuffle=False, num_workers=0
                )
                tst_load = torch.utils.data.DataLoader(
                    tstset, batch_size=len(tstset), shuffle=False, num_workers=0
                )
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 100

            elif self.dataset == "emnist":
                transform = transforms.Compose(
                    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
                )
                trnset = torchvision.datasets.EMNIST(
                    root="%s/Raw" % self.data_path,
                    split="letters",
                    train=True,
                    download=True,
                    transform=transform,
                )
                tstset = torchvision.datasets.EMNIST(
                    root="%s/Raw" % self.data_path,
                    split="letters",
                    train=False,
                    download=True,
                    transform=transform,
                )

                # filter the labels with limitation of 10
                filtered_indices = trnset.targets.clone().detach() <= 10
                trnset.targets = trnset.targets[filtered_indices] - 1
                trnset.data = trnset.data[filtered_indices]

                filtered_indices = tstset.targets.clone().detach() <= 10
                tstset.targets = tstset.targets[filtered_indices] - 1
                tstset.data = tstset.data[filtered_indices]

                trn_load = torch.utils.data.DataLoader(
                    trnset, batch_size=len(trnset), shuffle=False, num_workers=1
                )
                tst_load = torch.utils.data.DataLoader(
                    tstset, batch_size=len(tstset), shuffle=False, num_workers=1
                )
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10

            trn_itr = trn_load.__iter__()
            tst_itr = tst_load.__iter__()
            # labels are of shape (n_data,)
            trn_x, trn_y = trn_itr.__next__()
            tst_x, tst_y = tst_itr.__next__()

            trn_x = trn_x.numpy()
            trn_y = trn_y.numpy().reshape(-1, 1)
            tst_x = tst_x.numpy()
            tst_y = tst_y.numpy().reshape(-1, 1)

            # Shuffle Data
            rand_perm = np.random.permutation(len(trn_y))
            trn_x = trn_x[rand_perm]
            trn_y = trn_y[rand_perm]

            self.trn_x = trn_x
            self.trn_y = trn_y
            self.tst_x = tst_x
            self.tst_y = tst_y

            ###
            n_data_per_clnt = int((len(trn_y)) / self.n_client)
            if self.unbalanced_sgm != 0:
                # Draw from lognormal distribution
                clnt_data_list = np.random.lognormal(
                    mean=np.log(n_data_per_clnt),
                    sigma=self.unbalanced_sgm,
                    size=self.n_client,
                )
                clnt_data_list = (
                    clnt_data_list / np.sum(clnt_data_list) * len(trn_y)
                ).astype(int)
                diff = np.sum(clnt_data_list) - len(trn_y)

                # Add/Subtract the excess number starting from first client
                if diff != 0:
                    for clnt_i in range(self.n_client):
                        if clnt_data_list[clnt_i] > diff:
                            clnt_data_list[clnt_i] -= diff
                            break
            else:
                clnt_data_list = (np.ones(self.n_client) * n_data_per_clnt).astype(int)
            ###

            if self.rule == "dirichlet":
                cls_priors = np.random.dirichlet(
                    alpha=[self.rule_arg] * self.n_cls, size=self.n_client
                )
                prior_cumsum = np.cumsum(cls_priors, axis=1)
                idx_list = [np.where(trn_y == i)[0] for i in range(self.n_cls)]
                cls_amount = [len(idx_list[i]) for i in range(self.n_cls)]

                clnt_x = [
                    np.zeros(
                        (clnt_data_list[clnt__], self.channels, self.height, self.width)
                    ).astype(np.float32)
                    for clnt__ in range(self.n_client)
                ]
                clnt_y = [
                    np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64)
                    for clnt__ in range(self.n_client)
                ]

                while np.sum(clnt_data_list) != 0:
                    curr_clnt = np.random.randint(self.n_client)
                    # If current node is full resample a client
                    print("Remaining Data: %d" % np.sum(clnt_data_list))
                    if clnt_data_list[curr_clnt] <= 0:
                        continue
                    clnt_data_list[curr_clnt] -= 1
                    curr_prior = prior_cumsum[curr_clnt]
                    while True:
                        cls_label = np.argmax(np.random.uniform() <= curr_prior)
                        # Redraw class label if trn_y is out of that class
                        if cls_amount[cls_label] <= 0:
                            continue
                        cls_amount[cls_label] -= 1
                        clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = trn_x[
                            idx_list[cls_label][cls_amount[cls_label]]
                        ]
                        clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = trn_y[
                            idx_list[cls_label][cls_amount[cls_label]]
                        ]

                        break

                clnt_x = np.asarray(clnt_x)
                clnt_y = np.asarray(clnt_y)

                cls_means = np.zeros((self.n_client, self.n_cls))
                for clnt in range(self.n_client):
                    for cls in range(self.n_cls):
                        cls_means[clnt, cls] = np.mean(clnt_y[clnt] == cls)
                prior_real_diff = np.abs(cls_means - cls_priors)
                print("--- Max deviation from prior: %.4f" % np.max(prior_real_diff))
                print("--- Min deviation from prior: %.4f" % np.min(prior_real_diff))

            elif (
                self.rule == "iid"
                and self.dataset == "cifar100"
                and self.unbalanced_sgm == 0
            ):
                assert len(trn_y) // 100 % self.n_client == 0
                # Only have the number clients if it divides 500
                # Perfect IID partitions for cifar100 instead of shuffling
                idx = np.argsort(trn_y[:, 0])
                n_data_per_clnt = len(trn_y) // self.n_client
                # clnt_x dtype needs to be float32, the same as weights
                clnt_x = np.zeros(
                    (self.n_client, n_data_per_clnt, 3, 32, 32), dtype=np.float32
                )
                clnt_y = np.zeros((self.n_client, n_data_per_clnt, 1), dtype=np.float32)
                trn_x = trn_x[idx]  # 50000*3*32*32
                trn_y = trn_y[idx]
                n_cls_sample_per_device = n_data_per_clnt // 100
                for i in range(self.n_client):  # devices
                    for j in range(100):  # class
                        clnt_x[
                            i,
                            n_cls_sample_per_device
                            * j : n_cls_sample_per_device
                            * (j + 1),
                            :,
                            :,
                            :,
                        ] = trn_x[
                            500 * j
                            + n_cls_sample_per_device * i : 500 * j
                            + n_cls_sample_per_device * (i + 1),
                            :,
                            :,
                            :,
                        ]
                        clnt_y[
                            i,
                            n_cls_sample_per_device
                            * j : n_cls_sample_per_device
                            * (j + 1),
                            :,
                        ] = trn_y[
                            500 * j
                            + n_cls_sample_per_device * i : 500 * j
                            + n_cls_sample_per_device * (i + 1),
                            :,
                        ]

            elif self.rule == "iid":

                clnt_x = [
                    np.zeros(
                        (clnt_data_list[clnt__], self.channels, self.height, self.width)
                    ).astype(np.float32)
                    for clnt__ in range(self.n_client)
                ]
                clnt_y = [
                    np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64)
                    for clnt__ in range(self.n_client)
                ]

                clnt_data_list_cum_sum = np.concatenate(
                    ([0], np.cumsum(clnt_data_list))
                )
                for clnt_idx_ in range(self.n_client):
                    clnt_x[clnt_idx_] = trn_x[
                        clnt_data_list_cum_sum[clnt_idx_] : clnt_data_list_cum_sum[
                            clnt_idx_ + 1
                        ]
                    ]
                    clnt_y[clnt_idx_] = trn_y[
                        clnt_data_list_cum_sum[clnt_idx_] : clnt_data_list_cum_sum[
                            clnt_idx_ + 1
                        ]
                    ]

                clnt_x = np.asarray(clnt_x)
                clnt_y = np.asarray(clnt_y)

            self.clnt_x = clnt_x
            self.clnt_y = clnt_y

            self.tst_x = tst_x
            self.tst_y = tst_y

            # Save data
            os.mkdir("%s/%s" % (self.data_path, self.name))

            np.save("%s/%s/clnt_x.npy" % (self.data_path, self.name), clnt_x)
            np.save("%s/%s/clnt_y.npy" % (self.data_path, self.name), clnt_y)

            np.save("%s/%s/tst_x.npy" % (self.data_path, self.name), tst_x)
            np.save("%s/%s/tst_y.npy" % (self.data_path, self.name), tst_y)

        else:
            print("Data is already downloaded in the folder.")
            self.clnt_x = np.load(
                "%s/%s/clnt_x.npy" % (self.data_path, self.name), allow_pickle=True
            )
            self.clnt_y = np.load(
                "%s/%s/clnt_y.npy" % (self.data_path, self.name), allow_pickle=True
            )
            self.n_client = len(self.clnt_x)

            self.tst_x = np.load(
                "%s/%s/tst_x.npy" % (self.data_path, self.name), allow_pickle=True
            )
            self.tst_y = np.load(
                "%s/%s/tst_y.npy" % (self.data_path, self.name), allow_pickle=True
            )

            if self.dataset == "mnist":
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10
            elif self.dataset == "cifar10":
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 10
            elif self.dataset == "cifar100":
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 100
            elif self.dataset == "fashion_mnist":
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10
            elif self.dataset == "emnist":
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10

        print("Class frequencies:")
        count = 0
        for clnt in range(self.n_client):
            print(
                "Client %3d: " % clnt
                + ", ".join(
                    [
                        "%.3f" % np.mean(self.clnt_y[clnt] == cls)
                        for cls in range(self.n_cls)
                    ]
                )
                + ", Amount:%d" % self.clnt_y[clnt].shape[0]
            )
            count += self.clnt_y[clnt].shape[0]

        print("Total Amount:%d" % count)
        print("--------")

        print(
            "      Test: "
            + ", ".join(
                ["%.3f" % np.mean(self.tst_y == cls) for cls in range(self.n_cls)]
            )
            + ", Amount:%d" % self.tst_y.shape[0]
        )


def generate_syn_logistic(
    dimension,
    n_clnt,
    n_cls,
    avg_data=4,
    alpha=1.0,
    beta=0.0,
    theta=0.0,
    iid_sol=False,
    iid_dat=False,
):

    # alpha is for minimizer of each client
    # beta  is for distirbution of points
    # theta is for number of data points

    diagonal = np.zeros(dimension)
    for j in range(dimension):
        diagonal[j] = np.power((j + 1), -1.2)
    cov_x = np.diag(diagonal)

    samples_per_user = (
        np.random.lognormal(mean=np.log(avg_data + 1e-3), sigma=theta, size=n_clnt)
    ).astype(int)
    print("samples per user")
    print(samples_per_user)
    print("sum %d" % np.sum(samples_per_user))

    num_samples = np.sum(samples_per_user)

    data_x = list(range(n_clnt))
    data_y = list(range(n_clnt))

    mean_W = np.random.normal(0, alpha, n_clnt)
    B = np.random.normal(0, beta, n_clnt)

    mean_x = np.zeros((n_clnt, dimension))

    if not iid_dat:  # If IID then make all 0s.
        for i in range(n_clnt):
            mean_x[i] = np.random.normal(B[i], 1, dimension)

    sol_W = np.random.normal(mean_W[0], 1, (dimension, n_cls))
    sol_B = np.random.normal(mean_W[0], 1, (1, n_cls))

    if iid_sol:  # Then make vectors come from 0 mean distribution
        sol_W = np.random.normal(0, 1, (dimension, n_cls))
        sol_B = np.random.normal(0, 1, (1, n_cls))

    for i in range(n_clnt):
        if not iid_sol:
            sol_W = np.random.normal(mean_W[i], 1, (dimension, n_cls))
            sol_B = np.random.normal(mean_W[i], 1, (1, n_cls))

        data_x[i] = np.random.multivariate_normal(mean_x[i], cov_x, samples_per_user[i])
        data_y[i] = np.argmax((np.matmul(data_x[i], sol_W) + sol_B), axis=1).reshape(
            -1, 1
        )

    data_x = np.asarray(data_x)
    data_y = np.asarray(data_y)
    return data_x, data_y


class DatasetSynthetic:
    def __init__(
        self,
        alpha,
        beta,
        theta,
        iid_sol,
        iid_data,
        n_dim,
        n_clnt,
        n_cls,
        avg_data,
        name_prefix,
    ):
        self.dataset = "synt"
        self.name = name_prefix + "_"
        self.name += "%d_%d_%d_%d_%f_%f_%f_%s_%s" % (
            n_dim,
            n_clnt,
            n_cls,
            avg_data,
            alpha,
            beta,
            theta,
            iid_sol,
            iid_data,
        )

        data_path = "Data"
        if not os.path.exists("%s/%s/" % (data_path, self.name)):
            # Generate data
            print("Sythetize")
            data_x, data_y = generate_syn_logistic(
                dimension=n_dim,
                n_clnt=n_clnt,
                n_cls=n_cls,
                avg_data=avg_data,
                alpha=alpha,
                beta=beta,
                theta=theta,
                iid_sol=iid_sol,
                iid_dat=iid_data,
            )
            os.mkdir("%s/%s/" % (data_path, self.name))
            np.save("%s/%s/data_x.npy" % (data_path, self.name), data_x)
            np.save("%s/%s/data_y.npy" % (data_path, self.name), data_y)
        else:
            # Load data
            print("Load")
            data_x = np.load(
                "%s/%s/data_x.npy" % (data_path, self.name), allow_pickle=True
            )
            data_y = np.load(
                "%s/%s/data_y.npy" % (data_path, self.name), allow_pickle=True
            )

        for clnt in range(n_clnt):
            print(
                ", ".join(["%.4f" % np.mean(data_y[clnt] == t) for t in range(n_cls)])
            )

        self.clnt_x = data_x
        self.clnt_y = data_y

        self.tst_x = np.concatenate(self.clnt_x, axis=0)
        self.tst_y = np.concatenate(self.clnt_y, axis=0)
        self.n_client = len(data_x)
        print(self.clnt_x.shape)


class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_x, data_y=True, train=False, dataset_name=""):
        self.name = dataset_name
        if self.name == "mnist" or self.name == "synt" or self.name == "emnist":
            self.X_data = torch.tensor(data_x).float()
            self.y_data = data_y
            if not isinstance(data_y, bool):
                self.y_data = torch.tensor(data_y).float()

        elif self.name == "cifar10" or self.name == "cifar100":
            self.train = train
            self.transform = transforms.Compose([transforms.ToTensor()])

            self.X_data = data_x
            self.y_data = data_y
            if not isinstance(data_y, bool):
                self.y_data = data_y.astype("float32")

        elif self.name == "shakespeare":

            self.X_data = data_x
            self.y_data = data_y

            self.X_data = torch.tensor(self.X_data).long()
            if not isinstance(data_y, bool):
                self.y_data = torch.tensor(self.y_data).float()

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        if self.name == "mnist" or self.name == "synt" or self.name == "emnist":
            X = self.X_data[idx, :]
            if isinstance(self.y_data, bool):
                return X
            else:
                y = self.y_data[idx]
                return X, y

        elif self.name == "cifar10" or self.name == "cifar100":
            img = self.X_data[idx]
            if self.train:
                img = (
                    np.flip(img, axis=2).copy() if (np.random.rand() > 0.5) else img
                )  # Horizontal flip
                if np.random.rand() > 0.5:
                    # Random cropping
                    pad = 4
                    extended_img = np.zeros((3, 32 + pad * 2, 32 + pad * 2)).astype(
                        np.float32
                    )
                    extended_img[:, pad:-pad, pad:-pad] = img
                    dim_1, dim_2 = np.random.randint(pad * 2 + 1, size=2)
                    img = extended_img[:, dim_1 : dim_1 + 32, dim_2 : dim_2 + 32]
            img = np.moveaxis(img, 0, -1)
            img = self.transform(img)
            if isinstance(self.y_data, bool):
                return img
            else:
                y = self.y_data[idx]
                return img, y

        elif self.name == "shakespeare":
            x = self.X_data[idx]
            y = self.y_data[idx]
            return x, y


# Main

In [None]:
def run_experiment_for_pretrained(path_pt, tag, channel_type, args_override=None, malicious_ids=None, attack_params=None):
    """
    Runs one full FedDyn experiment using a pretrained checkpoint at path_pt.
    - tag: short string for filenames (e.g. 'feddyn_ch0')
    - channel_type: int (0,1,2) to control channel generation behavior
    - args_override: optional dict to override args_parser() values
    - malicious_ids: list of client indices that are adversarial
    - attack_params: dict passed to malicious clients (keys: poison_fraction, epsilon, poison_labels)
    Returns: saved filepath (from save_performance)
    """
    # config
    args = args_parser()
    if args_override:
        for k, v in args_override.items():
            setattr(args, k, v)
    args.channel_type = channel_type

    # Data
    data_obj = DatasetObject(dataset="cifar10", n_client=args.n_clients, rule=args.rule, unbalanced_sgm=0)
    client_x_all = data_obj.clnt_x
    client_y_all = data_obj.clnt_y
    cent_x = np.concatenate(client_x_all, axis=0)
    cent_y = np.concatenate(client_y_all, axis=0)

    # weights
    weight_list = np.asarray([len(client_y_all[i]) for i in range(args.n_clients)])
    if args.algorithm_name in ("FedDyn", "SCAFFOLD"):
        weight_list = weight_list / np.sum(weight_list) * args.n_clients

    # device & model factory
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_func = lambda: Model(args.model_name)

    # load pretrained checkpoint robustly
    print(f"\nLoading checkpoint {path_pt} for tag {tag} -> device {device}")
    ckpt = torch.load(path_pt, map_location=device)
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        sd = ckpt["state_dict"]
    elif isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        sd = ckpt["model_state_dict"]
    elif isinstance(ckpt, dict):
        sd = ckpt
    else:
        raise RuntimeError(f"Unsupported checkpoint format: {type(ckpt)} for {path_pt}")

    # strip DataParallel prefix if present
    sd_clean = {}
    for k, v in sd.items():
        new_k = k.replace("module.", "") if isinstance(k, str) and k.startswith("module.") else k
        sd_clean[new_k] = v

    init_model = model_func().to(device)
    try:
        init_model.load_state_dict(sd_clean)
        print("Loaded pretrained weights into init_model OK.")
    except Exception as e:
        print("Failed to load checkpoint into model:", e)
        # provide diagnostic
        sample_keys = list(sd.keys())[:50]
        print("Sample checkpoint keys:", sample_keys)
        raise

    init_par_list = get_model_params([init_model])[0]
    n_param = len(init_par_list)
    args.n_param = n_param

    np.random.seed(args.rand_seed)

    # Channel setup (use args defaults and some constants)
    fc = 915 * 10**6
    BS_Gain = 10 ** (5.0 / 10)
    RIS_Gain = 10 ** (5.0 / 10)
    User_Gain = 10 ** (0.0 / 10)
    dimen_RIS = 1.0 / 10
    BS = np.array([-50, 0, 10])
    RIS = np.array([0, 0, 10])
    x0 = np.ones([args.n_clients], dtype=int)

    channel = Channel(
        SNR=args.SNR,
        n_clients=args.n_clients,
        location_range=args.location_range,
        fc=fc,
        alpha_direct=args.alpha_direct,
        n_RIS_ele=args.n_RIS_ele,
        n_receive_ant=args.n_receive_ant,
        User_Gain=User_Gain,
        x0=x0,
        BS=BS,
        BS_Gain=BS_Gain,
        RIS=RIS,
        RIS_Gain=RIS_Gain,
        dimen_RIS=dimen_RIS,
    )

    gibbs = Gibbs(
        n_clients=args.n_clients,
        n_receive_ant=args.n_receive_ant,
        n_RIS_ele=args.n_RIS_ele,
        Jmax=args.Jmax,
        weight_list=weight_list,
        tau=args.tau,
        nit=args.nit,
        threshold=args.threshold,
    )

    air_comp = AirComp(
        n_receive_ant=args.n_receive_ant,
        weight_list=weight_list,
        transmit_power=args.transmit_power,
    )

    # Algorithm / Server setup (we use FedDyn per your choice)
    args.data_obj = data_obj
    args.air_comp = air_comp
    args.model_func = model_func
    args.init_model = init_model
    args.algorithm_name = "FedDyn"
    algorithm_factory = AlgorithmFactory(args)
    algorithm = algorithm_factory.create_algorithm(args.algorithm_name)

    # malicious setup default (if user didn't pass)
    if malicious_ids is None:
        malicious_ids = []  # default no attackers
    if attack_params is None:
        attack_params = {"poison_fraction": 0.0, "epsilon": 0.0, "poison_labels": None}

    # create clients list (each starts with init params)
    clients_list = np.array([
        Client(
            algorithm=algorithm,
            device=device,
            weight=float(weight_list[i]),
            train_data_X=client_x_all[i],
            train_data_Y=client_y_all[i],
            model=init_model,                     # template; local_train will re-init model
            client_param=np.copy(init_par_list),
            malicious=(i in malicious_ids),
            attack_params=attack_params if (i in malicious_ids) else None,
            client_id=i
        )
        for i in range(args.n_clients)
    ])

    # Save initial model for reproducibility
    out_init_path = f"Output/{data_obj.name}/{tag}_init_mdl.pt"
    os.makedirs(os.path.dirname(out_init_path), exist_ok=True)
    torch.save(init_model.state_dict(), out_init_path)
    print("Saved init model to", out_init_path)

    # FedDyn-specific init
    if args.algorithm_name == "FedDyn":
        local_param_list = np.zeros((args.n_clients, n_param)).astype("float32")
        cloud_model = model_func().to(device)
        cloud_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))
        cloud_model_param = get_model_params([cloud_model], n_param)[0]

    # performance buffers
    trn_perf_sel = np.zeros((args.comm_rounds, 2))
    trn_perf_all = np.zeros((args.comm_rounds, 2))
    tst_perf_sel = np.zeros((args.comm_rounds, 2))
    tst_perf_all = np.zeros((args.comm_rounds, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))
    all_model = model_func().to(device)
    all_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))
    server = Server(avg_model, all_model, device, algorithm)

    # COMMUNICATION ROUNDS
    for t in range(args.comm_rounds):
        print(f"\n=== Round {t+1}/{args.comm_rounds} (model={tag}) ===")

        # Channel/Gibbs selection
        if not args.noiseless:
            h_d, G, x, sigma = channel.generate()
            if args.channel_type == 0:
                pass
            elif args.channel_type == 1:
                h_d = np.zeros_like(h_d)
            elif args.channel_type == 2:
                h_d, G, x, sigma = generate_rician_single_ris(args, channel_seed=args.rand_seed, round_idx=t)
            else:
                raise ValueError("Unsupported channel type")

            start = time.time()
            x_optim, f_optim, h_optim = gibbs.optimize(h_d, G, x, sigma)
            end = time.time()
            print("Gibbs time:", end - start, "s")
        else:
            # noiseless: random selection with act_prob
            inc_seed = 0
            x_optim = np.array([0])
            while np.sum(x_optim) == 0:
                np.random.seed(t + args.rand_seed + inc_seed)
                active_clients = np.random.uniform(size=args.n_clients)
                x_optim = (active_clients <= args.act_prob).astype(np.int8)
                inc_seed += 1
            x_optim = x_optim.astype(np.int8)

        selected_clnts_idx = np.where(x_optim == 1)[0]
        selected_clnts = clients_list[selected_clnts_idx]
        print("Selected indices:", selected_clnts_idx)

        # Clients local training
        for i, client in enumerate(selected_clnts):
            inputs = {"curr_round": t, "avg_model": server.avg_model}
            if args.algorithm_name == "FedDyn":
                inputs["cloud_model"] = cloud_model
                inputs["cloud_model_param"] = cloud_model_param
                inputs["cloud_model_param_tensor"] = torch.tensor(cloud_model_param, dtype=torch.float32, device=device)
                inputs["local_param"] = local_param_list[selected_clnts_idx[i]]

            # pass attack params only for malicious clients (you can add schedule logic here)
            if client.malicious:
                inputs["attack_params"] = client.attack_params

            print(f" -> Training client {client.client_id} (malicious={client.malicious})")
            client.local_train(inputs)

            # FedDyn update of local_param
            if args.algorithm_name == "FedDyn":
                local_param_list[selected_clnts_idx[i]] = inputs["local_param"]

        # Aggregation via AirComp or noiseless averaging
        inputs_agg = {"clients_list": clients_list, "selected_clnts_idx": selected_clnts_idx, "weight_list": weight_list}
        if args.algorithm_name == "FedDyn":
            inputs_agg["local_param_list"] = local_param_list
            inputs_agg["cloud_model"] = cloud_model
            inputs_agg["cloud_model_param"] = cloud_model_param

        if not args.noiseless:
            # gather per-client flattened params and call air_comp.transmit (which expects shape info as n_param)
            clients_param_list = np.array([client.client_param for client in clients_list])
            inputs_agg["avg_mdl_param"] = air_comp.transmit(
                n_param,
                clients_param_list[selected_clnts_idx],
                x_optim,
                f_optim,
                h_optim,
                sigma,
            )
        else:
            # noiseless: weighted average of selected clients
            clients_param_list = np.array([client.client_param for client in clients_list])
            sel = selected_clnts_idx
            weights_sel = weight_list[sel].reshape((-1,1))
            inputs_agg["avg_mdl_param"] = np.sum(clients_param_list[sel] * weights_sel / np.sum(weights_sel), axis=0)

        # Server aggregates (FedDyn.update logic inside algorithm)
        server.aggregate(inputs_agg)

        if args.algorithm_name == "FedDyn":
            cloud_model = inputs_agg["cloud_model"]
            cloud_model_param = inputs_agg["cloud_model_param"]

        # Evaluate & log
        evaluate_performance(
            cent_x, cent_y,
            data_obj.tst_x, data_obj.tst_y,
            data_obj.dataset,
            server.avg_model, server.all_model, device,
            tst_perf_sel, trn_perf_sel, tst_perf_all, trn_perf_all, t
        )

    # Save results for this run
    run_id = time.strftime("%Y%m%d-%H%M%S")
    outpath = save_performance(
        args.comm_rounds,
        tst_perf_all,
        algorithm.name + f"_{tag}",
        data_obj.name,
        args.model_name,
        args.n_clients,
        args.noiseless,
        args.rule,
        channel_type=args.channel_type,
        run_id=run_id,
        out_dir="Output/experiments"
    )
    print("Saved results to", outpath)
    return outpath


# ---------- MAIN wrapper: specify your two pretrained FedDyn model files + channel types ----------
def main(channel_type_override=None):
    kaggle_dir = "/kaggle/input/fl-edge-ai-feddyn-cifar10-iid-3channels-base/data"

    # Update these to the exact filenames you have in the dataset.
    # Each entry: tag, path, and the channel_type the model was trained with (0,1,2)
    model_runs = [
        {"tag": "feddyn_ch0", "path": os.path.join(kaggle_dir, "feddyn_ch0.pt"), "channel_type": 0},
        {"tag": "feddyn_ch2", "path": os.path.join(kaggle_dir, "feddyn_ch2.pt"), "channel_type": 2},
    ]

    # Example attacker config (same for both runs here); change per-run by extending model_runs entries
    default_malicious_ids = [2, 7, 18]   # indices of malicious clients
    default_attack_params = {"poison_fraction": 0.3, "epsilon": 4.0 / 255.0, "poison_labels": None}

    saved_paths = []
    for entry in model_runs:
        tag = entry["tag"]
        path = entry["path"]
        ch_type = entry["channel_type"] if channel_type_override is None else channel_type_override
        if not os.path.exists(path):
            raise FileNotFoundError(f"Model file not found: {path}")
        print(f"\nRunning model {tag} (channel_type={ch_type}) from file: {path}")
        out = run_experiment_for_pretrained(
            path_pt=path,
            tag=tag,
            channel_type=ch_type,
            args_override=None,
            malicious_ids=default_malicious_ids,
            attack_params=default_attack_params,
        )
        saved_paths.append((tag, out))

    print("\nAll runs finished. Saved files:")
    for tag, p in saved_paths:
        print(tag, "->", p)
    return saved_paths


# Run when this cell executes
if __name__ == "__main__":
    saved = main(channel_type_override=None)