In [1]:
from hashlib import md5
from typing import List, Set, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from cnn_net import HalfConvNet
from torch.utils.data import DataLoader, Subset

from alphafed import logger
from alphafed.hetero_nn import HeteroNNCollaboratorScheduler


class DemoHeteroCollaborator(HeteroNNCollaboratorScheduler):

    def __init__(self,
                 feature_key: str,
                 batch_size: int,
                 schedule_timeout: int = 30,
                 is_feature_trainable: bool = True) -> None:
        super().__init__(feature_key=feature_key,
                         schedule_timeout=schedule_timeout,
                         is_feature_trainable=is_feature_trainable)
        self.batch_size = batch_size

    def load_local_ids(self) -> List[str]:
        # 聚合方取 MNIST 训练样本的前 20000 个、测试样本的前 5000 个，
        # 协作方取 MNIST 训练样本的第 10000 - 30000 个、测试样本的第 3000 - 7000 个.
        # 求交后共有训练样本 10000 个，测试样本 2000 个。
        # 为避免 ID 冲突，测试样本 ID 全部加 100000，以使其和训练样本 ID 不重合
        train_ids = [str(i) for i in range(10000, 30000)]
        test_ids = [str(i) for i in range(103000, 107000)]
        return train_ids + test_ids

    def split_dataset(self, id_intersection: Set[str]) -> Tuple[Set[str], Set[str]]:
        # 示例中训练样本 ID < 100000，测试样本 ID >= 100000，排序后前 10000 个样本为训练样本 ID。
        ids = [int(_id) for _id in id_intersection]
        ids.sort()
        train_ids = ids[:10000]
        test_ids = ids[10000:]

        logger.info(f'Got {len(train_ids)} intersecting samples for training.')
        logger.info(f'Got {len(test_ids)} intersecting samples for testing.')

        return set(train_ids), set(test_ids)

    def build_feature_model(self) -> nn.Module:
        # 现实中每个协作方都可以定义自己的特征模型，不需要与其它任何参与方相同
        return HalfConvNet()

    def build_feature_optimizer(self, feature_model: nn.Module) -> optim.Optimizer:
        # 现实中每个协作方都可以定义自己的优化器，不需要与其它任何参与方相同
        return optim.SGD(feature_model.parameters(), lr=0.01, momentum=0.9)

    def _erase_left(self, _image: torch.Tensor) -> torch.Tensor:
        return _image[:, :, :, 14:]

    def iterate_train_feature(self,
                              feature_model: nn.Module,
                              train_ids: Set[str]) -> torch.Tensor:
        train_dataset = torchvision.datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])
        )
        train_ids: List = list(train_ids)
        train_ids.sort(key=lambda x: md5(bytes(x + self.current_round)).digest())

        train_dataset = Subset(train_dataset, train_ids)
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self.batch_size,
                                       shuffle=False)

        for _data, _ in self.train_loader:  # 丢弃标签数据
            _data = self._erase_left(_data)
            yield feature_model(_data)

    def iterate_test_feature(self,
                             feature_model: nn.Module,
                             test_ids: Set[str]) -> torch.Tensor:
        test_dataset = torchvision.datasets.MNIST(
            'data',
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])
        )
        test_ids = [(i - 100000) for i in test_ids]
        test_dataset = Subset(test_dataset, test_ids)
        self.test_loader = DataLoader(test_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False)

        for _data, _ in self.test_loader:  # 丢弃标签数据
            _data = self._erase_left(_data)
            yield feature_model(_data)

In [2]:
from alphafed import mock_context


# 参与方的模拟启动脚本，需要复制到另一个 Notebook 脚本文件中执行
# 与横向联邦不同，异构联邦中 scheduler 实例和聚合方的不一样
scheduler = DemoHeteroCollaborator(
    feature_key='demo_collaborator',
    batch_size=128  # 各个参与方需要设置相同的 batch_size，否则数据无法对齐
)
collaborate_id = 'ec72a25c-ff83-47f2-a4bc-74321acace99'  # 需要与聚合方设置的 ID 一致
task_id = '5a8c6b50-6267-4258-afb0-a4e7413ead58'  # 需要与聚合方设置的 ID 一致
aggregator_id = '9afb99da-2e4c-4676-a434-5f312f583947'  # 需要与聚合方设置的 ID 一致
# 算法实际运行时会从任务管理器获取任务参与节点的 Node ID 列表，但是在模拟环境不能通过
# 访问实际接口获得这个信息，所以需要通过 nodes 参数将这个列表配置在模拟环境中。
with mock_context(id=collaborate_id, nodes=[aggregator_id, collaborate_id]):  # 在模拟调试环境中运行
    # _run 接口是最底层的 Scheduler 接口定义的，横向联邦、异构联邦框架都实现了对应接口
    scheduler._run(id=collaborate_id, task_id=task_id)

2023-02-02 08:39:44,874|DEBUG|scheduler|_switch_status|125:
self.status='init'
2023-02-02 08:39:44,878|INFO|scheduler|push_log|118:
Begin to validate local context.
2023-02-02 08:39:44,879|INFO|scheduler|push_log|118:
Local context is ready.
2023-02-02 08:39:44,880|INFO|scheduler|push_log|118:
Node ec72a25c-ff83-47f2-a4bc-74321acace99 is up.
2023-02-02 08:39:44,880|DEBUG|scheduler|_switch_status|125:
self.status='gethoring'
2023-02-02 08:39:44,881|INFO|scheduler|push_log|118:
Checking in the task ...
2023-02-02 08:39:44,889|DEBUG|hetero_nn|_check_in|974:
_wait_for_check_in_response ...
2023-02-02 08:39:45,914|INFO|scheduler|push_log|118:
Node ec72a25c-ff83-47f2-a4bc-74321acace99 have taken part in the task.
2023-02-02 08:39:45,915|DEBUG|scheduler|_switch_status|125:
self.status='ready'
2023-02-02 08:39:45,916|DEBUG|scheduler|_switch_status|125:
self.status='synchronizing'
2023-02-02 08:39:45,917|INFO|scheduler|push_log|118:
Waiting for synchronizing state with the host ...
2023-02-02 0