In [None]:
import logging
import random
import time
from collections import defaultdict
from itertools import chain, combinations
from math import dist

import elpigraph
import joblib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
seed = 1683180350

In [None]:
logging.basicConfig(
    filename='compute_tree.log',
    level=logging.INFO,
)

In [None]:
def log_execution(callable, available_types=(str, int, float)):
    """Log information of a callable object."""

    def is_available_type(x):
        if isinstance(x, tuple):
            x = x[1]  # the value of keyword argument
        return isinstance(x, available_types)

    def wrapper(*args, **kwargs):
        start = time.time()
        result = callable(*args, **kwargs)
        end = time.time()

        # filter out non-available types
        args = filter(is_available_type, args)
        kwargs = filter(is_available_type, kwargs.items())

        # convert arguments to string
        args = map(str, args)
        kwargs = [f'{k}={v}' for k, v in kwargs]

        # generate logging message
        name = callable.__name__
        elapsed = end - start
        args = ','.join(chain(args, kwargs))
        logging.info(f'{name}({args})={result:.2f}({elapsed:.2f}s)')

        return result

    return wrapper

In [None]:
def compute_principle_tree(X, n_nodes):
    """Compute the principle tree."""
    random.seed(seed)
    return elpigraph.computeElasticPrincipalTree(
        X, NumNodes=n_nodes, Lambda=0.05, Mu=0.1,
    )[0]


@log_execution
def tree_energy(X, n_nodes):
    """Compute the principal tree and get the energy of the tree."""
    tree = compute_principle_tree(X, n_nodes)
    return tree.get('FinalReport').get('ENERGY')

In [None]:
vectors = pd.read_pickle('caches/vectors.pkl')

In [None]:
readonly = True
max_n_nodes = 100

if readonly:
    energies = pd.read_csv('tables/energies.csv', index_col='n_nodes')
else:
    energies = {}
    for n in tqdm(range(1, max_n_nodes + 1)):
        energies[n] = tree_energy(vectors.values, n)

    energies = pd.Series(energies, name='energy')
    energies.index.name = 'n_nodes'
    energies.to_csv('tables/energies.csv')

plt.title('Elbow Plot for Elastic Principal Tree')
plt.plot(energies.iloc[:max_n_nodes])
plt.xlabel('Number of Nodes')
plt.ylabel('Elastic Energy')
plt.xticks(range(0, max_n_nodes+1, 25))
plt.grid()

plt.savefig('figures/tree-elbow.png', dpi=300)

In [None]:
readonly = True

if readonly:
    elpigraph_tree = joblib.load('caches/elpigraph_tree.pkl')
else:
    n_nodes = 25
    elpigraph_tree = compute_principle_tree(vectors.values, n_nodes)
    joblib.dump(elpigraph_tree, 'caches/elpigraph_tree.pkl')

In [None]:
# the fraction of data to plot
fraction = 0.05

plt.title('Elastic Principal Tree Embedding on Data (PCA)')
elpigraph.plot.PlotPG(
    # random sample of records to reduce memory usage
    vectors.groupby('pid').sample(frac=fraction, random_state=seed).values,
    PG=elpigraph_tree, Do_PCA=False, show_text=False,
)
plt.xlabel('Component #1')
plt.ylabel('Component #2')
plt.xticks([])
plt.yticks([])

plt.savefig('figures/tree-embed.png', dpi=300)

In [None]:
node_positions = pd.DataFrame(elpigraph_tree['NodePositions'])
edges = pd.DataFrame(
    elpigraph_tree['Edges'][0], columns=['source', 'target']
)

In [None]:
transformer = joblib.load('caches/transformer.joblib')

# calculate the position without any disease
initial = pd.DataFrame(
    np.zeros((1, transformer.n_features_in_)),
    columns=transformer.feature_names_in_,
)
initial_position = transformer.transform(initial).squeeze()

# distances to the initial position
distances = node_positions.apply(
    dist, args=(initial_position,), axis='columns')
# use rank as the new node indeces
ranks = distances.rank().astype(int) - 1
edges = edges.applymap(ranks.get)
node_positions.index = ranks
# the direction is same as the distance increasing from initial position
edges = edges.apply(sorted, axis='columns')

tree = nx.DiGraph(edges)
# mark nodes with their positions
for index, row in node_positions.iterrows():
    tree.nodes[index]['position'] = row.values

joblib.dump(tree, 'caches/tree.joblib')

In [None]:
# calculate distances between nodes
distances = defaultdict(dict)
for source, target in combinations(tree.nodes, 2):
    distances[source][target] = dist(
        tree.nodes[source]['position'], tree.nodes[target]['position']
    )
layout = nx.kamada_kawai_layout(tree, dist=distances)
nx.draw_networkx(tree, layout)
plt.savefig('figures/tree.png', dpi=300)

joblib.dump(layout, 'caches/layout.joblib')