In [44]:
import secrets
from dataclasses import dataclass
from typing import List

from alphafed.contractor.common import ContractEvent
from alphafed.contractor.task_message_contractor import (
    ApplySharedFileSendingDataEvent, TaskMessageContractor,
    TaskMessageEventFactory)

In [45]:
@dataclass
class CheckinEvent(ContractEvent):
    """An event of checkin for a specific task."""

    TYPE = 'checkin'

    peer_id: str
    nonce: str

    @classmethod
    def contract_to_event(cls, contract: dict) -> 'CheckinEvent':
        event_type = contract.get('type')
        peer_id = contract.get('peer_id')
        nonce = contract.get('nonce')
        assert event_type == cls.TYPE, f'合约类型错误: {event_type}'
        assert peer_id and isinstance(peer_id, str), f'invalid peer_id: {peer_id}'
        assert nonce or isinstance(nonce, str), f'invalid nonce: {nonce}'
        return CheckinEvent(peer_id=peer_id, nonce=nonce)


@dataclass
class CheckinResponseEvent(ContractEvent):
    """An event of responding checkin event."""

    TYPE = 'checkin_response'

    round: int
    aggregator: str
    nonce: str

    @classmethod
    def contract_to_event(cls, contract: dict) -> 'CheckinResponseEvent':
        event_type = contract.get('type')
        round = contract.get('round')
        aggregator = contract.get('aggregator')
        nonce = contract.get('nonce')
        assert event_type == cls.TYPE, f'合约类型错误: {event_type}'
        assert isinstance(round, int) and round >= 0, f'invalid round: {round}'
        assert aggregator and isinstance(aggregator, str), f'invalid aggregator: {aggregator}'
        assert nonce and isinstance(nonce, str), f'invalid nonce: {nonce}'
        return CheckinResponseEvent(round=round, aggregator=aggregator, nonce=nonce)


@dataclass
class StartRoundEvent(ContractEvent):
    """An event of starting a new round of training."""

    TYPE = 'start_round'

    round: int
    calculators: List[str]
    aggregator: str

    @classmethod
    def contract_to_event(cls, contract: dict) -> 'StartRoundEvent':
        event_type = contract.get('type')
        round = contract.get('round')
        calculators = contract.get('calculators')
        aggregator = contract.get('aggregator')
        assert event_type == cls.TYPE, f'合约类型错误: {event_type}'
        assert isinstance(round, int) and round > 0, f'invalid round: {round}'
        assert (
            calculators and isinstance(calculators, list)
            and all(_peer_id and isinstance(_peer_id, str) for _peer_id in calculators)
        ), f'invalid participants: {calculators}'
        assert aggregator and isinstance(aggregator, str), f'invalid aggregator: {aggregator}'
        return StartRoundEvent(round=round,
                               calculators=calculators,
                               aggregator=aggregator)


@dataclass
class ReadyForAggregationEvent(ContractEvent):
    """An event of notifying that the aggregator is ready for aggregation."""

    TYPE = 'ready_for_aggregation'

    round: int

    @classmethod
    def contract_to_event(cls, contract: dict) -> 'ReadyForAggregationEvent':
        event_type = contract.get('type')
        round = contract.get('round')
        assert event_type == cls.TYPE, f'合约类型错误: {event_type}'
        assert isinstance(round, int) and round > 0, f'invalid round: {round}'
        return ReadyForAggregationEvent(round=round)


@dataclass
class CloseRoundEvent(ContractEvent):
    """An event of closing a specific round of training."""

    TYPE = 'close_round'

    round: int

    @classmethod
    def contract_to_event(cls, contract: dict) -> 'CloseRoundEvent':
        event_type = contract.get('type')
        round = contract.get('round')
        assert event_type == cls.TYPE, f'合约类型错误: {event_type}'
        assert isinstance(round, int) and round > 0, f'invalid round: {round}'
        return CloseRoundEvent(round=round)


UploadTrainingResultsEvent = ApplySharedFileSendingDataEvent
DistributeParametersEvent = ApplySharedFileSendingDataEvent


class SimpleFedAvgEventFactory(TaskMessageEventFactory):

    _CLASS_MAP = {
        CheckinEvent.TYPE: CheckinEvent,
        CheckinResponseEvent.TYPE: CheckinResponseEvent,
        StartRoundEvent.TYPE: StartRoundEvent,
        ReadyForAggregationEvent.TYPE: ReadyForAggregationEvent,
        CloseRoundEvent.TYPE: CloseRoundEvent,
        **TaskMessageEventFactory._CLASS_MAP
    }


class SimpleFedAvgContractor(TaskMessageContractor):

    def __init__(self, task_id: str):
        super().__init__(task_id=task_id)
        self._event_factory = SimpleFedAvgEventFactory

    def checkin(self, peer_id: str) -> str:
        """Checkin to the task.

        :return
            A nonce string used for identifying matched sync_state reply.
        """
        nonce = secrets.token_hex(16)
        event = CheckinEvent(peer_id=peer_id, nonce=nonce)
        self._new_contract(targets=self.EVERYONE, event=event)
        return nonce

    def respond_check_in(self,
                         round: int,
                         aggregator: str,
                         nonce: str,
                         requester_id: str):
        """Respond checkin event."""
        event = CheckinResponseEvent(round=round, aggregator=aggregator, nonce=nonce)
        self._new_contract(targets=[requester_id], event=event)

    def start_round(self,
                    calculators: List[str],
                    round: int,
                    aggregator: str):
        """Create a round of training."""
        event = StartRoundEvent(round=round,
                                calculators=calculators,
                                aggregator=aggregator)
        self._new_contract(targets=self.EVERYONE, event=event)

    def notify_ready_for_aggregation(self, round: int):
        """Notify all that the aggregator is ready for aggregation."""
        event = ReadyForAggregationEvent(round=round)
        self._new_contract(targets=self.EVERYONE, event=event)

    def close_round(self, round: int):
        """Start a round of training."""
        event = CloseRoundEvent(round=round)
        self._new_contract(targets=self.EVERYONE, event=event)

In [46]:
import io
import os
import sys
import traceback
from abc import ABCMeta, abstractmethod
from typing import Dict, Tuple
from zipfile import ZipFile

import torch
from alphafed import get_result_dir, logger
from alphafed.data_channel.shared_file_data_channel import \
    SharedFileDataChannel
from alphafed.scheduler import Scheduler
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [47]:
class SimpleFedAvgScheduler(Scheduler, metaclass=ABCMeta):
    """A simple FedAvg implementation as an example of customized scheduler."""

    _INIT = 'init'
    _GETHORING = 'gethering'
    _READY = 'ready'
    _IN_A_ROUND = 'in_a_round'
    _UPDATING = 'updating'
    _CALCULATING = 'calculating'
    _WAIT_FOR_AGGR = 'wait_4_aggr'
    _AGGREGATING = 'aggregating'
    _PERSISTING = 'persisting'
    _CLOSING_ROUND = 'closing_round'
    _FINISHING = 'finishing'

    def __init__(self, clients: int, rounds: int):
        """Init.

        Args:
            clients:
                The number of calculators.
            rounds:
                The number of training rounds.
        """
        super().__init__()
        self.state = self._INIT

        self.clients = clients
        self.rounds = rounds

        self._participants = []

    @abstractmethod
    def build_model(self) -> Module:
        """Return a model object which will be used for training."""

    @property
    def model(self) -> Module:
        """Get the model object which is used for training."""
        if not hasattr(self, '_model'):
            self._model = self.build_model()
        return self._model

    @abstractmethod
    def build_optimizer(self, model: Module) -> Optimizer:
        """Return a optimizer object which will be used for training.

        Args:
            model:
                The model object which is used for training.
        """

    @property
    def optimizer(self) -> Optimizer:
        """Get the optimizer object which is used for training."""
        if not hasattr(self, '_optimizer'):
            self._optimizer = self.build_optimizer(model=self.model)
        return self._optimizer

    @abstractmethod
    def build_train_dataloader(self) -> DataLoader:
        """Define the training dataloader.

        You can transform the dataset, do some preprocess to the dataset.

        Return:
            training dataloader
        """

    @property
    def train_loader(self) -> DataLoader:
        """Get the training dataloader object."""
        if not hasattr(self, '_train_loader'):
            self._train_loader = self.build_train_dataloader()
        return self._train_loader

    @abstractmethod
    def build_test_dataloader(self) -> DataLoader:
        """Define the testing dataloader.

        You can transform the dataset, do some preprocess to the dataset. If you do not
        want to do testing after training, simply make it return None.

        Args:
            dataset:
                training dataset
        Return:
            testing dataloader
        """

    @property
    def test_loader(self) -> DataLoader:
        """Get the testing dataloader object."""
        if not hasattr(self, '_test_loader'):
            self._test_loader = self.build_test_dataloader()
        return self._test_loader

    @abstractmethod
    def state_dict(self) -> Dict[str, Tensor]:
        """Get the params that need to train and update.

        Only the params returned by this function will be updated and saved during aggregation.

        Return:
            List[Tensor], The list of model params.
        """

    @abstractmethod
    def load_state_dict(self, state_dict: Dict[str, Tensor]):
        """Load the params that trained and updated.

        Only the params returned by state_dict() should be loaded by this function.
        """

    @abstractmethod
    def train_an_epoch(self):
        """Define the training steps in an epoch."""

    @abstractmethod
    def test(self):
        """Define the testing steps.

        If you do not want to do testing after training, simply make it pass.
        """

    def _setup_context(self, id: str, task_id: str, is_initiator: bool = False):
        assert id, 'must specify a unique id for every participant'
        assert task_id, 'must specify a task_id for every participant'

        self.id = id
        self.task_id = task_id
        self._result_dir = get_result_dir(self.task_id)
        self._log_dir = os.path.join(self._result_dir, 'tb_logs')
        self.tb_writer = SummaryWriter(log_dir=self._log_dir)

        self.is_initiator = is_initiator

        self.contractor = SimpleFedAvgContractor(task_id=task_id)
        self.data_channel = SharedFileDataChannel(self.contractor)
        self.model
        self.optimizer
        self.round = 0

    def _run(self, id: str, task_id: str, is_initiator: bool = False, recover: bool = False):
        self._setup_context(id=id, task_id=task_id, is_initiator=is_initiator)
        self.push_log(message='Local context is ready.')
        try:
            if self.is_initiator and recover:
                self._recover_progress()
            else:
                self._clean_progress()
            self._launch_process()
        except Exception:
            # 将错误信息推送到 Playground 前端界面，有助于了解错误原因并修正
            err_stack = '\n'.join(traceback.format_exception(*sys.exc_info()))
            self.push_log(err_stack)

    def _recover_progress(self):
        """Try to recover and continue from last running."""
        # 如果上一次执行计算任务因为某些偶发原因失败了。在排除故障原因后，希望能够从失败的地方
        # 恢复计算进度继续计算，而不是重新开始，可以在这里提供恢复进度的处理逻辑。
        pass

    def _clean_progress(self):
        """Clean existing progress data."""
        # 如果曾经执行过计算任务，在计算环境中留下了一些过往的痕迹。现在想要从头开始重新运行计算
        # 任务，但是残留的数据可能会干扰当前这一次运行，可以在这里提供清理环境的处理逻辑。
        pass

    def _launch_process(self):
        self.push_log(f'Node {self.id} is up.')

        self._switch_status(self._GETHORING)
        self._check_in()

        self._switch_status(self._READY)
        self.round = 1

        for _ in range(self.rounds):
            self._switch_status(self._IN_A_ROUND)
            self._run_a_round()
            self._switch_status(self._READY)
            self.round += 1

        if self.is_initiator:
            self.push_log(f'Obtained the final results of task {self.task_id}')
            self._switch_status(self._FINISHING)
            self.test()
            self._close_task()

    def _check_in(self):
        """Check in task and get ready.

        As an initiator (and default the aggregator), records each participants
        and launches training process.
        As a participant, checkins and gets ready for training.
        """
        if self.is_initiator:
            self.push_log('Waiting for participants taking part in ...')
            self._wait_for_gathering()
        else:
            is_checked_in = False
            # the aggregator may be in special state so can not response
            # correctly nor in time, then retry periodically
            self.push_log('Checking in the task ...')
            while not is_checked_in:
                is_checked_in = self._check_in_task()
            self.push_log(f'Node {self.id} have taken part in the task.')

    def _wait_for_gathering(self):
        """Wait for participants gethering."""
        logger.debug('_wait_for_gathering ...')
        for _event in self.contractor.contract_events():
            if isinstance(_event, CheckinEvent):
                if _event.peer_id not in self._participants:
                    self._participants.append(_event.peer_id)
                    self.push_log(f'Welcome a new participant ID: {_event.peer_id}.')
                    self.push_log(f'There are {len(self._participants)} participants now.')
                self.contractor.respond_check_in(round=self.round,
                                                 aggregator=self.id,
                                                 nonce=_event.nonce,
                                                 requester_id=_event.peer_id)
                if len(self._participants) == self.clients:
                    break
        self.push_log('All participants gethered.')

    def _check_in_task(self) -> bool:
        """Try to check in the task."""
        nonce = self.contractor.checkin(peer_id=self.id)
        logger.debug('_wait_for_check_in_response ...')
        for _event in self.contractor.contract_events(timeout=30):
            if isinstance(_event, CheckinResponseEvent):
                if _event.nonce != nonce:
                    continue
                self.round = _event.round
                self._aggregator = _event.aggregator
                return True
        return False

    def _run_a_round(self):
        """Perform a round of FedAvg calculation.

        As an aggregator, selects a part of participants as actual calculators
        in the round, distributes latest parameters to them, collects update and
        makes aggregation.
        As a participant, if is selected as a calculator, calculates and uploads
        parameter update.
        """
        if self.is_initiator:
            self._run_as_aggregator()
        else:
            self._run_as_data_owner()

    def _run_as_aggregator(self):
        self._start_round()
        self._distribute_model()
        self._process_aggregation()
        self._close_round()

    def _start_round(self):
        """Prepare and start calculation of a round."""
        self.push_log(f'Begin the training of round {self.round}.')
        self.contractor.start_round(round=self.round,
                                    calculators=self._participants,
                                    aggregator=self.id)
        self.push_log(f'Calculation of round {self.round} is started.')

    def _distribute_model(self):
        buffer = io.BytesIO()
        torch.save(self.state_dict(), buffer)
        self.push_log('Distributing parameters ...')
        accept_list = self.data_channel.send_stream(source=self.id,
                                                    target=self._participants,
                                                    data_stream=buffer.getvalue())
        self.push_log(f'Successfully distributed parameters to: {accept_list}')
        if len(self._participants) != len(accept_list):
            reject_list = [_target for _target in self._participants
                           if _target not in accept_list]
            self.push_log(f'Failed to distribute parameters to: {reject_list}')
            raise RuntimeError('Failed to distribute parameters to some participants.')
        self.push_log('Distributed parameters to all participants.')

    def _process_aggregation(self):
        """Process aggregation depending on specific algorithm."""
        self._switch_status(self._WAIT_FOR_AGGR)
        self.contractor.notify_ready_for_aggregation(round=self.round)
        self.push_log('Now waiting for executing calculation ...')
        accum_result, result_count = self._wait_for_calculation()
        if result_count < self.clients:
            self.push_log('Task failed because some calculation results lost.')
            raise RuntimeError('Task failed because some calculation results lost.')
        self.push_log(f'Received {result_count} copies of calculation results.')

        self._switch_status(self._AGGREGATING)
        self.push_log('Begin to aggregate and update parameters.')
        for _key in accum_result.keys():
            if accum_result[_key].dtype in (
                torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64
            ):
                logger.warn(f'average a int value may lose precision: {_key=}')
                accum_result[_key].div_(result_count, rounding_mode='trunc')
            else:
                accum_result[_key].div_(result_count)
        self.load_state_dict(accum_result)
        self.push_log('Obtained a new version of parameters.')

    def _wait_for_calculation(self) -> Tuple[Dict[str, torch.Tensor], int]:
        """Wait for every calculator finish its task or timeout."""
        result_count = 0
        accum_result = self.state_dict()
        for _param in accum_result.values():
            if isinstance(_param, torch.Tensor):
                _param.zero_()

        self.push_log('Waiting for training results ...')
        while result_count < len(self._participants):
            source, training_result = self.data_channel.receive_stream(receiver=self.id)
            buffer = io.BytesIO(training_result)
            _new_state_dict = torch.load(buffer)
            for _key in accum_result.keys():
                accum_result[_key].add_(_new_state_dict[_key])
            result_count += 1
            self.push_log(f'Received calculation results from ID: {source}')
        return accum_result, result_count

    def _close_round(self):
        """Close current round when finished."""
        self._switch_status(self._CLOSING_ROUND)
        self.contractor.close_round(round=self.round)
        self.push_log(f'The training of Round {self.round} complete.')

    def _run_as_data_owner(self):
        self._wait_for_starting_round()
        self._switch_status(self._UPDATING)
        self._wait_for_updating_model()

        self._switch_status(self._CALCULATING)
        self.push_log('Begin to run calculation ...')
        self.train_an_epoch()
        self.push_log('Local calculation complete.')

        self._wait_for_uploading_model()
        buffer = io.BytesIO()
        torch.save(self.state_dict(), buffer)
        self.push_log('Pushing local update to the aggregator ...')
        self.data_channel.send_stream(source=self.id,
                                      target=[self._aggregator],
                                      data_stream=buffer.getvalue())
        self.push_log('Successfully pushed local update to the aggregator.')
        self._switch_status(self._CLOSING_ROUND)
        self._wait_for_closing_round()

        self.push_log(f'ID: {self.id} finished training task of round {self.round}.')

    def _wait_for_starting_round(self):
        """Wait for starting a new round of training."""
        self.push_log(f'Waiting for training of round {self.round} begin ...')
        for _event in self.contractor.contract_events():
            if isinstance(_event, StartRoundEvent):
                self.push_log(f'Training of round {self.round} begins.')
                return

    def _wait_for_updating_model(self):
        """Wait for receiving latest parameters from aggregator."""
        self.push_log('Waiting for receiving latest parameters from the aggregator ...')
        _, parameters = self.data_channel.receive_stream(receiver=self.id)
        buffer = io.BytesIO(parameters)
        new_state_dict = torch.load(buffer)
        self.load_state_dict(new_state_dict)
        self.push_log('Successfully received latest parameters.')
        return

    def _wait_for_uploading_model(self):
        """Wait for uploading trained parameters to aggregator."""
        self.push_log('Waiting for aggregation begin ...')
        for _event in self.contractor.contract_events():
            if isinstance(_event, ReadyForAggregationEvent):
                return

    def _wait_for_closing_round(self):
        """Wait for closing current round of training."""
        self.push_log(f'Waiting for closing signal of training round {self.round} ...')
        for _event in self.contractor.contract_events():
            if isinstance(_event, CloseRoundEvent):
                return

    def _close_task(self, is_succ: bool = True):
        """Close the FedAvg calculation.

        As an aggregator, broadcasts the finish task event to all participants,
        uploads the final parameters and tells L1 task manager the task is complete.
        As a participant, do nothing.
        """
        self.push_log(f'Closing task {self.task_id} ...')
        if self.is_initiator:
            self._switch_status(self._FINISHING)
            report_file_path, model_file_path = self._prepare_task_output()
            self.contractor.upload_metric_report(receivers=self.contractor.EVERYONE,
                                                 report_file=report_file_path)
            self.contractor.upload_model(receivers=self.contractor.EVERYONE,
                                         model_file=model_file_path)
            self.contractor.notify_task_completion(result=True)
        self.push_log(f'Task {self.task_id} closed. Byebye!')

    def _prepare_task_output(self) -> Tuple[str, str]:
        """Generate final output files of the task.

        Return:
            Local paths of the report file and model file.
        """
        self.push_log('Uploading task achievement and closing task ...')

        os.makedirs(self._result_dir, exist_ok=True)

        report_file = os.path.join(self._result_dir, "report.zip")
        with ZipFile(report_file, 'w') as report_zip:
            for path, _, filenames in os.walk(self._log_dir):
                rel_dir = os.path.relpath(path=path, start=self._result_dir)
                rel_dir = rel_dir.lstrip('.')  # ./file => file
                for _file in filenames:
                    rel_path = os.path.join(rel_dir, _file)
                    report_zip.write(os.path.join(path, _file), rel_path)
        report_file_path = os.path.abspath(report_file)

        model_file = os.path.join(self._result_dir, "model.pt")
        with open(model_file, 'wb') as f:
            torch.save(self.state_dict(), f)
        model_file_path = os.path.abspath(model_file)

        self.push_log('Task achievement files are ready.')
        return report_file_path, model_file_path


In [48]:
import os
from time import time
from typing import Dict

import torch
import torch.nn.functional as F
from alphafed import get_dataset_dir, logger
from torch.nn import Conv2d, Dropout2d, Linear, Module
from torch.optim import SGD, Optimizer
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [49]:
class ConvNet(Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = Conv2d(in_channels=1, out_channels=10, kernel_size=5)
        self.conv2 = Conv2d(in_channels=10, out_channels=20, kernel_size=5)
        self.conv2_drop = Dropout2d()
        self.fc1 = Linear(in_features=320, out_features=50)
        self.fc2 = Linear(in_features=50, out_features=10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

In [50]:
class SimpleTaskScheduler(SimpleFedAvgScheduler):

    def __init__(self,
                 clients: int,
                 rounds: int,
                 batch_size: int,
                 learning_rate: float,
                 momentum: float) -> None:
        super().__init__(clients=clients, rounds=rounds)
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.momentum = momentum

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.seed = 42
        torch.manual_seed(self.seed)

    def build_model(self) -> Module:
        model = ConvNet()
        return model.to(self.device)

    def build_optimizer(self, model: Module) -> Optimizer:
        assert self.model, 'must initialize model first'
        return SGD(self.model.parameters(),
                   lr=self.learning_rate,
                   momentum=self.momentum)

    def build_train_dataloader(self) -> DataLoader:
        return DataLoader(
            datasets.MNIST(
                get_dataset_dir(self.task_id),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            ),
            batch_size=self.batch_size,
            shuffle=True
        )

    def build_test_dataloader(self) -> DataLoader:
        return DataLoader(
            datasets.MNIST(
                get_dataset_dir(self.task_id),
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            ),
            batch_size=self.batch_size,
            shuffle=False
        )

    def state_dict(self) -> Dict[str, torch.Tensor]:
        return self.model.state_dict()

    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
        self.model.load_state_dict(state_dict)

    def train_an_epoch(self) -> None:
        self.model.train()
        for data, labels in self.train_loader:
            data: torch.Tensor
            labels: torch.Tensor
            data, labels = data.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.nll_loss(output, labels)
            loss.backward()
            self.optimizer.step()

    def test(self):
        start = time()
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, labels in self.test_loader:
                data, labels = data.to(self.device), labels.to(self.device)
                output: torch.Tensor = self.model(data)
                test_loss += F.nll_loss(output, labels, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(labels.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        accuracy = correct / len(self.test_loader.dataset)
        correct_rate = 100. * accuracy
        logger.info(f'Test set: Average loss: {test_loss:.4f}')
        logger.info(
            f'Test set: Accuracy: {accuracy} ({correct_rate:.2f}%)'
        )

        end = time()

        self.tb_writer.add_scalar('timer/run_time', end - start, self.round)
        self.tb_writer.add_scalar('test_results/average_loss', test_loss, self.round)
        self.tb_writer.add_scalar('test_results/accuracy', accuracy, self.round)
        self.tb_writer.add_scalar('test_results/correct_rate', correct_rate, self.round)

In [51]:
scheduler = SimpleTaskScheduler(clients=2,
                                rounds=5,
                                batch_size=128,
                                learning_rate=0.01,
                                momentum=0.9)
scheduler.launch_task(task_id='943d472cc5a74d17a6a01d0e9a8f4707')