 <a name="top"></a>
# Understanding how medGAN works

Author: [Sylvain Combettes](https://github.com/sylvaincom).

Edward Choi's original repository: [medgan](https://github.com/mp2893/medgan). <br/>
My own medGAN repository: [medgan](https://github.com/sylvaincom/medgan).

The final goal of my project is to use medGAN on my own dataset. For that, I first need to understand how medGAN works. In this notebook, I provide a few explanations that can help better understand medGAN. Because there are some confidentiality issues with the MIMIC-III dataset, I cleared the output. 

Before reading this notebook, be sure to have read [A few additional tips on how to run Edward Choi's medGAN
](https://github.com/sylvaincom/medgan/blob/master/tips-for-medgan.md).

---
### Tables of Contents

- [Loading the MIMIC-III dataset](#load-mimic)
- [How to interpret `gen-samples.npy`?](#gen-samples)
- [Comparing the (fake) generated samples to the real-life original ones](#comparison)

---
### Imports

In [None]:
import numpy as np
import os
import pickle
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.transforms as mtransforms

---
# Loading the MIMIC-III dataset <a name="load-mimic"></a>

## `ADMISSIONS.csv` file

In [None]:
adm_data = pd.read_csv("ADMISSIONS.csv")
print(adm_data.shape)
adm_data.head()

In [None]:
n,p = adm_data.shape
for f in adm_data:
    print('Missing values in {}: {}%'.format(f, sum(adm_data[f].isna())*100/n))

We have a lot of missing values.

## `DIAGNOSES_ICD.csv` file

In [None]:
diag_data = pd.read_csv('DIAGNOSES_ICD.csv')
print(diag_data.shape)
diag_data.head()

In [None]:
n,p = diag_data.shape
for f in diag_data:
    print('Missing values in {}: {}%'.format(f, sum(diag_data[f].isna())*100/n))

We have very few missing values.

Does one `ICD9_CODE` appear distinctly more than others in proportion? No; our dqtqset is balanced:

In [None]:
diag_data['ICD9_CODE'].value_counts(normalize=True).head()

---
# How to interpret `gen-samples.npy`? <a name="gen-samples"></a>

We load the `gen-samples.csv` which is `medgan.py`'s output converted in a csv file.

In [None]:
gen_data = pd.read_csv("gen-samples.csv", sep=';')
print(gen_data.shape)
gen_data.head(10)

Some questions about this data frame:
* What do the columns correspond to? They do not look like `ADMISSIONS.csv` nor `DIAGNOSIS_ICD.csv`.
* What do the rows correspond to?
* Why are the values not binary?
* Why is one row out of two composed only of missing values (`NaN`)?

We can find some answers in an issue opened in Edward Choi's GitHub: [How to interpret the samples?](https://github.com/mp2893/medgan/issues/3). Actually, as in the `.matrix` file, each row corresponds to a single synthetic patient and each column corresponds to a specific ICD9 diagnosis code. We can use the `.types` file created by `process_mimic.py` to map each column to a specific ICD9 diagnosis code. Read the beginning part of the source code of `process_mimic.py` for more information about `.types` file:
```python
# Output files
# <output file>.pids: cPickled Python list of unique Patient IDs. Used for intermediate processing
# <output file>.matrix: Numpy float32 matrix. Each row corresponds to a patient. Each column corresponds to a ICD9 diagnosis code.
# <output file>.types: cPickled Python dictionary that maps string diagnosis codes to integer diagnosis codes.
```

What is ICD-9? See [ICD-9](https://en.wikipedia.org/wiki/International_Statistical_Classification_of_Diseases_and_Related_Health_Problems#ICD-9) and [List of ICD-9 codes](https://en.wikipedia.org/wiki/List_of_ICD-9_codes).

We need to round the values ourselves:

In [None]:
gen_data = gen_data.round(0)
gen_data.head()

We claim that we should delete the rows with missing values:

In [None]:
gen_data = pd.DataFrame.dropna(gen_data)
print(gen_data.shape)
gen_data.head()

In [None]:
pd.DataFrame.describe(gen_data)

## `.types` file

_cPickled Python dictionary that maps string diagnosis codes to integer diagnosis codes._

In [None]:
map_dict = pickle.load(open('training-data.types', 'rb'))
# print('An excerpt of the `mapping` dictionary is:', dict(list(mapping.items())[:10]))
print(type(map_dict))
map_dict

In [None]:
map_pd = pd.DataFrame(list(map_dict.items()))
print(map_pd.shape)
map_pd.head(10)

Thus, as its name suggests, `process_mimic.py` is really dependent on the MIMIC-III dataset. We probably will not use `process_mimic.py` on our own dataset and only run `medgan.py`. Out of `process_mimic.py`, we only need to understand how the generated `.matrix` file is constructed (lines 109 to 119).

## `.matrix` file

_Numpy float32 matrix. Each row corresponds to a patient. Each column corresponds to a ICD9 diagnosis code._

In [None]:
input_data_array = pickle.load(open('training-data.matrix', 'rb'))
# print('An excerpt of the `mapping` dictionary is:', dict(list(mapping.items())[:10]))
print(type(input_data_array))
input_data_array

In [None]:
input_data_pd = pd.DataFrame(input_data_array)
print(input_data_pd.shape)
input_data_pd.head(10)

As we chose, the input data is binary. 

We can note that the input of `medgan` and the [output](#gen-samples) of `medgan` have the same number of columns and the values have the same type (binary). Thus, `gen-samples.npy` is a (fake) realistic generated dataset corresponding to the `.matrix` file.

## `.pids` file

_cPickled Python list of unique Patient IDs. Used for intermediate processing_

In [None]:
id_list = pickle.load(open('training-data.pids', 'rb'))
# print('An excerpt of the `mapping` dictionary is:', dict(list(mapping.items())[:10]))
print(type(id_list))
id_list

In [None]:
id_pd = pd.DataFrame(id_list)
print(id_pd.shape)
id_pd.head(10)

---
# Comparing the (fake) generated samples to the real-life original ones  <a name="comparison"></a>

In this section, we wish to compare the accuracy of the (fake) generated dataset considering the original one. As in Choi's paper, we use dimension-wise probability.

## Probability distribution of input data

In [None]:
n, p = input_data_pd.shape
print(n, p)

In [None]:
input_prop = input_data_pd.sum()
input_prop_list = input_prop.tolist()

plt.plot(input_prop_list)
plt.xlabel('Index of variable')
plt.ylabel('Frequency of 1')
plt.title('input_data_pd')
plt.show()

In [None]:
# Approximate probability of 1 for the j-th variable (j between 0 and 1070):
j = 0
print(round(sum(input_data_pd[j])*100/n,2), '%')

In [None]:
proba_input = [sum(input_data_pd[j])/n for j in range(p)]

For a feature (dimension), we assume that the proportion of `1` is the Bernoulli success probability _p_.

In [None]:
print(proba_input[0:10])

plt.plot(proba_input)
plt.xlabel('Index of variable')
plt.ylabel('Bernoulli probability success')
plt.title('input_data_pd')
plt.show()

## Probability distribution of output data

In [None]:
n, p = gen_data.shape
print(n, p)

In [None]:
gen_data_prop_list = gen_data.sum().tolist()

plt.plot(gen_data_prop_list)
plt.xlabel('Index of variable')
plt.ylabel('Frequency of 1')
plt.title('gen_data')
plt.show()

In [None]:
proba_output = [sum(gen_data[f])/n for f in list(gen_data)]

In [None]:
print(proba_output[0:10])

plt.plot(proba_output)
plt.xlabel('Index of variable')
plt.ylabel('Bernoulli probability success')
plt.title('gen_data')
plt.show()

## Comparison: dimension-wise probability

In [None]:
fig, ax = plt.subplots()
ax.scatter(proba_input, proba_output, c='black', label='Bernoulli success probability')
line = mlines.Line2D([0, 1], [0, 1], color='red')
transform = ax.transAxes
line.set_transform(transform)
ax.add_line(line)

plt.title('dimension-wise probability performance of medGAN')
plt.xlabel('for the real data')
plt.ylabel('for the (fake) generated data')
plt.legend()
plt.show()

The diagonal red line indicates the ideal performance where the real and the (fake) realistic generated data show identical quality. Based on th eprevious graph, we can say that medGAN has a really good performance.

Back to [top](#top).