# Classify ALS status using AFQ-Insight

In [1]:
import afqinsight as afqi
import matplotlib.pyplot as plt
import numpy as np
import pickle

from mpl_toolkits.mplot3d import Axes3D

from bokeh.io import output_notebook
from bokeh.layouts import row, column, widgetbox
from bokeh.models import BoxSelectTool, HoverTool, Title
from bokeh.palettes import Spectral10
from bokeh.plotting import figure, show, ColumnDataSource

%matplotlib notebook

In [2]:
output_notebook()

## Load the data

In [3]:
afq_data = afqi.load_afq_data('../afqinsight/data/classification_data', target_col='class', binary_positive='ALS')

In [4]:
x, y, groups, cols = afq_data.x, afq_data.y, afq_data.groups, afq_data.cols

## Find the optimal feature coefficients $\widehat{\beta}$

We search for the optimal coefficients using two different classification scores: accuracy and ROC AUC.

In [11]:
hp_cv_res_acc = afqi.fit_hyperparams_cv(
    x, y, groups, max_evals_per_cv=1000,
    score='accuracy',
    trials_pickle_dir='./cv_trials_cv10_rs42_accuracy',
    verbose=1, random_state=42, clf_threshold=0.5
)

hp_cv_res_auc = afqi.fit_hyperparams_cv(
    x, y, groups, max_evals_per_cv=1000,
    score='roc_auc',
    trials_pickle_dir='./cv_trials_cv10_rs42_rocauc',
    verbose=1, random_state=42, clf_threshold=0.5
)

100%|██████████| 10/10 [00:19<00:00,  1.98s/it]
100%|██████████| 10/10 [00:18<00:00,  1.80s/it]


Let's examine the regularization parameters for each CV split.

In [12]:
[(r.alpha1, r.alpha2) for r in hp_cv_res_auc]

[(0.35147294619698494, 0.0684492039099914),
 (0.7221609766243215, 0.054933506562241594),
 (0.32844055241108483, 0.0231990189431205),
 (0.014323199760265448, 0.0029343349221855677),
 (0.5621782149255845, 0.035436219117816456),
 (0.940745576729, 0.012248556128239284),
 (0.5563691445451803, 0.02160470829610179),
 (0.25755070816421, 0.031464226297018),
 (0.31974962754306113, 0.023145691410387116),
 (0.4390627535470068, 0.04232965823815161)]

And look at the classification scores achieved.

In [13]:
def print_results_summary(hp_cv_results):
    template = '{stat:15s} {mean:7.5g} ({var:7.5g})'
    test = [r.test for r in hp_cv_results]
    train = [r.train for r in hp_cv_results]
    test_acc = [t.accuracy for t in test]
    test_auc = [t.auc for t in test]
    test_aps = [t.avg_precision for t in test]
    train_acc = [t.accuracy for t in train]
    train_auc = [t.auc for t in train]
    train_aps = [t.avg_precision for t in train]

    print('Statistic         mean   (variance)')
    print('--------------  ------- ------------')
    print(template.format(stat='test accuracy', mean=np.mean(test_acc), var=np.var(test_acc)))
    print(template.format(stat='test AUC', mean=np.mean(test_auc), var=np.var(test_auc)))
    print(template.format(stat='test avg prec', mean=np.mean(test_aps), var=np.var(test_aps)))
    print(template.format(stat='train accuracy', mean=np.mean(train_acc), var=np.var(train_acc)))
    print(template.format(stat='train AUC', mean=np.mean(train_auc), var=np.var(train_auc)))
    print(template.format(stat='train avg prec', mean=np.mean(train_aps), var=np.var(train_aps)))

In [14]:
print_results_summary(hp_cv_res_acc)

Statistic         mean   (variance)
--------------  ------- ------------
test accuracy   0.85833 (0.015347)
test AUC        0.93056 (0.013773)
test avg prec   0.95056 (0.0070028)
train accuracy  0.93528 (0.0029949)
train AUC       0.98175 (0.00034641)
train avg prec  0.98367 (0.00026708)


In [15]:
print_results_summary(hp_cv_res_auc)

Statistic         mean   (variance)
--------------  ------- ------------
test accuracy   0.84167 (0.024236)
test AUC        0.91944 (0.013341)
test avg prec   0.94222 (0.0068037)
train accuracy  0.94686 (0.0024612)
train AUC       0.98751 (0.00019168)
train avg prec  0.98856 (0.00014737)


In [16]:
p = figure(plot_width=700, plot_height=700, toolbar_location='above')
p.title.text = 'Classification probabilities for each CV split'
p.add_layout(
    Title(text='Click on legend entries to hide/show the corresponding lines',
          align="left"), 'right'
)

names = ['cv_idx = {i:d}'.format(i=i) for i in range(len(hp_cv_res_auc))]

hover = HoverTool(
    tooltips=[("index", "$index"),],
    mode='vline',
)
hover.point_policy = 'snap_to_data'
hover.line_policy = 'nearest'

for res, color, name in zip(hp_cv_res_auc, Spectral10, names):
    p.line(np.arange(len(y)), afqi.insight._sigmoid(x.dot(res.beta_hat)),
           line_width=2, color=color, alpha=0.8, legend=name)

p.line(np.arange(len(y)), y, line_width=3, alpha=0.8, legend='ground truth')
p.line(np.arange(len(y)), 0.5 * np.ones(len(y)), 
       line_width=2, line_dash='dashed', alpha=0.8, legend='threshold')
p.add_tools(hover)
p.legend.location = 'top_right'
p.legend.click_policy = 'hide'

show(p)

Using the hover tool on the chart above, we can see that subjects 05, 07, 08, 16, 19, 30, 32, 35, 36 are all hard to classify (they are consistently closer to the classification threshold of 0.5). We should fire up the AFQ browser and look at how these subjects compare to the rest of the subjects in their group.

Here are links to a running instance of AFQ-Browser with the hard to classify subjects:
- [False negatives](https://yeatmanlab.github.io/Sarica_2017/?table[prevSort][count]=2&table[prevSort][order]=ascending&table[prevSort][key]=&table[sort][count]=2&table[sort][order]=ascending&table[sort][key]=class&table[selectedRows][subject_005]=true&table[selectedRows][subject_007]=true&table[selectedRows][subject_008]=true&table[selectedRows][subject_016]=true&table[selectedRows][subject_019]=true&table[selectedRows][subject_030]=false&table[selectedRows][subject_032]=false&table[selectedRows][subject_035]=false&table[selectedRows][subject_036]=false&plots[checkboxes][right-corticospinal]=true&plots[zoom][rd][scale]=1&plots[zoom][rd][translate][0]=-3&plots[zoom][rd][translate][1]=-21&plots[zoom][fa][scale]=2.1140360811227614&plots[zoom][fa][translate][0]=-27.244995845837778&plots[zoom][fa][translate][1]=-106.10468474511174&plots[plotKey]=fa&plots[errorType]=stderr&plots[lineOpacity]=0.09355440414507772)
- [False positives](https://yeatmanlab.github.io/Sarica_2017/?table[prevSort][count]=2&table[prevSort][order]=ascending&table[prevSort][key]=&table[sort][count]=2&table[sort][order]=ascending&table[sort][key]=class&table[selectedRows][subject_005]=false&table[selectedRows][subject_007]=false&table[selectedRows][subject_008]=false&table[selectedRows][subject_016]=false&table[selectedRows][subject_019]=false&table[selectedRows][subject_030]=true&table[selectedRows][subject_032]=true&table[selectedRows][subject_035]=true&table[selectedRows][subject_036]=true&plots[checkboxes][right-corticospinal]=true&plots[zoom][rd][scale]=1&plots[zoom][rd][translate][0]=-3&plots[zoom][rd][translate][1]=-21&plots[zoom][fa][scale]=2.1140360811227614&plots[zoom][fa][translate][0]=-27.244995845837778&plots[zoom][fa][translate][1]=-106.10468474511174&plots[plotKey]=fa&plots[errorType]=stderr&plots[lineOpacity]=0.09355440414507772)

Let's sort the features by their importance

In [17]:
feature_dicts = afqi.multicol2dicts(cols, tract_symmetry=False)

mean_beta = np.mean(np.array(
    [res.beta_hat for res in hp_cv_res_auc]
), axis=0)

sorted_features = afqi.sort_features(feature_dicts, mean_beta)

sorted_features

[({'metric': 'fa', 'nodeID': 34, 'tractID': 'Right Corticospinal'},
  -0.057948077336283155),
 ({'metric': 'fa', 'nodeID': 33, 'tractID': 'Right Corticospinal'},
  -0.05704323654664243),
 ({'metric': 'fa', 'nodeID': 35, 'tractID': 'Right Corticospinal'},
  -0.056429104777020275),
 ({'metric': 'fa', 'nodeID': 32, 'tractID': 'Right Corticospinal'},
  -0.05275280574110338),
 ({'metric': 'fa', 'nodeID': 38, 'tractID': 'Right Corticospinal'},
  -0.052441055527429635),
 ({'metric': 'fa', 'nodeID': 36, 'tractID': 'Right Corticospinal'},
  -0.05226105070542332),
 ({'metric': 'fa', 'nodeID': 37, 'tractID': 'Right Corticospinal'},
  -0.05123490760998841),
 ({'metric': 'fa', 'nodeID': 39, 'tractID': 'Right Corticospinal'},
  -0.051035424287296724),
 ({'metric': 'fa', 'nodeID': 40, 'tractID': 'Right Corticospinal'},
  -0.048913298890418316),
 ({'metric': 'fa', 'nodeID': 41, 'tractID': 'Right Corticospinal'},
  -0.04646919053954916),
 ({'metric': 'fa', 'nodeID': 31, 'tractID': 'Right Corticospinal'

It's nice to see the top few features in a sorted list, but let's plot the features to get a feel for their distributions

In [18]:
beta_hats = afqi.beta_hat_by_groups(mean_beta, columns=cols, drop_zeros=True)

In [19]:
def plot_betas(beta_hat, all_metrics, ecdf=False):
    ps = []

    colors = {}
    for idx, metric in enumerate(all_metrics):
        colors[metric] = Spectral10[idx]
        
    for idx, tract in enumerate(beta_hat.keys()):
        ps.append(figure(plot_width=750, plot_height=250, toolbar_location='right'))
        ps[idx].title.text = tract

        for metric in beta_hat[tract].keys():
            b = beta_hat[tract][metric]
            if not all(b == 0):
                if ecdf:
                    cdf = afqi.utils.ecdf(b)
                    ps[idx].circle(cdf.x, cdf.y,
                                   size=5, color=colors[metric],
                                   alpha=0.8, legend=metric)
                else:
                    ps[idx].line(np.arange(len(b)), b,
                                 line_width=4, color=colors[metric],
                                 alpha=0.8, legend=metric)

        ps[idx].legend.location = 'bottom_right'
        ps[idx].legend.click_policy = 'hide'

    show(column(ps))

First let's plot the coefficients themselves

In [20]:
plot_betas(beta_hat=beta_hats, all_metrics=cols.levels[cols.names.index('metric')])

Now let's plot the empirical cumulative distribution function (ECDF) of each set of coefficients.

In [21]:
plot_betas(beta_hat=beta_hats, all_metrics=cols.levels[cols.names.index('metric')], ecdf=True)

## Output coefficients to AFQ-Browser CSV format

In [22]:
afqi.output_beta_to_afq(
    beta_hat=mean_beta,
    columns=cols,
    workdir_in='../afqinsight/data/classification_data',
    workdir_out='../afqinsight/data_with_weights',
    scale_beta=True
)

OutputFiles(nodes_file='/Users/Adam/code/projects/afq/insight/afqinsight/data_with_weights/nodes.csv', subjects_file='/Users/Adam/code/projects/afq/insight/afqinsight/data_with_weights/subjects.csv')