# Lung Segmentation Pegasus Workflow

Precise detection of the borders of organs and lesions in medical images such as X-rays, CT, or MRI scans is an essential step towards correct diagnosis and treatment planning. We implement a workflow that employs supervised learning techniques to locate lungs on X-ray images. Lung instance segmentation workflow uses [Chest X-ray](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4256233/) for predicting lung masks from the images using [U-Net](https://arxiv.org/abs/1505.04597) model.

The workflow uses a **Chest X-ray Masks** and Labels dataset (high-resolution X-ray images and masks) availabe publicly. The dataset is split into training, validation, and test sets before the workflow starts. Each set consists of original lung images and their associated masks. The **Pre-processing** step and Data Augmentation of Images is done to resize images (lungs and masks) and normalize lung X-rays. Additionally, for each pair of lung image and mask in the train dataset, two new pairs are generated through **image augmentation** (e.g., rotations, flips). Next, the train and validation data are passed to the UNet **hyperparameter optimization** step, where different learning rates are explored. The **training** of UNet fine-tunes the UNet model with the recommended learning rate on the concatenated train and validation set, and obtains the weights. Then **inference** on Unet is done using the trained model to generate masks for the test X-ray images. Finally, the **evaluation** is performed in order to generate a PDF file with the scores for relevant performance metrics and prints examples of lung segmentation images produced by the model.

![Lung Segmentation](img/segmentation.png)

**Machine Learning steps in the workflow :**
<br>
<img src="img/ml_steps.png" style="width: 850px;"/>
<br>

## Container
All tools required to execute the jobs are all included in the container available on Dockerhub :
<br>[Lung Segmentation Container](https://hub.docker.com/r/papajim/lung-segmentation) which runs on python and uses  machine learning libraries defined in `Docker/Dockerfile` as -
* scikit-learn 
* tensorflow==2.1.0
* h5py 
* numpy==1.18.4 
* pandas 
* opencv-python 
* keras==2.3.1 
* optuna 
* segmentation_models
* matplotlib

## Input Data
Sample input data has been provided in `inputs` containing images and masks for training and testing.
<br>`inputs/train_images` **:** consists of 512x512 chest x-ray images for training
<br>`inputs/train_masks` **:** consists of 512x512 lung masks for training
<br>`inputs/test_images` **:** consists of 256x256 chest x-ray images for testing


## Workflow
The workflow pre-processes the input data and then trains machine learning model to automatically predict lung masks.

<img src="img/workflow.png" style="width: 600px;"/>

<br>The descriptions for various jobs in the worklfow are listed in a table below

| Job Label         | Description                                              |
| ------------------|----------------------------------------------------------|
| preprocess_test   | data preprocessing for the testing set of x-ray images   |
| preprocess_val    | data preprocessing for the validation set of x-ray images|
| hpo               | hyperparameter optimization step for UNet model          |
| train_model       | training the UNet model and fine-tuning it               |
| predict_masks     | predicting the lung masks                                |
| evaluate          | generates scores for relevant performance metrics        |


## 1. Create the Lung Segmentation Workflow

By now, you have a good idea about the Pegasus Workflow API.
We now create the workflow for the Lung segmentation based on the picture above.

All workflow parameters are have been set along with input dataset values. This workflow is running on the sample dataset, which is included in the repository under `inputs` directory. The workflow parameters and input files location are set in the beginning of the workflow.

In [None]:
import logging as log
import math
import sys, os
from argparse import ArgumentParser
from pathlib import Path
import pandas as pd
import json
import random
import numpy as np

log.basicConfig(level=log.INFO)

# --- Import Pegasus API -----------------------------------------------------------
from Pegasus.api import *

# --- Top Directory Setup ----------------------------------------------------------
top_dir = Path(__file__).parent.resolve()


######################## WORKFLOW PARAMETERS ########################
IGNORE_IMAGES = {'CHNCXR_0025_0.png', 'CHNCXR_0036_0.png', 'CHNCXR_0037_0.png', 'CHNCXR_0038_0.png', 'CHNCXR_0039_0.png', 'CHNCXR_0040_0.png', 'CHNCXR_0065_0.png', 'CHNCXR_0181_0.png', 'CHNCXR_0182_0.png', 'CHNCXR_0183_0.png', 'CHNCXR_0184_0.png', 'CHNCXR_0185_0.png', 'CHNCXR_0186_0.png', 'CHNCXR_0187_0.png', 'CHNCXR_0188_0.png', 'CHNCXR_0189_0.png', 'CHNCXR_0190_0.png', 'CHNCXR_0191_0.png', 'CHNCXR_0192_0.png', 'CHNCXR_0193_0.png', 'CHNCXR_0194_0.png', 'CHNCXR_0195_0.png', 'CHNCXR_0196_0.png', 'CHNCXR_0197_0.png', 'CHNCXR_0198_0.png', 'CHNCXR_0199_0.png', 'CHNCXR_0200_0.png', 'CHNCXR_0201_0.png', 'CHNCXR_0202_0.png', 'CHNCXR_0203_0.png', 'CHNCXR_0204_0.png', 'CHNCXR_0205_0.png', 'CHNCXR_0206_0.png', 'CHNCXR_0207_0.png', 'CHNCXR_0208_0.png', 'CHNCXR_0209_0.png', 'CHNCXR_0210_0.png', 'CHNCXR_0211_0.png', 'CHNCXR_0212_0.png', 'CHNCXR_0213_0.png', 'CHNCXR_0214_0.png', 'CHNCXR_0215_0.png', 'CHNCXR_0216_0.png', 'CHNCXR_0217_0.png', 'CHNCXR_0218_0.png', 'CHNCXR_0219_0.png', 'CHNCXR_0220_0.png', 'CHNCXR_0336_1.png', 'CHNCXR_0341_1.png', 'CHNCXR_0342_1.png', 'CHNCXR_0343_1.png', 'CHNCXR_0344_1.png', 'CHNCXR_0345_1.png', 'CHNCXR_0346_1.png', 'CHNCXR_0347_1.png', 'CHNCXR_0348_1.png', 'CHNCXR_0349_1.png', 'CHNCXR_0350_1.png', 'CHNCXR_0351_1.png', 'CHNCXR_0352_1.png', 'CHNCXR_0353_1.png', 'CHNCXR_0354_1.png', 'CHNCXR_0355_1.png', 'CHNCXR_0356_1.png', 'CHNCXR_0357_1.png', 'CHNCXR_0358_1.png', 'CHNCXR_0359_1.png', 'CHNCXR_0360_1.png', 'CHNCXR_0481_1.png', 'CHNCXR_0482_1.png', 'CHNCXR_0483_1.png', 'CHNCXR_0484_1.png', 'CHNCXR_0485_1.png', 'CHNCXR_0486_1.png', 'CHNCXR_0487_1.png', 'CHNCXR_0488_1.png', 'CHNCXR_0489_1.png', 'CHNCXR_0490_1.png', 'CHNCXR_0491_1.png', 'CHNCXR_0492_1.png', 'CHNCXR_0493_1.png', 'CHNCXR_0494_1.png', 'CHNCXR_0495_1.png', 'CHNCXR_0496_1.png', 'CHNCXR_0497_1.png', 'CHNCXR_0498_1.png', 'CHNCXR_0499_1.png', 'CHNCXR_0500_1.png', 'CHNCXR_0502_1.png', 'CHNCXR_0505_1.png', 'CHNCXR_0560_1.png', 'CHNCXR_0561_1.png', 'CHNCXR_0562_1.png', 'CHNCXR_0563_1.png', 'CHNCXR_0564_1.png', 'CHNCXR_0565_1.png'}
NUM_OF_HPO_JOBS = 1
num_inputs = 1
gpus = False

# --- Get input files --------------------------------------------------------------
lung_img_dir = Path("inputs/train_images")
lung_mask_img_dir = Path("inputs/train_masks")


# --- Data Preprocessing function --------------------------------------------------
def train_test_val_split(preprocess, training_input_files, mask_files, processed_training_files, processed_val_files, processed_test_files, training_masks, val_masks, test_masks, num_inputs):
    np.random.seed(4)
    process_jobs = [Job(preprocess).add_args("--type", group) for group in ["train", "val", "test"]]
    augmented_masks = []


    # --- Write ReplicaCatalog -----------------------------------------------------
    rc = ReplicaCatalog()

    # add mask images to rc
    for f in LUNG_MASK_IMG_DIR.iterdir():
        if f.name.endswith(".png"):
            if f.name in IGNORE_IMAGES:
                continue
            
            mask_files.append(File(f.name))
            rc.add_replica(site="local", lfn=f.name, pfn=f.resolve())
    
    #add an empty(probably checkpoint file
    #checkpoint files  and results (empty one should be given if none exists)
    for fname in ["inputs/checkpoints/study_checkpoint.pkl", "bin/model/unet.py", "bin/model/utils.py"]:
        p = Path(__file__).parent.resolve() / fname
        if not p.exists():
            with open(p, "w") as dummyFile:
                dummyFile.write("")
        replicaFile = File(p.name)
        rc.add_replica(site="local", lfn=replicaFile, pfn=p)

    for f in LUNG_IMG_DIR.iterdir():
        if f.name.endswith(".png") and ("mask" not in f.name.lower()) and (f.name not in IGNORE_IMAGES):
            training_input_files.append(f)

    random.shuffle(training_input_files)
    l = len(training_input_files) if num_inputs == -1 else num_inputs
    print('Length ', l)

    i = 0
    for file in training_input_files:
        if i+1 <= 0.7*l:
            f = File("train_{}".format(file.name))
            rc.add_replica(site="local", lfn=f, pfn=file.resolve()) 

            process_jobs[0].add_inputs(f)
            log.info("preprocess_train adding input {}".format(f))
            op_file1 = File(f.lfn.replace(".png", "_norm.png"))
            op_file2 = File(f.lfn.replace(".png", "_0_norm.png"))
            op_file3 = File(f.lfn.replace(".png", "_1_norm.png"))
            op_mask2 = File(file.name.replace(".png", "_0_mask.png"))
            op_mask3 = File(file.name.replace(".png", "_1_mask.png"))

            for m in mask_files:
                mname = m.lfn[0:-9]
                if file.name[0:-4] == mname:
                    training_masks.append(m)
                    break

            process_jobs[0].add_outputs(op_file1, op_file2, op_file3, op_mask2, op_mask3)
            augmented_masks.extend([op_mask2, op_mask3])
            processed_training_files.extend([op_file1, op_file2, op_file3])

        elif i+1 <= 0.9*l:
            f = File("val_{}".format(file.name))
            rc.add_replica(site="local", lfn=f, pfn=file.resolve())

            process_jobs[1].add_inputs(f)
            log.info("preprocess_val adding input {}".format(f))
            op_file = File(f.lfn.replace(".png", "_norm.png"))
            for m in mask_files:
                mname = m.lfn[0:-9]
                if file.name[0:-4] == mname:
                    val_masks.append(m)
                    break
                    
            process_jobs[1].add_outputs(op_file)
            processed_val_files.append(op_file)

        else:
            f = File("test_{}".format(file.name))
            rc.add_replica(site="local", lfn=f, pfn=file.resolve())

            process_jobs[2].add_inputs(f)
            op_file = File(f.lfn.replace(".png", "_norm.png"))
            for m in mask_files:
                mname = m.lfn[0:-9]
                if file.name[0:-4] == mname:
                    test_masks.append(m)

            process_jobs[2].add_outputs(op_file)
            log.info("preprocess_test adding input {}".format(f))
            processed_test_files.append(op_file)

        i += 1

    log.info("writing rc with {} files collected from: {}".format(len(training_input_files)+len(mask_files), [LUNG_IMG_DIR, LUNG_MASK_IMG_DIR]))
    rc.write()
    process_jobs[0].add_inputs(*training_masks)
    training_masks.extend(augmented_masks)
    return process_jobs



# --- Write SiteCatalog --------------------------------------------------------
sc = SiteCatalog()
shared_scratch_dir = os.path.join(top_dir, "scratch")
local_storage_dir = os.path.join(top_dir, "output")

local = Site("local")\
            .add_directories(
                Directory(Directory.SHARED_SCRATCH, shared_scratch_dir)
                    .add_file_servers(FileServer("file://" + shared_scratch_dir, Operation.ALL)),
                Directory(Directory.LOCAL_STORAGE, local_storage_dir)
                    .add_file_servers(FileServer("file://" + local_storage_dir, Operation.ALL))
            )

condorpool = Site("condorpool")\
                .add_pegasus_profile(
                    style="condor",
                    data_configuration="condorio"
                )\
                .add_condor_profile(universe="vanilla")\
                .add_profiles(Namespace.PEGASUS, key="data.configuration", value="condorio")

sc.add_sites(local, condorpool)
sc.write()


# --- Write Properties ---------------------------------------------------------
props = Properties()
props["pegasus.mode"] = "development"
props.write()



# --- Write TransformationCatalog ----------------------------------------------
tc = TransformationCatalog()

# all jobs to be run in the following container
unet_wf_cont = Container(	
                "unet_wf_model",	
                Container.SINGULARITY,	
                    image="docker:///papajim/lung-segmentation:latest",
                    image_site="docker_hub"
            )

tc.add_containers(unet_wf_cont)

preprocess = Transformation(
                "preprocess",
                site="local",
                pfn=top_dir / "bin/preprocess/preprocess.py",
                is_stageable=True,
                container=unet_wf_cont
            )

unet = Transformation(
                "unet",
                site="local",
                pfn=top_dir / "bin/model/unet.py",
                is_stageable=True,
                container=unet_wf_cont
            )

utils = Transformation(
                "utils",
                site="local",
                pfn=top_dir / "bin/model/utils.py",
                is_stageable=True,
                container=unet_wf_cont
            )

hpo_task = Transformation( 
                "hpo",
                site="local",
                pfn=top_dir / "bin/model/hpo.py",
                is_stageable=True,
                container=unet_wf_cont
            ).add_pegasus_profile(cores=8, runtime=14400)
hpo_task.add_profiles(Namespace.CONDOR, key='request_memory', value='8 GB')


train_model = Transformation( 
                "train_model",
                site="local",
                pfn=top_dir / "bin/model/train_model.py",
                is_stageable=True,
                container=unet_wf_cont
            ).add_pegasus_profile(cores=8, runtime=7200)
train_model.add_profiles(Namespace.CONDOR, key='request_memory', value='8 GB')

predict_masks = Transformation( 
                "predict_masks",
                site="local",
                pfn=top_dir / "bin/model/prediction.py",
                is_stageable=True,
                container=unet_wf_cont
            ).add_pegasus_profile(cores=8, runtime=3600)


evaluate_model = Transformation( 
                "evaluate",
                site="local",
                pfn=top_dir / "bin/model/evaluate.py",
                is_stageable=True,
                container=unet_wf_cont
            )

if gpus:
    hpo_task.add_pegasus_profile(gpus=1)
    train_model.add_pegasus_profile(gpus=1)
    predict_masks.add_pegasus_profile(gpus=1)

tc.add_transformations(preprocess, hpo_task, train_model, predict_masks, evaluate_model, unet, utils)

log.info("writing tc with transformations: {}, containers: {}".format([k for k in tc.transformations], [k for k in tc.containers]))
tc.write()


# --- Generate and run Workflow ------------------------------------------------
wf = Workflow("lung-instance-segmentation-wf")

#create preprocess job
training_input_files = []
mask_files = []
processed_training_files = []
processed_val_files = []
processed_test_files = []
training_masks = []
val_masks = []
test_masks = []
process_jobs = train_test_val_split(preprocess, training_input_files, mask_files, processed_training_files, processed_val_files, processed_test_files, training_masks, val_masks, test_masks, num_inputs)
wf.add_jobs(*process_jobs)
log.info("generated 3 preprocess jobs")

# create hpo job
log.info("generating hpo job")
hpo_checkpoint_result = File(f"study_checkpoint.pkl")
study_result_list = []
unet_file = File("unet.py")
study_result = File("study_results.txt")
study_result_list.append(study_result)
hpo_job = Job(hpo_task)\
            .add_args("--results_file", study_result)\
            .add_inputs(*processed_training_files, *processed_val_files, *training_masks, *val_masks, unet_file)\
            .add_outputs(study_result)\
            .add_checkpoint(hpo_checkpoint_result)
wf.add_jobs(hpo_job)

# create training job
log.info("generating train_model job")
model = File("model.h5")
utils_file = File("utils.py")
train_job = Job(train_model)\
                .add_args("--params_file", study_result_list[0])\
                .add_inputs(study_result_list[0], *processed_training_files, *processed_val_files, *training_masks, *val_masks, unet_file, utils_file)\
                .add_outputs(model)
wf.add_jobs(train_job)

# create mask prediction job
log.info("generating prediction job; using {} test lung images".format(len(processed_test_files)))
predicted_masks = [File("pred_"+f.lfn.replace(".png", "_mask.png")[5:]) for f in processed_test_files]
predict_job = Job(predict_masks)\
                .add_inputs(model, *processed_test_files, unet_file)\
                .add_outputs(*predicted_masks)
wf.add_jobs(predict_job)

# create evalute job
pdf_analysis = File("EvaluationAnalysis.pdf")
evaluate_job = Job(evaluate_model)\
                .add_inputs(*processed_training_files, *processed_test_files, *predicted_masks, *test_masks, unet_file)\
                .add_outputs(pdf_analysis)
wf.add_jobs(evaluate_job)


## 2. Plan and Submit the Workflow

We will now plan and submit the workflow for execution. By default we are running jobs on site **condorpool** i.e the selected ACCESS resource.

In [None]:
wf.plan(submit=True, dir="runs", sites=["condorpool"], output_sites=["local"])

After the workflow has been successfully planned and submitted, you can use the Python `Workflow` object in order to monitor the status of the workflow. It shows in detail the counts of jobs of each status and also the whether the job is idle or running.

In [None]:
wf.status()

## 3.  Launch Pilots Jobs on ACCESS resources

At this point you should have some idle jobs in the queue. They are idle because there are no resources yet to execute on. Resources can be brought in with the HTCondor Annex tool, by sending pilot jobs (also called glideins) to the ACCESS resource providers. These pilots have the following properties:

A pilot can run multiple user jobs - it stays active until no more user jobs are available or until end of life has been reached, whichever comes first.

A pilot is partitionable - job slots will dynamically be created based on the resource requirements in the user jobs. This means you can fit multiple user jobs on a compute node at the same time.

A pilot will only run jobs for the user who started it.

The process of starting pilots is described in the [ACCESS Pegasus Documentation](https://xsedetoaccess.ccs.uky.edu/confluence/redirect/ACCESS+Pegasus.html)

## 4. Statistics

Depending on if the workflow finished successfully or not, you have options on what to do next. If the workflow failed you can use `wf.analyze()` do get help finding out what went wrong. If the workflow finished successfully, we can pull out some statistcs from the provenance database:

In [None]:
wf.statistics()