In [1]:
from time import time

import torch
import torch.nn.functional as F
import torchvision
from cnn_net import ConvNet
from torch import nn, optim
from torch.utils.data import DataLoader

from alphafed.fed_avg import FedAvgScheduler


class DemoAvg(FedAvgScheduler):

    def __init__(self,
                 max_rounds: int = 0,
                 merge_epochs: int = 1,
                 calculation_timeout: int = 300,
                 log_rounds: int = 0,
                 involve_aggregator: bool = False,
                 batch_size: int = 128,
                 learning_rate: float = 0.01,
                 momentum: float = 0.9) -> None:
        """初始化参数说明.

        以下为 `FedAvgScheduler` 父类定义的初始化参数。
        max_rounds:
            训练多少轮。
        merge_epochs:
            每次参数聚合前，在本地训练几个 epoch。
        calculation_timeout:
            本地训练超时时间。
        log_rounds:
            每隔几轮训练执行一次测试，评估记录当前训练效果。
        involve_aggregator:
            聚合方是否也使用本地数据参与训练，默认不参与，只负责聚合。

        以下 `DemoAvg` 自定义的扩展初始化参数。
        batch_size、learning_rate、momentum
            训练参数。
        """
        super().__init__(max_rounds=max_rounds,
                         merge_epochs=merge_epochs,
                         calculation_timeout=calculation_timeout,
                         log_rounds=log_rounds,
                         involve_aggregator=involve_aggregator)
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.momentum = momentum

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.seed = 42
        torch.manual_seed(self.seed)

    def build_model(self) -> nn.Module:  # 返回模型实例
        model = ConvNet()
        return model.to(self.device)

    def build_optimizer(self, model: nn.Module) -> optim.Optimizer:  # 返回优化器实例
        return optim.SGD(model.parameters(),
                         lr=self.learning_rate,
                         momentum=self.momentum)

    # 网络下载 MNIST 训练数据用于训练，如果是实际使用场景，需要开发者自行处理数据加载
    def build_train_dataloader(self) -> DataLoader:
        return DataLoader(
            torchvision.datasets.MNIST(
                'data',  # 数据下载目录
                train=True,
                download=True,
                transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
            ),
            batch_size=self.batch_size,
            shuffle=True
        )

    # 本示例不使用验证集数据，所以不需要实现 build_validation_dataloader

    # 网络下载 MNIST 测试数据用于测试，如果是实际使用场景，需要开发者自行处理数据加载
    def build_test_dataloader(self) -> DataLoader:
        return DataLoader(
            torchvision.datasets.MNIST(
                'data',  # 数据下载目录
                train=False,
                download=True,
                transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
            ),
            batch_size=self.batch_size,
            shuffle=False
        )

    def train_an_epoch(self) -> None:  # 执行一个 epoch 训练的逻辑，与本地训练模型时的代码相同
        self.model.train()
        for data, labels in self.train_loader:
            data, labels = data.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.nll_loss(output, labels)
            loss.backward()
            self.optimizer.step()

    def run_test(self):  # 执行一次测试的逻辑，与本地模型测时的代码相同
        self.model.eval()
        start = time()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, labels in self.test_loader:
                data, labels = data.to(self.device), labels.to(self.device)
                output: torch.Tensor = self.model(data)
                test_loss += F.nll_loss(output, labels, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(labels.view_as(pred)).sum().item()

        end = time()

        test_loss /= len(self.test_loader.dataset)
        accuracy = correct / len(self.test_loader.dataset)
        correct_rate = 100. * accuracy

        # 记录 TensorBoard 日志，self.current_round 的值为“当前是训练的第几轮”
        self.tb_writer.add_scalar('timer/run_time', end - start, self.current_round)
        self.tb_writer.add_scalar('test_results/average_loss', test_loss, self.current_round)
        self.tb_writer.add_scalar('test_results/accuracy', accuracy, self.current_round)
        self.tb_writer.add_scalar('test_results/correct_rate', correct_rate, self.current_round)

In [2]:
from alphafed import mock_context

# 参与方的模拟启动脚本，需要复制到另一个 Notebook 脚本文件中执行
# scheduler 实例和聚合方的一模一样
scheduler = DemoAvg(max_rounds=3, log_rounds=1, calculation_timeout=1800)

task_id = 'cbb3ffd0-838c-41ca-a41a-7c11cae29181'  # 任务 ID 必须与聚合方一致
aggregator_id = '1bb9feba-7b53-455b-b127-0eb19ffc9d3f'  # 必须与聚合方配置的一致
col_id_1 = '663ad4b0-b617-409f-8bc9-3682b30f7f30'  # 必须与聚合方配置的一致
col_id_2 = '0fc1a571-2920-47bf-9e4e-b4edb7fa2caa'  # 必须与聚合方配置的一致
with mock_context(id=col_id_1, nodes=[aggregator_id, col_id_1, col_id_2]):  # 在模拟调试环境中运行
    scheduler._run(id=col_id_1, task_id=task_id)

2023-02-03 08:35:02,040|DEBUG|scheduler|_switch_status|125:
self.status='init'
2023-02-03 08:35:02,063|INFO|scheduler|push_log|118:
Begin to validate local context.
2023-02-03 08:35:02,064|INFO|scheduler|push_log|118:
Local context is ready.
2023-02-03 08:35:02,065|INFO|scheduler|push_log|118:
Node 663ad4b0-b617-409f-8bc9-3682b30f7f30 is up.
2023-02-03 08:35:02,066|DEBUG|scheduler|_switch_status|125:
self.status='gethering'
2023-02-03 08:35:02,066|INFO|scheduler|push_log|118:
Checking in the task ...
2023-02-03 08:35:02,077|DEBUG|fed_avg|_wait_for_check_in_response|479:
_wait_for_check_in_response ...
2023-02-03 08:35:03,104|INFO|scheduler|push_log|118:
Node 663ad4b0-b617-409f-8bc9-3682b30f7f30 have taken part in the task.
2023-02-03 08:35:03,105|DEBUG|scheduler|_switch_status|125:
self.status='ready'
2023-02-03 08:35:03,105|DEBUG|scheduler|_switch_status|125:
self.status='synchronizing'
2023-02-03 08:35:03,106|INFO|scheduler|push_log|118:
Synchronizing round state ...
2023-02-03 08:35