In [None]:
import textwrap
from itertools import combinations, pairwise
from math import dist, log2

import joblib
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import seaborn as sns
from multiprocess import Pool

from utils import locate_on_2D_space

In [None]:
tree = joblib.load('caches/tree.joblib')
layout = joblib.load('caches/layout.joblib')

terminals = (node for node in tree.nodes if tree.out_degree(node) == 0)
trajectories = [
    nx.shortest_path(tree, 0, terminal) for terminal in terminals
]
trajectories

In [None]:
def nearest_node(vector, tree):
    """Calculate the nearest node of the tree for the given vector"""
    distances = {
        index: dist(position, vector)
        for index, position in tree.nodes.data('position')
    }
    return min(distances, key=distances.get)

In [None]:
vectors = pd.read_pickle('caches/vectors.pkl')
vector_nodes = vectors.apply(nearest_node, tree=tree, axis='columns')
vector_nodes

In [None]:
def assign_group(last, trajectories):
    """Assign group to the patient based on the last node."""
    matched_groups = set()
    for index, trajectory in enumerate(trajectories):
        if last in trajectory:
            matched_groups.add(index)

    if len(matched_groups) == 0:
        raise ValueError('no group found')

    if len(matched_groups) == 1:
        return matched_groups.pop()
    else:
        # ambiguous group
        return -1

In [None]:
last_nodes = vector_nodes.sort_index().groupby('pid').last()
groups = last_nodes.apply(assign_group, trajectories=trajectories)
groups

In [None]:
def to_edge_text(row):
    source = row['source']
    target = row['target']
    return f'{source:2} -> {target:2}'


def to_diseases_text(diseases):
    # print(type(diseases))
    return ', '.join(diseases)


# calulate number of patients in each node
for last in last_nodes:
    for edge in pairwise(nx.shortest_path(tree, 0, last)):
        tree.edges[edge]['n_patients'] = (
            tree.edges[edge].get('n_patients', 0) + 1
        )

data = pd.DataFrame(
    tree.edges.data(),
    columns=['source', 'target', 'data']
)
data = pd.DataFrame(
    data['data'].to_list(),
    index=data.apply(to_edge_text, axis='columns')
)
for column in ['increased', 'associated', 'unique']:
    data[column] = data[column].apply(to_diseases_text)

data.to_csv('tables/edge-data.csv')
data

In [None]:
with Pool() as pool:

    def _locate_on_2D_space(x):
        # generate a temporary function for multiprocessing
        return locate_on_2D_space(x, tree, layout, scaler=5e-4)

    points = pool.map(_locate_on_2D_space, vectors.values)

points = pd.DataFrame(points, columns=['x', 'y'], index=vectors.index)
points

In [None]:
def wrap_text(text, shorten_width=5*4, wrap_width=5*2+1):
    """Wrap text into multiple lines."""
    shorten = textwrap.shorten(text, width=shorten_width)
    return textwrap.fill(shorten, width=wrap_width)

In [None]:
plt.figure(figsize=(6.4*2.5, 4.8*2.5))
width = [
    log2(value['n_patients']+1)*2
    for value in tree.edges.values()
]
nx.draw_networkx(
    tree, layout, edge_color='grey', width=width, alpha=0.8,
)

sns.scatterplot(
    points.groupby('pid').sample(frac=0.2, random_state=0),
    x='x', y='y', hue='pid', s=50, alpha=0.5, legend=False,
)

edge_labels = {
    edge: wrap_text(', '.join(values['unique']))
    for edge, values in tree.edges.items()
}
# edge_labels[(0, 1)] = '[...]'
nx.draw_networkx_edge_labels(
    tree, layout, edge_labels=edge_labels, rotate=False,
    font_size=6, clip_on=True,
)

plt.savefig('figures/tree-diseases.png', dpi=300)

In [None]:
def connecting_subgraph(tree, nodes):
    """Return subgraph of tree connecting the given nodes."""
    graph = tree.to_undirected()
    shortest_paths = dict(nx.all_pairs_shortest_path(graph))

    # nodes of edges to connect the given nodes
    connected_nodes = set()
    for source, target in combinations(nodes, 2):
        path = shortest_paths.get(source, {}).get(target)
        connected_nodes.update(path)

    return tree.subgraph(connected_nodes)

In [None]:
subtrees = {}
for index, trajectory in enumerate(trajectories):
    subtree = tree.subgraph(trajectory)
    subtrees[index] = subtree

undefined_pids = groups[groups == -1].index
undefined_nodes = vector_nodes[
    vector_nodes.index.get_level_values('pid').isin(undefined_pids)
]
subtrees[-1] = connecting_subgraph(tree, undefined_nodes)

In [None]:
als = '335'

records = pd.read_pickle('caches/records.pkl')
als_records = records[records['disease'] == als].copy()
als_records.sort_values(['pid', 'date'], inplace=True)
als_index = als_records.set_index(['pid', 'date']).index
als_points = vectors[vectors.index.isin(als_index)].apply(
    locate_on_2D_space, tree=tree, layout=layout, axis='columns',
)
als_points = pd.DataFrame(
    list(als_points.values),
    index=als_points.index,
    columns=['x', 'y'],
)

In [None]:
min_support = 3

counts = groups.value_counts()
counts = counts[counts >= min_support]

n_columns = 2
n_rows = (len(counts)+1) // n_columns
plt.figure(figsize=(6.4*n_columns, 4.8*n_rows))
for i, (group, count) in enumerate(counts.items()):
    subtree = subtrees[group]

    plt.subplot(n_rows, 2, i+1)

    if group == -1:
        label = f'Uncategorized Group ({count} patients)'
    else:
        label = f'Trajectory {group} ({count} patients)'
    plt.title(label)

    width = [
        log2(value['n_patients']+1)
        for value in subtree.edges.values()
    ]
    nx.draw_networkx(
        subtree, layout, width=width, edge_color='grey', alpha=0.8,
    )

    edge_labels = {
        edge: wrap_text(', '.join(values['associated']))
        for edge, values in subtree.edges.items()
    }
    nx.draw_networkx_edge_labels(
        subtree, layout, edge_labels=edge_labels, rotate=False,
        font_size=10, clip_on=True,
    )

    sub_als_points = als_points[
        als_points.index.get_level_values('pid').isin(
            groups[groups == group].index
        )]
    first_als_points = sub_als_points.groupby('pid').first()
    plt.scatter(
        first_als_points['x'], first_als_points['y'],
        s=100, c='red', alpha=0.5, label='first ALS'
    )

    plt.legend(loc='lower left')

plt.savefig('figures/trajectories.png', dpi=300)