In [None]:
from itertools import pairwise

import joblib
import networkx as nx
import pandas as pd
from icd9cms.icd9 import search
from scipy import stats

In [None]:
def contain(contrainer, target):
    """Check if target is contained in the contrainer."""
    return target in contrainer

In [None]:
als_code = '335'
confidence_level = 0.95

patients = pd.read_pickle('caches/patients.pkl')
records = pd.read_pickle('caches/records.pkl')

comorbidities = records.groupby('pid')['disease'].apply(set)
odds_ratios = {}
for disease in records['disease'].unique():
    exposed = comorbidities.apply(contain, target=disease)
    cases = comorbidities.apply(contain, target=als_code)
    table = pd.crosstab(exposed, cases)

    odds_ratio = stats.contingency.odds_ratio(table)
    interval = odds_ratio.confidence_interval(confidence_level)
    odds_ratios[disease] = {
        'odds_ratio': odds_ratio.statistic,
        'low': interval.low,
        'high': interval.high,
    }

odds_ratios = pd.DataFrame.from_dict(odds_ratios, orient='index')
odds_ratios

In [None]:
def desc(code):
    """Get the description of the code."""
    return search(code).short_desc


def parent(code):
    """Get the parent code of the code."""
    return search(code).parent.short_desc

In [None]:
min_odds_ratio = 2
min_odds_ratio_low = 1

associated_diseases = odds_ratios[
    (odds_ratios['odds_ratio'] >= min_odds_ratio)
    & (odds_ratios['low'] >= min_odds_ratio_low)
].index
odds_ratios['class'] = list(map(parent, odds_ratios.index))
odds_ratios['name'] = list(map(desc, odds_ratios.index))
display(associated_diseases)

odds_ratios = odds_ratios.loc[associated_diseases]
odds_ratios.sort_values('odds_ratio', ascending=False, inplace=True)

odds_ratios['confidence_interval'] = odds_ratios.apply(
    lambda row: f'{row["low"]:.2f}-{row["high"]:.2f}', axis=1)
odds_ratios.to_csv('tables/odds-ratios.csv')

In [None]:
min_difference = 1

tree = joblib.load('caches/tree.joblib')
transformer = joblib.load('caches/transformer.joblib')

for source, target in tree.edges:
    positions = [
        tree.nodes[source]['position'],
        tree.nodes[target]['position'],
    ]
    cumulations = pd.DataFrame(
        transformer.inverse_transform(positions),
        columns=transformer.feature_names_in_,
        index=[source, target],
    )
    differences = (cumulations.loc[target]-cumulations.loc[source])

    increased = set(differences[differences >= min_difference].index)
    associated = increased.intersection(associated_diseases)

    tree.edges[source, target]['increased'] = increased
    tree.edges[source, target]['associated'] = associated

In [None]:
def mark_branch_only_diseases(tree, current, root=0):
    """Recursively mark the branch-only diseases for each edge."""
    for successor in tree.successors(current):
        unique = tree.edges[current, successor]['associated'].copy()

        # remove the diseases that are inherited from the ancestors
        for source, target in pairwise(
                nx.shortest_path(tree, root, current)):
            unique -= tree.edges[source, target]['associated']

        # remove the diseases that are shared with other successors
        for other in tree.successors(current):
            if other == successor:  # skip itself
                continue

            unique -= tree.edges[current, other]['associated']

        tree.edges[current, successor]['unique'] = unique
        mark_branch_only_diseases(tree, successor, root)

In [None]:
mark_branch_only_diseases(tree, 0)

In [None]:
joblib.dump(tree, 'caches/tree.joblib')