# Federated MedMNIST3D 

In [1]:
# Install dependencies if not already installed
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torch.nn.functional as F

import medmnist

## Connect to the Federation

In [2]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port=50051
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(
#     client_id=client_id,
#     director_node_fqdn=director_node_fqdn,
#     director_port=director_port,
#     cert_chain=cert_chain,
#     api_cert=api_certificate,
#     api_private_key=api_private_key
# )

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)


In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

{'one': {'shard_info': node_info {
    name: "one"
  }
  shard_description: "MedMNIST dataset, shard number 1 out of 1"
  sample_shape: "28"
  sample_shape: "28"
  sample_shape: "28"
  target_shape: "1"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-05-25 22:50:19',
  'current_time': '2022-05-25 22:50:26',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [4]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (28, 28, 28), target shape: (1, 1)'

## Describing FL experimen

In [5]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

## Load MedMNIST INFO

In [6]:
from medmnist import INFO, Evaluator

data_flag = 'synapsemnist3d'
#data_flag = 'bloodmnist'
download = True

num_epochs = 3
batch_size = 128

lr = 0.001
gamma=0.1
milestones = [0.5 * num_epochs, 0.75 * num_epochs]

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
info

{'python_class': 'SynapseMNIST3D',
 'description': 'The SynapseMNIST3D is a new 3D volume dataset to classify whether a synapse is excitatory or inhibitory. It uses a 3D image volume of an adult rat acquired by a multi-beam scanning electron microscope. The original data is of the size 100×100×100um^3 and the resolution 8×8×30nm^3, where a (30um)^3 sub-volume was used in the MitoEM dataset with dense 3D mitochondria instance segmentation labels. Three neuroscience experts segment a pyramidal neuron within the whole volume and proofread all the synapses on this neuron with excitatory/inhibitory labels. For each labeled synaptic location, we crop a 3D volume of 1024×1024×1024nm^3 and resize it into 28×28×28 voxels. Finally, the dataset is randomly split with a ratio of 7:1:2 into training, validation and test set.',
 'url': 'https://zenodo.org/record/5208230/files/synapsemnist3d.npz?download=1',
 'MD5': '1235b78a3cd6280881dd7850a78eadb6',
 'task': 'binary-class',
 'label': {'0': 'inhibit

### Register dataset

In [9]:
from wspace_utils.utils import Transform3D, model_to_syncbn

shape_transform = False

train_transform = Transform3D(mul='random') if shape_transform else Transform3D()
eval_transform = Transform3D(mul='0.5') if shape_transform else Transform3D()


In [10]:
from PIL import Image

class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""


    def __init__(self, dataset, transform=None, target_transform=None, as_rgb=False):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        self.as_rgb = as_rgb

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
                
        img, label = self.dataset[index]
        
        if self.target_transform:
            label = self.target_transform(label)  
        else:
            label = label.astype(int)
        
        # Change between 2D and 3D
        img = np.stack([img/255.]*(3 if self.as_rgb else 1), axis=0)
        
        if self.transform is not None:
            img = self.transform(img)

        #print(f"ECCOMI dtype = {img.dtype}, type = {type(img)}")

        if self.target_transform is not None:
            target = self.target_transform(target)
        
        #img = np.array(img) 
        
        return img, label


In [11]:
class MedMnistFedDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=train_transform, as_rgb=False
        )       
        
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=eval_transform, as_rgb=False
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

### Create Mnist federated dataset

In [12]:
fed_dataset = MedMnistFedDataset(train_bs=64, valid_bs=512)

In [13]:
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    sample = (np.array(sample))
    print(sample.shape, target.shape)

(10, 1, 28, 28, 28) torch.Size([10, 1, 1])


In [14]:
print(f"dtype = {sample.dtype}, type = {type(sample)}")


dtype = float32, type = <class 'numpy.ndarray'>


## Describe the model and optimizer

In [15]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.bn2 = nn.GroupNorm(num_groups=2, num_channels=planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
                # nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.bn2 = nn.GroupNorm(num_groups=2, num_channels=planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
        # self.bn3 = nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
                # nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = F.adaptive_avg_pool3d(out, output_size=4)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def ResNet50(in_channels, num_classes):
    return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)

In [16]:
from acsconv.converters import ACSConverter, Conv3dConverter, Conv2_5dConverter

model = ResNet18(in_channels=n_channels, num_classes=n_classes)
model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=None))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

The ``converters`` are currently experimental. It may not support operations including (but not limited to) Functions in ``torch.nn.functional`` that involved data dimension


In [17]:
from torchvision import models

print(model)

Conv3dConverter(
ResNet(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn1): SynchronizedBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): SynchronizedBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): SynchronizedBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): SynchronizedBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), 

### Register model

In [18]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model.model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model.model)

In [19]:
type(model)

acsconv.converters.conv3d_converter.Conv3dConverter

## Define and register FL tasks

In [20]:
TI = TaskInterface()

from logging import getLogger

import torch
import tqdm

logger = getLogger(__name__)

train_custom_params={'criterion':criterion,'task':task}

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**train_custom_params)
@TI.register_fl_task(model='model', data_loader='train_loader',
                     device='device', optimizer='optimizer')
def train(model, train_loader, device, optimizer, criterion, task):
    total_loss = []
    
    print(f"CREATING TQDM")
    train_loader = tqdm.tqdm(train_loader, desc="train")
    print(f"ENTERING TRAIN")
    model.train()
    print(f"MODEL TO DEVICE")
    model.to(device)
    
    print(f"ENTERING THE LOOP")
    for inputs, targets in train_loader:
        
        inputs, targets = inputs.to(device), targets.to(device)
        print(f"TO DEVICES DONE")
    
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32).to(device)
            loss = criterion(outputs, targets)
        else:
            targets = torch.squeeze(targets, 1).long().to(device)
            loss = criterion(outputs, targets)
        
        total_loss.append(loss.item())
        
        loss.backward()
        optimizer.step()
    
    return {'train_loss': np.mean(total_loss),}



val_evaluator = medmnist.Evaluator(data_flag, 'val')
val_custom_params={'criterion':criterion, 'evaluator':val_evaluator}

@TI.add_kwargs(**val_custom_params)
@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device, criterion, evaluator):

    val_loader = tqdm.tqdm(val_loader, desc="validate")
    model.eval()
    model.to(device)

    val_score = 0
    total_samples = 0
    total_loss = []
    y_score = torch.tensor([]).to(device)

    print(f"ENTERING THE LOOP")
    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model(inputs.to(device))
            print(f"OUTPUT DONE")
            
            #targets = torch.squeeze(targets, 1).long().to(device)
            #loss = criterion(outputs, targets)
            #m = nn.Softmax(dim=1)
            #outputs = m(outputs).to(device)
            #targets = targets.float().resize_(len(targets), 1)

            #total_loss.append(loss.item())
            #y_score = torch.cat((y_score, outputs), 0)
            #------------
            #total_samples += targets.shape[0]
            #pred = outputs.argmax(dim=1)
            #val_score += pred.eq(targets).sum().cpu().numpy()
            
        #y_score = y_score.detach().cpu().numpy()
        #auc, acc = evaluator.evaluate(y_score)
        
        #acc = val_score / total_samples        
        #test_loss = sum(total_loss) / len(total_loss)
        acc = 1.0
        test_loss = 1.0

        return {'acc': acc,
                'test_loss': test_loss,
                }

In [21]:
#print(arr.flags['WRITEABLE'])

## Time to start a federated learning experiment

In [22]:
# create an experimnet in federation
experiment_name = 'medmnist_exp'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [23]:
# If I use autoreload I got a pickling error
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=3,
                    opt_treatment='RESET',
                    device_assignment_policy='CUDA_PREFERRED')

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(model_interface)

fl_experiment.stream_metrics(tensorboard_logs=False)

### 