In [1]:
import torch
import torch.nn.functional as F
import torchvision
from torch.nn import Module
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from cnn_net import ConvNet
from simple_fed_avg.scheduler import SimpleFedAvgScheduler


class MyScheduler(SimpleFedAvgScheduler):

    def __init__(self, rounds: int) -> None:
        super().__init__(rounds)

    def before_check_in(self, is_aggregator: bool):
        """完成集合前初始化本地资源。"""
        if is_aggregator:
            self.test_loader = DataLoader(
                torchvision.datasets.MNIST(
                    'data',
                    train=False,
                    download=True,
                    transform=torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,))
                    ])
                ),
                batch_size=64,
                shuffle=False
            )
        else:
            self.train_loader = DataLoader(
                torchvision.datasets.MNIST(
                    'data',
                    train=True,
                    download=True,
                    transform=torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,))
                    ])
                ),
                batch_size=64,
                shuffle=True
            )

    def before_training(self, is_aggregator: bool):
        """训练开始前的初始化工作。"""
        self.optimizer = SGD(params=self.model.parameters(), lr=0.01, momentum=0.9)

    @property
    def model(self) -> Module:
        """获取训练使用的模型对象。"""
        if not hasattr(self, '_model'):
            self._model = ConvNet()
        return self._model

    def train_an_epoch(self):
        """完成一个 epoch 训练的逻辑。"""
        self.model.train()
        for data, labels in self.train_loader:
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.nll_loss(output, labels)
            loss.backward()
            self.optimizer.step()

    def run_test(self, writer: SummaryWriter):
        """测试的逻辑。"""
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, labels in self.test_loader:
                output: torch.Tensor = self.model(data)
                test_loss += F.nll_loss(output, labels, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(labels.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        correct_rate = 100. * correct / len(self.test_loader.dataset)
        writer.add_scalar('test_results/average_loss', test_loss, self.round)
        writer.add_scalar('test_results/correct_rate', correct_rate, self.round)

In [2]:
import logging
from alphafed import logger, mock_context

# 关闭 DEBUG 日志，如果需要更详细的调试日志，可以将这一行注释掉
logger.setLevel(logging.INFO)

# 参与方的模拟启动脚本，需要复制到单独的 Notebook 脚本文件中执行
# scheduler 实例和聚合方的一模一样
scheduler = MyScheduler(rounds=5)

aggregator_id = '4d41ca74-078c-494c-bf7f-324534460d10'  # 与聚合方一致
col_id_1 = 'e71690e4-347c-4fda-be6f-c8d49040c692'  # 与聚合方一致
col_id_2 = 'fab1898f-ed8a-47de-ba17-62fcd8916189'  # 与聚合方一致
task_id = 'a701073d-e2bb-401b-a1f0-bf2a8e58dd2a'  # 与聚合方一致
with mock_context(id=col_id_1, nodes=[aggregator_id, col_id_1, col_id_2]):
    scheduler._run(id=col_id_1, task_id=task_id)

2023-02-06 07:06:00,626|INFO|scheduler|push_log|118:
节点 e71690e4-347c-4fda-be6f-c8d49040c692 初始化完毕。
2023-02-06 07:06:00,627|INFO|scheduler|push_log|118:
发送集合请求，等待聚合方响应。
2023-02-06 07:06:01,668|INFO|scheduler|push_log|118:
收到响应，集合成功。
2023-02-06 07:06:01,670|INFO|scheduler|push_log|118:
节点 e71690e4-347c-4fda-be6f-c8d49040c692 准备就绪，可以开始执行计算任务。
2023-02-06 07:06:01,671|INFO|scheduler|push_log|118:
等待训练开始信号 ...
2023-02-06 07:06:06,698|INFO|scheduler|push_log|118:
开始训练 ...
2023-02-06 07:06:06,699|INFO|scheduler|push_log|118:
等待接收全局模型 ...
2023-02-06 07:06:08,335|INFO|scheduler|push_log|118:
接收全局模型成功。
2023-02-06 07:06:08,336|INFO|scheduler|push_log|118:
开始训练本地模型 ...
2023-02-06 07:06:25,408|INFO|scheduler|push_log|118:
训练本地模型完成。
2023-02-06 07:06:25,411|INFO|scheduler|push_log|118:
准备发送本地模型 ...
2023-02-06 07:06:26,615|INFO|scheduler|push_log|118:
发送本地模型完成。
2023-02-06 07:06:26,616|INFO|scheduler|push_log|118:
等待训练开始信号 ...
2023-02-06 07:06:28,631|INFO|scheduler|push_log|118:
开始训练 ...
2023-02-06 07: