# Federated PyTorch UNET Tutorial using Workflow API

In [None]:
# Install dependencies if not already installed
%pip install torch
%pip install matplotlib
%pip install ray

In [None]:
# Import installed modules
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from skimage import io
from torchvision import transforms as tsf
from torch.utils.data import Dataset, DataLoader

from os import listdir

from openfl.experimental.workflow.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.workflow.runtime import LocalRuntime
from openfl.experimental.workflow.placement import aggregator, collaborator
from openfl.utilities import validate_file_hash

Download Kvasir dataset

In [None]:
!wget 'https://datasets.simula.no/downloads/hyper-kvasir/hyper-kvasir-segmented-images.zip' -O kvasir.zip
ZIP_SHA384 = ('66cd659d0e8afd8c83408174'
            '1ade2b75dada8d4648b816f2533c8748b1658efa3d49e205415d4116faade2c5810e241e')
validate_file_hash('./kvasir.zip', ZIP_SHA384)
!unzip -n kvasir.zip -d ./data

Now we are ready to define our dataset and model to perform federated learning on.

In [4]:
DATA_PATH = './data/segmented-images/'

In [5]:
def read_data(image_path, mask_path):
    """
    Read image and mask from disk.
    """
    img = io.imread(image_path)
    assert(img.shape[2] == 3)
    mask = io.imread(mask_path)
    return (img, mask[:, :, 0].astype(np.uint8))


class KvasirDataset(Dataset):
    """
    Kvasir dataset contains 1000 images for all collaborators.
    Args:
        data_path: path to dataset on disk
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
    """

    def __init__(self, collaborator_count, collaborator_num, is_validation):
        self.images_path = DATA_PATH + 'images/'
        self.masks_path = DATA_PATH + 'masks/'
        self.images_names = [
            img_name
            for img_name in sorted(listdir(self.images_path))
            if len(img_name) > 3 and img_name[-3:] == 'jpg'
        ]

        self.images_names = self.images_names[collaborator_num:: collaborator_count]
        self.is_validation = is_validation
        assert(len(self.images_names) > 8)
        validation_size = len(self.images_names) // 8
        if is_validation:
            self.images_names = self.images_names[-validation_size:]
        else:
            self.images_names = self.images_names[: -validation_size]

        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])

    def __getitem__(self, index):
        name = self.images_names[index]
        img, mask = read_data(self.images_path + name, self.masks_path + name)
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.images_names)

Define the model:

In [11]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=False):
        super(Up, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        if bilinear:
            self.Up = nn.Upsample(
                scale_factor=2,
                mode="bilinear",
                align_corners=True
            )
        else:
            self.Up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.Up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX //
                        2, diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x

The next step is setting up the participants, an `Aggregator` and a few `Collaborator`s which will train the model, partition the dataset between the collaborators, and pass them to the appropriate runtime environment (in our case, a `LocalRuntime`).


In [12]:
# Setup participants
aggregator_ = Aggregator()
aggregator_.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = [f'collaborator{i}' for i in range(2)]
collaborators = [Collaborator(name=name) for name in collaborator_names]

for collaborator_idx, collaborator_ in enumerate(collaborators):
    collaborator_.private_attributes = {
            'train_loader': DataLoader(KvasirDataset(len(collaborators), collaborator_idx, is_validation=False),
                                                         num_workers=8, batch_size=6, shuffle=True),
            'test_loader': DataLoader(KvasirDataset(len(collaborators), collaborator_idx, is_validation=True),
                                                        num_workers=8, batch_size=6)
    }

local_runtime = LocalRuntime(aggregator=aggregator_, collaborators=collaborators, backend='single_process')

Define an aggregation algorithm, optimizer and a loss function:

In [13]:
def FedAvg(models, weights=None):
    """
    Federated averaging of model parameters.

    Args:
        models (list[torch.nn.Module]): List of PyTorch models to aggregate.
        weights (list[float], optional): List of weights for each model.
            Defaults to equal weights if not specified.
    
    Returns:
        torch.nn.Module: New model with averaged parameters.
    """    
    # Start with a new model based on the first model's architecture
    new_model = type(models[0])()
    new_model.load_state_dict(models[0].state_dict())

    # Aggregate parameters
    new_state_dict = {}
    for key in models[0].state_dict().keys():
        # Collect the corresponding parameters from all models
        tensors = [model.state_dict()[key].cpu().numpy() for model in models]
        avg = np.average(tensors, weights=weights, axis=0)
        new_state_dict[key] = torch.tensor(avg, dtype=models[0].state_dict()[key].dtype).to(models[0].state_dict()[key].device)

    new_model.load_state_dict(new_state_dict)
    return new_model

def get_optimizer(model):
    return optim.Adam(model.parameters(), lr=1e-3)

def soft_dice_loss(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    score = 1 - score.sum() / num
    return score

def soft_dice_coef(output, target):
    num = target.size(0)
    m1 = output.view(num, -1)
    m2 = target.view(num, -1)
    intersection = m1 * m2
    score = 2.0 * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
    return score

Set up work to be executed by the aggregator and the collaborators by extending `FLSpec`:

In [None]:
class FederatedFlow(FLSpec):
    def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.optimizer = optimizer
        self.n_rounds = rounds
        self.loss = 0.

    @aggregator
    def start(self):
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.current_round = 0
        self.next(self.aggregated_model_validation, foreach='collaborators')

    def evaluate_segmentation_accuracy(self, data_loader):
        self.model.eval()
        total_dice_score = 0
        total_samples = 0
        with torch.no_grad():
            for batch_data, batch_target in data_loader:
                num_samples = batch_target.shape[0]
                total_samples += num_samples                
                output = self.model(batch_data)
                batch_dice_score = soft_dice_coef(output, batch_target)
                total_dice_score += batch_dice_score.sum().cpu().numpy()
    
        if total_samples == 0:
            print("\nValidation set is empty. Returning score as 0.0\n")
            return 0.0
        
        avg_dice_score = total_dice_score / total_samples
        print(f"\nValidation Results: Average Dice Coefficient: {avg_dice_score:.4f} (over {total_samples} samples)\n")
        return avg_dice_score

    @collaborator
    def aggregated_model_validation(self):
        print(f'Performing aggregated model validation for collaborator {self.input}, model: {id(self.model)}')
        self.agg_validation_score = self.evaluate_segmentation_accuracy(self.test_loader)

        self.next(self.train)

    @collaborator
    def train(self):
        # Log after processing a quarter of the samples
        log_threshold = .25

        self.model.train()

        self.optimizer = get_optimizer(self.model)
        for batch_idx, (data, target) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = soft_dice_loss(output, target)
            loss.backward()
            self.optimizer.step()

            if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:
                print('Train Epoch: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(data), len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                self.loss = loss.item()
                log_threshold += .25
                torch.save(self.model.state_dict(), 'model.pth')
                torch.save(self.optimizer.state_dict(), 'optimizer.pth')
            
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        print(f'Performing local model validation for collaborator {self.input}')
        self.local_validation_score = self.evaluate_segmentation_accuracy(self.test_loader)
        print(
            f'Done with local model validation for collaborator {self.input}, Accuracy: {self.local_validation_score}')
        self.next(self.join)

    @aggregator
    def join(self, inputs):
        print(f'joining')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = inputs[0].optimizer
        self.current_round += 1

        self.average_loss = sum(input.loss for input in inputs) / len(inputs)
        self.aggregated_model_accuracy = sum(
            input.agg_validation_score for input in inputs) / len(inputs)
        self.local_model_accuracy = sum(
            input.local_validation_score for input in inputs) / len(inputs)
        print(f'Average aggregated model accuracy = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')

        if self.current_round < self.n_rounds:
            self.next(self.aggregated_model_validation, foreach='collaborators')
        else:
            self.next(self.end)

    @aggregator
    def end(self):
        print(f'Flow ended')

Finally, run the federation:

In [None]:
model = UNet()
flflow = FederatedFlow(model, get_optimizer(model), rounds=30, checkpoint=False)
flflow.runtime = local_runtime
flflow.run()