In [1]:
# Install dependencies if not already installed
!pip install torchvision==0.8.1
!pip -q install vit_pytorch linformer

Looking in indexes: http://proxypip-icv.inn.intel.com:8088/root/pypi/+simple/
Collecting torchvision==0.8.1
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/337/820e680e51938/torchvision-0.8.1-cp38-cp38-manylinux1_x86_64.whl (12.8 MB)
Collecting pillow>=4.1.1
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/a5a/4532a12314149/Pillow-8.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
Collecting torch==1.7.0
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/110/54f26eee5c311/torch-1.7.0-cp38-cp38-manylinux1_x86_64.whl (776.8 MB)
Collecting future
  Using cached future-0.18.2-py3-none-any.whl
Collecting dataclasses
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/454/a69d788c7fda4/dataclasses-0.6-py3-none-any.whl (14 kB)
Installing collected packages: future, dataclasses, torch, pillow, torchvision
Successfully installed dataclasses-0.6 future-0.18.2 pillow-8.4.0 torch-1.7.0 torchvision-0.8.

In [2]:
!pip install matplotlib

Looking in indexes: http://proxypip-icv.inn.intel.com:8088/root/pypi/+simple/
Collecting matplotlib
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/208/9b9014792dcc8/matplotlib-3.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.3 MB)
Collecting fonttools>=4.22.0
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/680/71406009e7ef6/fonttools-4.28.1-py3-none-any.whl (873 kB)
Collecting cycler>=0.10
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/3a2/7e95f763a428a/cycler-0.11.0-py3-none-any.whl (6.4 kB)
Collecting kiwisolver>=1.0.1
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/b6a/5431940f28b6d/kiwisolver-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.2 MB)
Collecting setuptools-scm>=4
  Using cached http://proxypip-icv.inn.intel.com:8088/root/pypi/%2Bf/4c6/4444b1d49c406/setuptools_scm-6.3.2-py3-none-any.whl (33 kB)
Collecting tomli>=1.0.0
  Using cached http://proxypip-icv.inn.int

## Import Libraries

In [3]:
from __future__ import print_function

from itertools import chain
import os
import random
import zipfile

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT

# Connect to the Federation

In [4]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
director_node_fqdn = 'nnlicv901.inn.intel.com'
director_port = 50051

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(
#     client_id=client_id,
#     director_node_fqdn=director_node_fqdn,
#     director_port=director_port,
#     tls=True,
#     cert_chain=cert_chain,
#     api_cert=api_certificate,
#     api_private_key=api_private_key
# )

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

In [5]:
shard_registry = federation.get_shard_registry()
shard_registry

_InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "Unexpected <class 'KeyError'>: 'last_updated'"
	debug_error_string = "{"created":"@1637168761.697088416","description":"Error received from peer ipv4:10.125.90.225:50051","file":"src/core/lib/surface/call.cc","file_line":1062,"grpc_message":"Unexpected <class 'KeyError'>: 'last_updated'","grpc_status":2}"
>

In [None]:
federation.target_shape

['1']

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (300, 300, 3), target shape: (1,)'

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

In [None]:
# Training settings
batch_size = 64
#epochs = 10
lr = 3e-5
gamma = 0.7
seed = 42
device = 'cuda'

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

### Register dataset

In [None]:
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms

class DogsCatsShardDataset(Dataset):
    def __init__(self, dataset, transform_type="train"):
        self._dataset = dataset

        # Image Augumentation
        if transform_type == "train":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize((224, 224)),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]
            )
        elif transform_type == "val":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                ]
            )
        elif transform_type == "test":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                ]
            )
        else:
            raise ValueError("Invalid transform type: {}".format(transform_type))


    def __len__(self):
        self.filelength = len(self._dataset)
        return self.filelength

    def __getitem__(self, idx):
        img, label = self._dataset[idx]
        img_transformed = self.transform(img).numpy()
        return img_transformed, label[0]

# Now you can implement you data loaders using dummy_shard_desc
class DogsCatsSD(DataInterface):

    def __init__(self, validation_fraction=1/5, **kwargs):
        super().__init__(**kwargs)
        
        self.validation_fraction = validation_fraction
        
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self._shard_dataset = DogsCatsShardDataset(shard_descriptor.get_dataset('train'))

        validation_size = max(1, int(len(self._shard_dataset) * self.validation_fraction))
        
        self.train_indeces = np.arange(len(self._shard_dataset) - validation_size)
        self.val_indeces = np.arange(len(self._shard_dataset) - validation_size, len(self._shard_dataset))

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        train_sampler = SubsetRandomSampler(self.train_indeces)

        return DataLoader(
            self._shard_dataset,
            num_workers=8,
            batch_size=self.kwargs['train_bs'],
            sampler=train_sampler
        )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        val_sampler = SubsetRandomSampler(self.val_indeces)
        return DataLoader(
            self._shard_dataset,
            num_workers=8,
            batch_size=self.kwargs['valid_bs'],
            sampler=val_sampler
        )

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_indeces)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.val_indeces)

In [None]:
fed_dataset = DogsCatsSD(train_bs=batch_size, valid_bs=batch_size)
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    print(sample.shape)

torch.Size([8, 3, 224, 224])


### Describe a model and optimizer

#### Linformer

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

#### Visual Transformer

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
)

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

#### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

### Define and register FL tasks

In [None]:
TI = TaskInterface()

import tqdm
from openfl.component.aggregation_functions import Median

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

#The Interactive API supports overriding of the aggregation function
aggregation_function = Median()

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
@TI.set_aggregation_function(aggregation_function)
def train(model, train_loader, optimizer, device, loss_fn=criterion, some_parameter=None):
    function_defined_in_notebook(some_parameter)
    epoch_loss = 0
    epoch_accuracy = 0  
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    model.train()
    model.to(device)

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(target).to(device, dtype=torch.long)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == target).float().mean()
        epoch_accuracy += acc.cpu().numpy() / len(train_loader)
        epoch_loss += loss.detach().cpu().numpy() / len(train_loader)
        
    return {'loss' : epoch_loss, 'accuracy': epoch_accuracy}

@TI.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device):
   
    model.eval()
    model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, target in val_loader:
            data, target = torch.tensor(data).to(device), torch.tensor(target).to(device, dtype=torch.long)
            val_output = model(data)
            val_loss = criterion(val_output, target)

            acc = (val_output.argmax(dim=1) == target).float().mean()
            epoch_val_accuracy += acc.cpu().numpy() / len(val_loader)
            epoch_val_loss += val_loss.detach().cpu().numpy() / len(val_loader)

    return {'val_loss' : epoch_val_loss, 'val_accuracy': epoch_val_accuracy}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'ViT_DogsCats_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=50,
                    opt_treatment='CONTINUE_GLOBAL',
                    device_assignment_policy='CUDA_PREFERRED')

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()