In [1]:
from typing import Optional

from p2pfl.communication.commands.message.start_learning_command import StartLearningCommand
from p2pfl.communication.protocols.communication_protocol import CommunicationProtocol
from p2pfl.communication.protocols.protobuff.grpc import GrpcCommunicationProtocol
from p2pfl.learning.aggregators.aggregator import Aggregator
from p2pfl.learning.aggregators.fedavg import FedAvg
from p2pfl.learning.aggregators.scaffold import Scaffold
from p2pfl.learning.dataset.p2pfl_dataset import P2PFLDataset
from p2pfl.learning.frameworks.learner import Learner
from p2pfl.learning.frameworks.learner_factory import LearnerFactory
from p2pfl.learning.frameworks.p2pfl_model import P2PFLModel
from p2pfl.learning.frameworks.simulation import try_init_learner_with_ray
from p2pfl.node_state import NodeState
from p2pfl.stages.workflows import LearningWorkflow


class Node:

    def __init__(
        self,
        model: P2PFLModel,
        data: P2PFLDataset,
        address: str = "127.0.0.1",
        learner: Optional[Learner] = None,
        aggregator: Optional[Aggregator] = None,
        protocol: Optional[CommunicationProtocol] = None,
        simulation: bool = False,
        **kwargs,
    ) -> None:

        self._communication_protocol = GrpcCommunicationProtocol() if protocol is None else protocol
        self.addr = self._communication_protocol.set_addr(address)

        # Aggregator
        self.aggregator = FedAvg() if aggregator is None else aggregator
        self.aggregator.set_addr(self.addr)

        # Learner
        if learner is None:  # if no learner, use factory default
            learner = LearnerFactory.create_learner(model)()
        self.learner = try_init_learner_with_ray(learner)
        self.learner.set_addr(self.addr)
        self.learner.set_model(model)
        self.learner.set_data(data)
        self.learner.indicate_aggregator(self.aggregator)

        # State
        self.__running = False
        self.state = NodeState(self.addr, simulation=simulation)
        self.simulation = simulation  # so far it does not contribute much

        # Workflow
        self.learning_workflow = LearningWorkflow()

        # Commands
        commands = [
            StartLearningCommand(self.__start_learning_thread),
        ]
        self._communication_protocol.add_command(commands)


Node(
    None,
    None,
    protocol=GrpcCommunicationProtocol(),
    address="address",
    aggregator=Scaffold()
)

  from .autonotebook import tqdm as notebook_tqdm
2025-02-25 15:58:35,426	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-02-25 15:58:38,749	INFO worker.py:1812 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m
2025-02-25 15:58:47.693338: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740499127.718084    2766 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740499127.724369    2766 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-25 15:58:47.745898: I tensorflow/core/platform/cpu_feature_guard.cc:

ValueError: The address is invalid.

In [2]:
#
# This file is part of the federated_learning_p2p (p2pfl) distribution (see https://github.com/pguijas/p2pfl).
# Copyright (c) 2022 Pedro Guijas Bravo.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

"""GRPC communication protocol."""

import random
from abc import abstractmethod
from datetime import datetime
from functools import wraps
from typing import Any, Callable, Optional, Union

from p2pfl.communication.commands.command import Command
from p2pfl.communication.commands.message.heartbeat_command import HeartbeatCommand
from p2pfl.communication.protocols.communication_protocol import CommunicationProtocol
from p2pfl.communication.protocols.exceptions import CommunicationError, ProtocolNotStartedError
from p2pfl.communication.protocols.protobuff.client import ProtobuffClient
from p2pfl.communication.protocols.protobuff.gossiper import Gossiper
from p2pfl.communication.protocols.protobuff.heartbeater import Heartbeater
from p2pfl.communication.protocols.protobuff.neighbors import Neighbors
from p2pfl.communication.protocols.protobuff.proto import node_pb2
from p2pfl.communication.protocols.protobuff.server import ProtobuffServer
from p2pfl.settings import Settings
from p2pfl.utils.node_component import allow_no_addr_check


def running(func):
    """Ensure that the server is running before executing a method."""

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if not self._server.is_running():
            raise ProtocolNotStartedError("The protocol has not been started.")
        return func(self, *args, **kwargs)

    return wrapper


class ProtobuffCommunicationProtocol(CommunicationProtocol):
    """
    Protobuff communication protocol.

    Args:
        addr: Address of the node.
        commands: Commands to add to the communication protocol.

    .. todo:: https://grpc.github.io/grpc/python/grpc_asyncio.html
    .. todo:: Decouple the heeartbeat command.

    """

    def __init__(
        self,
        commands: Optional[list[Command]] = None,
    ) -> None:
        """Initialize the GRPC communication protocol."""
        print("1")
        # (addr) Super
        CommunicationProtocol.__init__(self)
        # Neighbors
        print("2")
        self._neighbors = Neighbors(self.bluid_client)
        # Gossip
        print("3")
        self._gossiper = Gossiper(self._neighbors)
        # GRPC
        print("4")
        self._server = self.build_server(self._gossiper, self._neighbors, commands)
        # Hearbeat
        print("5")
        self._heartbeater = Heartbeater(self._neighbors, self.build_msg)
        # Commands
        print("6")
        self.add_command(HeartbeatCommand(self._heartbeater))
        if commands is None:
            commands = []
        self.add_command(commands)

    @allow_no_addr_check
    @abstractmethod
    def bluid_client(self, *args, **kwargs) -> ProtobuffClient:
        """Build client function."""
        pass

    @allow_no_addr_check
    @abstractmethod
    def build_server(self, *args, **kwargs) -> ProtobuffServer:
        """Build server function."""
        pass

    def set_addr(self, addr: str) -> str:
        """Set the addr of the node."""
        # Delegate on server
        self._server.set_addr(addr)
        # Set on super
        raise NotImplementedError("The method set_addr must be implemented in the subclass.")
        print("PROPAGAR ESTO A COMPONENTES!!!!!!!!")
        print("PROPAGAR ESTO A COMPONENTES!!!!!!!!")
        print("PROPAGAR ESTO A COMPONENTES!!!!!!!!")
        print("PROPAGAR ESTO A COMPONENTES!!!!!!!!")
        addr = self._server
        return super().set_addr(addr)

    def start(self) -> None:
        """Start the GRPC communication protocol."""
        self._server.start()
        self._heartbeater.start()
        self._gossiper.start()

    @running
    def stop(self) -> None:
        """Stop the GRPC communication protocol."""
        self._heartbeater.stop()
        self._gossiper.stop()
        self._neighbors.clear_neighbors()
        self._server.stop()

    def add_command(self, cmds: Union[Command, list[Command]]) -> None:
        """
        Add a command to the communication protocol.

        Args:
            cmds: The command to add.

        """
        self._server.add_command(cmds)

    @running
    def connect(self, addr: str, non_direct: bool = False) -> bool:
        """
        Connect to a neighbor.

        Args:
            addr: The address to connect to.
            non_direct: The non direct flag.

        """
        return self._neighbors.add(addr, non_direct=non_direct)

    @running
    def disconnect(self, nei: str, disconnect_msg: bool = True) -> None:
        """
        Disconnect from a neighbor.

        Args:
            nei: The neighbor to disconnect from.
            disconnect_msg: The disconnect message flag.

        """
        self._neighbors.remove(nei, disconnect_msg=disconnect_msg)

    def build_msg(self, cmd: str, args: Optional[list[str]] = None, round: Optional[int] = None) -> node_pb2.RootMessage:
        """
        Build a RootMessage to send to the neighbors.

        Args:
            cmd: Command of the message.
            args: Arguments of the message.
            round: Round of the message.

        Returns:
            RootMessage to send.

        """
        if round is None:
            round = -1
        if args is None:
            args = []
        hs = hash(str(cmd) + str(args) + str(datetime.now()) + str(random.randint(0, 100000)))
        args = [str(a) for a in args]

        return node_pb2.RootMessage(
            source=self.addr,
            round=round,
            cmd=cmd,
            message=node_pb2.Message(
                ttl=Settings.gossip.TTL,
                hash=hs,
                args=args,
            ),
        )

    def build_weights(
        self,
        cmd: str,
        round: int,
        serialized_model: bytes,
        contributors: Optional[list[str]] = None,
        weight: int = 1,
    ) -> node_pb2.RootMessage:
        """
        Build a RootMessage with a Weights payload to send to the neighbors.

        Args:
            cmd: Command of the message.
            round: Round of the message.
            serialized_model: Serialized model to send.
            contributors: List of contributors.
            weight: Weight of the message (number of samples).

        Returns:
            RootMessage to send.

        """
        if contributors is None:
            contributors = []
        return node_pb2.RootMessage(
            source=self.addr,
            round=round,
            cmd=cmd,
            weights=node_pb2.Weights(
                weights=serialized_model,
                contributors=contributors,
                num_samples=weight,
            ),
        )

    @running
    def send(
        self,
        nei: str,
        msg: Union[node_pb2.RootMessage],
        raise_error: bool = False,
        remove_on_error: bool = True,
    ) -> None:
        """
        Send a message to a neighbor.

        Args:
            nei: The neighbor to send the message.
            msg: The message to send.
            raise_error: If raise error.
            remove_on_error: If remove on error.

        """
        try:
            self._neighbors.get(nei).send(msg, raise_error=raise_error, disconnect_on_error=remove_on_error)
        except CommunicationError as e:
            if remove_on_error:
                self._neighbors.remove(nei)
            if raise_error:
                raise e

    @running
    def broadcast(self, msg: node_pb2.RootMessage, node_list: Optional[list[str]] = None) -> None:
        """
        Broadcast a message to all neighbors.

        Args:
            msg: The message to broadcast.
            node_list: Optional node list.

        """
        neis = self._neighbors.get_all(only_direct=True)
        neis_clients = [nei[0] for nei in neis.values()]
        for nei in neis_clients:
            nei.send(msg)

    @running
    def get_neighbors(self, only_direct: bool = False) -> dict[str, Any]:
        """
        Get the neighbors.

        Args:
            only_direct: The only direct flag.

        """
        return self._neighbors.get_all(only_direct)

    @running
    def wait_for_termination(self) -> None:
        """
        Get the neighbors.

        Args:
            only_direct: The only direct flag.

        """
        self._server.wait_for_termination()

    @running
    def gossip_weights(
        self,
        early_stopping_fn: Callable[[], bool],
        get_candidates_fn: Callable[[], list[str]],
        status_fn: Callable[[], Any],
        model_fn: Callable[[str], Any],
        period: Optional[float] = None,
        create_connection: bool = False,
    ) -> None:
        """
        Gossip model weights.

        Args:
            early_stopping_fn: The early stopping function.
            get_candidates_fn: The get candidates function.
            status_fn: The status function.
            model_fn: The model function.
            period: The period.
            create_connection: The create connection flag.

        """
        if period is None:
            period = Settings.gossip.MODELS_PERIOD
        self._gossiper.gossip_weights(
            early_stopping_fn,
            get_candidates_fn,
            status_fn,
            model_fn,
            period,
            create_connection,
        )


In [3]:
#
# This file is part of the federated_learning_p2p (p2pfl) distribution
# (see https://github.com/pguijas/p2pfl).
# Copyright (c) 2024 Pedro Guijas Bravo.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

"""GRPC implementation of the `CommunicationProtocol`."""

from typing import Optional

from p2pfl.communication.commands.command import Command
from p2pfl.communication.protocols.protobuff.grpc.client import GrpcClient
from p2pfl.communication.protocols.protobuff.grpc.server import GrpcServer
from p2pfl.utils.node_component import allow_no_addr_check


class GrpcCommunicationProtocol(ProtobuffCommunicationProtocol):
    """GRPC communication protocol."""

    def __init__(self, commands: Optional[list[Command]] = None) -> None:
        """Initialize the GRPC communication protocol."""
        print("a")
        super().__init__(commands)
        print("b")

    @allow_no_addr_check
    def bluid_client(self, *args, **kwargs) -> GrpcClient:
        """Build client function."""
        return GrpcClient(*args, **kwargs)

    @allow_no_addr_check
    def build_server(self, *args, **kwargs) -> GrpcServer:
        """Build server function."""
        return GrpcServer(*args, **kwargs)

GrpcCommunicationProtocol()

a
1
2
3
4
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
TODO ADDR
5
6


ValueError: Addr must be set before calling this method.