# Federated Learning Applications to Wireless Networks - A Demonstration

Federated Learning (FL) has many advantages with applications to future wirless networks.
As a distributed learning technique, FL is very promising for IoT and fringe-devices where the environment is often bandwidth-limited and energy-efficiency is a premium.
In addition, the privacy of data collection and distribution is of growing concern with these devices maintaining sensitive information, such as real-time location data and even medical information.

In this demonstration we will examine the benefits of using FL to train Deep Neural Networks (DNNs) in application to wireless network systems.
Specifically, we will see that FL is applicable in the areas of:

- "Green" communications
- Low bandwidth environments
- Data privacy

To demonstrate the advantages of low-bandwidth, "green", and privacy-centric communication, we will build a Convolutional Neural Network (CNN) image classifier and train using FL techniques on the traditional CIFAR10 dataset. Each image within CIFAR10 is relatively large. If network devices were required to transmit these images to a centralized server for model training there would be considerable bandwidth overhead. As such, we will show that FL can be used to lower communication overhead by training localized models that only transmit model parameters instead of the raw data over the network. In addition, because the images themselves are not being transmitted this approach is privacy-centric by nature. The global model at the central server does not need to know the details of the data, only the outcome of each localized model.

## TOC

This demo is organized into the following sections:

- [Setup environment](#setup-environment)
- [Load datasets](#load-datasets)
    - [CIFAR10](#cifar10)
- ...

## Setup environment

In [66]:
import os

In [67]:
workspace_root = '~/Desktop/fl-demo-workspace'
data_path = os.path.join(workspace_root, 'data')

## Load datasets

In [68]:
import torchvision
import torchvision.transforms as transforms

### CIFAR10

The CIFAR10 dataset is a collection of $60,000$ images of size $32\times32$ and $3$ color channels that are split into 10 different classes.

The classes are represented by the set: $c \in \{'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'\}$

In [69]:
# Define label list (indices are important!)
labels_cifar10 = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [70]:
# Define CIFAR10 image transformations for train/test sets.
transform_train_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [71]:
# Automatically load CIFAR10 dataset into train/test sets.
trainset_cifar10 = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train_cifar10)
testset_cifar10 = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test_cifar10)

Files already downloaded and verified
Files already downloaded and verified


In [72]:
# Define data loaders for train/test sets.
trainloader_cifar10 = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader_cifar10 = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

## Define DNN Models

In [73]:
import torch
import torch.nn
import torch.optim
import torch.utils.data

In [74]:
class CIFAR10Classifier(torch.nn.Module):
    def __init__(self): # , w: int, h: int, n_channels: int, n_classes: int, n_conv_layers: int = 3
        super().__init__()

        # conv_layers = []
        # in_channels = n_channels
        # out_channels = n_channels*2
        # for i in range(n_conv_layers):
        #     conv_layers.append(torch.nn.Conv2d(in_channels=n_channels))

        self.block_cnn = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=63, kernel_size=(3,3), padding=1), # 3x32x32 --> 63x32x32
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2,2), stride=2), # 32x32 --> 16x16
            torch.nn.Conv2d(in_channels=63, out_channels=126, kernel_size=(3,3), padding=1), # 63x32x32 --> 126x16x16
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2,2), stride=2), # 16x16 --> 8x8
            torch.nn.Conv2d(in_channels=126, out_channels=252, kernel_size=(3,3), padding=1), # 126x16x16 --> 252x8x8
            torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=252, out_channels=252, kernel_size=(3,3), padding=1), # 252x8x8 --> 252x8x8
            # torch.nn.ReLU(),
        )

        self.block_linear = torch.nn.Sequential(
            torch.nn.Linear(in_features=252*8*8, out_features=1000),
            torch.nn.Linear(in_features=1000, out_features=10),
        )

    def forward(self, x):

        # Feed input through sequential CNN/MaxPool layers.
        out = self.block_cnn(x)

        # Reshape CNN-block output to fit into linear-block.
        # Reshapes into: (batches, channels, width, height) --> (batches, channels * width * height)
        out = out.view(out.size(0), -1)

        # Feed reshaped CNN-block output into linear-block.
        out = self.block_linear(out)
        return out

## Training

In [75]:
import time
from contextlib import contextmanager

class timecontext:
    """Elapsed time context manager."""
    def __enter__(self):
        self.seconds = time.time()
        return self
    
    def __exit__(self, type, value, traceback):
        self.seconds = time.time() - self.seconds

@contextmanager
def timecontextprint(description='Elapsed time'):
    """Context manager to print elapsed time from call."""
    with timecontext() as t:
        yield t
    print(f"{description}: {t.seconds} seconds")

### Traditional Network (not FL!)

In [76]:
def compute_accuracy(pred, truth):
    pred = pred.float()
    truth = truth.float()
    return (pred == truth).sum().float()/truth.size(0)*100.0

In [77]:
def evaluate(model, loader, device='cpu'):
    model.eval()
    ys, y_preds = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        ys.append(y)
        y_preds.append(torch.argmax(model(x), dim=1))
    
    y = torch.cat(ys, dim=0)
    y_pred = torch.cat(y_preds, dim=0)
    return compute_accuracy(y_pred, y)

In [78]:
def train(model, loader, epoch, optim, criterion, device='cpu'):
    """Helper to train a single model."""
    model.train()
    for e in range(epoch):
        running_loss = 0.0
        for x, y in loader:

            # Send data to desired device.
            x = x.to(device)
            y = y.to(device)

            # Evaluate the model.
            y_pred = model(x)

            # Compute losses.
            loss = criterion(y_pred, y)

            # Zero the gradient, back-propagate, and step the optimizer.
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Accumulate the loss for this epoch.
            running_loss += loss.item()

        # Report epoch results.
        print(f'Epoch {e}: loss {running_loss}')

In [79]:
# Set runtime device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [80]:
# Define traditionally-trained singular model.
model_trad = CIFAR10Classifier()

In [65]:
# Learning hyperparameters.
epoch = 1
lr = 1e-1

# Train the model.
# Display training time too.
with timecontextprint() as elapsed:
    model_trad.to(device)
    optim = torch.optim.Adam(model_trad.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    train(model_trad, loader=trainloader_cifar10, epoch=epoch, optim=optim, criterion=criterion, device=device)

KeyboardInterrupt: 

In [None]:
# Evaluate the model.
model_trad.eval()
train_acc = get_model_acc(model, train_loader, device=device)
test_acc = get_model_acc(model, test_loader, device=device)
print(f'Training accuracy: {train_acc}, Testing accuracy: {test_acc}')