# Jarvis Data `Client()`

The `Client()` object as part of the `jarvis.train.client` module provides a simple yet powerful interface for loading and preprocessing data in the context of neural network training. In this notebook we will explore the following functionality:

* creating `Client()` objects and generators
* customizing `Client()` via `*.yml` files
* customizing `Client()` via class overload

# Environment

### Jarvis library

Jarvis is a custom Python package to facilitate data science and deep learning for healthcare. Let us begin by installing the Jarvis library and required dependencies:

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any needed libraries:

In [None]:
import os
from jarvis.train.client import Client
from jarvis.train import datasets
from jarvis.utils.general import *

# Data

The data used in this tutorial will consist of a small toy cohort of non-contrast CT images of the head and corresponding whole brain segmentation masks.

### Download

The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/ct_bet_demo`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download data
paths = datasets.download(name='ct/bet-demo')

While the location of downloaded data is not required for baseline functionality, manipulation of the underlying data structures and configuration files may be needed for implementing customized training. As a result it may be useful to note this location for future reference. All code and key configuration files are located relative to a project `code` root directory. This directory can be retrieved in one of two methods:

1. The return of a `datasets.download(...)` call with always yield a Python dictionary with two entries:

* `paths['code']`: root directory of all Python code and configuration files
* `paths['data']`: root directory of all raw data

Note that by default, these two directories are identical and placed in `/data/raw/[datatset_name]`.

2. Alternatively, the paths can be recovered using the `jarvis.utils.general.tools` module:

In [None]:
# --- Retrieve paths using Jarvis tools
from jarvis.utils.general import tools as jtools

paths = jtools.get_paths('ct/bet-demo')

Within the `paths['code']` directory, a standard folder hierachy can be typically found:

```
| --- data /
      | --- csvs /
          | --- db-all.csv.gz
          ...
      | --- ymls /
          | --- db-all.yml
          | --- client.yml
          ...
```

Keep a note of this structure and reference this figure as needed in the rest of this tutorial. As we will be manipulating the `client.yml` file throughout the course of this tutorial, let us save a reference to this location now:

In [None]:
# --- Set CLIENT_PATH for tutorial
CLIENT_PATH = '{}/data/ymls/client.yml'.format(paths['code'])
assert os.path.exists(CLIENT_PATH)

# Training

The primary function of the Jarvis data `Client()` is to provide a seamless interface for loading heterogenous, multidimensional healthcare datasets for machine learning training. The key consideration is that for complex datasets, the *method* by which one chooses to load data (single-slice, multiple slices, 3D, patch-based, etc) is in fact a key hyperparameter that will need to be tested empirically for best results. In addition, various permuations in data normalization, augmentation and prediction targets (binary vs. multiclass vs. multiple loss functons, etc) may need to be optimized during experimentation.

The Jarvis `Client()` aims to consolidate these various options into a simple API that can be customized via a single `configs` Python dictionary or through a client `YAML` file. After defining the required parameters, an instantiatied `Client()` object can be used to then create training and validation Python generators that can be used agnostically in a number of machine learning libraries including Tensorflow / Keras.  

## Creating a `Client()` object

There are two ways of creating a `Client()` object. The first and simplest method is to use the `datasets.prepare(name...=, keyword=...)` function. The `name` argument references the dataset name (identical to the name used for data download). The `keyword` argument (optional) can be used to specify a particular predefined client configuration (otherwise the first `client*.yml` file in the `/data/ymls/` directory will be chosen). 

In [None]:
# --- Creating client / generators - option (1)
gen_train, gen_valid, client = datasets.prepare('ct/bet-demo')

This method is convenient, and in fact creates the required Python generators needed for algorithm training, however the user will have somewhat limited control of the `Client()` object itself. For greater flexibility, consider a two-step approach whereby the client is first manually defined by pointing to a specific `*.yml` file, and then creating the required Python generators in a second step.

In [None]:
# --- Create client / generators - option (2)
client = Client(CLIENT_PATH)
gen_train, gen_valid = client.create_generators()

## Python generators

At this point, the generic Python generators above will yield training and validation data that can be used to plug into standard machine learning library APIs. In Tensorflow / Keras, use the `model.fit(...)` function:

```python
model.fit(
    x=gen_train,
    validation_data=gen_valid
    ...)
```

Note that any number of valid Python iterator techniques can be used to access the data manually:

In [None]:
# --- Load next training batch
xs, ys = next(gen_train)

## `xs` and `ys` 

Each iteration through the `Client()` created Python generator yields two dictionaries, `xs` and `ys`, which conform to the Tensorflow / Keras API for model input(s) and output(s). This generic approach of yielding Python dictionary variables allows for flexible definition of one or more input(s) and output(s), as long as the same naming conventions are used by the `Client()` Python generator and Tensorflow / Keras 2.0 `Model()`. This schematic will highlight the key considerations:

```python
from tensorflow.keras import Input, Model, layers

# --- Define model

Client()             Model()

xs.keys()            inputs = {}
* 'dat_1'       ==>  inputs['dat_1'] = Input(..., name='dat_1')
* 'dat_2'       ==>  inputs['dat_2'] = Input(..., name='dat_2')
* ...

ys.keys()            logits = {}
* 'lbl_1'       ==>  logits['lbl_1'] = layers.Conv3D(..., name='lbl_1')(...)
* 'lbl_2'       ==>  logits['lbl_2'] = layers.Conv3D(..., name='lbl_2')(...)
* ...

# --- Instantiate model

model = Model(inputs=inputs, logits=logits)

# --- Compile model

model.compile(
    optimzer=...,
    loss={
        'lbl_1': ..., 
        'lbl_2': ...},
    metrics={
        'lbl_1': ...,
        'lbl_2': ...})

```

To summarize,

1. All keys in `xs.keys()` must:

* match a correponding key in the `inputs` dict that is eventually passed to `Model(inputs=inputs, ...)`
* match a `name` of the Keras input tensor defined via `Input(..., name=name)`

2. All keys in `ys.keys()` must:

* match a correponding key in the `logits` dict that is eventually passed to `Model(..., logits=logits)`
* match a `name` of the Keras output logit tensor defined via `layers.__(..., name=name)(...)`
* match a corresponding key in the `loss` and `metrics` arguments during the `model.compile(...)` call

### Testing

For algorithm testing, it may be necessary to load data in a different manner than during training. For example, even if single 2D slices from a 3D volume are used during training (or smaller patches from a large single 2D image), it will be important to perform algorithm testing of the full volume. To create dedicated Python generators in *test mode* (full 3D matrix shape), pass the `test=True` argument into the `client.create_generators(...)` method:

In [None]:
# --- Create test generators
test_train, test_valid = client.create_generators(test=True)

Note that in most scenarios, the `test_valid` generator will load the primary cohort needed for testing (e.g. validation split). The `test_train` generator can be used to iterate through training cohort itself if comparisons between train / valid performance are desired.

# Configuration

The configuration details for a `Client()` object are defined in a client `*.yml` file (the same file that is manually based when to the `Client()` constructor when instantiating a new object):

In [None]:
# --- Instantiate a new Client
client = Client(CLIENT_PATH)

The following shows the contents of the `client.yml` file in this example

```yml
_id:
  project: ct/bet-demo
  version: null
_db: /data/ymls/db-sum.yml
batch:
  fold: -1
  sampling:
    bg: 0.5
    fg: 0.5
  size: 8
specs:
  xs:
    dat:
      dtype: float32
      loads: dat
      norms:
        clip:
          max: 256
          min: 0
        scale: 64
        shift: 32
      shape:
      - 1
      - 512
      - 512
      - 1
    msk:
      dtype: float32
      loads: msk
      norms:
        clip:
          max: 1
          min: 0
      shape:
      - 1
      - 512
      - 512
      - 1
  ys:
    bet:
      dtype: uint8
      loads: bet
      norms:
        clip:
          max: 1
      shape:
      - 1
      - 512
      - 512
      - 1
```

The `_id` and `_db` fields are protected attributes that should not be changed. For more information see corresponding tutorial for creating `Client()` objects. In this tutorial, we will focus on customizing the `batch` and `specs` attributes.

## `Client()` Data

In general a total of `N`-number of training examples are defined for each cohort, where a *training example* is represented by the minimal single unit used for training (e.g. a single 2D slice, multi-slice slab, 2D patch, etc). Each unit of training is specified as a single row in a master `*.csv(.gz)` file (see folder hierarchy above). The data itself is stored as two attributes (Pandas `DataFrame`) within the `client.db` variable:

* `client.db.fnames`: all relevant file names for serialized data
* `client.db.header`: all other relevant data in addition to file names

In [None]:
# --- Pandas DataFrame for serialized data file fnames
client.db.fnames

In [None]:
# --- Pandas DataFrame for all other non file name data
client.db.header

Note that file names are often repeated in this representation as each **row** represents a single 2D slice; since many 2D slices are used a training (e.g. each row represents one *unit* of training), each file containing a full 3D volume will be repeated. To see the full combined data from any given single row, use the `client.db.row(...)` method:

In [None]:
# --- Show specification for first row of data
client.db.row(0)

Note that this is the raw metadata used to load data by the `Client()` object.

## Batch Composition

As above, each row of data represents one *unit* of training. To create a training batch from these individual examples, a number of variables are defined in the `batch` attribute of the `Client()` `YAML` file, which together specify the composition of each data batch used for algorithm training. 

```yml
batch:
  fold: -1
  sampling:
    fg: 0.5
    bg: 0.5 
```

The composition (source) of each data batch used for network training derives from two primary factors:

* cross validation fold for current experiment
* stratified sampling (if any)

### Validation fold

For a given experiment, the specified `fold` represents the data split to use for validation. For example if fold is set to 0, then all rows with a value of 0 in the `valid` column will be used for validation. Set the fold to -1 in order to use **all data** for both training and validation (e.g. no validation). By default, the dataset is split into five validation folds. To perform `N`-fold cross validation, repeat each experiment a total of `N` times where the fold is set to an integer between `0` and `N-1`. 

### Stratified sampling

For certain experiments, it is critical for specific rows to be used during training at a higher or lower frequency than others. This technique is known as stratified sampling. A common example is to oversample from the images containing a positive disease finding (e.g. pathology tends to be rare) so that the distribution of positive and negative training examples is balanced.

The syntax to define stratified sampling is a series of key-value pairs such that:

* key: name of header column, containing a boolean vector, that defines rows which are part of a specific cohort
* value: rate between `[0, 1]` for which to sample from this specific cohort

All values in total must add to 1.0.

In the example above, the `fg` column contains a value of True or False for each row depending on whether or not a positive mask value is present in the corresponding label. The `bg` column contains the opposite information. 

### Batch size

Use this field to specify the number of training examples (batch size) used for each iteration.

## Data Specifications

```yml
specs:
  xs:
    dat:
      dtype: float32
      loads: dat
      norms:
        clip: 
          min: 0
          max: 256
        shift: 64
        scale: 64
      shape:
      - 1
      - 512
      - 512
      - 1
  ...
```

The `specs` entry in the configuration file defines the type, shape and origin of data to be loaded, as well as any necessary normalization operations to perform on the loaded data. The two primary entries in `specs` are `xs` and `ys` which represent training input data and output label(s), respectively. For each entry `xs` and `ys`, one or many individual volumes may be defined depending on architecture requirements.

For each individual volume, the following parameters must be defined:

* `dtype`: str representing the data type (often `float32` for input data, `uint8` for output label)
* `shape`: 4D shape of input (Z x Y x X x channel); note channel often == 1; for 2D data, Z == 1
* `loads`: column (if any) to load data from (see below for more details)
* `norms`: normalization parameter (see below for more details); use `null` keyword to indicate no normalization

### Data loading

The `loads` entry determines the column name from which to populate data for this input or output variable. If the column is part of the `fnames` DataFrame then the corresponding file will be loaded (at the slice location specified by `coord` if the data is 3D or 4D). If the column is part of the `header` DataFrame then the raw value will be coverted to a Numpy array.

If no corresponding data should be loaded, use the keyword `null`. In this scenario, the corresponding array should be dynamically defined in an overloaded `client.preprocess(...)` method.

### Normalization parameters

A total of up to three parameters can be used to define a data normalization strategy:

* clip (includes min and/or max)
* shift
* scale

These parameters are implemented in the following normalization formula:

```
arr = (arr.clip(min=..., max=...) - shift) / scale
```

There are three ways to define these normalization parameters:

1. **Constant value (integer or float)**. Use a literal value if you know that the required parameter value is constant for all data. This is most commonly used for CT imaging data (voxel values are precalibrated as Hounsfeld Units).

```yml
norms:
  clip: 
    min: 0
    max: 256
  shift: 64
  scale: 64
```

2. **Numpy function (@ keyword)**. Use a keyword prefixed by `@` to represent valid Numpy function (e.g. `np.(...)`) to be applied dynamically to each input image upon load. Common usage here includes the `@mean` and `@std` functions to implement a simple z-score transformation. *Important*: these methods should be used only if the input image is guaranteed to provide valid return(s); as a counterexample, the standard deviation on a uniform 2D image (seen at the top and bottom of 3D volumes) will be undefined. Thus this should rarely be used for 3D or 4D volumes (see option 3 below instead).

```yml
norms: 
  shift: @mean
  scale: @std
```

3. **Column name**. Use a regular str (no @ prefix) to indicate a column name containing any custom normalization parameters. For 3D volumes, a common strategy is to normalize each slice by the mean and standard deviation of the entire volume. However, since loading the entire 3D volume for each slice during training is inefficient, volume statistics can instead be stored in each row of the `*.csv` file and simply referenced for preprocessing. A z-score normalization implemented using this technique is the recommended approach for MR imaging data.

```yml
norms:
  shift: mu
  scale: sd
```

### Mapping

As an alternative to the above, occasionally the data needs to be altered via a mapping of current to target values. This is most commonly performed to transform model outputs. For example, if the serialized label contains a total of 5 values (0, 1, 2, 3, 4) and you wish to binarize the output into class 0 (0, 2, 4) and class 1 (1, 3), the following mapping can be specified:

```yml
norms:
  mapping:
    0: 0
    1: 1
    2: 0
    3: 1
    4: 0
```

# Overloading `Client()`

While the majority of common data loading permutations can be accomplished by simple modifications to the client `YAML` file, occasionally additional Python code will be needed to enhance functionality. This is especially true if you wish to perform custom data augmentation, or you need to manually create a custom mask array for modifying / weighting the loss function.

To accomplish these tasks, the easiest strategy is to overload the `preprocess(...)` method of the `Client()` class. The method allows a user to insert any arbitrary number of modifications to the loaded dataset prior to release (yielding) by the Python generator.

To do so, either create a new `Client` class via object inheritance, or use the Jarvis `@overload` decorator:

```python
from jarvis.train.client import Client
from jarvis.utils.general import overload

@overload(Client)
def preprocess(self, arrays, **kwargs):
    ...
    return arrays
```

Note that the signature for the `preprocess(...)` method includes a single Python dictionary, `arrays`, which consists of a nested dictionary containing both `xs` and `ys` variables, as well as additional metadata in the `kwargs` argument as needed. Within this method, **any** number of modifications to the `arrays` variable may be performed. Most commonly this will include code to modify the input(s) and/or output(s) pixel values, or to generate a custom training mask for modifying the loss function. Some additional special use cases are described below.

## Creating a new array

To create a new array such a `msk` variable that is not already defined in the current `Client()` configuration, consider the following workflow:

1. Update the client `YAML` file to generate the new array under `specs`. As above, the `loads: null` attribute can be specified such that an empty array is generated (does not *load* any existing column of data).

```yml
specs:
  xs:
    dat: ... (existing from above) ...
    msk:
      dtype: float32
      loads: null
      norms: null
      shape:
      - 1
      - 512
      - 512
      - 1
  ...
```

2. Overload the `Client`, ensuring to populate `arrays['msk']` with the appropriate data

```python
@overload(Client)
def preprocess(self, arrays, **kwargs):
    
    # --- Create arrays['msk']
    arrays['msk'] = arrays['lbl'] > 0
    
    return arrays
```

Note that this approach is preferred, as the correct Tensorflow / Keras `Input(...)` will be automatically created when invoking the `client.get_inputs(...)` method. 

If the `yml` if is **not** updated, then `client.get_inputs(...)` will only yield the inputs that are currently defined. Thus, the new `msk` variable will need to be manually created:

```python
# --- Create existing inputs
from tensorflow.keras import Input
inputs = client.get_inputs(Input)

# --- Add additional dictionary entry for new msk
inputs['msk'] = Input(shape=..., name='msk')
```

## Modifying array shape

Occasionally the input shape loaded by the `Client()` is not the final shape you wish to use in the model training. For example, you may want to load a multi-slice input, peform some sort of data augmentation (including oblique projection out of the currently z-axis plane), and then extract out just the middle slice for training. To accomplish this, consider the following workflow:

1. Update the client `YAML` file to indicate that the final `input` shape is different than the shape of the `saved` data:

```yml
specs:
  xs:
    dat:
      dtype: float32
      loads: dat
      norms: ... (from above) ...
      shape:
        input: [1, 512, 512, 1]
        saved: [3, 512, 512, 1]
```

2. Overload the `Client`, ensuring to populate `arrays['dat']` with the appropriate data shape

```python
@overload(Client)
def preprocess(self, arrays, **kwargs):
    
    # --- Perform some sort of modification to arrays['dat']
    pass

    # --- Finalize arrays['dat'] from 3-slices to 1-slice
    arrays['dat'] = arrays['dat'][1:-1]
    
    return arrays
```

Note that it is critical to overload the `Client` when specifying different `input` and `saved` shapes, otherwise the data yielded by the Python generator will not match the `Input` shape expected by the Tensorflow / Keras model.