# Federated MultiINPUT

In [None]:
# Install dependencies if not already installed
import os
import pandas as pd
import PIL
from PIL import Image

import tqdm
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import utils, transforms
import torch.nn.functional as F

import glob
import nibabel as nib
import time

from torch.utils.tensorboard import SummaryWriter

## Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
director_node_fqdn = 'ai2'
director_port=50051

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)


In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

## Describing FL experimen

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

## Load MedMNIST INFO

#### Parameters

In [None]:
batch_size = 8
experiment_name= "FL_MultiInput"
epochs = 200

myseed = 11
torch.manual_seed(myseed)
np.random.seed(myseed)
generator = torch.Generator()
generator.manual_seed(myseed)

num_classes=1


all_columns = ['AGE','PTGENDER','ADAS11', 'MMSE', 'FAQ', \
               'RAVLT_immediate', 'RAVLT_learning', 'RAVLT_forgetting', \
               'CDRSB', 'APOE4']

required_columns = ['AGE','PTGENDER','APOE4']

if len(required_columns) == len(all_columns):
    experiment_name = 'img_full10'
else:
    experiment_name = experiment_name + "_" + str(myseed) + '_img_' + ('_'.join(required_columns)).lower()
print(f"{experiment_name}")


### Register dataset

In [None]:
class TransformedDataset(Dataset):
    """Data extraction"""

    def __init__(self, input_dataframe, transform=None, required_columns=required_columns):
        """Initialize Dataset."""
        self.input_df = input_dataframe
        self.transform = transform

    def __len__(self):
        """Length of dataset."""
        return len(self.input_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = index.tolist()

        line = self.input_df[idx]#[idx, 0:]
        
        # Get Label
        y = line['labels']
        
        # Get tabular
        tabular = line[required_columns]
        tabular = torch.DoubleTensor(tabular)
                   
        # Get image
        img_path= os.path.join(line['IMG_PATH'])
        image = nib.load(img_path)
        image = image.get_fdata() 
        
        if self.transform is not None:
            img = self.transform(img)

        image = transforms.functional.to_tensor(image)
        image = image.unsqueeze(dim=0)
        
        return image, tabular, y

In [None]:
class MultiINPUTFedDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=None, required_columns = required_columns
        )       
        
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=None, required_columns = required_columns
        )
        
       
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=1, batch_size=self.kwargs['train_bs'], shuffle=True)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=1, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

### Create Mnist federated dataset

In [None]:
fed_dataset = MultiINPUTFedDataset(train_bs=8, valid_bs=8, test_bs=8)

## Describe the model and optimizer

## Multi_input net: 3D ResNet + MPL

In [None]:
def get_inplanes():
    return [64, 128, 256, 512]


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, block_inplanes, \
                 n_input_channels=3, conv1_t_size=7, \
                 conv1_t_stride=1, no_max_pool=False, \
                 shortcut_type='B', widen_factor=1.0, \
                 n_classes=400, img_contribution=10, \
                 tabular_val=10, tabular_contribution=10):
        super().__init__()

        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, img_contribution)
        
        #qui ho cambiato da num_classes a 10, per far cat con testo
        #TESTO
        self.relu = nn.ReLU()
        self.ln1 = nn.Linear(tabular_val, 50) #23 sono le colonne in input
        self.ln2 = nn.Linear(50, 50)
        self.ln3 = nn.Linear(50, tabular_contribution)
        self.ln4 = nn.Linear(tabular_contribution+img_contribution, n_classes) #20 perchè 10 derivano da img e 10 da tab
        

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.FloatTensor):
            zero_pads = zero_pads

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block,
                                     planes=planes * block.expansion,
                                     stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, tab):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        tab = self.ln1(tab)
        tab = self.relu(tab)
        tab = self.ln2(tab)
        tab = self.relu(tab)
        tab = self.ln3(tab)
        tab = self.relu(tab)
        
        x = torch.cat((x, tab), dim=1)
        x= self.relu(x)
        
        x = self.ln4(x)        

        return x


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def generate_model(model_depth, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)

    return model


model = generate_model(18, n_input_channels=1, widen_factor=1.0, \
                       n_classes=1, img_contribution=10, \
                       tabular_val=len(required_columns), tabular_contribution=10)
model = model.double()
print(model)


print('Total Parameters:',
      sum([torch.numel(p) for p in model.parameters()]))
print('Trainable Parameters:',
      sum([torch.numel(p) for p in model.parameters() if p.requires_grad]))

In [None]:
## Model Params
optimizer = optim.Adam(model.parameters(), lr = 0.0001)
## Define a loss 
criterion = nn.BCEWithLogitsLoss()

### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

## Define and register FL tasks

In [None]:
TI = TaskInterface()

train_custom_params={'criterion':criterion}

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**train_custom_params)
@TI.register_fl_task(model='model', data_loader='train_loader',
                     device='device', optimizer='optimizer')
def train(model, train_loader, device, optimizer, criterion):
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    total_loss, total_acc, total_samples = [],0,0

    model.train()
    model.to(device)
    
    for img, tabular, labels in train_loader:
        img, tabular, labels = torch.tensor(img).to(device), torch.tensor(tabular).to(device), torch.tensor(labels).to(device, dtype=torch.int64)
        optimizer.zero_grad()
        
        # Compute output
        pred = model(img, tabular)
        labels = labels.unsqueeze(1)
        labels = labels.float()
        loss = criterion(pred.float(), labels)
        loss.backward()
        optimizer.step()
        
        # update loss
        total_loss.append(loss.item())
        pred_labels = (pred >= 0).float() # Binarize predictions to 0 and 1
        batch_accuracy = (pred_labels == labels).sum().item()/img.size(0)
        # Update accuracy
        total_acc += batch_accuracy

    return {'train_loss': np.mean(total_loss), 
            'train_acc':  total_acc/len(train_loader)}


val_custom_params={'criterion':criterion}

@TI.add_kwargs(**val_custom_params)
@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device, criterion):

    val_loader = tqdm.tqdm(val_loader, desc="validate")
    total_loss, total_acc, total_samples = [],0,0
    
    model.eval()
    model.to(device)
    with torch.no_grad():
        for img, tabular, labels in val_loader:
            img, tabular, labels = torch.tensor(img).to(device), torch.tensor(tabular).to(device), torch.tensor(labels).to(device, dtype=torch.int64)

            # Compute output
            pred = model(img, tabular)
            labels = labels.unsqueeze(1)
            labels = labels.float()
            loss = criterion(pred.float(), labels)  
            
            # update loss
            total_loss.append(loss.item())
            pred_labels = (pred >= 0).float()
            
             # Binarize predictions to 0 and 1
            batch_accuracy = (pred_labels == labels).sum().item()/img.size(0)
            # Update accuracy
            total_acc += batch_accuracy

        return {'val_loss': np.mean(total_loss),
                'val_acc': total_acc/len(val_loader)}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=epochs,
                    opt_treatment='RESET',
                    device_assignment_policy='CUDA_PREFERRED',
                    pip_install_options="")

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
fl_experiment.restore_experiment_state(MI)
fl_experiment.stream_metrics(tensorboard_logs=True)

### 

In [None]:
#FLexperiment.get_best_model()