# Federated PyTorch UNET Tutorial
## Using low-level Python API

In [1]:
# Install dependencies if not already installed
!pip install torchvision==0.8.1
!pip install scikit-image

You should consider upgrading via the '/home/davidyuk/.virtualenvs/openfl/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/davidyuk/.virtualenvs/openfl/bin/python -m pip install --upgrade pip' command.[0m


### Describe the model and optimizer

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

In [12]:
"""
UNet model definition
"""
from layers import soft_dice_coef, soft_dice_loss, DoubleConv, Down, Up

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x
    
    # Lookup by parent module (as in current inspect)
    if hasattr(object, '__module__'):
        object_ = sys.modules.get(object.__module__)
        if hasattr(object_, '__file__'):
            return object_.__file__
    
    # If parent module is __main__, lookup by methods (NEW)
    for name, member in inspect.getmembers(object):
        if inspect.isfunction(member) and object.__qualname__ + '.' + member.__name__ == member.__qualname__:
            return inspect.getfile(member)
    else:
        raise TypeError('Source for {!r} not found'.format(object))
inspect.getfile = new_getfile

obj = UNet
cell_code = "".join(inspect.linecache.getlines(new_getfile(obj)))
class_code = extract_symbols(cell_code, obj.__name__)[0][0]
print(class_code)

In [14]:
optimizer_adam = optim.Adam(model_unet.parameters(), lr=1e-4)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

In [15]:
import os
from hashlib import sha384
import PIL
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tsf

In [None]:
os.makedirs('data', exist_ok=True)
!wget -nc 'https://datasets.simula.no/hyper-kvasir/hyper-kvasir-segmented-images.zip' -O ./data/kvasir.zip
ZIP_SHA384 = 'e30d18a772c6520476e55b610a4db457237f151e'\
    '19182849d54b49ae24699881c1e18e0961f77642be900450ef8b22e7'
assert sha384(open('./data/kvasir.zip', 'rb').read(
    os.path.getsize('./data/kvasir.zip'))).hexdigest() == ZIP_SHA384
!unzip -n ./data/kvasir.zip -d ./data

In [10]:
rank_worldsize='3,4'
(int(num) for num in rank_worldsize.split(','))

<generator object <genexpr> at 0x7f616dce3f90>

In [40]:
DATA_PATH = './data/segmented-images/'
import numpy as np

def read_data(image_path, mask_path):
    """
    Read image and mask from disk.
    """
    img = Image.open(image_path)
    img = np.asarray(img)
    assert(img.shape[2] == 3)
    mask = Image.open(mask_path)
    mask = np.asarray(mask)
    return (img, mask[:, :, 0].astype(np.uint8))


class KvasirDataset(Dataset):
    """
    Kvasir dataset contains 1000 images for all collaborators.
    Args:
        data_path: path to dataset on disk
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
    """

    def __init__(self, images_path = './kvasir_data/segmented-images/images/', \
                        masks_path = './kvasir_data/segmented-images/masks/',
                        validation_fraction=1/8, is_validation=False):

        self.images_path = images_path
        self.masks_path = masks_path
        self.images_names = [
            img_name
            for img_name in sorted(os.listdir(self.images_path))
            if len(img_name) > 3 and img_name[-3:] == 'jpg'
        ]

        assert(len(self.images_names) > 2), "Too few images"
        
        validation_size = max(1, int(len(self.images_names) * validation_fraction))
        
        if is_validation:
            self.images_names = self.images_names[-validation_size :]
        else:
            self.images_names = self.images_names[: -validation_size]

        # Prepare transforms
        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])
        

    def __getitem__(self, index):
        name = self.images_names[index]
        img, mask = read_data(self.images_path + name, self.masks_path + name)
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.images_names)

In [41]:
img = Image.open('./kvasir_data/segmented-images/images/cju0qkwl35piu0993l0dewei2.jpg')
img.resize()

In [42]:
dset = KvasirDataset()
dset[0][0]



array([[[-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        ...,
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.]],

       [[-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        ...,
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.]],

       [[-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        ...,
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.],
        [-1., -1., -1., ..., -1., -1., -1.]]], dtype=float32)

### Define Federated Learning tasks

In [7]:
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss):
    
    function_defined_in_notebook()
    
    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = unet_model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}

## Describing FL experiment

In [1]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

In [2]:
from torch.utils.data import Dataset, DataLoader
class A(Dataset, DataInterface):
    pass
a = A()
# dir(a)

### Register model

In [9]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_unet, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_unet)

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [10]:
class UserDataset:
    def __init__(self, path_to_local_data):
        print(f'User Dataset initialized with {path_to_local_data}')
        
        
class OpenflMixin:   
    def _delayed_init(self):
        raise NotImplementedError
        
        
class FedDataset(OpenflMixin):
    def __init__(self, UserDataset):
        self.user_dataset_class = UserDataset
        print('We implement all abstract methods from mixin in this class')
        
    def _delayed_init(self, data_path):
        print('This method is called on the collaborator node')
        dataset_obj = self.user_dataset_class(data_path)
        
        
fed_dataset = FedDataset(UserDataset)
fed_dataset._delayed_init('data path on the collaborator node')

We implement all abstract methods from mixin in this class
This method is called on the collaborator node
User Dataset initialized with data path on the collaborator node


In [11]:
# class FedDataset(DataInterface):
#     def __init__(self, UserDatasetClass, **kwargs):
#         self.UserDatasetClass = UserDatasetClass
#         self.kwargs = kwargs
    
#     def _delayed_init(self, data_path='1,1'):
#         # With the next command the local dataset will be loaded on the collaborator node
#         # For this example we have the same dataset on the same path, and we will shard it
#         # So we use `data_path` information for this purpose.
#         self.rank, self.world_size = [int(part) for part in data_path.split(',')]
        
#         validation_fraction=1/8
#         self.train_set = self.UserDatasetClass(validation_fraction=validation_fraction, is_validation=False)
#         self.valid_set = self.UserDatasetClass(validation_fraction=validation_fraction, is_validation=True)
        
#         # Do the actual sharding
#         self._do_sharding( self.rank, self.world_size)
        
#     def _do_sharding(self, rank, world_size):
#         # This method relies on the dataset's implementation
#         # i.e. coupled in a bad way
#         self.train_set.images_names = self.train_set.images_names[ rank-1 :: world_size ]

#     def get_train_loader(self, **kwargs):
#         """
#         Output of this method will be provided to tasks with optimizer in contract
#         """
#         return DataLoader(
#             self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
#             )

#     def get_valid_loader(self, **kwargs):
#         """
#         Output of this method will be provided to tasks without optimizer in contract
#         """
#         return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

#     def get_train_data_size(self):
#         """
#         Information for aggregation
#         """
#         return len(self.train_set)

#     def get_valid_data_size(self):
#         """
#         Information for aggregation
#         """
#         return len(self.valid_set)
    
# fed_dataset = FedDataset(KvasirDataset, train_bs=8, valid_bs=8)

### Register tasks

In [None]:
TI = TaskInterface()
import torch

import tqdm

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='unet_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@TI.register_fl_task(model='unet_model', data_loader='val_loader', device='device')     
def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = unet_model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}

## Time to start a federated learning experiment

In [1]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
cert_chain = 'cert/cert_chain.crt'
API_certificate = 'cert/API_certificate.crt'
API_private_key = 'cert/API_private.key'

federation = Federation(director_node_fqdn='some.fqdn', disable_tls=False,
                       cert_chain=cert_chain, API_cert=API_certificate, API_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(director_node_fqdn='localhost', disable_tls=True)

In [2]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)

(300, 300, 3) (6,) 10


In [None]:
"1,6"

In [17]:
# Now you can implement you data loaders using dummy_shard_desc
class KvasirSD(DataInterface, Dataset):

    def __init__(self, shard_descriptor, validation_fraction=1/8, is_validation=False, **kwargs):
        super().__init__(**kwargs)

        self.shard_descriptor = shard_descriptor
        
        self.validation_fraction = validation_fraction
        
        validation_size = max(1, int(len(self.shard_descriptor) * validation_fraction))
        self.train_indeces = list(range())
        
#         if is_validation:
#             self.images_names = self.images_names[-validation_size :]
#         else:
#             self.images_names = self.images_names[: -validation_size]
        
        # Prepare transforms
        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])
        

    def __getitem__(self, index):
        img, mask = self.shard_descriptor[index]
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.shard_descriptor)
    
    
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

NameError: name 'DataInterface' is not defined

SyntaxError: invalid syntax (<ipython-input-22-9778031892f4>, line 1)

# Long-Living entities update

* We now may have director running on another machine.
* We use Federation API to communicate with Director.
* Federation object should hold a Director's client (for user service)
* Keeping in mind that several API instances may be connacted to one Director.


* We do not think for now how we start a Director.
* But it knows the data shape and target shape for the DataScience problem in the Federation.
* Director holds the list of connected envoys, we do not need to specify it anymore.
* Director and Envoys are responsible for encrypting connections, we do not need to worry about certs.


* Yet we MUST have a cert to communicate to the Director.
* We MUST know the FQDN of a Director.
* Director communicates data and target shape to the Federation interface object.


* Experiment API may use this info to construct a dummy dataset and a `shard descriptor` stub.

#### Certification of an aggregator
* fx workspace certify: creates cert folder and CA as well as cert_chain
* fx aggregator generate-cert-request --fqdn `FQDN`: you can pass a specific aggregator FQDN if you want
* fx aggregator certify --fqdn `FQDN` --silent: signes aggregators cert
<br> After that just pass the paths to required certs to the Federation API

#### Certification of a collaborator
just follow the usual procedure: <br>
fx collaborator generate-cert-request -d {DATA_PATH} -n {COL} 

fx collaborator certify --request-pkg {COL_DIRECTORY}/{FED_WORKSPACE}/col_{COL}_to_agg_cert_request.zip

fx collaborator certify --import {FED_DIRECTORY}/agg_to_col_{COL}_signed_cert.zip

In [2]:
# Interface generator
class IFO:
    def __init__(self):
        self.a = 1
    
class EXP:
    def __init__(self):
        self.ifo = IFO()
        
    def get_ifo(self):
        return self.ifo
        
exp = EXP()
ifo = exp.get_ifo()
ifo.a = 2

# print(exp.ifo.a)
exp2 = exp.__class__()
exp.ifo.a

2

In [None]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation,)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.prepare_workspace_distribution(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=7, \
                              opt_treatment='CONTINUE_GLOBAL')
# This command starts the aggregator server. You can pass 'metric' log_level
fl_experiment.start_experiment(model_provider=MI, log_level='INFO', log_file='federation.log')

# When the aggregator server blocks the notebook one can start collaborators
# For the test run just type console command from the workspace directory:
# `fx collaborator start -d data.yaml -n {col_name}` for all collaborators
# For the distributed experiment transfer zipped workspace to the collaborator nodes and run
# `fx workspace import --archive {workspace_name}.zip` cd to the workspace and start collaborators

## Now we validate the best model!

In [None]:
best_model = fl_experiment.get_best_model()

In [None]:
fed_dataset._delayed_init()

In [None]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

## We can tune model further!

In [None]:
MI = ModelInterface(model=best_model, optimizer=optimizer_adam, framework_plugin=framework_adapter)
fl_experiment.start_experiment(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=4, \
                              opt_treatment='CONTINUE_GLOBAL')

In [None]:
best_model = fl_experiment.get_best_model()
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

In [40]:
a = (np.zeros((2,4)), np.ones(2,), 2*np.ones(2,))
a = [elem for elem in a if elem.shape==2 else elem.newaxis()]
np.concatenate(a)

SyntaxError: invalid syntax (<ipython-input-40-b0f7ed81de85>, line 2)

In [15]:
class A:
    def __init__(self, **kwargs):
        print("Class A", kwargs)
        
class B:
    def __init__(self, **kwargs):
        super().__init__()
        print("Class B", kwargs)
        
class C(B, A):
    def __init__(self, **kwargs):
        super().__init__()
        print("Class C", kwargs)
        
# class A:
#     def __init__(self, **kwargs):
#         print("Class A", kwargs)

In [16]:
c = C(x=1, z=5)

Class A {}
Class B {}
Class C {'x': 1, 'z': 5}
