In [1]:
from hashlib import md5
from time import time
from typing import Dict, List, Set, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from cnn_net import HalfConvNet
from torch.utils.data import DataLoader, Subset

from alphafed import logger
from alphafed.hetero_nn import HeteroNNHostScheduler


# 本示例中使用的推理模型很简单，所以直接在脚本中定义。现实中如果模型比较复杂，
# 也可以参考 HalfConvNet 放在独立文件中然后引入。
class InferModule(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(40, 20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, input):
        out = F.relu(self.fc1(input))
        out = self.fc2(out)
        return F.log_softmax(out, dim=-1)


class DemoHeteroHost(HeteroNNHostScheduler):

    def __init__(self,
                 feature_key: str,
                 batch_size: int,
                 max_rounds: int = 0,
                 calculation_timeout: int = 300,
                 log_rounds: int = 0) -> None:
        """初始化参数说明.

        以下为 `HeteroNNHostScheduler` 父类定义的初始化参数。
        feature_key:
            当前参与方的特征 key 标识，用于在特征聚合时区分特征来源，以支持区别化处理。
            train_a_batch 和 test 方法中使用的 feature_key 参数既来自于这里。
        max_rounds:
            训练多少轮。
        calculation_timeout:
            本地训练超时时间。
        log_rounds:
            每隔几轮训练执行一次测试，评估记录当前训练效果。

        以下 `DemoHeteroHost` 自定义的扩展初始化参数。
        batch_size:
            训练参数。
        """
        super().__init__(feature_key=feature_key,
                         max_rounds=max_rounds,
                         calculation_timeout=calculation_timeout,
                         log_rounds=log_rounds)
        self.batch_size = batch_size

    def load_local_ids(self) -> List[str]:
        # 聚合方取 MNIST 训练样本的前 20000 个、测试样本的前 5000 个，
        # 协作方取 MNIST 训练样本的第 10000 - 30000 个、测试样本的第 3000 - 7000 个.
        # 求交后共有训练样本 10000 个，测试样本 2000 个。
        # 为避免 ID 冲突，测试样本 ID 全部加 100000，以使其和训练样本 ID 不重合
        train_ids = [str(i) for i in range(0, 20000)]
        test_ids = [str(i) for i in range(100000, 105000)]
        return train_ids + test_ids

    def split_dataset(self, id_intersection: Set[str]) -> Tuple[Set[str], Set[str]]:
        # 示例中训练样本 ID < 100000，测试样本 ID >= 100000，排序后前 10000 个样本为训练样本 ID。
        ids = [int(_id) for _id in id_intersection]
        ids.sort()
        train_ids = ids[:10000]
        test_ids = ids[10000:]

        logger.info(f'Got {len(train_ids)} intersecting samples for training.')
        logger.info(f'Got {len(test_ids)} intersecting samples for testing.')

        return set(train_ids), set(test_ids)

    def build_feature_model(self) -> nn.Module:
        return HalfConvNet()

    def build_feature_optimizer(self, feature_model: nn.Module) -> optim.Optimizer:
        return optim.SGD(feature_model.parameters(), lr=0.01, momentum=0.9)

    def _erase_right(self, _image: torch.Tensor) -> torch.Tensor:
        return _image[:, :, :, :14]

    def iterate_train_feature(self,
                              feature_model: nn.Module,
                              train_ids: Set[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        train_dataset = torchvision.datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])
        )
        # 每一轮训练时，根据“样本 ID + 当前轮次”的哈希值对样本排序，既可以对每个训练轮次的
        # 样本随机重排序，又可以保证不同参与方随机重排后的样本依然是对齐的。
        train_ids: List = list(train_ids)
        train_ids.sort(key=lambda x: md5(bytes(x + self.current_round)).digest())

        train_dataset = Subset(train_dataset, train_ids)
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self.batch_size,
                                       shuffle=False)

        for _data, _labels in self.train_loader:
            _data = self._erase_right(_data)
            yield feature_model(_data), _labels

    def iterate_test_feature(self,
                             feature_model: nn.Module,
                             test_ids: Set[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        test_dataset = torchvision.datasets.MNIST(
            'data',
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])
        )
        test_ids = [(i - 100000) for i in test_ids]
        test_dataset = Subset(test_dataset, test_ids)
        self.test_loader = DataLoader(test_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False)

        for _data, _labels in self.test_loader:
            _data = self._erase_right(_data)
            yield feature_model(_data), _labels

    def build_infer_model(self) -> nn.Module:
        return InferModule()

    def build_infer_optimizer(self, infer_model: nn.Module) -> optim.Optimizer:
        return optim.SGD(infer_model.parameters(), lr=0.01, momentum=0.9)

    def train_a_batch(self, feature_projection: Dict[str, torch.Tensor], labels: torch.Tensor):
        fusion_tensor = torch.concat((feature_projection['demo_host'],
                                      feature_projection['demo_collaborator']), dim=1)
        self.optimizer.zero_grad()
        out = self.infer_model(fusion_tensor)
        loss = F.nll_loss(out, labels)
        loss.backward()
        self.optimizer.step()

    def run_test(self,
                 batched_feature_projections: List[torch.Tensor],
                 batched_labels: List[torch.Tensor]):
        # 测试时也需要各个参与方提供本地测试数据对应的特征张量
        start = time()
        test_loss = 0
        correct = 0
        for _feature_projection, _lables in zip(batched_feature_projections, batched_labels):
            fusion_tensor = torch.concat((_feature_projection['demo_host'],
                                          _feature_projection['demo_collaborator']), dim=1)
            out: torch.Tensor = self.infer_model(fusion_tensor)
            test_loss += F.nll_loss(out, _lables)
            pred = out.max(1, keepdim=True)[1]
            correct += pred.eq(_lables.view_as(pred)).sum().item()

        end = time()

        test_loss /= len(self.test_loader.dataset)
        accuracy = correct / len(self.test_loader.dataset)
        correct_rate = 100. * accuracy

        self.tb_writer.add_scalar('timer/run_time', end - start, self.current_round)
        self.tb_writer.add_scalar('test_results/average_loss', test_loss, self.current_round)
        self.tb_writer.add_scalar('test_results/accuracy', accuracy, self.current_round)
        self.tb_writer.add_scalar('test_results/correct_rate', correct_rate, self.current_round)

In [2]:
from alphafed import mock_context


# 聚合方的模拟启动脚本
scheduler = DemoHeteroHost(
    feature_key='demo_host',
    batch_size=128,  # 各个参与方需要设置相同的 batch_size，否则数据无法对齐
    max_rounds=5,  # 受本地资源限制，运行速度可能会很慢，调试时不建议设置太高
    calculation_timeout=1800,  # 受本地资源限制，运行速度可能会很慢，设置太低容易导致超时
    log_rounds=1
)

aggregator_id = '9afb99da-2e4c-4676-a434-5f312f583947'  # 设置一个假想 ID
task_id = '5a8c6b50-6267-4258-afb0-a4e7413ead58'  # 设置一个假想 ID
# 算法实际运行时会从任务管理器获取任务参与节点的 Node ID 列表，但是在模拟环境不能通过
# 访问实际接口获得这个信息，所以需要通过 nodes 参数将这个列表配置在模拟环境中。
collaborate_id = 'ec72a25c-ff83-47f2-a4bc-74321acace99'  # 设置一个假想 ID
with mock_context(id=aggregator_id, nodes=[aggregator_id, collaborate_id]):  # 在模拟调试环境中运行
    # _run 接口是最底层的 Scheduler 接口定义的，横向联邦、异构联邦框架都实现了对应接口
    scheduler._run(id=aggregator_id, task_id=task_id, is_initiator=True)

2023-02-02 08:39:39,538|DEBUG|scheduler|_switch_status|125:
self.status='init'
2023-02-02 08:39:39,541|INFO|scheduler|push_log|118:
Begin to validate local context.
2023-02-02 08:39:39,543|INFO|scheduler|push_log|118:
Local context is ready.
2023-02-02 08:39:39,544|INFO|scheduler|push_log|118:
Node 9afb99da-2e4c-4676-a434-5f312f583947 is up.
2023-02-02 08:39:39,545|DEBUG|scheduler|_switch_status|125:
self.status='gethoring'
2023-02-02 08:39:39,546|INFO|scheduler|push_log|118:
Waiting for participants taking part in ...
2023-02-02 08:39:45,561|INFO|scheduler|push_log|118:
Welcome a new partner ID: ec72a25c-ff83-47f2-a4bc-74321acace99.
2023-02-02 08:39:45,562|INFO|scheduler|push_log|118:
There are 1 partners now.
2023-02-02 08:39:45,569|INFO|scheduler|push_log|118:
All partners have gethored.
2023-02-02 08:39:45,569|DEBUG|scheduler|_switch_status|125:
self.status='ready'
2023-02-02 08:39:45,570|DEBUG|scheduler|_switch_status|125:
self.status='synchronizing'
2023-02-02 08:39:45,571|INFO|s