# Federated TxTOnly

In [None]:
# Install dependencies if not already installed
import os
import pandas as pd
import PIL
from PIL import Image

import tqdm
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import utils, transforms
import torch.nn.functional as F

import glob
import nibabel as nib
import time

from torch.utils.tensorboard import SummaryWriter

## Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
director_node_fqdn = 'ai2'
director_port=50051

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)


In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

## Describing FL experimen

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

## Load MedMNIST INFO

#### Parameters

In [None]:
batch_size = 8
experiment_name= "FL_txt_only"
epochs = 200

myseed = 14
torch.manual_seed(myseed)
np.random.seed(myseed)
generator = torch.Generator()
generator.manual_seed(myseed)

num_classes=1

all_columns = ['AGE','PTGENDER','ADAS11', 'MMSE', 'FAQ', \
               'RAVLT_immediate', 'RAVLT_learning', 'RAVLT_forgetting', \
               'CDRSB', 'APOE4']

required_columns = ['AGE','PTGENDER','APOE4']

if len(required_columns) == len(all_columns):
    experiment_name = 'img_full10'
else:
    experiment_name = experiment_name + "_" + str(myseed)+ "_" +('_'.join(required_columns)).lower()
print(f"{experiment_name}")


### Register dataset

In [None]:
class TransformedDataset(Dataset):
    """Data extraction"""

    def __init__(self, input_dataframe, transform=None, required_columns=required_columns):
        """Initialize Dataset."""
        self.input_df = input_dataframe
        self.transform = transform

    def __len__(self):
        """Length of dataset."""
        return len(self.input_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = index.tolist()

        line = self.input_df[idx]
        
        # Get Label
        y = line['labels']
        
        # Get tabular
        tabular = line[required_columns]
        tabular = torch.DoubleTensor(tabular)

        return tabular, y

In [None]:
class MultiINPUTFedDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            required_columns=required_columns,
            transform=None
        )       
        
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            required_columns=required_columns,
            transform=None
        )
       
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=1, batch_size=self.kwargs['train_bs'], shuffle=True)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=1, batch_size=self.kwargs['valid_bs'])
    
    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

### Create Mnist federated dataset

In [None]:
TEST_fed_dataset = MultiINPUTFedDataset(train_bs=8, valid_bs=8, test_bs=8)
from walter_sd_test import MultiINPUTShardDescriptor as misd
adni_num=1
TEST_fed_dataset.shard_descriptor = misd(adni_num=adni_num,
                                    data_dir= f'/home/user1/fast_storage/a{adni_num}',
                                    img_dir= f'ADNI{adni_num}_ALL_T1',
                                    csv_path= '/home/user1/fast_storage/ADNI_csv')

for i, (sample, target) in enumerate(TEST_fed_dataset.get_train_loader()):
    if not i == 1:
        print(sample, target)
        print(sample.shape, target.shape)

In [None]:
fed_dataset = MultiINPUTFedDataset(train_bs=8, valid_bs=8, test_bs=8)

## Describe the model and optimizer

## IMG-Only Net

In [None]:
class TextNN(nn.Module):

    #Constructor
    def __init__(self, num_variables):
    # Call parent contructor
        super().__init__()
        #torch.manual_seed(myseed)
        self.relu = nn.ReLU()
        self.ln1 = nn.Linear(num_variables, 50) #num_variables sono le colonne in input
        self.ln2 = nn.Linear(50, 50)
        self.ln3 = nn.Linear(50, 10)
        self.ln4 = nn.Linear(10, 1)
    
    def forward(self, tab):
        tab = self.ln1(tab)
        tab = self.relu(tab)
        tab = self.ln2(tab)
        tab = self.relu(tab)
        tab = self.ln3(tab)
        tab = self.relu(tab)
        tab = self.ln4(tab)

        return tab

model = TextNN(len(required_columns)) # required_columns - label column
model = model.double()
print(model)

print('Total Parameters:',
      sum([torch.numel(p) for p in model.parameters()]))
print('Trainable Parameters:',
      sum([torch.numel(p) for p in model.parameters() if p.requires_grad]))

In [None]:
## Model Params
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
## Define a loss 
criterion = nn.BCEWithLogitsLoss()

### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

## Define and register FL tasks

In [None]:
TI = TaskInterface()

train_custom_params={'criterion':criterion}

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**train_custom_params)
@TI.register_fl_task(model='model', data_loader='train_loader',
                     device='device', optimizer='optimizer')
def train(model, train_loader, device, optimizer, criterion):
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    total_loss, total_acc, total_samples = [],0,0
    #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=epochs, steps_per_epoch=len(train_loader))    
    model.train()
    model.to(device)
    
    for tab, labels in train_loader:
        tab, labels = torch.tensor(tab).to(device), torch.tensor(labels).to(device, dtype=torch.int64)
        optimizer.zero_grad()
        
        # Compute output
        pred = model(tab)
        labels = labels.unsqueeze(1)
        labels = labels.float()
        loss = criterion(pred.float(), labels)
        loss.backward()
        optimizer.step()
        
        # update loss
        total_loss.append(loss.item())
        pred_labels = (pred >= 0).float() # Binarize predictions to 0 and 1
        batch_accuracy = (pred_labels == labels).sum().item()/tab.size(0)
        # Update accuracy
        total_acc += batch_accuracy

    return {'train_loss': np.mean(total_loss), 
            'train_acc': total_acc/len(train_loader),}


val_custom_params={'criterion':criterion}

@TI.add_kwargs(**val_custom_params)
@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device, criterion):

    val_loader = tqdm.tqdm(val_loader, desc="validate")
    total_loss, total_acc, total_samples = [],0,0
    
    model.eval()
    model.to(device)
    with torch.no_grad():
        for tab, labels in val_loader:           
            tab, labels = torch.tensor(tab).to(device), torch.tensor(labels).to(device, dtype=torch.int64)

            # Compute output
            pred = model(tab)
            labels = labels.unsqueeze(1)
            labels = labels.float()
            loss = criterion(pred.float(), labels)  
            
            # update loss
            total_loss.append(loss.item())
            pred_labels = (pred >= 0).float()
            
             # Binarize predictions to 0 and 1
            batch_accuracy = (pred_labels == labels).sum().item()/tab.size(0)
            # Update accuracy
            total_acc += batch_accuracy

        return {'val_loss': np.mean(total_loss),
                'val_acc': total_acc/len(val_loader),}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=epochs,
                    opt_treatment='RESET',
                    device_assignment_policy='CUDA_PREFERRED',
                    pip_install_options="")

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
fl_experiment.restore_experiment_state(MI)
fl_experiment.stream_metrics(tensorboard_logs=True)

### 

In [None]:
#FLexperiment.get_best_model()