# Deploy comunication compression scheme in FedLab

This tutorial provides comprehensive examples about implementing a communication efficiency scheme in FedLab. 

We take the baseline gradient compression algorithms as examples (top-k for gradient sparsification and QSGD for gradient quantization).

## Compress example

In [1]:
import sys
sys.path.append("../")

from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
import torch

tpk_compressor = TopkCompressor(compress_ratio=0.05) # top 5% gradient
qsgd_compressor = QSGDCompressor(n_bit=8)


In [2]:
# top-k
tensor = torch.randn(size=(100,))
shape = tensor.shape
print("To be compressed tensor:", tensor)

# compress
values, indices = tpk_compressor.compress(tensor)
print("Compressed results top-k values:",values)
print("Compressed results top-k indices:", indices)

# decompress
decompressed = tpk_compressor.decompress(values, indices, shape)
print("Decompressed results:", decompressed)

In [3]:
# qsgd
tensor = torch.randn(size=(100,))
shape = tensor.shape
print("To be compressed tensor:", tensor)

# compress
norm, signs, values = qsgd_compressor.compress(tensor)
print("Compressed results QSGD norm:", norm)
print("Compressed results QSGD signs:", signs)
print("Compressed results QSGD values:", values)


In [4]:
# decompress
decompressed = qsgd_compressor.decompress([norm, signs, values])
print("Decompressed results:", decompressed)

## Use compressor in federated learning

For example on the client side, we could compress the tensors are to compressed and upload the compressed results to server. And server could decompress the tensors follows the compression agreements.

In jupyter notebook, we take the standalone scenario as example.

In [5]:
from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

class CompressSerialClientTrainer(SGDSerialClientTrainer):
    def setup_compressor(self, compressor):
        #self.compressor = TopkCompressor(compress_ratio=k)
        self.compressor = compressor

    @property
    def uplink_package(self):
        package = super().uplink_package
        new_package = []
        for content in package:
            pack = [self.compressor.compress(content[0])]
            new_package.append(pack)
        return new_package

class CompressServerHandler(SyncServerHandler):
    def setup_compressor(self, compressor, type):
        #self.compressor = TopkCompressor(compress_ratio=k)
        self.compressor = compressor
        self.type = type

    def load(self, payload) -> bool:
        if self.type == "topk":
            values, indices = payload[0]
            decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

        if self.type == "qsgd":
            n, s, l = payload[0]
            decompressed_payload = self.compressor.decompress((n,s,l))
        
        return super().load([decompressed_payload])

In [6]:
# main, this part we follow the pipeline in pipeline_tutorial.ipynb
# But replace the hander and trainer by the above defined for communication compression

# configuration
import os
from opcode import cmp_op
from munch import Munch
from fedlab.models.mlp import MLP

model = MLP(784, 10)
args = Munch

args.total_client = 100
args.alpha = 0.5
args.seed = 42
args.preprocess = False if os.path.exists("../datasets/mnist/fedmnist/train/data2.pkl") else True
args.cuda = True if torch.cuda.is_available() else False
args.cmp_op = "qsgd" # "topk, qsgd"

args.k = 0.1 # topk
args.bit = 8 # qsgd

if args.cmp_op == "topk":
    compressor = TopkCompressor(args.k)

if args.cmp_op == "qsgd":
    compressor = QSGDCompressor(args.bit)

from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST

fed_mnist = PartitionedMNIST(root="../datasets/mnist/",
                             path="../datasets/mnist/fedmnist/",
                             num_clients=args.total_client,
                             partition="noniid-labeldir",
                             dir_alpha=args.alpha,
                             seed=args.seed,
                             preprocess=args.preprocess,
                             download=True,
                             verbose=True,
                             transform=transforms.Compose([
                                 transforms.ToPILImage(),
                                 transforms.ToTensor()
                             ]))

dataset = fed_mnist.get_dataset(0)  # get the 0-th client's dataset
dataloader = fed_mnist.get_dataloader(
    0,
    batch_size=128)  # get the 0-th client's dataset loader with batch size 128


In [7]:
# client
from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer

# local train configuration
args.epochs = 5
args.batch_size = 128
args.lr = 0.1

trainer = CompressSerialClientTrainer(model, args.total_client,
                                 cuda=args.cuda)  # serial trainer
# trainer = SGDClientTrainer(model, cuda=True) # single trainer

trainer.setup_dataset(fed_mnist)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_compressor(compressor)

# server
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

# global configuration
args.com_round = 10
args.sample_ratio = 0.1

handler = CompressServerHandler(model=model,
                            global_round=args.com_round,
                            num_clients=args.total_client,
                            sample_ratio=args.sample_ratio,
                            cuda=args.cuda)
handler.setup_compressor(compressor, args.cmp_op)

In [8]:
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision

class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        while self.handler.if_stop is False:
            # server side
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package

            # client side
            self.trainer.local_process(broadcast, sampled_clients)
            uploads = self.trainer.uplink_package

            # server side
            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(),
                                 self.test_loader)
            print(f"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}")


test_data = torchvision.datasets.MNIST(root="../datasets/mnist/",
                                       train=False,
                                       download=True,
                                       transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=1024)

standalone_eval = EvalPipeline(handler=handler,
                               trainer=trainer,
                               test_loader=test_loader)
standalone_eval.main()