# Dendritic Spine Classification

In [None]:
dataset_path = "0.025 0.025 0.1 dataset"

## Load Metrics

In [None]:
from spine_metrics import SpineMetricDataset
from spine_fitter import SpineGrouping


merged_grouping = SpineGrouping().load(f"{dataset_path}/manual_classification/manual_classification_merged_reduced.json")

metrics = SpineMetricDataset().load(f"{dataset_path}/metrics.csv")
metrics = metrics.get_spines_subset(merged_grouping.samples)

classic = metrics.get_metrics_subset(["OpenAngle", "CVD", "AverageDistance",
                                      "LengthVolumeRatio", "LengthAreaRatio", "JunctionArea",
                                      "Length", "Area", "Volume", "ConvexHullVolume", "ConvexHullRatio"])
chords = metrics.get_metrics_subset(["OldChordDistribution"])

print("Merged expert classification")
display(merged_grouping.show(metrics))

## SVM Classification

In [None]:
from sklearn import svm
from sklearn.model_selection import GridSearchCV
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.metrics import pairwise_distances
from sklearn.linear_model import LinearRegression
from ipywidgets import widgets    
import csv
from notebook_widgets import create_dir
    

print(len(merged_grouping.samples))
for (key, value) in merged_grouping.groups.items():
    print(f"{key}: {len(value)}")

histogram_size = 100


def histogram_intersection(histogram_1, histogram_2):
    return sum([min(histogram_1[k], histogram_2[k]) for k in range(histogram_size)])


def histogram_intersection_pairwise(X, Y):
    return pairwise_distances(X, Y, metric=histogram_intersection)


def combined_kernel(X, Y):
    histograms_X = X[:, :histogram_size]
    histograms_Y = Y[:, :histogram_size]
    classic_X = X[:, histogram_size:]
    classic_Y = Y[:, histogram_size:]
    return histogram_intersection_pairwise(histograms_X, histograms_Y) + rbf_kernel(classic_X, classic_Y)


scores = {}


def perform_classificaiton(label, metrics, ratios, kernel="rbf"):
    metrics = metrics.get_spines_subset(merged_grouping.samples)
    
    output = []
    pred_groupings = []
    for train_ratio in ratios:
        train_grouping = merged_grouping.get_balanced_subset(train_ratio)
        train_metrics = metrics.get_spines_subset(train_grouping.samples)
        train_names = train_metrics.ordered_spine_names
        
        test_grouping = merged_grouping.get_spines_subset(merged_grouping.samples.difference(train_grouping.samples))
        test_metrics = metrics.get_spines_subset(test_grouping.samples)
        test_names = test_metrics.ordered_spine_names

        train_target = [train_grouping.get_group(name) for name in train_names]
        test_target = [test_grouping.get_group(name) for name in test_names]

        parameters = {'C': [1, 10]}

        train_data = train_metrics.as_array()
        test_data = test_metrics.as_array()

        svc = svm.SVC(kernel=kernel)
        clf = GridSearchCV(svc, parameters, cv=4)
        clf.fit(train_data, train_target)
        output.append(clf.score(test_data, test_target))
        
        pred_labels = clf.predict(metrics.as_array())
        pred_groups = {label: set() for label in ["Thin", "Mushroom", "Stubby"]}
        for i, name in enumerate(metrics.ordered_spine_names):
            pred_groups[pred_labels[i]].add(name)
        pred_grouping = SpineGrouping(metrics.spine_names, pred_groups)
        pred_groupings.append(pred_grouping)
        
    reg = LinearRegression().fit(np.reshape(ratios, (-1, 1)), output)
    plt.plot(ratios, output, label=label)
#     plt.plot(ratios, reg.predict(np.reshape(ratios, (-1, 1))), label=label)
    print(f"{label} max accuracy: {max(output)}")
    
    best_grouping = pred_groupings[np.argmax(output)]

    scores[label] = (output)
    
    return best_grouping

ratios = np.linspace(0.1, 0.9, 9)

res_groupings = {}

input_metrics = {"Classic": classic, "Combined": metrics, "Chord": chords}

for label, metric_set in input_metrics.items():
    res_groupings[label] = perform_classificaiton(label, metric_set, ratios)

plt.xlabel("train / (train + test)")
plt.ylabel("mean accuracy")
plt.legend()

plt.show()


def export_graphs(_):
    create_dir("output")
    filename = f"output/{dataset_path}_classification_accuracy.csv"
    with open(filename, mode="w") as file:
        writer = csv.writer(file)
        writer.writerow(["train_ratio"] + list(ratios))
        for label, score in scores.items():
            writer.writerow([label] + score)
    print(f"Saved classification accuracy graphs to '{filename}'.")


export_button = widgets.Button(description="Export to .csv")
export_button.on_click(export_graphs)
display(export_button)


for label, res_grouping in res_groupings.items():
    metric_set = input_metrics[label]
    display(widgets.HBox([widgets.VBox([widgets.Label("Manual classification:"), merged_grouping.show(metric_set)]),
                          widgets.VBox([widgets.Label(f"{label} metrics SVM classification:"), res_grouping.show(metric_set)])]))

## 2 Datasets

In [None]:
from sklearn import svm
from sklearn.model_selection import GridSearchCV
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.metrics import pairwise_distances
from sklearn.linear_model import LinearRegression
from typing import Tuple    
from spine_metrics import SpineMetricDataset
from spine_fitter import SpineGrouping
from ipywidgets import widgets

    
# [Merged Grouping, Combined Metrics, Classic metrics, Chord metrics]
def load_dataset(dataset_path) -> Tuple[SpineGrouping, SpineMetricDataset, SpineMetricDataset, SpineMetricDataset]:
    grouping = SpineGrouping().load(f"{dataset_path}/manual_classification/manual_classification_merged_reduced.json")

    metrics = SpineMetricDataset().load(f"{dataset_path}/metrics.csv")
    metrics = metrics.get_spines_subset(grouping.samples)

    classic = metrics.get_metrics_subset(["OpenAngle", "CVD", "AverageDistance",
                                          "LengthVolumeRatio", "LengthAreaRatio", "JunctionArea",
                                          "Length", "Area", "Volume", "ConvexHullVolume", "ConvexHullRatio"])
    chords = metrics.get_metrics_subset(["OldChordDistribution"])

    return grouping, metrics, classic, chords


train = load_dataset("train 0.025 0.025 0.1 dataset")
test = load_dataset("test 0.025 0.025 0.1 dataset")
union = load_dataset("0.025 0.025 0.1 dataset")

train, test = test, train


histogram_size = 100

def do_classification(label, train_grouping, train_metrics, test_grouping, test_metrics):
    pred_groupings = []

    train_names = train_metrics.ordered_spine_names
    test_names = test_metrics.ordered_spine_names

    train_target = [train_grouping.get_group(name) for name in train_names]
    test_target = [test_grouping.get_group(name) for name in test_names]

    parameters = {'C': [1, 10]}

    train_data = train_metrics.as_array()
    test_data = test_metrics.as_array()

    svc = svm.SVC(kernel="rbf")
    clf = GridSearchCV(svc, parameters, cv=4)
    clf.fit(train_data, train_target)
    print(f"{label} metrics SVM accuracy = {clf.score(test_data, test_target)}")

    pred_labels = clf.predict(test_metrics.as_array())
    pred_groups = {label: set() for label in ["Thin", "Mushroom", "Stubby"]}
    for i, name in enumerate(test_metrics.ordered_spine_names):
        pred_groups[pred_labels[i]].add(name)
    pred_grouping = SpineGrouping(test_metrics.spine_names, pred_groups)
    display(widgets.HBox([test_grouping.show(test_metrics), pred_grouping.show(test_metrics)]))

classic_grouping = do_classification("Classic", train[0], train[2], test[0], test[2])
combined_grouping = do_classification("Combined", train[0], train[1], test[0], test[1])
chord_grouping_rbf = do_classification("Chord", train[0], train[3], test[0], test[3])

## Expert Accuracy

In [None]:
from pathlib import Path

path = Path(f"{dataset_path}/manual_classification")
classifications_paths = path.glob("manual_classification_?.json")
groupings = [SpineGrouping().load(str(path)) for path in classifications_paths]


for i, grouping in enumerate(groupings):
    print(f"{i + 1}: {SpineGrouping.accuracy(merged_grouping, grouping):.2f}")
print(f"Mean: {np.mean([SpineGrouping.accuracy(merged_grouping, grouping) for grouping in groupings]):.2f}")