# Federated PyTorch UNET Tutorial
## Using low-level Python API

In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
# Install dependencies if not already installed
!pip install torchvision
!pip install scikit-image
!pip install dill



### Describe the model and optimizer

In [3]:
import torch.nn as nn
import torch.optim as optim
from model import UNet, soft_dice_loss, soft_dice_coef

In [4]:
"""
A pytorch
"""
model_unet = UNet()

In [5]:
optimizer_adam = optim.Adam(model_unet.parameters(), lr=1e-3)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

In [6]:
import os
from hashlib import sha384
import PIL
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tsf
from skimage import io

In [7]:
os.makedirs('data', exist_ok=True)
!wget -nc 'https://datasets.simula.no/hyper-kvasir/hyper-kvasir-segmented-images.zip' -O ./data/kvasir.zip
ZIP_SHA384 = 'e30d18a772c6520476e55b610a4db457237f151e'\
    '19182849d54b49ae24699881c1e18e0961f77642be900450ef8b22e7'
assert sha384(open('./data/kvasir.zip', 'rb').read(
    os.path.getsize('./data/kvasir.zip'))).hexdigest() == ZIP_SHA384
!unzip -n ./data/kvasir.zip -d ./data

File ‘./data/kvasir.zip’ already there; not retrieving.
Archive:  ./data/kvasir.zip


In [8]:
DATA_PATH = './data/segmented-images/'
import numpy as np

def read_data(image_path, mask_path):
    """
    Read image and mask from disk.
    """
    img = io.imread(image_path)
    assert(img.shape[2] == 3)
    mask = io.imread(mask_path)
    return (img, mask[:, :, 0].astype(np.uint8))


class KvasirDataset(Dataset):
    """
    Kvasir dataset contains 1000 images for all collaborators.
    Args:
        data_path: path to dataset on disk
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
    """

#     def __init__(self, data_path, collaborator_count, collaborator_num, is_validation):
    def __init__(self, images_path = './data/segmented-images/images/', \
                        masks_path = './data/segmented-images/masks/',
                        validation_fraction=1/8, is_validation=False):

        self.images_path = images_path
        self.masks_path = masks_path
        self.images_names = [img_name for img_name in sorted(os.listdir(
            self.images_path)) if len(img_name) > 3 and img_name[-3:] == 'jpg']

        assert(len(self.images_names) > 2), "Too few images"
        
        validation_size = max(1, int(len(self.images_names) * validation_fraction))
        
        if is_validation:
            self.images_names = self.images_names[-validation_size :]
        else:
            self.images_names = self.images_names[: -validation_size]
        
        # Prepare transforms
        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])
        

    def __getitem__(self, index):
        name = self.images_names[index]
        img, mask = read_data(self.images_path + name, self.masks_path + name)
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.images_names)

### Define Federated Learning tasks

In [9]:
def function_defined_in_notebook():
    print('I will cause problems')

    
    
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss):
    
    function_defined_in_notebook()
    
    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = unet_model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}

## Describing FL experiment

In [10]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register model

In [11]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_unet, optimizer=optimizer_adam, framework_plugin=framework_adapter)

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [12]:
class UserDataset:
    def __init__(self, path_to_local_data):
        print(f'User Dataset initialized with {path_to_local_data}')
        
        
class OpenflMixin:   
    def _delayed_init(self):
        raise NotImplementedError
        
        
class FedDataset(UserDataset, OpenflMixin):
    def __init__(self):
        print('We implement all abstract methods from mixin in this class')
        
    def _delayed_init(self, data_path):
        print('This method is called on the collaborator node')
        super().__init__(data_path)
        
        
fed_dataset = FedDataset()
fed_dataset._delayed_init('data path on the collaborator node')

We implement all abstract methods from mixin in this class
This method is called on the collaborator node
User Dataset initialized with data path on the collaborator node


In [13]:
class FedDataset(DataInterface):
    def __init__(self, UserDatasetClass, **kwargs):
        self.UserDatasetClass = UserDatasetClass
        self.kwargs = kwargs
    
    def _delayed_init(self, data_path='1,1'):
        # With the next command the local dataset will be loaded on the collaborator node
        # For this example we have the same dataset on the same path, and we will shard it
        # So we use `data_path` information for this purpose.
        self.rank, self.world_size = [int(part) for part in data_path.split(',')]
        
        validation_fraction=1/8
        self.train_set = self.UserDatasetClass(validation_fraction=validation_fraction, is_validation=False)
        self.valid_set = self.UserDatasetClass(validation_fraction=validation_fraction, is_validation=True)
        
        # Do the actual sharding
        self._do_sharding( self.rank, self.world_size)
        
    def _do_sharding(self, rank, world_size):
        # This method relies on the dataset's implementation
        # i.e. coupled in a bad way
        self.train_set.images_names = self.train_set.images_names[ rank-1 :: world_size ]
        self.valid_set.images_names = self.valid_set.images_names[ rank-1 :: world_size ]
        

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    
fed_dataset = FedDataset(KvasirDataset, train_bs=8, valid_bs=16)

### Register tasks

In [14]:
TI = TaskInterface()
import torch

import tqdm

def function_defined_in_notebook(some_parameter):
    print('I will cause problems')
    print(f'Also I accept a parameter and it is {some_parameter}')

# We do not actually need to register additional kwargs, Just serialize them
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='unet_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@TI.register_fl_task(model='unet_model', data_loader='val_loader', device='device')     
def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = unet_model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}


# @TI.register_fl_task(model='unet_model', data_loader='val_loader', device='device')     
# def test_task(np_array):
#     linear = nn.Linear(10, 5)
#     return linear(torch.tensor(np_array, dtype=torch.float))


## Time to start a federated learning experiment

In [15]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation
# will determine fqdn by itself
federation = Federation(central_node_fqdn='localhost', disable_tls=True)
col_data_paths = {'one': '1,2',
                'two': '2,2'}
federation.register_collaborators(col_data_paths=col_data_paths)

In [16]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation)

In [17]:
# If I use autoreload I got a pickling error
fl_experiment.start_experiment(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=5, \
                              opt_treatment='CONTINUE_GLOBAL')

tried to remove tensor: __opt_state_needed not present in the tensor dict
tried to remove tensor: __opt_state_needed not present in the tensor dict


Aggregator port = 50992


gRPC is running on insecure channel with TLS disabled.


In [18]:
fl_experiment.plan.config['tasks']

{'defaults': None,
 'settings': {},
 'train': {'function': 'train', 'kwargs': {'some_parameter': 42}},
 'localy_tuned_model_validate': {'function': 'validate',
  'kwargs': {'apply': 'local'}},
 'aggregated_model_validate': {'function': 'validate',
  'kwargs': {'apply': 'global'}}}

In [19]:
fl_experiment.plan.config['assigner']['settings']['task_groups'][0]['tasks']

['train', 'localy_tuned_model_validate', 'aggregated_model_validate']

In [20]:
fl_experiment.plan.Build(fl_experiment.plan.config['api_layer']['required_plugin_components']['serializer_plugin'], {})

<openfl.plugins.interface_serializer.dill_serializer.Dill_Serializer at 0x7f46be0f5310>

In [21]:
fl_experiment.plan.config

{'aggregator': {'template': 'openfl.component.Aggregator',
  'settings': {'db_store_rounds': 1,
   'init_state_path': 'save/init.pbuf',
   'best_state_path': 'save/best.pbuf',
   'last_state_path': 'save/last.pbuf',
   'rounds_to_train': 5,
   'aggregator_uuid': 'aggregator_plan.yaml_b08d15c8',
   'federation_uuid': 'plan.yaml_b08d15c8',
   'authorized_cols': ['one', 'two'],
   'assigner': <openfl.component.assigner.random_grouped_assigner.RandomGroupedAssigner at 0x7f45aafe0f50>,
   'defaults': 'plan/defaults/aggregator.yaml',
   'initial_tensor_dict': {'inc.conv.0.weight': array([[[[-0.15275218,  0.00385852, -0.02395825],
             [ 0.08870542, -0.11938288, -0.02559398],
             [-0.0622753 ,  0.05740096,  0.14895089]],
    
            [[-0.18637218, -0.17971292, -0.17063731],
             [ 0.08530347,  0.10000641,  0.12413657],
             [-0.02264823, -0.18337001, -0.11606432]],
    
            [[-0.11480568,  0.04559895,  0.04251714],
             [-0.06672994, -0.09

In [22]:
yaml.safe_dump()

NameError: name 'yaml' is not defined

In [None]:
import dill
with open('./tasks.pkl', 'wb') as f:
    dill.dump(TI, f,recurse=True)

In [None]:
import numpy as np
arr = np.arange(0,10)
test_task(arr)

In [None]:
import dill
with open('./model.pkl', 'wb') as f:
    dill.dump(MI, f, recurse=True)
# Pickling class    
# with open('./model_cls.pkl', 'wb') as f:
#     dill.dump(UNet, f, recurse=True)

In [None]:
UNet.__module__

In [None]:
import torch
a = torch.rand(1,3,128,128)
model_unet.forward(a)

In [None]:
import dill
with open('./dataloader.pkl', 'wb') as f:
    dill.dump(fed_dataset, f, recurse=True)

In [None]:
with open('./data.yaml', 'w') as f:
    for col_name, data_path in {'one': '1,2',
            'two': '2,2'}.items():
        f.write(f'{col_name},{data_path}\n')

In [None]:
task_contract = dict()
task_contract['optimizer'] = 1

In [None]:
validation = True if task_contract['optimizer'] is not None else False
validation