In [12]:
from classifier import ManifoldClassifier
import numpy as np
import pandas as pd
import utils
import xarray as xr
import matplotlib.pyplot as plt
from core import Core, use_dim_type
from component import ComponentGroups
from sklearn.model_selection import train_test_split

In [None]:
data_AD = np.load("data/AD.npy")
data_NL = np.load("data/Normal.npy")
protein_coding_indices = pd.read_csv('data/protein_coding_ID.csv', index_col=0).index.to_numpy()

X = np.concatenate([data_AD, data_NL], axis=0)
y = np.concatenate([np.ones(len(data_AD), dtype=np.int8),
                    np.zeros(len(data_NL), dtype=np.int8)], axis=0)
print(X.shape, y.shape)

In [14]:
X = xr.DataArray(X, dims=['person', 'gene', 'section'], coords={'person': np.arange(X.shape[0])})
y = xr.DataArray(y, dims=['person'])

In [15]:
ManifoldClassifier.manifolds_cache = Core(data=X.values, dataType="all").fit_manifolds()

In [16]:
scoring = [
    "accuracy",
    "average_precision",
    "balanced_accuracy",
    "f1",
    "f1_macro",
    "f1_micro",
    "f1_weighted",
    "jaccard",
    "jaccard_macro",
    "jaccard_micro",
    "jaccard_weighted",
    "matthews_corrcoef",
    "precision",
    "precision_macro",
    "precision_micro",
    "precision_weighted",
    "recall",
    "recall_macro",
    "recall_micro",
    "recall_weighted",
    "roc_auc",
]

In [17]:
important_metrics = [
    "f1_macro",
    "jaccard_macro",
    "recall_macro",
    "roc_auc"
]

In [None]:
n_clusters, n_genes = (13, 18), 2
n_clusters, n_genes = (2, 7), 19
n_clusters, n_genes = (4, 9), 24
n_clusters, n_genes = (4, 10), 3
n_clusters, n_genes = (6, 9), 20
n_clusters, n_genes = (2, 9), 2
n_clusters, n_genes = (3, 13), 14
# n_clusters, n_genes = (2, 1), 1

score_agg_method = "min"
dist_type = "hausdorff"
use_dim: use_dim_type = "1D4D"


def feature_getter(cgs: ComponentGroups) -> np.ndarray:
    return np.hstack([cgs.total_curvatures[:, None], cgs.areas[:, None]])

def typical_analyzer(cgs: ComponentGroups, people: np.ndarray, attr_mean: np.ndarray):
    total_curvature, mean_area = attr_mean
    typical_person_index = np.argmin(
        (cgs.total_curvatures[people] - total_curvature) ** 2
        + (cgs.areas[people] - mean_area) ** 2
    )
    return people[typical_person_index]


classifier = ManifoldClassifier(
    use_dim=use_dim,
    fit_manifold_config=dict(),
    cluster_configs=[
        dict(n_group=n_clusters[0], feature_getter=feature_getter),
        dict(n_group=n_clusters[1], feature_getter=feature_getter)
    ],
    analyze_config=dict(top_k=n_genes,
                        typical_analyzer=typical_analyzer,
                        selectable_indices=protein_coding_indices),
    classify_config=dict(score_agg_method=score_agg_method,
                         dist_type=dist_type)
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
classifier.fit(X_train, y_train)

utils.show_classification_result(y_train, classifier.predict_proba(X_train), names=['Normal', 'AD'])
utils.show_classification_result(y_test, classifier.predict_proba(X_test), names=['Normal', 'AD'])

In [None]:
a = np.array([0, 1, 1, 1])
b = np.array([1, 1, 1, 1])


TP = ((a==1) & (b==1)).sum()
TN = ((a==0) & (b==0)).sum()
FP = ((a==0) & (b==1)).sum()
FN = ((a==1) & (b==0)).sum()

print("recall:", utils.recall_score(a, b, average='binary'), TP / (TP + FN))
print("precision:", utils.precision_score(a, b, average='binary'), TP / (TP + FP))
print("f1:", utils.f1_score(a, b, average='binary'), 2 * TP / (2 * TP + FP + FN))
print("accuracy:", utils.accuracy_score(a, b), (TP + TN) / len(a))

# we should maximize recall
utils.balanced_accuracy_score(a, b)

In [20]:
labels = ['Normal', 'AD']

n1 = classifier.cluster_configs[0]['n_group']
n2 = classifier.cluster_configs[1]['n_group']

title = f'2 Class Classification Train Results for {labels[0]}/{labels[1]}'
print(title)
utils.show_classification_result(y_train, classifier.predict_proba(X_train), normalize='pred', names=labels)
plt.suptitle(title)
plt.show()

title = f'2 Class Classification Test Results for {labels[0]}/{labels[1]}'
print(title)
utils.show_classification_result(y_test, classifier.predict_proba(X_test), normalize='pred', names=labels)
plt.suptitle(title)
plt.show()

title = f'{n1+n2} Class Classification Results for AD/Normal'
pred_scores = classifier.mixed_core.classify_with_typical(use_dim=use_dim, data=X_train.values, **classifier.classify_config)
assert pred_scores.shape == (len(X_train), n1+n2)
true_label = np.empty(len(X_train), dtype=np.int8)
true_label[y_train==0] = np.array(classifier.cores[0].group_result.person2group)
true_label[y_train==1] = np.array(classifier.cores[1].group_result.person2group) + n1
print(title)
utils.show_classification_result(true_label, pred_scores, normalize=None)
plt.suptitle(title)

plt.figure(figsize=(6, 2))
class_counts = np.bincount(true_label)
plt.bar(range(n1), class_counts[:n1], label=labels[0])
plt.bar(range(n1, n1 + n2), class_counts[n1:], color='#896989',  label=labels[1])
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Frequency of Each Class in true_label')
plt.legend()
plt.show()

title = f'{n1} Class Classification Results for {labels[0]}s'
pred_scores = classifier.cores[0].classify_with_typical(use_dim=use_dim, data=X_train[y_train==0].values, **classifier.classify_config)
true_label = np.array(classifier.cores[0].group_result.person2group)
print(title)
utils.show_classification_result(true_label, pred_scores, normalize='pred')
plt.suptitle(title)

plt.figure(figsize=(3, 2))
utils.plot_bincount(true_label, title=f'Frequency of Each {labels[0]} Class', normalize=True)
plt.show()

title = f'{n2} Class Classification Results for {labels[1]}s'
pred_scores = classifier.cores[1].classify_with_typical(use_dim=use_dim, data=X_train[y_train==1].values, **classifier.classify_config)
true_label = np.array(classifier.cores[1].group_result.person2group)
print(title)
utils.show_classification_result(true_label, pred_scores, normalize='pred')
plt.suptitle(title)

plt.figure(figsize=(3, 2))
utils.plot_bincount(true_label, title=f'Frequency of Each {labels[1]} Class', normalize=True)
plt.show()
