In [None]:
import ray
ray.init()
from fl_strategies.fed_avg import FedAvg
import yaml
from fedrec.utilities.logger import NoOpLogger
from fedrec.python_executors.aggregator import AggregatorConfig
from fedrec.multiprocessing.process_manager import RayProcessManager
from typing import Callable, Dict
import sys
from json import dumps
from fedrec.multiprocessing.jobber import Jobber
from kafka import KafkaProducer
class JobExecutor:
    def __init__(self,
                 actorCls: Callable,
                 config: Dict,
                 actor_config: AggregatorConfig,
                 logger: NoOpLogger,
                 **kwargs) -> None:
        """ Class responsible for running aggregator/trainer on a single node.
        """
        # Construct trainer and do training
        if not set(['experiments','fedrec','fl_strategies']).issubset(set(sys.modules.values())):
            import experiments
            import fedrec
            import fl_strategies
        self.config = config
        self.worker = actorCls(0, config["model"], actor_config, logger, **kwargs)

        self.jobber = Jobber(
            self.worker, logger, config["multiprocessing"]["communications"])

    def run(self):
        return self.jobber.run()

def init_kafka(config):
    producer_url = "{}:{}".format(
        config["producer_url"], config["producer_port"])
    return KafkaProducer(
        bootstrap_servers=[producer_url],
        value_serializer=lambda x: dumps(x).encode('utf-8'))

In [None]:
with open("../configs/dlrm_fl.yml", 'r') as cfg:
    config = yaml.load(cfg, Loader=yaml.FullLoader)

ag_config = {
        # Seed for RNG used in shuffling the training data.
    "data_seed" : 100,
    # Seed for RNG used in initializing the model.
    "init_seed" : 100,
    # Seed for RNG used in computing the model's training loss.
    # Only relevant with internal randomness in the model, e.g. with dropout.
    "model_seed" : 100
}



In [None]:
producer = init_kafka(config["multiprocessing"]["communications"])

In [None]:
producer.send('testing1', value={"test": "test_abhinav"})

In [None]:
aggregator_cfg = AggregatorConfig(**ag_config)
JHhook = ray.remote(JobExecutor)
je_Aggregator = JHhook.remote(FedAvg,
                                config,
                                aggregator_cfg,
                                NoOpLogger())
je_Aggregator.run.remote()

In [None]:
ray.shutdown()


In [None]:
rpm = RayProcessManager()
rpm.distribute(JobExecutor, FedAvg.__name__, 1 ,
                        FedAvg, config["model"], aggregator_cfg, NoOpLogger())

rpm.start(FedAvg.__name__, "run")