# Visual Transformer + OpenFL for Dogs & Cats classification

In [None]:
# Install dependencies if not already installed
!pip install -r requirements.txt

## Import Libraries

In [None]:
import os
import random
from copy import deepcopy

from linformer import Linformer

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR

from torchvision import transforms

import tqdm

from vit_pytorch.efficient import ViT

# Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
director_node_fqdn = 'localhost'
director_port = 50051

# 1) Run with API layer - Director mTLS
# If the user wants to enable mTLS their must provide CA root chain,
# and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(
#     client_id=client_id,
#     director_node_fqdn=director_node_fqdn,
#     director_port=director_port,
#     tls=True,
#     cert_chain=cert_chain,
#     api_cert=api_certificate,
#     api_private_key=api_private_key
# )

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import DataInterface, FLExperiment, ModelInterface, TaskInterface

In [None]:
# Training settings
batch_size = 64
lr = 3e-5
seed = 42

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed)

### Register dataset

In [None]:
class DogsCatsShardDataset(Dataset):
    def __init__(self, dataset, transform_type="train"):
        self._dataset = dataset

        # Image Augumentation
        if transform_type == "train":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize((224, 224)),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]
            )
        elif transform_type == "val":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                ]
            )
        elif transform_type == "test":
            self.transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                ]
            )
        else:
            raise ValueError("Invalid transform type: {}".format(transform_type))

    def __len__(self):
        self.filelength = len(self._dataset)
        return self.filelength

    def __getitem__(self, idx):
        img, label = self._dataset[idx]
        img_transformed = self.transform(img).numpy()
        return img_transformed, label[0]


# Now you can implement your data loaders using dummy_shard_desc
class DogsCatsSD(DataInterface):

    def __init__(self, validation_fraction=1/5, **kwargs):
        super().__init__(**kwargs)

        self.validation_fraction = validation_fraction

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self._shard_dataset = DogsCatsShardDataset(shard_descriptor.get_dataset('train'))

        validation_size = max(1, int(len(self._shard_dataset) * self.validation_fraction))

        self.train_indexes = np.arange(len(self._shard_dataset) - validation_size)
        self.val_indexes = np.arange(len(self._shard_dataset) - validation_size, len(self._shard_dataset))

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        train_sampler = SubsetRandomSampler(self.train_indexes)

        return DataLoader(
            self._shard_dataset,
            num_workers=8,
            batch_size=self.kwargs['train_bs'],
            sampler=train_sampler
        )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        val_sampler = SubsetRandomSampler(self.val_indexes)
        return DataLoader(
            self._shard_dataset,
            num_workers=8,
            batch_size=self.kwargs['valid_bs'],
            sampler=val_sampler
        )

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_indexes)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.val_indexes)

In [None]:
fed_dataset = DogsCatsSD(train_bs=batch_size, valid_bs=batch_size)
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    print(sample.shape)

### Describe a model and optimizer

#### Linformer

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

#### Visual Transformer

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
)

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

#### Register model

In [None]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

### Define and register FL tasks

In [None]:
from openfl.interface.aggregation_functions import Median


TI = TaskInterface()


# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')


# The Interactive API supports overriding of the aggregation function
aggregation_function = Median()


# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='model', data_loader='train_loader',
                     device='device', optimizer='optimizer', round_num='round_num')
@TI.set_aggregation_function(aggregation_function)
def train(model, train_loader, optimizer, round_num, device, loss_fn=criterion, some_parameter=None):
    function_defined_in_notebook(some_parameter)
    epoch_loss = 0
    epoch_accuracy = 0

    # Be careful at the scheduler initialization stage makes 'step()', that's why: 
    # * if you have one epoch per round DO NOT do 'scheduler.step()' at all.
    # * if you have several epoch per round, makes 'scheduler.step()' for all of them EXCEPT the last one.
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1, verbose=True, last_epoch=round_num-1)
    train_loader = tqdm.tqdm(train_loader, desc="train")
    model.train()
    model.to(device)

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(target).to(device, dtype=torch.long)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == target).float().mean()
        epoch_accuracy += acc.cpu().numpy() / len(train_loader)
        epoch_loss += loss.detach().cpu().numpy() / len(train_loader)

    return {'loss': epoch_loss, 'accuracy': epoch_accuracy}


@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device):

    model.eval()
    model.to(device)

    val_loader = tqdm.tqdm(val_loader, desc="validate")

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, target in val_loader:
            data, target = torch.tensor(data).to(device), torch.tensor(target).to(device, dtype=torch.long)
            val_output = model(data)
            val_loss = criterion(val_output, target)

            acc = (val_output.argmax(dim=1) == target).float().mean()
            epoch_val_accuracy += acc.cpu().numpy() / len(val_loader)
            epoch_val_loss += val_loss.detach().cpu().numpy() / len(val_loader)

    return {'val_loss': epoch_val_loss, 'val_accuracy': epoch_val_accuracy}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'ViT_DogsCats_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI,
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=5,
                    opt_treatment='CONTINUE_GLOBAL',
                    device_assignment_policy='CUDA_PREFERRED')

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()