# FenGen个性化联邦学习例子

In [None]:
%load_ext autoreload
%autoreload 2

## 在secretflow环境创造3个实体[Alice，Bob，Charlie]，其中 Alice, Bob和Charlie 是三个PYU，Alice和Bob角色是client，Charlie角色是server。

In [None]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob', 'charlie'], address='local')
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')

In [None]:
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

## 导入相关依赖

In [None]:
from sfl.ml.nn.core.torch import (
    metric_wrapper,
    optim_wrapper,
    BaseModule,
    TorchModel,
)
from sfl.ml.nn import FLModel
from torchmetrics import Accuracy, Precision
from sfl.security.aggregation import SecureAggregator
from sfl.utils.simulation.datasets_fl import load_mnist
from torch import nn, optim
from torch.nn import functional as F
import torch

## 数据划分，这里模拟数据不平衡2:8分

In [None]:
(train_data, train_label), (test_data, test_label) = load_mnist(
    parts={alice: 0.2, bob: 0.8},
    normalized_x=True,
    categorical_y=True,
    is_torch=True,
)

## 定义一个神经网络模型，输出是logit

In [None]:
class ConvNet(BaseModule):
    """Small ConvNet for MNIST."""

    def __init__(self, kl_div_loss, num_classes):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc_in_dim = 192
        self.fc = nn.Linear(self.fc_in_dim, 10)
        self.kl_div_loss = kl_div_loss
        self.num_classes = num_classes

    def forward(self, x, start_layer_idx=0):
        if start_layer_idx == -1:
            x = self.fc(x)
            return x
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, self.fc_in_dim)
        x = self.fc(x)
        return x

定义神经网络模型的损失函数和优化器

In [None]:
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-2)

## 准备FedGen相关工作，生成器模型等

DiversityLoss是一个自定义的损失函数类

In [None]:
class DiversityLoss(nn.Module):
    """
    Diversity loss for improving the performance.
    """

    def __init__(self, metric):
        """
        Class initializer.
        """
        super().__init__()
        self.metric = metric
        self.cosine = nn.CosineSimilarity(dim=2)

    def compute_distance(self, tensor1, tensor2, metric):
        """
        Compute the distance between two tensors.
        """
        if metric == 'l1':
            return torch.abs(tensor1 - tensor2).mean(dim=(2,))
        elif metric == 'l2':
            return torch.pow(tensor1 - tensor2, 2).mean(dim=(2,))
        elif metric == 'cosine':
            return 1 - self.cosine(tensor1, tensor2)
        else:
            raise ValueError(metric)

    def pairwise_distance(self, tensor, how):
        """
        Compute the pairwise distances between a Tensor's rows.
        """
        n_data = tensor.size(0)
        tensor1 = tensor.expand((n_data, n_data, tensor.size(1)))
        tensor2 = tensor.unsqueeze(dim=1)
        return self.compute_distance(tensor1, tensor2, how)

    def forward(self, noises, layer):
        """
        Forward propagation.
        """
        if len(layer.shape) > 2:
            layer = layer.view((layer.size(0), -1))
        layer_dist = self.pairwise_distance(layer, how=self.metric)
        noise_dist = self.pairwise_distance(noises, how='l2')
        return torch.exp(torch.mean(-noise_dist * layer_dist))

FedGen需要有generator模型,并一些训练参数相关设置，优化器等

In [None]:
from sfl.ml.nn.fl.backend.torch.strategy.fed_gen import (
    FedGenGeneratorModel,
    FedGenActor,
)

kl_div_loss = nn.KLDivLoss(reduction="batchmean")
diversity_loss = DiversityLoss(metric='l1')
cross_entropy_loss = nn.CrossEntropyLoss()
num_classes = 10
generator = FedGenGeneratorModel(
    hidden_dimension=256,
    latent_dimension=192,
    noise_dim=64,
    num_classes=num_classes,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    diversity_loss=diversity_loss,
)

## 进行联邦学习

In [None]:
from sfl.security.aggregation.stateful_fedgen_aggregator import (
    StatefulFedGenAggregator,
)

net = ConvNet(diversity_loss, 20)

model_def = TorchModel(
    model_fn=ConvNet,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    metrics=[
        metric_wrapper(Accuracy, task="multiclass", num_classes=10, average='micro'),
        metric_wrapper(Precision, task="multiclass", num_classes=10, average='micro'),
    ],
    kl_div_loss=kl_div_loss,
    num_classes=num_classes,
)

server_actor = FedGenActor(device=charlie, generator=generator)
device_list = [alice, bob]
aggregator = StatefulFedGenAggregator(charlie, [alice, bob], server_actor)
# spcify params
fl_model = FLModel(
    server=charlie,
    device_list=device_list,
    model=model_def,
    strategy="fed_gen",  # fl strategy
    backend="torch",  # backend support ['tensorflow', 'torch']
    aggregator=aggregator,
    generator=generator,
)
history = fl_model.fit(
    train_data,
    train_label,
    validation_data=(test_data, test_label),
    epochs=20,
    batch_size=32,
    aggregate_freq=1,
)

## 绘制结果

In [None]:
from matplotlib import pyplot as plt

# Draw accuracy values for training & validation
plt.plot(history["global_history"]['multiclassaccuracy'])
plt.plot(history["global_history"]['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()