<a href="https://colab.research.google.com/github/rupeshgyawali/federated-covid-xray-detection/blob/main/federated/parallel_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Download and unzip dataset

In [None]:
# !gdown --id 1ZMgUQkwNqvMrZ8QaQmSbiDqXOWAewwou
# !unzip -q COVID-19_Radiography_Database.zip
!gdown --id 1bum9Sehb3AzUMHLhBMuowPKyr_PCrB3a
!unzip -q COVID-19_Radiography_Dataset.zip

Downloading...
From: https://drive.google.com/uc?id=1bum9Sehb3AzUMHLhBMuowPKyr_PCrB3a
To: /content/COVID-19_Radiography_Dataset.zip
100% 814M/814M [00:03<00:00, 234MB/s]


In [None]:
!rm -r COVID-19_Radiography_Dataset/**/masks/
!mv COVID-19_Radiography_Dataset/COVID/images/* COVID-19_Radiography_Dataset/COVID/
!mv COVID-19_Radiography_Dataset/Lung_Opacity/images/* COVID-19_Radiography_Dataset/Lung_Opacity/
!mv COVID-19_Radiography_Dataset/Normal/images/* COVID-19_Radiography_Dataset/Normal/
!mv COVID-19_Radiography_Dataset/Viral\ Pneumonia/images/* COVID-19_Radiography_Dataset/Viral\ Pneumonia
!rm -r COVID-19_Radiography_Dataset/**/images

### Install necessary libraries

In [None]:
!pip install flwr

Collecting flwr
  Downloading flwr-0.18.0-py3-none-any.whl (106 kB)
[K     |████████████████████████████████| 106 kB 4.3 MB/s 
Collecting grpcio<=1.43.0,>=1.27.2
  Downloading grpcio-1.43.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.1 MB)
[K     |████████████████████████████████| 4.1 MB 28.1 MB/s 
Collecting importlib-metadata<2.0.0,>=1.4.0
  Downloading importlib_metadata-1.7.0-py2.py3-none-any.whl (31 kB)
Installing collected packages: importlib-metadata, grpcio, flwr
  Attempting uninstall: importlib-metadata
    Found existing installation: importlib-metadata 4.11.3
    Uninstalling importlib-metadata-4.11.3:
      Successfully uninstalled importlib-metadata-4.11.3
  Attempting uninstall: grpcio
    Found existing installation: grpcio 1.44.0
    Uninstalling grpcio-1.44.0:
      Successfully uninstalled grpcio-1.44.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the

### Import necessary libraries

In [None]:
import os
import math
import time
from multiprocessing import Process

import flwr as fl
import tensorflow as tf
from flwr.server.strategy import FedAvg

### Get Compiled Model

In [None]:
def get_compiled_model():
    num_classes = 4
    # Create model
    model = tf.keras.models.Sequential([
      tf.keras.layers.Rescaling(1./256, input_shape=(256, 256, 3)),
      tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPool2D(),
      tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPool2D(),
      tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPool2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(num_classes),
    ])
    model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
    
    return model

### Prepare Dataset

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory('/content/COVID-19_Radiography_Dataset', seed=123, validation_split=0.2, subset="training")
test_ds = tf.keras.utils.image_dataset_from_directory('/content/COVID-19_Radiography_Dataset', seed=123, validation_split=0.2, subset="validation")

Found 21165 files belonging to 4 classes.
Using 16932 files for training.
Found 21165 files belonging to 4 classes.
Using 4233 files for validation.


### Partition Dataset

In [None]:
def partition_dataset(n_partition, partition_index):
  return train_ds.shard(n_partition, partition_index), test_ds.shard(n_partition, partition_index)

### Federated Client

In [None]:
class FederatedClient(fl.client.NumPyClient):
    def __init__(self, model, train_ds, test_ds) -> None:
        self.model = model
        self.train_ds = train_ds
        self.test_ds = test_ds

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.train_ds, validation_data=self.test_ds, epochs=1)
        return self.model.get_weights(), len(self.train_ds), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.test_ds, self.test_ds)
        return loss, len(self.test_ds), {"accuracy": acc}

#### Start client

In [None]:
def start_client(dataset, model):

    # dataset
    train, test = dataset

    # client
    client = FederatedClient(model, train, test)

    # start client
    fl.client.start_numpy_client("0.0.0.0:5700", client=client)

### Start Server

In [None]:
def start_server(num_rounds, num_clients, fraction_fit):
    strategy = FedAvg(min_available_clients=2, fraction_fit=fraction_fit)
    # Exposes the server by default on port 8080
    fl.server.start_server(
        server_address = '[::]:5700',
        strategy=strategy,
        config={"num_rounds": num_rounds},
    )

### Simulation

In [None]:
def run_simulation(num_rounds: int, num_clients: int, fraction_fit: float):
    processes = []

    # Start the server
    server_process = Process(
        target=start_server, args=(num_rounds, num_clients, fraction_fit)
    )
    server_process.start()
    processes.append(server_process)

    # Optionally block the script here for a second or two so the server has time to start
    time.sleep(2)

    # Load the dataset partitions
    # partitions = dataset.load(num_partitions=num_clients)

    # Start all the clients
    # for partition in partitions:
    #     client_process = Process(target=start_client, args=(partition,))
    #     client_process.start()
    #     processes.append(client_process)

    for i in range(num_clients):
      dataset = partition_dataset(num_clients, i)
      model = get_compiled_model()
      client_process = Process(target=start_client, args=(dataset, model))
      client_process.start()
      processes.append(client_process)

    # Block until all processes are finished
    for p in processes:
        p.join()

In [None]:
run_simulation(num_rounds=1, num_clients=4, fraction_fit=0.5)