# Federated PyTorch TinyImageNet Tutorial
## Using low-level Python API

In [None]:
!pip install torchvision==0.8.1

## Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
# cert_dir = '/home/amokrov/test/pki1/api/cert'
# director_node_fqdn = 'nnlicv656.inn.intel.com'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn='localhost', director_port='50051', tls=False)


In [None]:
shard_registry = federation.get_shard_registry()
shard_registry
federation.target_shape

In [None]:
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc[0]

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

In [None]:
import torchvision
from torchvision import transforms as T

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

augmentation = T.RandomApply([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.RandomResizedCrop(64)], p=.8)

training_transform = T.Compose([
    T.Lambda(lambda x: x.convert("RGB")),
    augmentation,
    T.ToTensor(),
    normalize])

valid_transform = T.Compose([
    T.Lambda(lambda x: x.convert("RGB")),
    T.ToTensor(),
    normalize])




In [None]:
from torch.utils.data import Dataset


class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        return self.transform_image(img), label

    def transform_image(self, img):
        return self.transform(img) if self.transform else img


In [None]:
class TinyImageNetDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=training_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            tranform=valid_transform
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

In [None]:
fed_dataset = TinyImageNetDataset(train_bs=8, valid_bs=8)
# fed_dataset.shard_descriptor = dummy_shard_desc
# for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
#     print(sample.shape)

### Describe the model and optimizer

In [None]:
import os
import glob
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
"""
MobileNetV2 model
"""

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model.requires_grad_(False)
        self.model.classifier[1] = torch.nn.Linear(in_features=1280, \
                        out_features=200, bias=True)

    def forward(self, x):
        x = self.model.forward(x)
        return x

model_net = Net()

In [None]:
params_to_update = []
for param in model_net.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        
optimizer_adam = optim.Adam(params_to_update, lr=1e-4)

def cross_entropy(output, target):
    """Binary cross-entropy metric
    """
    return F.cross_entropy(input=output,target=target)

### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

## Define and register FL tasks

In [None]:
task_interface = TaskInterface()
import torch

import tqdm

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device) 
        optimizer.zero_grad()
        output = net_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            pred = output.argmax(dim=1,keepdim=True)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'tinyimagenet_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=model_interface, 
                    task_keeper=task_interface,
                    data_loader=fed_dataset,
                    rounds_to_train=10,
                    opt_treatment='CONTINUE_GLOBAL')

In [None]:
"""
MobileNetV2 model
"""

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model.requires_grad_(False)
        self.model.classifier[1] = torch.nn.Linear(in_features=1280, \
                        out_features=200, bias=True)

    def forward(self, x):
        x = self.model.forward(x)
        return x

model_net = Net()

In [None]:
params_to_update = []
for param in model_net.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        
optimizer_adam = optim.Adam(params_to_update, lr=1e-4)

def cross_entropy(output, target):
    """Binary cross-entropy metric
    """
    return F.cross_entropy(input=output,target=target)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

### Define Federated Learning tasks

In [None]:
def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy):
    
    function_defined_in_notebook()
    
    net_model.train()
    net_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = net_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


def validate(net_model, val_loader, device):
    net_model.eval()
    net_model.to(device)

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            pred = output.argmax(dim=1,keepdim=True)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Describing FL experiment

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

### Register dataset

In [None]:
class UserDataset:
    def __init__(self, path_to_local_data):
        print(f'User Dataset initialized with {path_to_local_data}')
        
        
class OpenflMixin:   
    def _delayed_init(self):
        raise NotImplementedError
        
        
class FedDataset(OpenflMixin):
    def __init__(self, UserDataset):
        self.user_dataset_class = UserDataset
        print('We implement all abstract methods from mixin in this class')
        
    def _delayed_init(self, data_path):
        print('This method is called on the collaborator node')
        dataset_obj = self.user_dataset_class(data_path)
        
        
fed_dataset = FedDataset(UserDataset)
fed_dataset._delayed_init('data path on the collaborator node')

In [None]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

augmentation = T.RandomApply([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.RandomResizedCrop(64)], p=.8)

training_transform = T.Compose([
    T.Lambda(lambda x: x.convert("RGB")),
    augmentation,
    T.ToTensor(),
    normalize])

valid_transform = T.Compose([
    T.Lambda(lambda x: x.convert("RGB")),
    T.ToTensor(),
    normalize])

In [None]:
class FedDataset(DataInterface):
    def __init__(self, UserDatasetClass, **kwargs):
        self.UserDatasetClass = UserDatasetClass
        self.kwargs = kwargs
    
    def _delayed_init(self, data_path='1,1'):
        # With the next command the local dataset will be loaded on the collaborator node
        # For this example we have the same dataset on the same path, and we will shard it
        # So we use `data_path` information for this purpose.
        self.rank, self.world_size = [int(part) for part in data_path.split(',')]
        
        self.train_set = self.UserDatasetClass(is_validation=False, 
                                               transform=training_transform)
        self.valid_set = self.UserDatasetClass(is_validation=True, 
                                               transform=valid_transform)
        
        # Do the actual sharding
        self._do_sharding( self.rank, self.world_size)
        
    def _do_sharding(self, rank, world_size):
        # This method relies on the dataset's implementation
        # i.e. coupled in a bad way
        self.train_set.image_paths = self.train_set.image_paths[ rank-1 :: world_size ]

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    
fed_dataset = FedDataset(TinyImageNetDataset, train_bs=8, valid_bs=8)

### Register tasks

In [None]:
TI = TaskInterface()
import torch

import tqdm

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device) 
        optimizer.zero_grad()
        output = net_model(data)
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@TI.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            pred = output.argmax(dim=1,keepdim=True)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Time to start a federated learning experiment

In [None]:
# # Create a federation
from openfl.interface.interactive_api.federation import Federation

# 1) Run with aggregator-collaborator mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
cert_chain = 'cert/cert_chain.crt'
agg_certificate = 'cert/agg_certificate.crt'
agg_private_key = 'cert/agg_private.key'

federation = Federation(central_node_fqdn='some.fqdn',
                       cert_chain=cert_chain, agg_certificate=agg_certificate, agg_private_key=agg_private_key)
col_data_paths = {'one': '1,1',}
federation.register_collaborators(col_data_paths=col_data_paths)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(central_node_fqdn='localhost', tls=False)
# First number which is a collaborators rank may also be passed as a cuda device identifier
col_data_paths = {'one': '1,2',
                'two': '2,2'}
federation.register_collaborators(col_data_paths=col_data_paths)

#### Certification of an aggregator
* fx workspace certify: creates cert folder and CA as well as cert_chain
* fx aggregator generate-cert-request --fqdn `FQDN`: you can pass a specific aggregator FQDN if you want
* fx aggregator certify --fqdn `FQDN` --silent: signes aggregators cert
<br> After that just pass the paths to required certs to the Federation API

#### Certification of a collaborator
just follow the usual procedure: <br>
fx collaborator generate-cert-request -d {DATA_PATH} -n {COL} 

fx collaborator certify --request-pkg {COL_DIRECTORY}/{FED_WORKSPACE}/col_{COL}_to_agg_cert_request.zip

fx collaborator certify --import {FED_DIRECTORY}/agg_to_col_{COL}_signed_cert.zip

In [None]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation,)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.prepare_workspace_distribution(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=2, \
                              opt_treatment='CONTINUE_GLOBAL')


In [None]:
# This command starts the aggregator server
fl_experiment.start_experiment(model_provider=MI)

# When the aggregator server blocks the notebook one can start collaborators
# For the test run just type console command from the workspace directory:
# `fx collaborator start -d data.yaml -n {col_name}` for all collaborators
# For the distributed experiment transfer zipped workspace to the collaborator nodes and run
# `fx workspace import --archive {workspace_name}.zip` cd to the workspace and start collaborators

## Now we validate the best model!

In [None]:
best_model = fl_experiment.get_best_model()

In [None]:
fed_dataset._delayed_init()

In [None]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')