# MNIST classifier in IBM FL

## Outline:
- [Federated Learning(FL)](#intro)
- [Digit Recognition](#mnist)
- [Parties](#Parties)
    - [Party Configuration](#Party-Configuration)
    - [Party Setup](#Party-Setup)
- [Register All Parties Before Starting Training](#Register-All-Parties-Before-Starting-Training)
- [Visualize Results](#Visualize-Results)
- [Shut Down](#Shut-Down)

## Federated Learning (FL) <a name="intro"></a>

**Federated Learning (FL)** is a distributed machine learning process in which each participant node (or party) retains their data locally and interacts with  other participants via a learning protocol. 
One main driver behind FL is the need to not share data with others  due to privacy and confidentially concerns.
Another driver is to improve the speed of training a machine learning model by leveraging other participants' training processes.

Setting up such a federated learning system requires setting up a communication infrastructure, converting machine learning algorithms to federated settings and in some cases knowing about the intricacies of security and privacy enabling techniques such as differential privacy and multi-party computation. 

In this Notebook we use [IBM FL](https://github.com/IBM/federated-learning-lib) to have multiple parties train a classifier to recognise handwritten digits in the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). 

For a more technical dive into IBM FL, refer the whitepaper [here](https://arxiv.org/pdf/2007.10987.pdf).

In the following cells, we set up each of the components of a Federated Learning network (See Figure below) wherein all involved parties aid in training their respective local cartpoles to arrive at the upright pendulum state. In this notebook we default to 5 parties, but depending on your resources you may use fewer parties.

<img src="images/FL_Network.png" width="720"/>
<figcaption><center>Modified from Image Source: <a href="https://arxiv.org/pdf/2007.10987.pdf">IBM Federated Learning: An Enterprise FrameworkWhite Paper V0.1</a></center></figcaption>

## Digit Recognition <a name="mnist"></a>



<img src="images/MnistExamples.png" width="512"/>
<figcaption><center>Image Source: Josef Steppan / CC BY-SA <a href="https://creativecommons.org/licenses/by-sa/4.0">Wikimedia Commons</a></center></figcaption>

The problem at hand is to recognize digits from these tens of thousands of handwritten images. 

### Getting things ready

We begin by setting the number of parties that will participate in the federated learning run and splitting up the data among them.

In [None]:
import sys
party_id = 0

sys.path.append('.')
import os
os.chdir(".")

dataset = 'mnist'
import socket    
hostname = socket.gethostname()    
IPAddr = socket.gethostbyname(hostname)

## Parties

Each party holds its own dataset that is kept to itself and used to answer queries received from the aggregator. Because each party may have stored data in different formats, FL offers an abstraction called Data Handler. This module allows for custom implementations to retrieve the data from each of the participating parties. A local training handler sits at each party to control the local training happening at the party side. 

### Party Configuration

**Note**: in a typical FL setting, the parties may have very different configurations from each other. However, in this simplified example, the config does not differ much across parties. So, we first setup the configuration common to both parties, in the next cell. We discuss the parameters that are specific to each, in subsequent cells.

<img src="images/arch_party.png" width="680"/>
<figcaption><center>Image Source: <a href="https://arxiv.org/pdf/2007.10987.pdf">IBM Federated Learning: An Enterprise FrameworkWhite Paper V0.1</a></center></figcaption>

### Party Setup
In the following cell, we setup configurations for parties, including network-level details, hyperparameters as well as the model specifications. Please note that if you are running this notebook in distributed environment on separate nodes then you need to split the data locally and obtain the model h5 generated by the Aggregator.

#### Building Blocks in Party Configuration:

Now we configure party specific configurations in the `get_party_config` method, which specifies model related configurations as well as other parameters necessary for the federated learning setup. The model related configurations are identical to those generated (yaml files) during the terminal run.

Once these are done, we invoke them for each party, in the subsequent cell.

- `local_training_handler`: handles the train and eval commands, also initializes data, environment and models

- `protocol_handler`: party protocol handler communicates with the Aggregator bridging between Aggregator and local training handler

- `aggregator`: IP and port at which the Aggregator is running, so the party may connect to it

- `data`: information needed to initiate a data handler class; includes a given data path, a data handler class name, and a data handler file location

- `model`: details about the model, including name, the model class file location, i.e., path, and the given model specification path, i.e., spec. In this example, we use the Keras FL Model class, as indicated under the spec

**Note**: in a typical FL setting, the parties may have very different configurations from each other. However, in this simplified example, the config does not differ much across parties. Also, as of this release all parties registered with the aggregator participate in the training. Dynamically letting registered parties to skip training in certain rounds will be supported in subsequent versions.

### Step __: Run the cells in this section to choose your training data samples

In [None]:
from ipywidgets import Layout, Box, VBox, HTML, HBox, GridBox
import ipywidgets as widgets
import subprocess
## Run this cell to choose between selecting digits or generate data randomly from dataset
choice_header = HTML(value='<{size}>How would you like to generate the training data?'.format(size='h4'),
               layout=Layout(width='auto', grid_area='header'))
choice_radio = widgets.RadioButtons(
    options=['Select Digits', 'Randomly from Dataset'],
    value= 'Randomly from Dataset', # Defaults to this
#    layout={'width': 'max-content'}, # If the items' names are long
    disabled=False,
    grid_area='selection'
)

choice_made = choice_radio.value

def choice_radio_eventhandler(change):
#     print(change)
    global choice_made
    choice_made = change.new

choice_radio.observe(choice_radio_eventhandler, names='value')

GridBox(children=[choice_header, choice_radio],
       layout = Layout(
           width='auto',
           grid_template_rows='auto',
           grid_template_columns='98%',
           grid_template_areas='''
           "header"
           "selection"
            ''')
       )

In [None]:
## Run this cell to choose digits or generate data depending on the choice you made
############## UI elements for selectnig digits ##############
import numpy as np
select_header = HTML(value='<{size}>Select label(s) you\'d like in your dataset'.format(size='h4'),
               layout=Layout(width='auto', grid_area='header'))
select_label_checkbox = widgets.Box([
    widgets.Checkbox(
        value=False,
        description='Digit 0',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 1',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 2',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 3',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 4',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 5',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 6',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 7',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 8',
        disabled=False,
        indent=False
    ),
    widgets.Checkbox(
        value=False,
        description='Digit 9',
        disabled=False,
        indent=False
    )
])
selected = []
def checkbox_eventhandler(change):
#     print(change)
    if change.new:
        selected.append(int(change.owner.description.split(' ')[1].strip()))
    
[select_label_checkbox.children[i].observe(checkbox_eventhandler, names='value') for i in range(len(select_label_checkbox.children))]

def update_checkboxes(*args):
    for i in np.random.permutation(10)[:3]:
        select_label_checkbox.children[i].value = True
    
update_checkboxes()
####################################################
## for the 'randomly generating data from MNIST'
info_header = HTML(value='<{size}><center>Head to the _train step_ as you now have the data to train!'.format(size='h4'),
               layout=Layout(width='auto', grid_area='header'))

# print(choice_made)

if choice_made == 'Randomly from Dataset':    
    ## invoke generate_data.py with the usual parameters
    cmd_to_run = 'python generate_data.py -n 1 --dataset mnist -pp 200'
#     print('Executing {}'.format(cmd_to_run))
    process = subprocess.run(cmd_to_run, shell=True,
                         stdout=subprocess.PIPE, 
                         stderr=subprocess.PIPE,
                         universal_newlines=True)
    print(process.stdout)
    present = GridBox(children=[info_header],
              layout = Layout(
              width='auto',
              grid_template_rows='auto',
              grid_template_columns='98%',
              grid_template_areas='''
              "header"
              ''')
         )
    display(present)

elif choice_made == 'Select Digits':
    ## invoke the checkbox widgets here:
    present = GridBox(children=[select_header, select_label_checkbox],
            layout = Layout(
                width='auto',
               grid_template_rows='auto',
               grid_template_columns='98%',
               grid_template_areas='''
               "header"
               "selection"
                ''')
           )
    display(present)

In [None]:
## Run this cell to generate data if you chose 'Select Digits' earlier
if choice_made == 'Select Digits':    
    # print(selected)
    arg_str = ''
    for each in selected:
        arg_str = arg_str + str(each) + ' '

    arg_str = arg_str.strip()
    # print(arg_str)

    cmd_to_run = 'python generate_data.py -n 1 --dataset mnist -pp 200 --labels ' + arg_str
#     print('Executing {}'.format(cmd_to_run))
    process = subprocess.run(cmd_to_run, shell=True,
                             stdout=subprocess.PIPE, 
                             stderr=subprocess.PIPE,
                             universal_newlines=True)
    print(process.stdout)

In [None]:
def get_party_config(party_id, IPAddr):
    party_config = {
        'aggregator':
            {
                'ip': '172.30.110.3',
                'port': 5000
            },
        'connection': {
            'info': {
                'ip': IPAddr,
                'port': 8085 + party_id,
                'id': 'party' + str(party_id),
                'tls_config': {
                    'enable': False
                }
            },
            'name': 'FlaskConnection',
            'path': 'ibmfl.connection.flask_connection',
            'sync': False
        },
        'data': {
            'info': {
                'npz_file': 'examples/data/mnist/random/data_party'+ str(party_id) +'.npz'
            },
            'name': 'MnistKerasDataHandler',
            'path': 'ibmfl.util.data_handlers.mnist_keras_data_handler'
        },
        'local_training': {
            'name': 'LocalTrainingHandler',
            'path': 'ibmfl.party.training.local_training_handler'
        },
        'model': {
            'name': 'KerasFLModel',
            'path': 'ibmfl.model.keras_fl_model',
            'spec': {
                'model_definition': 'examples/configs/keras_classifier/compiled_keras.h5',
                'model_name': 'keras-cnn'
            }
        },
        'protocol_handler': {
            'name': 'PartyProtocolHandler',
            'path': 'ibmfl.party.party_protocol_handler'
        }
    }
    return party_config


### Running the Party

Now, we invoke the `get_party_config` function to setup party and `start()` it.

Finally, we register the party with the Aggregator.

In [None]:
from ibmfl.party.party import Party
import tensorflow as tf

party_config = get_party_config(party_id, IPAddr)
party = Party(config_dict=party_config)
party.start()
party.register_party()

## Register All Parties Before Starting Training

Now we have started and registered this Party. Next, we will start and register rest of the parties. Once all the Parties have registered we will go back to the Aggregator's notebook to start training.

## Visualize Results

Here we plot the summary graphs from each party's training.

In [None]:
import keras
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
_, (X_test, Y_test) = mnist.load_data()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import keras
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score
sample_count = 100
num_parties = 1
for i in range(10):
   test_filter = np.where(Y_test == i)
   X_test1, Y_test1 = X_test[test_filter], Y_test[test_filter]
   #labels, counts = np.unique(Y_test1, return_counts=True)
   np.random.seed(123)
   rand_v = np.random.randint(0, X_test1.shape[0],sample_count)
   test_digits = X_test1[rand_v]
   test_labels = Y_test1[rand_v]
   labels, counts = np.unique(test_labels, return_counts=True)
   #print('Original lable', labels)
   acc_list = np.array([])
   y_true = np.full(sample_count, i)

       
   y_pred = np.array([])
   for i_samples in range(sample_count):
      pred = party.fl_model.predict(test_digits[i_samples].reshape(1, 28, 28, 1))
      y_pred = np.append(y_pred, pred.argmax())
   acc = accuracy_score(y_true, y_pred) * 100
   #print('y prediction',y_pred)
   #print('y true',y_true)
   #print('accuracy',acc)
   acc_list = np.append(acc_list,acc)

   ind = np.arange(num_parties) 
   fig2 = plt.figure(constrained_layout=True,figsize=(10,10))
   fig2.tight_layout()
   spec2 = gridspec.GridSpec(ncols=6, nrows=4, figure=fig2) 
   f2_ax1 = fig2.add_subplot(spec2[0:1, 0:2])
   f2_ax1.imshow(test_digits[0], cmap='gray')
   plt.axis('off')
   f2_ax2 = fig2.add_subplot(spec2[0:1, 3:])
   labels = [(i+1) for i in range(num_parties)] 
   
   x = 3
  

   rects1 = f2_ax2.bar(x, acc_list, width=1)
   # Add some text for labels, title and custom x-axis tick labels, etc.
   f2_ax2.set_xlabel('Party ' + str(party_id))
   f2_ax2.set_ylabel('Average Prediction Accuracy \n over '+str(sample_count)+' samples')
   f2_ax2.set_title('Average Prediction Accuracy by parties for label '+str(i))
   f2_ax2.set_xlim(0,6)
   f2_ax2.set_ylim(0,100)
   f2_ax2.set_xticklabels(labels)

## Shut Down

Invoke the `stop()` method on each of the network participants to terminate the service.

In [None]:
party.stop()