# Federated PyTorch UNET Tutorial
## Using low-level Python API

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Install dependencies if not already installed
!pip install torch



### Describe the model and optimizer

In [9]:
import torch.nn as nn
import torch.optim as optim
from layers import down, up, double_conv, soft_dice_coef, soft_dice_loss

In [10]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = double_conv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 1024)
        self.up1 = up(1024, 512)
        self.up2 = up(512, 256)
        self.up3 = up(256, 128)
        self.up4 = up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x
    
model_unet = UNet()

In [13]:
optimizer_adam = optim.Adam(model_unet.parameters(), lr=1e-3)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

In [40]:
import os
from hashlib import sha384
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tsf
from skimage import io

In [41]:
os.makedirs('data', exist_ok=True)
!wget -nc 'https://datasets.simula.no/hyper-kvasir/hyper-kvasir-segmented-images.zip' -O ./data/kvasir.zip
ZIP_SHA384 = 'e30d18a772c6520476e55b610a4db457237f151e'\
    '19182849d54b49ae24699881c1e18e0961f77642be900450ef8b22e7'
assert sha384(open('./data/kvasir.zip', 'rb').read(
    os.path.getsize('./data/kvasir.zip'))).hexdigest() == ZIP_SHA384
!unzip -n ./data/kvasir.zip -d ./data

File ‘./data/kvasir.zip’ already there; not retrieving.
Archive:  ./data/kvasir.zip


In [56]:
DATA_PATH = './data/segmented-images/'

def read_data(image_path, mask_path):
    """
    Read image and mask from disk.
    """
    img = io.imread(image_path)
    assert(img.shape[2] == 3)
    mask = io.imread(mask_path)
    return (img, mask[:, :, 0].astype(np.uint8))


class KvasirDataset(Dataset):
    """
    Kvasir dataset contains 1000 images for all collaborators.
    Args:
        data_path: path to dataset on disk
        collaborator_count: total number of collaborators
        collaborator_num: number of current collaborator
        is_validation: validation option
    """

#     def __init__(self, data_path, collaborator_count, collaborator_num, is_validation):
    def __init__(self, images_path = './data/segmented-images/images/', \
                        masks_path = './data/segmented-images/masks/'):

        self.images_path = images_path
        self.masks_path = masks_path
        self.images_names = [img_name for img_name in sorted(listdir(
            self.images_path)) if len(img_name) > 3 and img_name[-3:] == 'jpg']

        assert(len(self.images_names) > 8)
        
        # Prepare transforms
        self.img_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332)),
            tsf.ToTensor(),
            tsf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.mask_trans = tsf.Compose([
            tsf.ToPILImage(),
            tsf.Resize((332, 332), interpolation=PIL.Image.NEAREST),
            tsf.ToTensor()])
        

    def __getitem__(self, index):
        name = self.images_names[index]
        img, mask = read_data(self.images_path + name, self.masks_path + name)
        img = self.img_trans(img).numpy()
        mask = self.mask_trans(mask).numpy()
        return img, mask

    def __len__(self):
        return len(self.images_names)

### Define Federated Learning tasks

In [47]:
def function_defined_in_notebook():
    print('I will cause problems')

    
    
def train(unet_model, train_loader, optimizer, device, loss_fn=soft_dice_loss):
    
    function_defined_in_notebook()
    
    unet_model.train()
    unet_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = pt.tensor(data).to(device), pt.tensor(
            target).to(device, dtype=pt.float32)
        optimizer.zero_grad()
        output = unet_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


def validate(unet_model, val_loader, device):
    unet_model.eval()
    unet_model.to(device)

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = unet_model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}

## Describing FL experiment

In [57]:
from openfl.interface.python_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

  return torch._C._cuda_getDeviceCount() > 0


### Register model

In [60]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_unet, optimizer=optimizer_adam, framework_plugin=framework_adapter)

### Register dataset

In [None]:
k = KvasirDataset

class Federated_dataset(DataInterface):
    def __init__(self, k):
        pass
    
    def init_kvasir(self):
        super().__init__()
        
        pass

In [11]:
class first:
    def __init__(self):
        print('first')
        
class second:
    def __init__(self):
        print('second')
        
        
class third(first, second):
    def __init__(self):
        print('third')
        
    def init(self):
        super().__init__()
        
        
thrd = third()

third


In [12]:
thrd.init()

first
