In [None]:
import json
import warnings
from itertools import chain

import joblib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
def parse(data):
    """Parse patient data from raw JSON data"""
    patients = []
    records = []

    for pid, data_patient in data.items():
        patients.append({
            'pid': pid,
            'birth': data_patient['BirthDay'],
            'death': data_patient['DeathDay'],
            'sex': data_patient['sex'],
        })

        for disease, date_date in data_patient['ICD'].items():
            # merge dates of different types of records
            for date in chain.from_iterable(date_date.values()):
                records.append({
                    'pid': pid,
                    'date': date,
                    'disease': disease,
                })

    patients = pd.DataFrame(patients).set_index('pid')
    records = pd.DataFrame(records)
    return patients, records

In [None]:
# parse the data of ALS (335.20) patients
with open('data/patients-icd9cm33520.json') as stream:
    data = json.load(stream)
    patients, records = parse(data)
    patients['ALS'] = True

# parse the data of non-ALS patients
with open('data/patients-icd9cm335-without.json') as stream:
    data = json.load(stream)
    patients_without, records_without = parse(data)
    patients_without['ALS'] = False

In [None]:
patients = pd.concat([patients, patients_without])
patients['birth'] = pd.to_datetime(patients['birth'], format='%Y%m')
patients['death'] = pd.to_datetime(patients['death'], format='%Y%m%d')
patients['sex'].replace({'M': 'male', 'F': 'female'}, inplace=True)

In [None]:
records = pd.concat([records, records_without])
records['date'] = pd.to_datetime(records['date'], format='%Y%m%d')
# make sure all the ICD codes are 3-digits long
records['disease'] = records['disease'].str.pad(3, fillchar='0')

In [None]:
def age(birth, date=pd.Timestamp(year=2013, month=12, day=31)):
    """Calculate the age of a patient"""
    return date.year - birth.year - (
        (date.month, date.day) < (birth.month, birth.day))


def show_histogram(patients):
    """Show the histogram of sex and gender for patients"""
    data = patients.copy()
    data['ALS'] = data['ALS'].map({True: 'ALS', False: 'Non-ALS'})
    data['Sex'] = data['sex'].str.capitalize()
    data['Age'] = data['birth'].apply(age)

    plt.subplot(2, 1, 1)
    plt.title('Histogram of Gender (Sex)')
    sns.histplot(data, x='Sex', hue='ALS', legend=True)
    plt.xlabel(None)

    plt.subplot(2, 1, 2)
    plt.title('Histogram of Age')
    sns.histplot(data, x='Age', hue='ALS', legend=True)
    plt.xlabel(None)

In [None]:
show_histogram(patients)

In [None]:
# number of control case for each ALS patient
control_ratio = 4
# the threshold of birth difference to trigger a warning
warning_delta = pd.Timedelta(days=1000)

# for each ALS patient, select a control case greedily,
# which have the same gender and closest birth date
controls = []
is_als = patients['ALS']
candidates = patients[~is_als].copy()
for index, patient in patients[is_als].iterrows():
    sub_candidates = candidates[candidates['sex'] == patient['sex']].copy()
    sub_candidates['delta'] = (
        sub_candidates['birth'] - patient['birth']
    ).abs()
    sub_conrols = sub_candidates.sort_values('delta').head(control_ratio)

    for delta in sub_conrols['delta']:
        if delta > warning_delta:
            warnings.warn(f'delta is too large ({delta=})')

    controls.append(sub_conrols)
    candidates.drop(sub_conrols.index, inplace=True)

controls = pd.concat(controls)

In [None]:
patients = pd.concat([patients[is_als], controls])
patients.to_pickle('caches/patients.pkl')
show_histogram(patients)

records = records[records['pid'].isin(patients.index)]
records.to_pickle('caches/records.pkl')

plt.savefig('figures/patients-histogram.png', dpi=300)

In [None]:
ALS_patients = patients[patients['ALS']].index
ALS_records = records[records['pid'].isin(ALS_patients)]

dummies = pd.get_dummies(ALS_records['disease'])

# count once per patient per day
dummies = dummies.join(ALS_records[['pid', 'date']])
dummies = dummies.groupby(['pid', 'date']).any()
dummies

In [None]:
# sort the patients and dates
dummies.sort_index(inplace=True)
# generate disease cumulative dataset
cumulations = dummies.groupby('pid').cumsum()

# check if all records are increasing
increasing = (
    cumulations - cumulations.groupby('pid').shift(1, fill_value=0)
) > 0
assert increasing.any(axis='columns').all()

cumulations.to_pickle('caches/cumulations.pkl')

cumulations

In [None]:
transformer = Pipeline([
    ('scaler', StandardScaler()),
    ('PCA', PCA()),
])
transformer

In [None]:
transformer.fit(cumulations)
explained_ratios = pd.Series(transformer['PCA'].explained_variance_ratio_)
explained_ratios.index += 1  # number of components starts from 1

figure, ax = plt.subplots()
# ax.set_title('Elbow Plot for PCA')
ax.set_xlabel('Number of Components')
ax.set_ylabel('Explained Variance Ratio')
p1, = ax.plot(explained_ratios, label='Explained Variance Ratio')
ax.grid(axis='x')

twin = ax.twinx()
twin.set_ylabel('Cumulative Explained Variance Ratio')
p2, = twin.plot(
    explained_ratios.cumsum(), color='orange', linestyle='--',
    label='Cumulative Explained Variance Ratio'
)
twin.grid(axis='y')

plt.legend(handles=[p1, p2], loc='center right')
plt.xlim(0, 100)

plt.savefig('figures/PCA-elbow.png', dpi=300)

In [None]:
n_components = 60

transformer['PCA'].set_params(n_components=n_components)
vectors = transformer.fit_transform(cumulations)
vectors = pd.DataFrame(vectors, index=cumulations.index)
display(vectors)

joblib.dump(transformer, 'caches/transformer.joblib')
vectors.to_pickle('caches/vectors.pkl')

total = transformer['PCA'].explained_variance_ratio_.sum()
print(f'{total=:.2%} ({n_components=})')