# Federated Learning Applications to Wireless Networks - A Demonstration

Federated Learning (FL) has many advantages with applications to future wirless networks.
As a distributed learning technique, FL is very promising for IoT and fringe-devices where the environment is often bandwidth-limited and energy-efficiency is a premium.
In addition, the privacy of data collection and distribution is of growing concern with these devices maintaining sensitive information, such as real-time location data and even medical information.

In this demonstration we will examine the benefits of using FL to train Deep Neural Networks (DNNs) in application to wireless network systems.
Specifically, we will see that FL is applicable in the areas of:

- "Green" communications
- Low bandwidth environments
- Data privacy

To demonstrate the advantages of low-bandwidth, "green", and privacy-centric communication, we will build a Convolutional Neural Network (CNN) image classifier and train using FL techniques on the traditional CIFAR10 dataset. Each image within CIFAR10 is relatively large. If network devices were required to transmit these images to a centralized server for model training there would be considerable bandwidth overhead. As such, we will show that FL can be used to lower communication overhead by training localized models that only transmit model parameters instead of the raw data over the network. In addition, because the images themselves are not being transmitted this approach is privacy-centric by nature. The global model at the central server does not need to know the details of the data, only the outcome of each localized model.

## TOC

This demo is organized into the following sections:

- [Setup environment](#setup-environment)
- [Load datasets](#load-datasets)
    - [CIFAR10](#cifar10)
- ...

## Setup environment

In [255]:
import os

In [256]:
workspace_root = '/fl-demo-workspace'
data_path = os.path.join(workspace_root, 'data')

## Load datasets

In [257]:
import torch.utils.data
import torchvision
import torchvision.transforms as transforms

### CIFAR10

The CIFAR10 dataset is a collection of $60,000$ images of size $32\times32$ and $3$ color channels that are split into 10 different classes.

The classes are represented by the set: $c \in \{'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'\}$

In [258]:
# Define label list (indices are important!)
labels_cifar10 = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [259]:
# Define CIFAR10 image transformations for train/test sets.
transform_train_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [260]:
# Automatically load CIFAR10 dataset into train/test sets.
trainset_cifar10 = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train_cifar10)
testset_cifar10 = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test_cifar10)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /fl-demo-workspace/data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting /fl-demo-workspace/data/cifar-10-python.tar.gz to /fl-demo-workspace/data
Files already downloaded and verified


In [261]:
# Define data loaders for train/test sets.
trainloader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=64, shuffle=True, num_workers=2)
testloader_cifar10 = torch.utils.data.DataLoader(testset_cifar10, batch_size=100, shuffle=True, num_workers=2)

## Define DNN Models

In [262]:
import torch
import torch.nn
import torch.optim

In [263]:
class CIFAR10Classifier(torch.nn.Module):
    def __init__(self): # , w: int, h: int, n_channels: int, n_classes: int, n_conv_layers: int = 3
        super().__init__()

        # conv_layers = []
        # in_channels = n_channels
        # out_channels = n_channels*2
        # for i in range(n_conv_layers):
        #     conv_layers.append(torch.nn.Conv2d(in_channels=n_channels))

        self.block_cnn = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), padding=1), # 3x32x32 --> 63x32x32
            torch.nn.BatchNorm2d(num_features=32),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=(2,2), stride=2), # 32x32 --> 16x16
            torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=1), # 63x32x32 --> 126x16x16
            torch.nn.BatchNorm2d(num_features=64),
            torch.nn.ReLU(inplace=True),
            # torch.nn.MaxPool2d(kernel_size=(2,2), stride=2), # 16x16 --> 8x8
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding=1), # 126x16x16 --> 252x8x8
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.ReLU(inplace=True),
            # torch.nn.Conv2d(in_channels=252, out_channels=252, kernel_size=(3,3), padding=1), # 252x8x8 --> 252x8x8
            # torch.nn.ReLU(),
        )

        in_features = 128*16*16
        self.block_linear = torch.nn.Sequential(
            # torch.nn.Linear(in_features=252*8*8, out_features=1000),
            # torch.nn.Linear(in_features=1000, out_features=10),

            # torch.nn.Linear(in_features=in_features, out_features=in_features),
            # torch.nn.Linear(in_features=in_features, out_features=in_features),
            torch.nn.Linear(in_features=in_features, out_features=10),
        )

    def forward(self, x):

        # Feed input through sequential CNN/MaxPool layers.
        out = self.block_cnn(x)

        # Reshape CNN-block output to fit into linear-block.
        # Reshapes into: (batches, channels, width, height) --> (batches, channels * width * height)
        out = out.view(out.size(0), -1)

        # Feed reshaped CNN-block output into linear-block.
        out = self.block_linear(out) # outputs (batches, labels)

        # Apply softmax to get label for each prediction.
        return out

## Training

In [264]:
import time
from contextlib import contextmanager

class timecontext:
    """Elapsed time context manager."""
    def __enter__(self):
        self.seconds = time.time()
        return self
    
    def __exit__(self, type, value, traceback):
        self.seconds = time.time() - self.seconds

@contextmanager
def timecontextprint(description='Elapsed time'):
    """Context manager to print elapsed time from call."""
    with timecontext() as t:
        yield t
    print(f"{description}: {t.seconds} seconds")

In [265]:
def compute_accuracy(pred, truth):
    """Compute accuracy of predictions versus truth."""
    pred = pred.float()
    truth = truth.float()
    return (pred == truth).sum().float()/truth.size(0)*100.0

In [266]:
def compute_model_accuracy(model, loader, device='cpu'):
    """Compute accuracy of a model for a given dataset loader."""
    model.to(device)
    model.eval()
    ys, y_preds = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        ys.append(y)
        y_preds.append(torch.argmax(model(x), dim=1))
    
    y = torch.cat(ys, dim=0)
    y_pred = torch.cat(y_preds, dim=0)
    return compute_accuracy(y_pred, y)

### Traditional Network (not FL!)

In [267]:
def train(model, loader, epoch, optim, criterion, device='cpu'):
    """Helper to train a single model."""
    model.to(device) # Send model to desired device.
    model.train() # Put the model into training mode.
    for e in range(epoch):
        running_loss = 0.0
        for i, data in enumerate(loader):
            # unpack the data, and send data to desired device.
            inputs, labels = data[0].to(device), data[1].to(device)

            # Zero the parameter gradients.
            optim.zero_grad()

            # Evaluate the model.
            preds = model(inputs)

            # Compute losses.
            loss = criterion(preds, labels)

            # Back-propagate, and step the optimizer.
            loss.backward()
            optim.step()

            # Accumulate the loss for this epoch.
            running_loss += loss.item()

        # Report epoch results.
        print(f"[{e}] loss: {running_loss}")

In [268]:
# Set runtime device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [269]:
# Define traditionally-trained singular model.
model_trad = CIFAR10Classifier()
model_trad_store = os.path.join(workspace_root, 'model_trad.pt')

In [270]:
load_from_file = False

# Load model from store file.
if load_from_file and os.path.exists(model_trad_store):
    model_trad.load_state_dict(torch.load(model_trad_store))
    print(f'Loaded traditional model from file: {model_trad_store}')

# Train model.
else:

    # Learning hyperparameters.
    epoch = 10
    lr = 1e-3
    print(f'Training traditional model: epoch={epoch}, lr={lr}')

    # Train the model.
    # Display training time too.
    with timecontextprint() as elapsed:
        optim = torch.optim.Adam(model_trad.parameters(), lr=lr)
        criterion = torch.nn.CrossEntropyLoss(reduction='mean')
        train(model_trad, loader=trainloader_cifar10, epoch=epoch, optim=optim, criterion=criterion, device=device)

    # Store model state to file.
    torch.save(model_trad.state_dict(), model_trad_store)
    print(f'Saved traditional model to file: {model_trad_store}')

Training traditional model: epoch=10, lr=0.001
[0] loss: 1251.3343421816826
[1] loss: 678.6561494767666
[2] loss: 547.6483214199543
[3] loss: 449.7029107809067
[4] loss: 368.1199539154768
[5] loss: 291.6405478566885
[6] loss: 227.499128960073
[7] loss: 172.9789839796722
[8] loss: 132.75634049251676
[9] loss: 106.83524388633668
Elapsed time: 100.69234728813171 seconds
Saved traditional model to file: /fl-demo-workspace/model_trad.pt


In [271]:
# Evaluate the model.
train_acc_trad = compute_model_accuracy(model_trad, trainloader_cifar10, device=device)
test_acc_trad = compute_model_accuracy(model_trad, testloader_cifar10, device=device)
print(f'Training accuracy: {train_acc_trad}, Testing accuracy: {test_acc_trad}')

Training accuracy: 95.58199310302734, Testing accuracy: 73.47999572753906


In [272]:
import sys
size_cifar10_train = sum(sys.getsizeof(img.storage()) + sys.getsizeof(lbl) for img,lbl in trainset_cifar10)
size_model_trad = sum(sys.getsizeof(p.storage()) for p in model_trad.parameters())
print(f"CIFAR10 train images total size: {size_cifar10_train:.4e} bytes")
print(f"Traditional model parameters total size: {size_model_trad:.4e} bytes")

CIFAR10 train images total size: 6.1938e+08 bytes
Traditional model parameters total size: 1.6866e+06 bytes


As you can see, the CIFAR training dataset in its entirety uncompressed is roughly 619.38 MB, wheras the model parameters themselves are only 1.68 MB. That's a space savings by roughly 368x!

### Federated Learning

sources:

- https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-1-a04894f78029