# PrivateFL Implementation in SecretFlow

This notebook implements the PrivateFL method based on the paper [PRIVATEFL: Accurate, Differentially Private Federated Learningvia Personalized Data Transformation]
by [Yuchen Yang∗, Bo Hui∗, Haolin Yuan∗, Neil Gong†, and Yinzhi Cao The Johns Hopkins University, †Duke University] ([https://www.usenix.org/system/files/sec23fall-prepub-427-yang-yuchen.pdf]).

The implementation has been modified and adapted to work with the SecretFlow framework
for demonstration and educational purposes.

# PrivateFL: 基于个性化数据转换的精确差分隐私联邦学习

PrivateFL是一种新的差分隐私联邦学习方法，旨在通过个性化数据转换来提高模型精度。其核心思想包括：

1. 观察到差分隐私（DP）会在联邦学习中引入额外的客户端异质性，从而降低模型精度。
2. 为每个客户端学习一个差分隐私的个性化数据转换，以减少DP引入的异质性。
3. 数据转换与本地模型同时学习，优化以最小化学习损失并最大化本地客户端模型效用。
4. 可与现有的个性化联邦学习方法和DP效用改进方法结合，进一步提高精度。

本实现基于SecretFlow框架，展示了PrivateFL在中央差分隐私（CDP）和本地差分隐私（LDP）设置下的应用。

## 1.modelUtil

In [1]:
import torch
from torchvision.models import alexnet, resnet18
from torch.nn.functional import relu, softmax, max_pool2d
from torch.nn.utils import spectral_norm
from torch import nn, tanh
import copy
from opacus.grad_sample import register_grad_sampler
from typing import Dict
import torchvision
from collections import OrderedDict
from numpy import median
import numpy as np
import torch.nn.functional as func


def agg_weights(weights):
    with torch.no_grad():
        weights_avg = copy.deepcopy(weights[0])
        for k in weights_avg.keys():
            for i in range(1, len(weights)):
                weights_avg[k] += weights[i][k]
            weights_avg[k] = torch.div(weights_avg[k], len(weights))
    return weights_avg


def evaluate_global(users, test_dataloders, users_index):

    testing_corrects = 0
    testing_sum = 0
    for index in users_index:
        result = users[index].evaluate(test_dataloders[index])
        corrects, total = sf.reveal(result)
        testing_corrects += corrects
        testing_sum += total
    # 计算并返回全局准确率
    if testing_sum > 0:
        acc = testing_corrects / testing_sum
        print(f"全局准确率: {acc:.4f}")
        return acc
    else:
        print("没有评估任何样本")
        return 0


# 个性化数据转换类
class InputNorm(nn.Module):
    def __init__(self, num_channel, num_feature):
        super().__init__()
        self.num_channel = num_channel
        self.gamma = nn.Parameter(torch.ones(num_channel))
        self.beta = nn.Parameter(torch.zeros(num_channel, num_feature, num_feature))

    def forward(self, x):
        if self.num_channel == 1:
            x = self.gamma * x
            x = x + self.beta
            return x
        if self.num_channel == 3:
            return torch.einsum("...ijk, i->...ijk", x, self.gamma) + self.beta


class resnet18(torch.nn.Module):
    """Constructs a ResNet-18 model."""

    def __init__(self, num_classes):
        super().__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        n_ftrs = self.backbone.fc.in_features
        self.backbone.fc = torch.nn.Linear(n_ftrs, num_classes)

    def forward(self, x):
        logits = self.backbone(x)
        return logits, softmax(logits, dim=-1)


class resnet18_IN(torch.nn.Module):
    """Constructs a ResNet-18wIN model."""

    def __init__(self, num_classes):
        super().__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        n_ftrs = self.backbone.fc.in_features
        self.backbone.fc = torch.nn.Linear(n_ftrs, num_classes)
        if num_classes == 8:
            self.norm = InputNorm(3, 150)
        else:
            self.norm = InputNorm(3, 120)

    def forward(self, x):
        x = self.norm(x)
        logits = self.backbone(x)
        return logits, softmax(logits, dim=-1)


class alexnet(torch.nn.Module):
    """Constructs a alexnet model."""

    def __init__(self, num_classes):
        super().__init__()
        self.backbone = torchvision.models.alexnet(pretrained=True)
        n_ftrs = self.backbone.classifier[-1].out_features
        self.fc = torch.nn.Linear(n_ftrs, num_classes)

    def forward(self, x):
        logits = self.backbone(x)
        logits = self.fc(logits)
        return logits, softmax(logits, dim=-1)


class alexnet_IN(torch.nn.Module):
    """Constructs a alexnet w IN model."""

    def __init__(self, num_classes):
        super().__init__()
        self.backbone = torchvision.models.alexnet(pretrained=True)
        n_ftrs = self.backbone.classifier[-1].out_features
        self.fc = torch.nn.Linear(n_ftrs, num_classes)
        self.norm = InputNorm(3, 150)

    def forward(self, x):
        x = self.norm(x)
        logits = self.backbone(x)
        logits = self.fc(logits)
        return logits, softmax(logits, dim=-1)


class mnist_fully_connected_IN(nn.Module):
    def __init__(self, num_classes):
        super(mnist_fully_connected_IN, self).__init__()
        self.hidden1 = 600
        self.hidden2 = 100
        self.fc1 = nn.Linear(28 * 28, self.hidden1, bias=False)
        self.fc2 = nn.Linear(self.hidden1, self.hidden2, bias=False)
        self.fc3 = nn.Linear(self.hidden2, num_classes, bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.norm = InputNorm(1, 28)

    def forward(self, x):
        x = self.norm(x)
        x = x.view(-1, 28 * 28)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        logits = self.fc3(x)
        return logits, softmax(logits, dim=1)


class mnist_fully_connected(nn.Module):
    def __init__(self, num_classes):
        super(mnist_fully_connected, self).__init__()
        self.hidden1 = 600
        self.hidden2 = 100
        self.fc1 = nn.Linear(28 * 28, self.hidden1, bias=False)
        self.fc2 = nn.Linear(self.hidden1, self.hidden2, bias=False)
        self.fc3 = nn.Linear(self.hidden2, num_classes, bias=False)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        logits = self.fc3(x)
        return logits, softmax(logits, dim=1)


class purchase_fully_connected(nn.Module):
    def __init__(self, num_classes):
        super(purchase_fully_connected, self).__init__()
        self.fc1 = nn.Linear(600, 512, bias=False)
        self.fc2 = nn.Linear(512, 256, bias=False)
        self.fc3 = nn.Linear(256, 128, bias=False)
        self.fc4 = nn.Linear(128, num_classes, bias=False)

    def forward(self, x):
        x = tanh(self.fc1(x))
        x = tanh(self.fc2(x))
        x = tanh(self.fc3(x))
        logits = self.fc4(x)
        return logits, softmax(logits, dim=1)


class purchase_fully_connected_IN(nn.Module):
    def __init__(self, num_classes):
        super(purchase_fully_connected_IN, self).__init__()
        self.fc1 = nn.Linear(600, 512, bias=False)
        self.fc2 = nn.Linear(512, 256, bias=False)
        self.fc3 = nn.Linear(256, 128, bias=False)
        self.fc4 = nn.Linear(128, num_classes, bias=False)
        self.norm = FeatureNorm(600)

    def forward(self, x):
        x = self.norm(x)
        x = tanh(self.fc1(x))
        x = tanh(self.fc2(x))
        x = tanh(self.fc3(x))
        logits = self.fc4(x)
        return logits, softmax(logits, dim=1)


class linear_model(nn.Module):
    def __init__(self, num_classes, input_shape=512):
        super(linear_model, self).__init__()
        self.fc1 = nn.Linear(input_shape, num_classes, bias=True)

    def forward(self, x):
        logits = self.fc1(x)
        return logits, softmax(logits, dim=1)


def standardize(x, bn_stats):
    if bn_stats is None:
        return x

    bn_mean, bn_var = bn_stats
    bn_mean, bn_var = bn_mean.to(x.device), bn_var.to(x.device)
    view = [1] * len(x.shape)
    view[1] = -1
    x = (x - bn_mean.reshape(view)) / torch.sqrt(bn_var.reshape(view) + 1e-5)

    # if variance is too low, just ignore
    x *= bn_var.reshape(view) != 0
    return x


class linear_model_DN(nn.Module):
    def __init__(self, num_classes, input_shape=512, bn_stats=False):
        super(linear_model_DN, self).__init__()
        if not bn_stats:
            self.bn_stats = (torch.zeros(input_shape), torch.ones(input_shape))
        else:
            mean = np.load("transfer/cifar100_resnext_mean.npy")
            var = np.load("transfer/cifar100_resnext_var.npy")
            self.bn_stats = (torch.from_numpy(mean), torch.from_numpy(var))
        self.fc1 = nn.Linear(input_shape, num_classes, bias=True)

    def forward(self, x):
        x = standardize(x, self.bn_stats)
        logits = self.fc1(x)
        return logits, softmax(logits, dim=1)


class FeatureNorm(nn.Module):
    def __init__(self, feature_shape):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1, feature_shape))

    def forward(self, x):
        x = torch.einsum("ni, j->ni", x, self.gamma)
        x = x + self.beta
        return x


@register_grad_sampler(FeatureNorm)
def compute_grad_sample(
    layer: InputNorm, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    gs = torch.einsum("nk,nk->n", backprops, activations)
    ret = {layer.gamma: gs}
    if layer.beta is not None:
        ret[layer.beta] = torch.einsum("n...i->ni", backprops)

    return ret


class linear_model_DN_IN(nn.Module):
    def __init__(self, num_classes, input_shape, bn_stats=False):
        super(linear_model_DN_IN, self).__init__()
        if not bn_stats:
            self.bn_stats = (torch.zeros(input_shape), torch.ones(input_shape))
        else:
            mean = np.load("cifar100_resnext_mean.npy")
            var = np.load("cifar100_resnext_mean.npy")
            self.bn_stats = (torch.from_numpy(mean), torch.from_numpy(var))
        self.backbone = nn.Linear(input_shape, num_classes, bias=True)
        self.norm = FeatureNorm(input_shape)

    def forward(self, x):
        x = self.norm(x)
        x = standardize(x, self.bn_stats)
        logits = self.backbone(x)
        return logits, softmax(logits, dim=1)


@register_grad_sampler(InputNorm)
def compute_grad_sample(
    layer: InputNorm, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    gs = torch.einsum("nk...,nk...->nk", backprops, activations)
    ret = {layer.gamma: gs}
    if layer.beta is not None:
        ret[layer.beta] = torch.einsum("nijk->nijk", backprops)

    return ret

## 2.dataloader

In [2]:
import os
import torch.utils.data
from PIL import Image
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np


class CHMNIST(torch.utils.data.Dataset):
    def __init__(self, root="data/CHMNIST", train=True, download=True, transform=None):
        self.images = []
        self.root = root
        self.targets = []
        self.train = train
        self.download = download
        self.transform = transform

        x_train, x_test, y_train, y_test = self._train_test_split()

        if self.train:
            self._setup_dataset(x_train, y_train)
        else:
            self._setup_dataset(x_test, y_test)

    def _train_test_split(self):
        img_names = []
        img_label = []
        for i, folder_name in enumerate(os.listdir(self.root)):

            for j, img_name in enumerate(os.listdir(self.root + "/" + folder_name)):
                img_names.append(os.path.join(self.root + "/", folder_name, img_name))
                img_label.append(int(folder_name[0:2]) - 1)

        x_train, x_test, y_train, y_test = train_test_split(
            img_names, img_label, train_size=0.9, random_state=1
        )

        return x_train, x_test, y_train, y_test

    def _setup_dataset(self, x, y):
        self.images = x
        self.targets = y

    def __getitem__(self, item):
        img_fn = self.images[item]
        label = self.targets[item]
        img = Image.open(img_fn)
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.images)


class Purchase(torch.utils.data.Dataset):
    def __init__(
        self,
        root="data/purchase/dataset_purchase",
        train=True,
        download=True,
        transform=None,
    ):
        self.images = []
        self.root = root
        self.targets = []
        self.train = train
        self.download = download
        self.transform = transform

        x_train, x_test, y_train, y_test = self._train_test_split()

        if self.train:
            self._setup_dataset(x_train, y_train)
        else:
            self._setup_dataset(x_test, y_test)

    def _train_test_split(self):
        df = pd.read_csv(self.root)

        img_names = df.iloc[:, 1:].to_numpy(dtype="f")
        img_label = df.iloc[:, 0].to_numpy() - 1
        x_train, x_test, y_train, y_test = train_test_split(
            img_names, img_label, train_size=0.8, random_state=1
        )

        return x_train, x_test, y_train, y_test

    def _setup_dataset(self, x, y):
        self.images = x
        self.targets = y

    def __getitem__(self, item):
        img = self.images[item]
        label = self.targets[item]
        return img, label

## 3.datasets

In [3]:
import random
from collections import defaultdict
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, CIFAR100, EMNIST

np.random.seed(2022)


def get_datasets(data_name, dataroot, preprocess=None):
    """
    get_datasets returns train/val/test data splits of CIFAR10/100 datasets
    :param data_name: name of dataset, choose from [cifar10, cifar100]
    :param dataroot: root to data dir
    :param normalize: True/False to normalize the data
    :param val_size: validation split size (in #samples)
    :return: train_set, val_set, test_set (tuple of pytorch dataset/subset)
    """

    if data_name == "cifar10":
        normalization = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        transform = (
            transforms.Compose(
                [transforms.ToTensor(), transforms.Resize(120), normalization]
            )
            if preprocess == None
            else preprocess
        )

        data_obj = CIFAR10
    elif data_name == "cifar100":
        normalization = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        transform = (
            transforms.Compose(
                [transforms.ToTensor(), transforms.Resize(224), normalization]
            )
            if preprocess == None
            else preprocess
        )

        data_obj = CIFAR100
    elif data_name == "mnist":
        normalization = transforms.Normalize((0.5,), (0.5,))
        transform = transforms.Compose([transforms.ToTensor(), normalization])
        data_obj = MNIST
    elif data_name == "fashionmnist":
        normalization = transforms.Normalize((0.5,), (0.5,))
        transform = transforms.Compose([transforms.ToTensor(), normalization])
        data_obj = FashionMNIST
    elif data_name == "emnist":
        normalization = transforms.Normalize((0.5,), (0.5,))
        transform = transforms.Compose([transforms.ToTensor(), normalization])
        data_obj = EMNIST
    elif data_name == "purchase":
        transform = transforms.Compose([transforms.ToTensor()])
        data_obj = Purchase
    elif data_name == "chmnist":
        normalization = transforms.Normalize(
            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        )
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((150, 150)), normalization]
        )
        data_obj = CHMNIST
    else:
        raise ValueError(
            "choose data_name from ['mnist', 'cifar10', 'cifar100', 'fashionmnist', 'emnist, 'purchase', 'chmnist']"
        )

    if data_name == "emnist":
        train_set = data_obj(
            dataroot, train=True, transform=transform, split="digits", download=True
        )

        test_set = data_obj(dataroot, train=False, split="digits", transform=transform)

    else:
        train_set = data_obj(dataroot, train=True, transform=transform, download=True)

        test_set = data_obj(dataroot, train=False, transform=transform)

    return train_set, test_set


def get_num_classes_samples(dataset):
    """
    extracts info about certain dataset
    :param dataset: pytorch dataset object
    :return: dataset info number of classes, number of samples, list of labels
    """
    # ---------------#
    # Extract labels #
    # ---------------#
    if isinstance(dataset, torch.utils.data.Subset):
        if isinstance(dataset.dataset.targets, list):
            data_labels_list = np.array(dataset.dataset.targets)[dataset.indices]
        else:
            data_labels_list = dataset.dataset.targets[dataset.indices]
    else:
        if isinstance(dataset.targets, list):
            data_labels_list = np.array(dataset.targets)
        else:
            data_labels_list = dataset.targets
    classes, num_samples = np.unique(data_labels_list, return_counts=True)
    num_classes = len(classes)
    return num_classes, num_samples, data_labels_list


def gen_classes_per_node(
    dataset, num_users, classes_per_user=2, high_prob=0.6, low_prob=0.4
):
    """
    creates the data distribution of each client
    :param dataset: pytorch dataset object
    :param num_users: number of clients
    :param classes_per_user: number of classes assigned to each client
    :param high_prob: highest prob sampled
    :param low_prob: lowest prob sampled
    :return: dictionary mapping between classes and proportions, each entry refers to other client
    """
    num_classes, num_samples, _ = get_num_classes_samples(dataset)

    # -------------------------------------------#
    # Divide classes + num samples for each user #
    # -------------------------------------------#
    # print(num_classes)
    assert (
        classes_per_user * num_users
    ) % num_classes == 0, "equal classes appearance is needed"
    count_per_class = (classes_per_user * num_users) // num_classes
    class_dict = {}
    for i in range(num_classes):
        probs = np.array([1] * count_per_class)
        probs_norm = (probs / probs.sum()).tolist()
        class_dict[i] = {"count": count_per_class, "prob": probs_norm}
    # -------------------------------------#
    # Assign each client with data indexes #
    # -------------------------------------#
    class_partitions = defaultdict(list)
    for i in range(num_users):
        c = []
        for _ in range(classes_per_user):
            class_counts = [class_dict[i]["count"] for i in range(num_classes)]
            max_class_counts = np.where(np.array(class_counts) == max(class_counts))[0]
            max_class_counts = np.setdiff1d(max_class_counts, np.array(c))
            c.append(np.random.choice(max_class_counts))
            class_dict[c[-1]]["count"] -= 1
        class_partitions["class"].append(c)
        class_partitions["prob"].append([class_dict[i]["prob"].pop() for i in c])
    return class_partitions


def gen_data_split(dataset, num_users, class_partitions):
    """
    divide data indexes for each client based on class_partition
    :param dataset: pytorch dataset object (train/val/test)
    :param num_users: number of clients
    :param class_partitions: proportion of classes per client
    :return: dictionary mapping client to its indexes
    """
    num_classes, num_samples, data_labels_list = get_num_classes_samples(dataset)

    # -------------------------- #
    # Create class index mapping #
    # -------------------------- #
    data_class_idx = {i: np.where(data_labels_list == i)[0] for i in range(num_classes)}

    # --------- #
    # Shuffling #
    # --------- #
    for data_idx in data_class_idx.values():
        random.shuffle(data_idx)

    # ------------------------------ #
    # Assigning samples to each user #
    # ------------------------------ #
    user_data_idx = [[] for i in range(num_users)]
    for usr_i in range(num_users):
        for c, p in zip(
            class_partitions["class"][usr_i], class_partitions["prob"][usr_i]
        ):
            end_idx = int(num_samples[c] * p)
            user_data_idx[usr_i].extend(data_class_idx[c][:end_idx])
            data_class_idx[c] = data_class_idx[c][end_idx:]
        if len(user_data_idx[usr_i]) % 2 == 1:
            user_data_idx[usr_i] = user_data_idx[usr_i][:-1]

    return user_data_idx


def gen_classes_id(num_users=10, num_classes_per_user=2, classes=10):
    class_partitions = defaultdict(list)
    class_counts = [list(range(classes)) for _ in range(num_classes_per_user)]
    user_data_classes = []
    for user in range(num_users):
        classes_user = np.random.choice(class_counts[0], size=1)
        class_counts[0].remove(classes_user[0])
        tmp = class_counts[1].copy()
        if classes_user[0] in tmp:
            tmp.remove(classes_user[0])
        if tmp is None:
            tmp = [user_data_classes[-1][0]]
            user_data_classes[-1][0] = classes_user[0]
        classes_user = np.append(classes_user, np.random.choice(tmp, size=1))
        class_counts[1].remove(classes_user[1])
        user_data_classes.append(classes_user)
    for c in user_data_classes:
        class_partitions["class"].append(c)
        class_partitions["prob"].append([0.5, 0.5])
    return class_partitions


def gen_classes(num_users=10, num_classes_per_user=6, classes=10):
    class_partitions = defaultdict(list)
    class_counts = [list(range(classes)) for _ in range(num_classes_per_user)]
    user_data_classes = []
    for user in range(num_users):
        user_data_classes.append(
            np.array([*range(user, user + num_classes_per_user)]) % 10
        )
    for c in user_data_classes:
        class_partitions["class"].append(c)
        class_partitions["prob"].append(
            [1 / num_classes_per_user] * num_classes_per_user
        )
    return class_partitions


def gen_random_loaders(
    data_name,
    data_path,
    num_users,
    bz,
    num_classes_per_user,
    num_classes,
    preprocess=None,
):
    """
    generates train/val/test loaders of each client
    :param data_name: name of dataset, choose from [cifar10, cifar100]
    :param data_path: root path for data dir
    :param num_users: number of clients
    :param bz: batch size
    :param classes_per_user: number of classes assigned to each client
    :return: train/val/test loaders of each client, list of pytorch dataloaders
    """
    loader_params = {
        "batch_size": bz,
        "shuffle": False,
        "pin_memory": True,
        "num_workers": 0,
    }
    dataloaders = []
    datasets = get_datasets(data_name, data_path, preprocess=preprocess)
    cls_partitions = None
    distribution = np.zeros((num_users, num_classes))
    for i, d in enumerate(datasets):
        if i == 0:
            cls_partitions = gen_classes_per_node(d, num_users, num_classes_per_user)
            print("\n每个客户端的类别分布:")
            for index in range(num_users):
                print(f"客户端 {index + 1}:")
                for class_idx, prob in zip(
                    cls_partitions["class"][index], cls_partitions["prob"][index]
                ):
                    print(f"  类别 {class_idx}: 概率 {prob:.4f}")
                distribution[index][cls_partitions["class"][index]] = cls_partitions[
                    "prob"
                ][index]

            loader_params["shuffle"] = True
        usr_subset_idx = gen_data_split(d, num_users, cls_partitions)

        subsets = list(map(lambda x: torch.utils.data.Subset(d, x), usr_subset_idx))
        dataloaders.append(
            list(
                map(lambda x: torch.utils.data.DataLoader(x, **loader_params), subsets)
            )
        )

    return dataloaders

## 4.FedUser

CDPUser 和 LDPUser

这两个类代表了PrivateFL中的客户端。它们的主要特点是：

- 包含个性化数据转换层（InputNorm），这是PrivateFL的核心创新。
- 在训练过程中同时优化数据转换和模型参数。

In [4]:
import logging
from secretflow import PYUObject, proxy

from collections import OrderedDict
import torchmetrics
import opacus
from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
import numpy as np
import time

# from modelUtil import *


@proxy(PYUObject)
class CDPUser:
    def __init__(
        self,
        index,
        device,
        model,
        input_shape,
        n_classes,
        train_dataloader,
        epochs,
        max_norm=1.0,
        disc_lr=5e-3,
        flr=1e-1,
    ):
        print(f"初始化 CDPUser 参数: index={index}, device={device}, model={model}")

        self.index = index
        self.device = device

        model_name = model.__name__ if isinstance(model, type) else model
        if "linear_model" in model_name:
            if input_shape == 1024:
                self.model = model(
                    num_classes=n_classes, input_shape=input_shape, bn_stats=True
                )
            else:
                self.model = model(
                    num_classes=n_classes, input_shape=input_shape, bn_stats=False
                )
        else:
            self.model = model(num_classes=n_classes)
        self.train_dataloader = train_dataloader
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.disc_lr = disc_lr
        self.acc_metric = torchmetrics.Accuracy(
            task="multiclass", num_classes=n_classes
        )  # ************to(self.device)
        self.max_norm = max_norm
        self.epochs = epochs
        self.flr = flr
        self.agg = True
        if "IN" in model_name:
            self.optim = torch.optim.SGD(
                [  # 转换层（self.model.norm）的参数使用了不同的学习率（self.flr），这允许个性化转换层有不同于模型其他部分的优化策略。
                    {"params": self.model.norm.parameters(), "lr": self.flr},
                    {
                        "params": [
                            v
                            for k, v in self.model.named_parameters()
                            if "norm" not in k
                        ]
                    },
                ],
                lr=self.disc_lr,
            )
            self.agg = False
        else:
            self.optim = torch.optim.SGD(self.model.parameters(), self.disc_lr)

    def train(self):
        self.model.train()
        loading = []
        for epoch in range(self.epochs):
            losses = []
            for images, labels in self.train_dataloader:
                images, labels = images, labels
                loading.append(self.optim.zero_grad())
                logits, preds = self.model(images)
                loss = self.loss_fn(logits, labels)
                loading.append(loss.backward())
                loading.append(self.optim.step())
                loading.append(self.acc_metric(preds, labels))
                losses.append(loss.item())
            sf.wait(loading)
            logging.info(
                f"Client: {self.index} ACC: {self.acc_metric.compute()}, Loss:{np.mean(losses)}"
            )
            self.acc_metric.reset()

    def evaluate(self, dataloader):
        logging.warning(f"Client {self.index} start evaluating")
        self.model.eval()
        testing_corrects = 0
        testing_sum = 0
        with torch.no_grad():
            for images, labels in dataloader:
                _, preds = self.model(images)
                testing_corrects += torch.sum(torch.argmax(preds, dim=1) == labels)
                testing_sum += len(labels)
        return testing_corrects.cpu().detach().numpy(), testing_sum

    def get_model_state_dict(self):
        return self.model.state_dict()

    def set_model_state_dict(self, weights):
        if self.agg == False:
            for key, value in self.model.state_dict().items():
                if "norm" not in key and "bn" not in key and "downsample.1" not in key:
                    self.model.state_dict()[key].data.copy_(weights[key])
        else:
            for key, value in self.model.state_dict().items():
                if "bn" not in key:
                    self.model.state_dict()[key].data.copy_(weights[key])


@proxy(PYUObject)
class LDPUser(CDPUser):
    def __init__(
        self,
        index,
        device,
        model,
        n_classes,
        input_shape,
        train_dataloader,
        epochs,
        rounds,
        target_epsilon,
        target_delta,
        sr,
        max_norm=2.0,
        disc_lr=5e-1,
        mp_bs=3,
    ):
        super().__init__(
            index,
            device,
            model,
            n_classes,
            input_shape,
            train_dataloader,
            epochs=epochs,
            max_norm=max_norm,
            disc_lr=disc_lr,
        )
        self.rounds = rounds
        self.target_epsilon = target_epsilon
        self.epsilon = 0
        self.delta = target_delta
        self.model = ModuleValidator.fix(self.model)
        self.optim = torch.optim.SGD(self.model.parameters(), self.disc_lr)
        self.sr = sr
        self.make_local_private()
        self.agg = True
        self.mp_bs = mp_bs

        model_name = model.__name__ if isinstance(model, type) else model
        if "IN" in model_name:
            self.agg = False

    def make_local_private(self):
        self.privacy_engine = opacus.PrivacyEngine()
        self.model, self.optim, self.train_dataloader = (
            self.privacy_engine.make_private_with_epsilon(
                module=self.model,
                optimizer=self.optim,
                data_loader=self.train_dataloader,
                epochs=self.epochs * self.rounds * self.sr,
                target_epsilon=self.target_epsilon,
                target_delta=self.delta,
                max_grad_norm=self.max_norm,
            )
        )

    def train(self):
        self.model.train()
        loading = []
        for epoch in range(self.epochs):
            with BatchMemoryManager(
                data_loader=self.train_dataloader,
                max_physical_batch_size=self.mp_bs,
                optimizer=self.optim,
            ) as batch_loader:
                for images, labels in batch_loader:
                    images, labels = images, labels
                    loading.append(self.optim.zero_grad())
                    logits, preds = self.model(images)
                    loss = self.loss_fn(logits, labels)
                    loading.append(loss.backward())
                    loading.append(self.optim.step())
                    loading.append(self.acc_metric(preds, labels))
        sf.wait(loading)
        self.epsilon = self.privacy_engine.get_epsilon(self.delta)
        logging.info(
            f"Client: {self.index} ACC: {self.acc_metric.compute()}, episilon: {self.epsilon}"
        )
        self.acc_metric.reset()

## 5.FedServer

In [5]:
from collections import OrderedDict
import opacus
from opacus.validators import ModuleValidator


@proxy(PYUObject)
class CDPServer:
    def __init__(
        self,
        device,
        model,
        input_shape,
        n_classes,
        noise_multiplier=1,
        sample_clients=10,
        disc_lr=1,
    ):
        print(f"初始化 CDPServer 参数: device={device}, model={model}")
        model_name = model.__name__ if isinstance(model, type) else model
        if "linear_model" in model_name:
            self.model = model(num_classes=n_classes, input_shape=input_shape)
        else:
            self.model = model(num_classes=n_classes)
        self.disc_lr = disc_lr
        self.device = device
        self.sample_clients = sample_clients
        self.noise_multiplier = noise_multiplier
        self.trainable_names = [k for k, _ in self.model.named_parameters()]
        self.agg = True
        if "IN" in model_name:
            self.agg = False

    def get_median_norm(self, weights):
        logging.warning("Calculating median norm")
        median_norm = OrderedDict()
        for k, v in self.model.named_parameters():
            norms = []
            for i in range(len(weights)):
                grad = v.detach() - weights[i][k]
                norms.append(grad.norm(2))
            median_norm[k] = min(median(norms), 10)
        return median_norm

    def get_model_state_dict(self):
        return self.model.state_dict()

    def agg_updates(self, weights):
        logging.warning("CDP Server Aggregating updates")
        with torch.no_grad():
            norms = self.get_median_norm(weights)
            if self.agg == False:
                for k, v in self.get_model_state_dict().items():
                    if "bn" not in k and "norm" not in k and "downsample.1" not in k:
                        sumed_grad = torch.zeros_like(v)
                        for i in range(len(weights)):
                            grad = weights[i][k] - v
                            grad = grad * min(1, norms[k] / grad.norm(2))
                            sumed_grad += grad
                        sigma = norms[k] * self.noise_multiplier
                        sumed_grad += torch.normal(0, sigma, v.shape)
                        value = v + sumed_grad / self.sample_clients
                        self.model.state_dict()[k].data.copy_(value.detach().clone())
            else:
                for k, v in self.get_model_state_dict().items():
                    if "bn" not in k:
                        sumed_grad = torch.zeros_like(v)
                        for i in range(len(weights)):
                            grad = weights[i][k] - v
                            grad = grad * min(1, norms[k] / grad.norm(2))
                            sumed_grad += grad
                        sigma = norms[k] * self.noise_multiplier
                        sumed_grad += torch.normal(0, sigma, v.shape)
                        value = v + sumed_grad / self.sample_clients
                        self.model.state_dict()[k].data.copy_(value.detach().clone())


@proxy(PYUObject)
class LDPServer(CDPServer):
    def __init__(
        self,
        device,
        model,
        n_classes,
        input_shape,
        noise_multiplier=1,
        sample_clients=10,
        disc_lr=1,
    ):
        super().__init__(
            device,
            model,
            n_classes,
            input_shape,
            noise_multiplier,
            sample_clients,
            disc_lr,
        )
        self.model = ModuleValidator.fix(self.model)
        self.privacy_engine = opacus.PrivacyEngine()
        self.model = self.privacy_engine._prepare_model(self.model)
        model_name = model.__name__ if isinstance(model, type) else model
        self.agg = True
        if "IN" in model_name:
            self.agg = False

    def agg_updates(self, weights):
        logging.warning("LDP Server aggregating updates")
        with torch.no_grad():
            if self.agg == False:
                for k, v in self.get_model_state_dict().items():
                    if "bn" not in k and "norm" not in k and "downsample.1" not in k:
                        sumed_grad = torch.zeros_like(v)
                        for i in range(len(weights)):
                            grad = weights[i][k] - v
                            sumed_grad += grad
                        value = v + sumed_grad / self.sample_clients
                        self.model.state_dict()[k].data.copy_(value.detach().clone())
            else:
                for k, v in self.get_model_state_dict().items():
                    if "bn" not in k:
                        sumed_grad = torch.zeros_like(v)
                        for i in range(len(weights)):
                            grad = weights[i][k] - v
                            sumed_grad += grad
                        value = v + sumed_grad / self.sample_clients
                        self.model.state_dict()[k].data.copy_(value.detach().clone())

## 6.FedAverage

In [6]:
from datetime import date
import argparse
import time

start_time = time.time()

# 直接设置参数值
args = type(
    "Args",
    (),
    {
        "data": "mnist",
        "nclient": 50,
        "nclass": 10,
        "ncpc": 2,
        "model": "mnist_fully_connected_IN",
        "mode": "CDP",
        "round": 60,
        "epsilon": 2,
        "physical_bs": 64,
        "sr": 1.0,
        "lr": 5e-3,
        "flr": 1e-2,
        "E": 1,
    },
)()

today = date.today().isoformat()
DATA_NAME = args.data
NUM_CLIENTS = args.nclient
NUM_CLASSES = args.nclass
NUM_CLASES_PER_CLIENT = args.ncpc
MODEL = args.model
MODE = args.mode
EPOCHS = 1
ROUNDS = args.round
BATCH_SIZE = 64
LEARNING_RATE_DIS = args.lr
LEARNING_RATE_F = args.flr
mp_bs = args.physical_bs
target_epsilon = args.epsilon
target_delta = 1e-3
sample_rate = args.sr

In [None]:
import secretflow as sf

# 初始化 secretflow
sf.shutdown()
sf.init(
    ["server"] + [f"client_{i}" for i in range(args.nclient)],
    address="local",
    num_gpus=1,
)
# 为服务器和每个客户端创建PYU（Party Unit）
server_pyu = sf.PYU("server")
client_pyus = [sf.PYU(f"client_{i}") for i in range(args.nclient)]

os.makedirs(f"log/E{args.E}", exist_ok=True)
user_param = {"disc_lr": LEARNING_RATE_DIS, "epochs": EPOCHS}
server_param = {}
if MODE == "LDP":
    user_obj = LDPUser
    server_obj = LDPServer
    user_param["rounds"] = ROUNDS
    user_param["target_epsilon"] = target_epsilon
    user_param["target_delta"] = target_delta
    user_param["sr"] = sample_rate
    user_param["mp_bs"] = mp_bs
elif MODE == "CDP":
    user_obj = CDPUser
    server_obj = CDPServer
    user_param["flr"] = LEARNING_RATE_F
    server_param["noise_multiplier"] = opacus.accountants.utils.get_noise_multiplier(
        target_epsilon=target_epsilon,
        target_delta=target_delta,
        sample_rate=sample_rate,
        steps=ROUNDS,
    )
    # print(f"noise_multipier: {server_param['noise_multiplier']}")
    server_param["sample_clients"] = sample_rate * NUM_CLIENTS
else:
    raise ValueError("Choose mode from [CDP, LDP]")

In [8]:
if DATA_NAME == "purchase":
    root = "data/purchase/dataset_purchase"
elif DATA_NAME == "chmnist":
    root = "data/CHMNIST"
else:
    root = "~/torch_data"

train_dataloaders, test_dataloaders = gen_random_loaders(
    DATA_NAME, root, NUM_CLIENTS, BATCH_SIZE, NUM_CLASES_PER_CLIENT, NUM_CLASSES
)

# print(user_param)


每个客户端的类别分布:
客户端 1:
  类别 0: 概率 0.1000
  类别 2: 概率 0.1000
客户端 2:
  类别 3: 概率 0.1000
  类别 1: 概率 0.1000
客户端 3:
  类别 4: 概率 0.1000
  类别 7: 概率 0.1000
客户端 4:
  类别 5: 概率 0.1000
  类别 6: 概率 0.1000
客户端 5:
  类别 9: 概率 0.1000
  类别 8: 概率 0.1000
客户端 6:
  类别 9: 概率 0.1000
  类别 1: 概率 0.1000
客户端 7:
  类别 4: 概率 0.1000
  类别 5: 概率 0.1000
客户端 8:
  类别 6: 概率 0.1000
  类别 0: 概率 0.1000
客户端 9:
  类别 8: 概率 0.1000
  类别 2: 概率 0.1000
客户端 10:
  类别 7: 概率 0.1000
  类别 3: 概率 0.1000
客户端 11:
  类别 6: 概率 0.1000
  类别 9: 概率 0.1000
客户端 12:
  类别 7: 概率 0.1000
  类别 5: 概率 0.1000
客户端 13:
  类别 0: 概率 0.1000
  类别 3: 概率 0.1000
客户端 14:
  类别 4: 概率 0.1000
  类别 8: 概率 0.1000
客户端 15:
  类别 2: 概率 0.1000
  类别 1: 概率 0.1000
客户端 16:
  类别 2: 概率 0.1000
  类别 0: 概率 0.1000
客户端 17:
  类别 7: 概率 0.1000
  类别 8: 概率 0.1000
客户端 18:
  类别 3: 概率 0.1000
  类别 9: 概率 0.1000
客户端 19:
  类别 1: 概率 0.1000
  类别 6: 概率 0.1000
客户端 20:
  类别 4: 概率 0.1000
  类别 5: 概率 0.1000
客户端 21:
  类别 3: 概率 0.1000
  类别 1: 概率 0.1000
客户端 22:
  类别 4: 概率 0.1000
  类别 7: 概率 0.1000
客户端 23:
  类别 0: 概率 0.1000
  

In [9]:
# 修改：将device参数替换为相应的PYU,在secretflow中，计算设备由PYU表示
users = [
    user_obj(
        i,
        client_pyus[i],
        device=client_pyus[i],
        model=globals()[MODEL],
        input_shape=None,
        n_classes=NUM_CLASSES,
        train_dataloader=train_dataloaders[i],
        **user_param
    )
    for i in range(NUM_CLIENTS)
]
server = server_obj(
    server_pyu,
    device=server_pyu,
    model=globals()[MODEL],
    input_shape=None,
    n_classes=NUM_CLASSES,
    **server_param
)

In [10]:
# def sf_train(clients, server, rounds):
# 原有代码: 为所有客户端设置初始模型
server_state_dict = server.get_model_state_dict()
for i in range(NUM_CLIENTS):
    # 使用 SecretFlow 的方法将服务器的状态字典传输到客户端
    client_state_dict = server_state_dict.to(users[i].device)
    users[i].set_model_state_dict(client_state_dict)
best_acc = 0
for round in range(ROUNDS):
    random_index = np.random.choice(
        NUM_CLIENTS, int(sample_rate * NUM_CLIENTS), replace=False
    )
    for index in random_index:
        users[index].train()
    if MODE == "LDP":
        weights_agg = agg_weights(
            [users[index].get_model_state_dict() for index in random_index]
        )
        for i in range(NUM_CLIENTS):
            users[i].set_model_state_dict(weights_agg)
    else:
        server.agg_updates(
            [
                users[index].get_model_state_dict().to(server.device)
                for index in random_index
            ]
        )
        server_state_dict = server.get_model_state_dict()
        for i in range(NUM_CLIENTS):
            client_state_dict = server_state_dict.to(users[i].device)
            users[i].set_model_state_dict(client_state_dict)
    print(f"Round: {round+1}")
    acc = evaluate_global(users, test_dataloaders, range(NUM_CLIENTS))
    if acc > best_acc:
        best_acc = acc
    if MODE == "LDP":
        eps = max([user.epsilon for user in users])
        print(f"Epsilon: {eps}")
        if eps > target_epsilon:
            break
    # return best_acc


end_time = time.time()
print("Use time: {:.2f}h".format((end_time - start_time) / 3600.0))
print(f"Best accuracy: {best_acc}")
results_df = pd.DataFrame(
    columns=["data", "num_client", "ncpc", "mode", "model", "epsilon", "accuracy"]
)
results_df = results_df._append(
    {
        "data": DATA_NAME,
        "num_client": NUM_CLIENTS,
        "ncpc": NUM_CLASES_PER_CLIENT,
        "mode": MODE,
        "model": MODEL,
        "epsilon": target_epsilon,
        "accuracy": best_acc,
    },
    ignore_index=True,
)
results_df.to_csv(
    f"log/E{args.E}/{DATA_NAME}_{NUM_CLIENTS}_{NUM_CLASES_PER_CLIENT}_{MODE}_{MODEL}_{target_epsilon}.csv",
    index=False,
)

sf.shutdown()

Round: 1
全局准确率: 0.0846
Round: 2
全局准确率: 0.1086
Round: 3
全局准确率: 0.1934
Round: 4
全局准确率: 0.1451
Round: 5
全局准确率: 0.1716
Round: 6
全局准确率: 0.1297
Round: 7
全局准确率: 0.1223
Round: 8
全局准确率: 0.2045
Round: 9
全局准确率: 0.2635
Round: 10
全局准确率: 0.3911
Round: 11
全局准确率: 0.5992
Round: 12
全局准确率: 0.6683
Round: 13
全局准确率: 0.8053
Round: 14
全局准确率: 0.8786
Round: 15
全局准确率: 0.9037
Round: 16
全局准确率: 0.9173
Round: 17
全局准确率: 0.9308
Round: 18
全局准确率: 0.9382
Round: 19
全局准确率: 0.9411
Round: 20
全局准确率: 0.9459
Round: 21
全局准确率: 0.9485
Round: 22
全局准确率: 0.9516
Round: 23
全局准确率: 0.9520
Round: 24
全局准确率: 0.9550
Round: 25
全局准确率: 0.9520
Round: 26
全局准确率: 0.9588
Round: 27
全局准确率: 0.9607
Round: 28
全局准确率: 0.9584
Round: 29
全局准确率: 0.9598
Round: 30
全局准确率: 0.9646
Round: 31
全局准确率: 0.9599
Round: 32
全局准确率: 0.9641
Round: 33
全局准确率: 0.9648
Round: 34
全局准确率: 0.9640
Round: 35
全局准确率: 0.9582
Round: 36
全局准确率: 0.9650
Round: 37
全局准确率: 0.9655
Round: 38
全局准确率: 0.9654
Round: 39
全局准确率: 0.9649
Round: 40
全局准确率: 0.9665
Round: 41
全局准确率: 0.9675
Round: 42
全局准确率: 0.9645
R

