# Federated PyTorch UNET Tutorial

In [None]:
# Install dependencies if not already installed
!pip install torch

First of all we need to set up our OpenFL workspace. To do this, simply run the `fx.init()` command as follows:

In [None]:
import openfl.native as fx

# Setup default workspace, logging, etc. Install additional requirements
fx.init('torch_unet_kvasir')

In [None]:
# Import installed modules
import PIL
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from skimage import io
from torchvision import transforms as tsf
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from os import listdir

from openfl.federated import FederatedModel, FederatedDataSet
from openfl.utilities import TensorKey
from openfl.utilities import validate_file_hash

Download Kvasir dataset

In [None]:
!wget 'https://datasets.simula.no/downloads/hyper-kvasir/hyper-kvasir-segmented-images.zip' -O kvasir.zip
ZIP_SHA384 = 'e30d18a772c6520476e55b610a4db457237f151e'\
    '19182849d54b49ae24699881c1e18e0961f77642be900450ef8b22e7'
validate_file_hash('./kvasir.zip', ZIP_SHA384)
!unzip -n kvasir.zip -d ./data

Now we are ready to define our dataset and model to perform federated learning on.

In [None]:
DATA_PATH = './data/segmented-images/'

In [None]:
def read_data(image_path, mask_path):
    """
    Read image and mask from disk.
    """
    img = io.imread(image_path)
    assert(img.shape[2] == 3)
    mask = io.imread(mask_path)
    return (img, mask[:, :, 0].astype(np.uint8))


class KvasirDataset(Dataset):
    """
    Kvasir dataset contains 1000 images for all collaborators.
    Args:
        data_path: path to dataset on disk
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
    """

    def __init__(self, data_path, collaborator_count, collaborator_num, is_validation):
        self.images_path = './data/segmented-images/images/'
        self.masks_path = './data/segmented-images/masks/'
        self.images_names = [
            img_name
            for img_name in sorted(listdir(self.images_path))
            if len(img_name) > 3 and img_name[-3:] == 'jpg'
        ]

        self.images_names = self.images_names[collaborator_num:: collaborator_count]
        self.is_validation = is_validation
        assert(len(self.images_names) > 8)
        validation_size = len(self.images_names) // 8
        if is_validation:
            self.images_names = self.images_names[-validation_size:]
        else:
            self.images_names = self.images_names[: -validation_size]

        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])

    def __getitem__(self, index):
        name = self.images_names[index]
        img, mask = read_data(self.images_path + name, self.masks_path + name)
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.images_names)

Here we redefine `FederatedDataSet` methods, if we don't want to use default batch generator from `FederatedDataSet`. 

In [None]:
class KvasirFederatedDataset(FederatedDataSet):
    def __init__(self, collaborator_count=1, collaborator_num=0, batch_size=1, **kwargs):
        """Instantiate the data object
        Args:
            collaborator_count: total number of collaborators
            collaborator_num: number of current collaborator
            batch_size:  the batch size of the data loader
            **kwargs: additional arguments, passed to super init
        """
        super().__init__([], [], [], [], batch_size, num_classes=2, **kwargs)

        self.collaborator_num = int(collaborator_num)

        self.batch_size = batch_size

        self.training_set = KvasirDataset(
            DATA_PATH, collaborator_count, collaborator_num, is_validation=False
        )
        self.valid_set = KvasirDataset(
            DATA_PATH, collaborator_count, collaborator_num, is_validation=True
        )

        self.train_loader = self.get_train_loader()
        self.val_loader = self.get_valid_loader()

    def get_valid_loader(self, num_batches=None):
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.batch_size)

    def get_train_loader(self, num_batches=None):
        return DataLoader(
            self.training_set, num_workers=8, batch_size=self.batch_size, shuffle=True
        )

    def get_train_data_size(self):
        return len(self.training_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

    def get_feature_shape(self):
        return self.valid_set[0][0].shape

    def split(self, collaborator_count, shuffle=True, equally=True):
        return [
            KvasirFederatedDataset(collaborator_count,
                           collaborator_num, self.batch_size)
            for collaborator_num in range(collaborator_count)
        ]

Our Unet model

In [None]:
def soft_dice_loss(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    score = 1 - score.sum() / num
    return score


def soft_dice_coef(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    return score.sum()


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=False):
        super(Up, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        if bilinear:
            self.Up = nn.Upsample(
                scale_factor=2,
                mode="bilinear",
                align_corners=True
            )
        else:
            self.Up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.Up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX //
                        2, diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x

    def validate(
        self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
    ):
        """ Validate. Redifine function from PyTorchTaskRunner, to use our validation"""
        self.rebuild_model(round_num, input_tensor_dict, validation=True)
        self.eval()
        self.to(self.device)
        val_score = 0
        total_samples = 0

        loader = self.data_loader.get_valid_loader()
        if use_tqdm:
            loader = tqdm.tqdm(loader, desc="validate")

        with torch.no_grad():
            for data, target in loader:
                samples = target.shape[0]
                total_samples += samples
                data, target = (
                    torch.tensor(data).to(self.device),
                    torch.tensor(target).to(self.device),
                )
                output = self(data)
                # get the index of the max log-probability
                val = soft_dice_coef(output, target)
                val_score += val.sum().cpu().numpy()

        origin = col_name
        suffix = "validate"
        if kwargs["apply"] == "local":
            suffix += "_local"
        else:
            suffix += "_agg"
        tags = ("metric", suffix)
        output_tensor_dict = {
            TensorKey("dice_coef", origin, round_num, True, tags): np.array(
                val_score / total_samples
            )
        }
        return output_tensor_dict, {}


def optimizer(x): return optim.Adam(x, lr=1e-3)

Create `KvasirFederatedDataset`, federated datasets for collaborators will be created in `split()` method of this object

In [None]:
fl_data = KvasirFederatedDataset(batch_size=6)

The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with OpenFL. It provides built-in federated training function which will be used while training. Using its `setup` function, collaborator models and datasets can be automatically obtained for the experiment. 

In [None]:
# Create a federated model using the pytorch class, optimizer function, and loss function
fl_model = FederatedModel(build_model=UNet, optimizer=optimizer,
                          loss_fn=soft_dice_loss, data_loader=fl_data)

In [None]:
collaborator_models = fl_model.setup(num_collaborators=2)
collaborators = {'one': collaborator_models[0], 'two': collaborator_models[1]}

We can see the current FL plan values by running the `fx.get_plan()` function

In [None]:
# Get the current values of the FL plan. Each of these can be overridden
print(fx.get_plan())

Now we are ready to run our experiment. If we want to pass in custom FL plan settings, we can easily do that with the `override_config` parameter

In [None]:
# Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(
    collaborators, override_config={'aggregator.settings.rounds_to_train': 30})

In [None]:
# Save final model
final_fl_model.save_native('final_pytorch_model')

Let's visually evaluate the results

In [None]:
collaborator = collaborator_models[0]
loader = collaborator.runner.data_loader.get_valid_loader()
model = final_fl_model.model
model.eval()
device = final_fl_model.runner.device
model.to(device)
with torch.no_grad():
    for batch, _ in zip(loader, range(5)):
        preds = model(batch[0].to(device))
        for image, pred, target in zip(batch[0], preds, batch[1]):
            plt.figure(figsize=(10, 10))
            plt.subplot(131)
            plt.imshow(image.permute(1, 2, 0).data.cpu().numpy() * 0.5 + 0.5)
            plt.title("img")
            plt.subplot(132)
            plt.imshow(pred[0].data.cpu().numpy())
            plt.title("pred")
            plt.subplot(133)
            plt.imshow(target[0].data.cpu().numpy())
            plt.title("targ")
            plt.show()