In [None]:
import random
import numpy as np
from sklearn import tree
from sklearn import svm
from sklearn import metrics
import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib inline
import json
import os
import sys
os.chdir('/local/home/dhaziza/entrack')
sys.path.append('/local/home/dhaziza/entrack/')
from src.data.plots import do_bars_comparison_plot, load_datasets, print_class_samples, plot_smt_metric, load_tensors_dump

CLASS_NAMES = ['HC', 'AD', 'MCI']
CLASS_COUNT = len(CLASS_NAMES)
CLASS_HC = CLASS_NAMES.index('HC')
CLASS_AD = CLASS_NAMES.index('AD')
CLASS_MCI = CLASS_NAMES.index('MCI')

# Batch size 48, bn, py_streaming, no 512fc
BASELINE_RUNS = {
    'cnn': ['20180409-195126', '20180411-005423', '20180410-151641', '20180407-101054', '20180407-101051'],
    'svm': ['20180411-181756', '20180411-163159'],
    'inception': ['20180411-193238'],
}
random.seed(0)

data_per_class_train, data_per_class_valid, data_per_class_all = load_datasets(num_classes=CLASS_COUNT)
print_class_samples(data_per_class_train, 'train')
print_class_samples(data_per_class_valid, 'valid')
plot_smt_metric({
    'Baseline': BASELINE_RUNS['cnn'],
    'Wrong Split': '20180411-111955',
}, 'eval/classifier/loss')

In [None]:
def decision_argmax(probas):
    return np.argmax(probas)

def decision_heur(probas):
    if probas[CLASS_AD] > 0.95:
        return CLASS_AD
    return np.argmax(probas * [1., 0, 1.])

def check_accuracy(dataset, decision_fn):
    accuracies = []
    for _class, v in dataset.items():
        accurate_count = 0
        total_count = len(v)
        for sample_output in v:
            if decision_fn(sample_output) == _class:
                accurate_count += 1
        acc = 100.0 * accurate_count / total_count if total_count > 0 else 0
        if total_count == 0:
            print('  Class %d: <no sample>' % _class)
        else:
            print('  Class %d: accuracy %.2f%% (%d/%d correct)' % (_class, acc, accurate_count, total_count))
        accuracies.append(acc)
    print('Mean Per Class Accuracy: %.2f%%' % np.mean(accuracies))

check_accuracy(data_per_class_all, decision_argmax)

In [None]:
def plot_histograms():
    probas_range = 0.5
    def plot_distrib(vector, title, range=(0., probas_range)):
        plt.hist(vector, range=range)
        plt.title(title)
        plt.show()

    for k in range(CLASS_COUNT):
        plt.hist([
            data_per_class_all[k1][:, k]
            for k1 in range(CLASS_COUNT)
        ], range=(0., probas_range), density=True)
        plt.title('P(%s|input_class)' % CLASS_NAMES[k])
        plt.legend(['input_class=%s' % k1 for k1 in CLASS_NAMES])
        plt.ylabel('Probability density')
        plt.show()

    plt.hist([
        data_per_class_all[k1][:, CLASS_HC] - data_per_class_all[k1][:, CLASS_AD]
        for k1 in range(CLASS_COUNT)
    ], range=(-probas_range, probas_range), density=True)
    plt.title('P(HC - AD|input_class)')
    plt.legend(['input_class=%s' % k1 for k1 in CLASS_NAMES])
    plt.ylabel('Probability density')
    plt.show()

    plt.hist([
        data_per_class_all[k1][:, k1]
        for k1 in range(CLASS_COUNT)
    ], range=(-probas_range, probas_range), density=True)
    plt.title('P(input_class|input_class)')
    plt.legend(['input_class=%s' % k1 for k1 in CLASS_NAMES])
    plt.ylabel('Probability density')
    plt.show()

In [None]:
# Training for parameters
x = [[], [], []]
colors = ['red', 'green', 'blue']
labels = []
for k, v in data_per_class_train.items():
    if len(v) > 0:
        c = len(v[:, 0])
        x[0] += list(v[:, 0])
        x[1] += list(v[:, 1])
        if v.shape[1] > 2:
            x[2] += list(v[:, 2])
        labels += [CLASS_NAMES[k]] * c

def scatter_plot(axis0, axis1, dataset=None):
    if dataset is None:
        dataset = data_per_class_all
    fig, ax = plt.subplots()
    for i in range(len(colors)):
        if len(dataset[i]) > 0:
            p0 = dataset[i][:, axis0]
            p1 = dataset[i][:, axis1]
            ax.scatter(p0,p1,c=colors[i], label=CLASS_NAMES[i], alpha=0.2, s=12)
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    # now plot both limits against eachother
    prev_lims = [ax.get_xlim(), ax.get_ylim()]
    ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    ax.add_patch(
        patches.Rectangle(
                (-50., -50.),   # (x,y)
                600,          # width
                600,          # height
                facecolor=colors[axis0],
                angle=-45.0,
                alpha=0.15,
        )
    )
    ax.add_patch(
        patches.Rectangle(
                (-50., -50.),   # (x,y)
                600,          # width
                600,          # height
                facecolor=colors[axis1],
                angle=45.0,
                alpha=0.15,
        )
    )
    #ax.plot(lims, np.array(lims) + 0.01, 'k-', linewidth=5, alpha=0.5, c=colors[axis0], zorder=0)
    #ax.plot(lims, np.array(lims) - 0.01, 'k-', linewidth=5, alpha=0.5, c=colors[axis1], zorder=0)
    #ax.set_aspect('equal')
    ax.set_xlim(prev_lims[0])
    ax.set_ylim(prev_lims[1])

    ax.set_xlabel('logits[%s]' % CLASS_NAMES[axis0])
    ax.set_ylabel('logits[%s]' % CLASS_NAMES[axis1])
    ax.legend()
    ax.grid(True)
    plt.show()

scatter_plot(0, 1)
#scatter_plot(0, 2)
#scatter_plot(1, 2)

In [None]:
do_bars_comparison_plot({
    'baseline': (BASELINE_RUNS['cnn'], BASELINE_RUNS['svm']),
    '200': ('20180410-124714', '20180410-151305'),
    '300': ('20180410-124751', '20180410-151343'),
    '400': ('20180410-124805', '20180410-162844'),
    '500': (['20180410-144746', '20180410-124820'], '20180410-162921'),
    '600': (['20180410-124838', '20180410-141605', '20180410-141608'], '20180410-163325'),
}, ('3DCNN', 'SVM'), 'N', 'Accuracy when training on a dataset of size 1000 with up to N patients')

In [None]:
do_bars_comparison_plot({
    '01__b32': ('20180408-194221', ),
    '02__b48': ('20180408-194545', ),
    '03__b48_bn': ('20180408-231353', ),
    '04__b48_bn_noscale': ('20180408-231638', ),
}, ('SVM', ), 'Run', 'SVM accuracy')

In [None]:
do_bars_comparison_plot({
    '01__nonorm': ('20180410-173322', ),
    '02__nonorm_bn': ('20180410-163340', ),
    '03__dn': ('20180410-151726', ),
    '04_baseline__dn_bn': (BASELINE_RUNS['cnn'], ),
}, ['3DCNN'], 'Run', 'Accuracy with various data processing', ylim=[0.75, 0.9])

In [None]:
do_bars_comparison_plot({
    '01__4': ('20180411-010527', '20180411-065451', ),
    '02__8': ('20180410-230522', '20180411-071435', ),
    '03__16': (['20180410-230647', '20180411-102359'], '20180411-011551', ),
    '04__32': (['20180410-230608', '20180407-020135', '20180406-222733'], '20180411-014406', ),
    '05_baseline__48': (BASELINE_RUNS['cnn'], ['20180410-151726', '20180411-100354']),
}, ['3DCNN_BN', '3DCNN_NOBN'], 'Run', 'Accuracy different batch sizes', ylim=[0.80, 0.90])

In [None]:
do_bars_comparison_plot({
    '999__baseline': (BASELINE_RUNS['cnn'], BASELINE_RUNS['inception'], BASELINE_RUNS['svm']),
    '010__10': ('20180411-165025', '20180411-145142', '20180411-183437', ),
    '020__20': ('20180411-130243', '20180411-120122', '20180411-183512', ),
    '040__40': ('20180411-144539', '20180411-130135', '20180411-190147', ),
    '100__100': ('20180411-135452', '20180411-152131', '20180411-183810', ),
    '150__150': ('20180411-172647', '20180411-195756', '20180411-192843'),
    '200__200': ('20180411-142421', '20180411-165904', '20180411-190023', ),
    '300__300': ('20180411-172727', '20180411-195802', ['20180411-192830', '20180411-201431']),
    '600__600': ('20180411-173518', '20180411-195730', '20180411-190904', ),
}, ['3DCNN', 'Inceptionv3', 'SVM'], 'Number of subjects per class', 'Accuracy with various number of patients [batch_size=16]', ylim=[0.50, 0.90])

In [None]:
do_bars_comparison_plot({
    '01__1': (['20180411-203444', '20180411-222724', '20180411-212836'],),
    '02__2': (['20180411-205157', '20180411-212840'], ),
    '03__3': (['20180411-203650', '20180411-222726', '20180411-212847'], ),
    '04__4': (['20180411-204710', '20180411-212857'], ),
    '05__5': (['20180411-203615', '20180411-223602', '20180411-221621', '20180411-215725'], ),
    '999__baseline': (BASELINE_RUNS['cnn'], ),
}, ['3DCNN',], 'Number of images per patient', 'Accuracy changes induced by adding more images of the same 350 patients', ylim=[0.50, 0.90])

In [None]:
def roc_curves(runs, avg_scores):
    plt.figure()
    
    for run_label, run_reason in runs:
        y_scores = []
        dump_tensors = load_tensors_dump(run_label)
        image_data = {}
        for step in avg_scores:
            run_data = dump_tensors[dump_tensors.keys()[step]]
            for _run_data in run_data:
                for i, image_label in enumerate(_run_data['classifier/image_label']):
                    image_label = image_label[0]
                    if image_label not in image_data:
                        image_data[image_label] = {
                            'scores': [],
                            'label': np.argmax(_run_data['classifier/labels'][i]),
                        }
                    image_data[image_label]['scores'].append(_run_data['classifier/proba'][i][CLASS_AD])
        y_scores = [np.median(d['scores']) for d in image_data.values()]
        y_true = [d['label'] for d in image_data.values()]
        pos_label = CLASS_AD
        fpr, tpr, thresholds = metrics.roc_curve(y_true, y_scores, pos_label=pos_label)
        roc_auc = metrics.auc(fpr, tpr)
        plt.plot(fpr, tpr,
                 lw=2, label='%s (area = %0.3f)' % (run_reason, roc_auc))

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic for AD class')
    plt.legend(loc="lower right")
    plt.show()

roc_curves([
    ('20180411-005423', 'batch_size=48'),
    ('20180410-230608', 'batch_size=32'),
    ('20180410-230647', 'batch_size=16'),
    ('20180410-230522', 'batch_size=8'),
    ('20180411-010527', 'batch_size=4')
], avg_scores=range(-10, 0))