# Workflow Interface 102: 
# Vision Transformer for Image Classification using MedMNIST
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Vision_Transformer/Workflow_Interface_102_Vision_Transformer.ipynb)

Introduced in the seminal paper "Attention is All you Need" transformers have revolutionized natural language processing by using self-attention mechanisms to capture global dependencies in textual data. Leveraging this, Dosovitskiy et al. introduced the one of the first successful and empirically validated pure transformer model for image classification in [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929v2). 


| <img src="images/vision_transformer.png" width="512"> | 
|:--:| 
| *[source](https://arxiv.org/abs/2010.11929v2)* |

In contrast to tradition convolutional neural networks which focus on capturing local image features within a spatial window using a sliding filter, the self-attention mechanism enables vision transformers to capture global relationships between image patches. 

In this tutorial, you will learn how to set up a horizontal federated learning workflow using the OpenFL Experimental Workflow Interface to train a vision transformer to classify images from the MedMNIST dataset. This notebook expands on the use case from the [first](https://github.com/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb) quick start notebook. Its objective is to demonstrate how a user can modify the workflow interface for different use cases

# Getting Started

First we start by installing the necessary dependencies for the workflow interface and the vision transformer

In [None]:
# !pip install git+https://github.com/intel/openfl.git
# !pip install -r requirements_workflow_interface.txt
# !pip install -r requirements_vision_transformer.txt

# Uncomment this if running in Google Colab
#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/Vision_Transformer/requirements_workflow_interface.txt
#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/Vision_Transformer/requirements_vision_transformer.txt

#import os
#os.environ["USERNAME"] = "colab"

# Setting up the experiment

For those of you who are familiar with a standard deep learning training pipeline, you may recognize that this section demonstrates many familiar steps such as setting up your data and defining your dataloader, model, parameters, helper functions, etc.

We start by importing the [MedMNIST](https://github.com/MedMNIST/MedMNIST/) package and defining our dataset. This cell will provide information about the package and list the available datasets. We will use the PathMNIST dataset. This is a colon pathology comprised of 107,180 unique 2D images. We will train our vision transformer to classify an individual image as one of 9 classes.

| <img src="images/pathmnist.png" width="1024"> | 
|:--:| 
| *Sample of images [(source)](https://medmnist.com/)* |

Set `data_flag` to choose a different dataset.

In [None]:
# https://github.com/MedMNIST/MedMNIST/blob/main/examples/getting_started.ipynb
import medmnist
from medmnist import INFO, Evaluator

print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

print('\n---- List of Available datasets ----\n')
for key in INFO:
    print(key)
    
print('\n------------------------------------\n')

data_flag = 'pathmnist'
print(f'Chosen dataset: {data_flag}')

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

Next, we will load our dataset and prepare it to be consumed by our model. We will be using the HuggingFace transformer library's implementation of the [vision transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit) pretrained on ImageNet-21k as the backbone of our network. To that end, we will use `ViTImageProcessor` which will provide the proper parameters needed to process and transform our dataset

In [None]:
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from transformers import ViTImageProcessor

import time
import numpy as np

# preprocessing
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

image_mean, image_std = processor.image_mean, processor.image_std
h = processor.size["height"]
w = processor.size["width"]

train_transforms = transforms.Compose([
    transforms.Resize([h, w]),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
    ])

test_transforms = transforms.Compose([
    transforms.Resize([h, w]),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
    ])


# load the data
medmnist_train = DataClass(split='train', transform=train_transforms, download=True)
medmnist_test = DataClass(split='test', transform=test_transforms, download=True)

# For demonstration purposes, we take a subset to reduce overall size and training time
##################
subset_indices = range(320)
medmnist_train = Subset(medmnist_train, subset_indices)
medmnist_test = Subset(medmnist_test, subset_indices)
##################

We now define our network and inference function. As previously noted, our network will use a pretrained vision transformer background `ViTModel`. We add a custom classification head, which will enable us to fine-tune our model on the chosen PathMNIST dataset

In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel


class CustomVisionTransformer(nn.Module):
    def __init__(self, num_classes):
        super(CustomVisionTransformer, self).__init__()
        self.backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.classifier = nn.Linear(self.backbone.config.hidden_size, num_classes)

    def forward(self, x):
        # Extract features from the transformer
        features = self.backbone(x)
        # Take the hidden state from the [CLS] token
        cls_token = features.last_hidden_state[:, 0, :]
        # Pass it through the classification head
        logits = self.classifier(cls_token)
        return logits
    
    
def inference(model, test_loader, criterion):
    model.eval()

    correct = 0
    test_loss = 0

    with torch.no_grad():
        for data, labels in test_loader:
            outputs = model(data)
            test_loss += criterion(outputs, labels.flatten())
            
            _, predicted = torch.max(outputs, 1)
            
            correct += (predicted == labels.flatten()).sum().item()
            
    test_loss /= len(test_loader.dataset)

    accuracy = float(correct / len(test_loader.dataset))
    return accuracy

# Setting up the OpenFL Workflow Interface

We will now set up the experimental OpenFL workflow interface in order to fine-tune our model in a horizontal federated learning framework. We import the `FLSpec`, `LocalRuntime`, and placement decorators.

- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.
- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.
- `aggregator/collaborator` - placement decorators that define where the task will be assigned

In [None]:
from copy import deepcopy

from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator


def FedAvg(models, weights=None):
    new_model = models[0]
    state_dicts = [model.state_dict() for model in models]
    state_dict = new_model.state_dict()
    for key in models[1].state_dict():
        state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],
                                                      axis=0, 
                                                      weights=weights))
    new_model.load_state_dict(state_dict)
    return new_model

Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with a base model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin. Throughout the process, we will save out the collaborator models as well as the final aggregated model.

| <img src="images/workflow.png" width="512"> | 
|:--:| 
| *General OpenFL Workflow Interface architecture* |

In [None]:
class FederatedFlow(FLSpec):
    def __init__(self, model, optimizer, criterion, rounds=2, epochs=3, **kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.rounds = rounds
        self.epochs = epochs

    @aggregator
    def start(self):
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.current_round = 0
        self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])

    @collaborator
    def aggregated_model_validation(self):
        print(f'Round: {self.current_round+1}\n-------------------------------')
        print(f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score = inference(self.model, self.test_loader, self.criterion)
        print(f'{self.input} value of {self.agg_validation_score}')
        self.next(self.train)

    @collaborator
    def train(self):
        if not os.path.exists(os.path.join('weights',f'{self.input}')):
            os.makedirs(os.path.join('weights',f'{self.input}'))
            
        best_acc = 0.0
        
        print(f"{self.input}")
        for t in range(self.epochs):
            for phase in ['train', 'val']:
                
                if phase == 'train':
                    self.model.train()
                    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.5)
                    train_loss = 0.0

                    for batch_idx, (images, labels) in enumerate(self.train_loader):
                        self.optimizer.zero_grad()
                        outputs = self.model(images)

                        loss = self.criterion(outputs, labels.flatten())
                        loss.backward()
                        self.optimizer.step()

                        train_loss += loss.item() * images.size(0)
                        data_size = len(self.train_loader)*images.size(0)
                        
                else:
                    self.local_validation_score = inference(self.model, self.test_loader, self.criterion)
            
            self.loss = train_loss/data_size
            print(f'Epoch {t+1} | Train Loss: {self.loss:.4f} | Local Acc: {self.local_validation_score:.4f}')

            if phase == 'val' and self.local_validation_score > best_acc:
                best_acc = self.local_validation_score
                torch.save(self.model.state_dict(), os.path.join('weights', f'{self.input}','model.pth'))
                torch.save(self.optimizer.state_dict(), os.path.join('weights', f'{self.input}','optimizer.pth'))
                
        self.training_completed = True
        self.next(self.local_model_validation)
        
    @collaborator
    def local_model_validation(self):
        self.local_validation_score = inference(self.model,self.test_loader, self.criterion)
        print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
        self.next(self.join, exclude=['training_completed'])

    @aggregator
    def join(self,inputs):
        self.average_loss = sum(input.loss for input in inputs)/len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = [input.optimizer for input in inputs][0]
        
        torch.save(self.model.state_dict(), os.path.join('weights', 'aggregated_model.pth'))
        torch.save(self.optimizer.state_dict(), os.path.join('weights', 'aggregated_optimizer.pth'))
        
        self.current_round += 1
        if self.current_round < self.rounds:
            self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])
        else:
            self.next(self.end)
        
    @aggregator
    def end(self):
        print(f'This is the end of the flow')  

You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only through the runtime. Each participant has its own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. 

Below, we segment shards of the PathMNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Portland. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa.  

You'll see that, for the sake of this demonstration, we simply sample an event amount of data from our main dataset and assign them to each collaborator. It is also here that we define `BATCH_SIZE`

In [None]:
BATCH_SIZE = 8

# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]

for idx, collaborator in enumerate(collaborators):
    train_subset_indices = np.array(range(idx,len(medmnist_train),len(collaborators)))
    local_train = Subset(medmnist_train, train_subset_indices)
    
    test_subset_indices = np.array(range(idx,len(medmnist_test),len(collaborators)))
    local_test = Subset(medmnist_test, test_subset_indices)
    collaborator.private_attributes = {
            'train_loader': DataLoader(dataset=local_train, batch_size=BATCH_SIZE, shuffle=True),
            'test_loader': DataLoader(dataset=local_test, batch_size=BATCH_SIZE, shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')
print(f'Local runtime collaborators = {local_runtime.collaborators}')

Now that we have our flow and runtime defined, let's run the experiment!

We will begin by defining a base model, optimizer, and loss function that'll be used by each collaborator. You may also define the number of rounds and epochs here if you do not wish to use the default values

In [None]:
model = CustomVisionTransformer(n_classes)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()

flflow = FederatedFlow(model=model, optimizer=optimizer, criterion=criterion)
flflow.runtime = local_runtime
flflow.run()

Now that the flow has completed, let's get the final model and accuracy. Note that the aggregated model accuracy was defined prior to the final training round. However, the saved out model is the final aggregated model during the `join` task

In [None]:
print(f'Sample of the final model weights: {flflow.model.state_dict()["classifier.weight"][0]}')

print(f'\nFinal aggregated model accuracy for {flflow.rounds} rounds of training: {flflow.aggregated_model_accuracy}')

We can get the final model, and all other aggregator attributes after the flow completes. But what if there's an intermediate model task and its specific output that we want to look at in detail? This is where **checkpointing** and reuse of Metaflow tooling come in handy.

Let's make a tweak to the flow object, and run the experiment one more time (we can even use our previous model / optimizer as a base for the experiment)

In [None]:
flflow2 = FederatedFlow(model=flflow.model, optimizer=flflow.optimizer, criterion=flflow.criterion, 
                        checkpoint=True)

flflow2.runtime = local_runtime
flflow2.run()

Now that the flow is complete, let's dig into some of the information captured along the way. 

**Note:** this required `checkpoint=True` to be set

In [None]:
run_id = flflow2._run_id

In [None]:
import metaflow
from metaflow import Metaflow, Flow, Task, Step

In [None]:
m = Metaflow()
list(m)

Let's look at the latest run that generated some results:

In [None]:
f = Flow('FederatedFlow').latest_run
list(f)

And its list of steps

In [None]:
s = Step(f'FederatedFlow/{run_id}/train')
list(s)

Now we see **4x** steps: **4** collaborators each performed **x** rounds of model training  

In [None]:
t = Task(f'FederatedFlow/{run_id}/train/3')
t

In [None]:
t.data

In [None]:
t.data.model

Now let's look at its log output (stdout)

In [None]:
print(t.stdout)

And any error logs? (stderr)

In [None]:
print(t.stderr)