# Bandwidth

Add in `requirements.txt` file:
```
flwr==1.10.0
ray==2.6.3
flwr-datasets[vision]==0.2.0
torch==2.2.1
torchvision==0.17.1
matplotlib==3.8.3
scikit-learn==1.4.2
seaborn==0.13.2
ipywidgets==8.1.2
transformers==4.42.4
accelerate==0.30.0
```

#### 1. Load imports

In [9]:
from flwr.client.mod import parameters_size_mod

from collections import OrderedDict
import logging
from logging import INFO
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common.logger import (
    ConsoleHandler,
    console_handler,
    FLOWER_LOGGER,
    LOG_COLORS,
)
from logging import LogRecord
from typing import Dict, List, Optional, Tuple, Union

from flwr.server import ServerAppComponents
from flwr.client import Client, ClientApp, NumPyClient
from flwr.client.mod import parameters_size_mod
from flwr.common import (
    Context,
    EvaluateRes,
    FitIns,
    FitRes,
    MessageType,
    Parameters,
    Scalar,
    Context,
    parameters_to_ndarrays,
    ndarrays_to_parameters,
)
from flwr.common.logger import (
    console_handler,
    log,
    update_console_handler,
)
from flwr.server import ClientManager, ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
import torch
from transformers import AutoModelForCausalLM, GPTNeoXForCausalLM


# Customize logging for the course.
class InfoFilter(logging.Filter):
    def filter(self, record):
        return record.levelno == INFO


FLOWER_LOGGER.removeHandler(console_handler)

# To filter logging coming from the Simulation Engine
# so it's more readable in notebooks
from logging import ERROR
backend_setup = {"init_args": {"logging_level": ERROR, "log_to_driver": True}}

class ConsoleHandlerV2(ConsoleHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def format(self, record: LogRecord) -> str:
        """Format function that adds colors to log level."""
        if self.json:
            log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
        else:
            log_fmt = (
                f"{LOG_COLORS[record.levelname] if self.colored else ''}"
                f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
                f"{LOG_COLORS['RESET'] if self.colored else ''}"
                f": %(message)s"
            )
        formatter = logging.Formatter(log_fmt)
        return formatter.format(record)


console_handlerv2 = ConsoleHandlerV2(
    timestamps=False,
    json=False,
    colored=True,
)
console_handlerv2.setLevel(INFO)
console_handlerv2.addFilter(InfoFilter())
FLOWER_LOGGER.addHandler(console_handlerv2)


def set_weights(net, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict(
        {k: torch.tensor(v) for k, v in params_dict}
    )
    net.load_state_dict(state_dict, strict=True)


def get_weights(net):
    ndarrays = [
        val.cpu().numpy() for _, val in net.state_dict().items()
    ]
    return ndarrays


#### 2.  Define the model

*  Initialize the model.

In [10]:
model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-14m",
    cache_dir="./pythia-14m/cache",
)



Find more information about [EleutherAI/pythia-14m](https://huggingface.co/EleutherAI/pythia-14m)

* Get some Model values.

In [11]:
vals = model.state_dict().values()
total_size_bytes = sum(p.element_size() * p.numel() for p in vals)
total_size_mb = int(total_size_bytes / (1024**2))

log(INFO, "Model size is: {} MB".format(total_size_mb))

[92mINFO [0m: Model size is: 53 MB
[92mINFO [0m: Model size is: 53 MB


* Define the FlowerClient.

In [12]:
class FlowerClient(NumPyClient):
    def __init__(self, net):
        self.net = net

    def fit(self, parameters, config):
        set_weights(self.net, parameters)
        # No actual training here
        return get_weights(self.net), 1, {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        # No actual evaluation here
        return float(0), int(1), {"accuracy": 0}


def client_fn(context: Context) -> FlowerClient:
    return FlowerClient(model).to_client()


client = ClientApp(
    client_fn,
    mods=[parameters_size_mod],
)

* Define the custom strategy: BandwidthTrackingFedAvg.

In [13]:
bandwidth_sizes = []


class BandwidthTrackingFedAvg(FedAvg):
    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}

        # Track sizes of models received
        for _, res in results:
            ndas = parameters_to_ndarrays(res.parameters)
            size = int(sum(n.nbytes for n in ndas) / (1024**2))
            log(INFO, f"Server receiving model size: {size} MB")
            bandwidth_sizes.append(size)

        # Call FedAvg for actual aggregation
        return super().aggregate_fit(server_round, results, failures)

    def configure_fit(self, server_round, parameters, client_manager):
        # Call FedAvg for actual configuration
        instructions = super().configure_fit(
            server_round, parameters, client_manager
        )

        # Track sizes of models to be sent
        for _, ins in instructions:
            ndas = parameters_to_ndarrays(ins.parameters)
            size = int(sum(n.nbytes for n in ndas) / (1024**2))
            log(INFO, f"Server sending model size: {size} MB")
            bandwidth_sizes.append(size)

        return instructions

In [14]:
params = ndarrays_to_parameters(get_weights(model))

def server_fn(context: Context):
    strategy = BandwidthTrackingFedAvg(
        fraction_evaluate=0.0,
        initial_parameters=params,
    )
    config = ServerConfig(num_rounds=1)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )


server = ServerApp(server_fn=server_fn)

* Run the simulation.

In [15]:
run_simulation(server_app=server,
               client_app=client,
               num_supernodes=2,
               backend_config=backend_setup
               )

[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=1, no round_timeout
[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=1, no round_timeout
[92mINFO [0m: 
[92mINFO [0m: 
[92mINFO [0m: [INIT]
[92mINFO [0m: [INIT]
[92mINFO [0m: Using initial global parameters provided by strategy
[92mINFO [0m: Using initial global parameters provided by strategy
[92mINFO [0m: Evaluating initial global parameters
[92mINFO [0m: Evaluating initial global parameters
[92mINFO [0m: 
[92mINFO [0m: 
[92mINFO [0m: [ROUND 1]
[92mINFO [0m: [ROUND 1]
[92mINFO [0m: Server sending model size: 53 MB
[92mINFO [0m: Server sending model size: 53 MB
[92mINFO [0m: Server sending model size: 53 MB
[92mINFO [0m: Server sending model size: 53 MB
[92mINFO [0m: configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m: configure_fit: strategy sampled 2 clients (out of 2)
[2m[36m(ClientAppActor pid=845)[0m [92mINFO [0m:      {'fitins.parameters': {'parameter

* Log how much bandwidth was used!

In [16]:
log(INFO, "Total bandwidth used: {} MB".format(sum(bandwidth_sizes)))

[92mINFO [0m: Total bandwidth used: 106 MB
[92mINFO [0m: Total bandwidth used: 106 MB
