# AM216 Final Project
## Zachary Miller

### Abstract
Neuron classification based on morphology is an extremely important task in neuroinformatics. Such classifications can allow insight into data collected using methods that traditionally preclude the information that the classifier is meant to estimate. For example, neurons collected using electron microscopy give incredibly detailed morphological inofrmation about the neuron, but lack information regarding neurotransmitter type or functional activity. Classifiers could be able to learn subtle reltionships between morphology and these other features, and thus fill in the missing information after the fact. Here, I investiage the potential of graph neural networks for classifying neurons encoded as attributed graphs.

### Project Motivation
The goal of this project was to investigate the efficacy of representing neuron morphologies as attributed graphs for classification tasks. The standard format for working with neuron morphologies in the .swc format, which stores neurons as a set of spatial data ("nodes") attached by a series of edges. Though very convenient for visualization and qualitative tasks, these file formats are difficult to work with for classification tasks since there are currently no classification models that can work directly with .swc files. The traditional approach for dealing with this problem has been to pick several features of interest and extract these into a feature vector for each neuron and perform classificaiton on these feature vectors alone. However, this approach has seen limited success for classification tasks in the past and has several drawbacks. Most notably, this approach requires a human to hand pick features that might be of interest for the classification task at hand. Such features are rarely known a priori and will change depending on the specifics of the classification task at hand. Furthermore, the features that would lead to the best classification results will not necessarily come in a form that is intuitive for humans, and therefore may never even be considered for the feature vectors. Therefore, it is of great interest to find a more native data representation for neuron morphologies that would allow a model to learn from the data directly.

One such representation could be an attributed graph (AG). That is, a graph where each node in the network has a vector of associated data, called an "attribute". In fact, the .swc file format is already essientially an AG, so conversion from a .swc file to an AG is very easy. Combined with recent advances in graph neural networks (GNN), attributed graphs may offer a far better way of encoding neural morphology data for classification tasks, especially when there is a lot of data available for training the GNN. This project began investigating the ability of GNNs to classify neuron morphologies based on AG representations.

### Data and Methods
For this project, I started a data set containing 3,119 unlabeled, spatially registered .swc files representing neurons in the larval zebrafish. These neurons were then labeled into 10 different classes and converted into AGs using the natverse suite of R libraries for dealing with neuron morphologies in the accompanying R script. The labels were generated using the NBLAST algortithm and heirarchical clustering implimented in `nat.nblast`. This algorithm has been used with great success to cluster neurons into different cell types in numerous previous studies, and I am considering the 10 clusters identified with NBLAST to be ground truth. It is worth noting, however, that the choice of these 10 clusters does not map onto any morphologically meaningful division of these cells (though cells within labels are more "similar" than cells acoross labels), this project was simply an exploritory analysis to see if the GCN could learn to classify cells with non-arbitrary labels. These resluting graphs and their labels were then read into python, where I added a self edge to each node and transformed into the format accepted by `DeepGraphLibrary` (more on that in a second). Performing this preprocessing below...

In [2]:
import numpy as np
import pandas as pd
import os
import random

import networkx as nx 
import tensorflow as tf
import dgl
from sklearn.preprocessing import OneHotEncoder

from classifier import GCN

#  Set paths to data and labels
data_path = "/home/zack/Desktop/Lab_Work/Data/neuron_morphologies/Zebrafish/aligned_040120/Zbrain_neurons_graphs"
lbl_path = "/home/zack/Desktop/Lab_Work/Data/neuron_morphologies/Zebrafish/aligned_040120/test_NBLAST_labels.csv"

# Read in the data
graph_list = []
lbl_list = []
name_list = []

# Read in the labels and remove the file extension from the names
lbls_df = pd.read_csv(lbl_path, index_col=0)
lbls_df.index = list(map(lambda x : os.path.splitext(x)[0], lbls_df.index))

# 
dir_obj = os.fsencode(data_path)
for file in os.listdir(dir_obj):
    filename = os.fsdecode(file)
    file_path = os.path.join(data_path, filename)
    
    if os.path.isdir(file_path) == False:
        # Load the graph as an nx_graph and get the node attributes for
        # conversion into a DGL graph
        nx_graph = nx.read_gml(file_path)
        nx_atbs = list(nx_graph.nodes.data())
        num_nodes = len(nx_atbs)
        node_atbs = np.zeros((num_nodes, 4))
        
        # Aggregate all the node attributes into one numpy array
        for idx, node in enumerate(nx_atbs):
            node_atbs[idx, 0] = node[1]['X']
            node_atbs[idx, 1] = node[1]['Y']
            node_atbs[idx, 2] = node[1]['Z']
            node_atbs[idx, 3] = node[1]['diam']
            
        # Add "self" edges so each node will be included in its own convolution
        nx_graph.add_edges_from(zip(nx_graph.nodes(), nx_graph.nodes()))
        
        # Create the DGL graph with the node attributes
        dgl_graph = dgl.DGLGraph()
        dgl_graph.from_networkx(nx_graph)
        dgl_graph.ndata['data'] = tf.convert_to_tensor(node_atbs, 
                                                       dtype=tf.float32)
        
        # Add all the elements to lists
        graph_list.append(dgl_graph)
        lbl_list.append(lbls_df.loc[filename,"nblast_cluster"])
        name_list.append(filename)

For the actual classification of the cells, I chose to use a graph convolutional neural network (GCN) implimented with `DeepNetworkLibrary` (dgl) and `TensorFlow` (tf). I chose to use the GCN because of its simplicity and track record in graph classification, particulaly with molecules encodded as graphs. The idea behind a GCN is similar to the of a traditional convolutional neural network for image classification. However, rather than convolving some function over groups of nearby pixels in an image, the GCN convolves over the features of "nearby" nodes in the network. Repeating this operation many times in series can allow the network to learn very detailed combinations of local features that map onto the classes it is trying to learn. If we denote a graph as $G$, then mathematically a single graph convolutional layer is trying to learn a function such that $h^{l+1}=f(h^l,G)$, where $h^l$ is the convolved node attribute vector for node $l$. There are many different ways to formulate $f(\cdot)$ such that it can be tuned by the GCN. I chose to use dgl's base implimentation of $f(\cdot)$, which leads to the function $h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})$. Here, $\mathcal{N}(i)$ is the neighborhood set of node $i$, $c_{ij}$ is a normalization constant given by $c_{ij} = \sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}$, $\sigma$ is an activation function, $b^l$ is a learnable bias, and $W^{(l)}$ is a learnable weight matrix. When a graph is put through a layer, this function is iterated over the $h^l$ of every node in the graph in order to form the updated $h^{l+1}$. Notice that this means that $N$ consecutive convolution layers allows information to be considered from $N^{th}$ order neighbors of any given node. 

For my network, I used 3 consecutive graph convolution layers as described above. For my activation function, I used a leaky RELU activation (tradional RELU was leading to many "dead" neurons). These layers were followed by an averaging of the final $h$ node attribute accross all nodes in the graph. The resulting vector was then fed into a single dense layer made up of 10 nodes (one for each class). That network architecture is implimented in `classifier.py`. I then created a 70/15/15 train/validate/test split and trained the network using backpropogation with categorical crossentropy as my loss. Looking at the results below...

In [4]:
# Format the data for training
lbl_arr = np.asarray(lbl_list)[:,np.newaxis]
enc = OneHotEncoder(sparse=False)
lbl_arr = enc.fit_transform(lbl_arr)

combined_list = list(zip(name_list, graph_list, lbl_arr))
random.shuffle(combined_list)
split_1 = int(0.7*len(combined_list))
split_2 = int(0.15*len(combined_list))
train_data = combined_list[:split_1]
val_data = combined_list[split_1:split_1+split_2]
test_data = combined_list[split_1+split_2:]

# Train the model
model = GCN(4, 32, 10)
loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
EPOCHS = 20
PATIENCE = 2

train_loss_list = []
val_loss_list = []

print("Training...")
for epoch in range(EPOCHS):
    epoch_train_loss = 0
    epoch_val_loss = 0
    for (name, graph, lbl) in train_data:
        with tf.GradientTape() as tape:
            lbl = lbl.reshape(1,10)
            prediction = model(graph)
            loss = loss_func(lbl, prediction)
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            epoch_train_loss += loss
            
    print('Epoch {}, total training loss {:.4f}'.format(epoch, epoch_train_loss))
    print("Calculating validation loss...")
    for (name, graph, lbl) in val_data:
        lbl = lbl.reshape(1,10)
        prediction = model(graph)
        loss = loss_func(lbl, prediction)
        epoch_val_loss += loss
    
    print('Epoch {}, total validation loss {:.4f}'.format(epoch, epoch_val_loss))
    val_loss_list.append(epoch_val_loss)
    if epoch >= PATIENCE:
        if max(val_loss_list[epoch-PATIENCE:])<epoch_val_loss:
            print('Training stopped on epoch {}'.format(epoch))
            break
    
    # Can't use standard ways of saving a tf model since this is a custom model.
    # Normally this means you have to set the input shape manually and then you
    # can save, but in this case the input shape is variable. I will have to figure
    # this out in the future...
    
    # model_path = '/home/zack/Documents/AM216/final_project/best_model/1/'
    # tf.saved_model.save(model, model_path)
    
# Test the model
true_lbls = np.asarray([element[2] for element in test_data])
model_predictions = np.zeros((len(test_data),10))

test_loss = 0
for i, (name, graph, lbl) in enumerate(test_data):
    lbl = lbl.reshape(1,10)
    prediction = model(graph)
    loss = loss_func(lbl, prediction)
    test_loss += loss
    model_predictions[i,np.argmax(prediction)] = 1
    
results = true_lbls+model_predictions
num_correct = results[results==2].shape[0]
total_num = len(test_data)
acc = num_correct/total_num

print("Test loss was {:.4f} for {:.2f} percent accuracy".format(test_loss, 100*acc))

Training...
Epoch 0, total training loss 9919.8867
Calculating validation loss...
Epoch 0, total validation loss 906.4484
Epoch 1, total training loss 4162.9302
Calculating validation loss...
Epoch 1, total validation loss 794.7467
Epoch 2, total training loss 3265.9395
Calculating validation loss...
Epoch 2, total validation loss 765.2902
Epoch 3, total training loss 2913.6863
Calculating validation loss...
Epoch 3, total validation loss 677.8873
Epoch 4, total training loss 2729.7449
Calculating validation loss...
Epoch 4, total validation loss 535.4236
Epoch 5, total training loss 2586.6021
Calculating validation loss...
Epoch 5, total validation loss 510.5650
Epoch 6, total training loss 2477.9734
Calculating validation loss...
Epoch 6, total validation loss 488.9573
Epoch 7, total training loss 2400.1892
Calculating validation loss...
Epoch 7, total validation loss 489.0369
Epoch 8, total training loss 2336.7012
Calculating validation loss...
Epoch 8, total validation loss 482.826

### Results and Discussion
72.92 percent accuracy across 10 classes is extremely promising given the simplicity of the model outlined above, especially when considering the accuracy of previous cell type classification results which tend to hover around the 75-90% range. Two things in particular could be greatly improved upon in future. First, the network architecture could be much more sophisticated. This includes adding more layers to model in order to consider higher order neighbors for each node, using more sophisticated convolution functions, and adding more dense layers to the network. Second, the step of averaging all the $h$ nodes of the graph in order to get a single fixed length vector representation of the graph could be made more complex to allow for richer representations of the graph, though this step must remain differentiable in order to allow for training using backpropogation. Together, these adjustments could greatly improve upon the accuracy of the simple model used here.

One point of skepticism that I have, however, is regarding the difficulty of the classification task at hand. Clustering using NBLAST is known to heavily weight the spatial location of the neurons compared to some more local geometric features. Therefore, when clustering 3,119 neurons spanning the entire brain of the larval zebrafish, it is possible that NBLAST simply divided the neurons into 10 groups based on their soma location, or some other purely spatial feature. It could be the case that the convolutional layers of the above GCN are really doing nothing more calculating the an approximate average of the node features, 3 of 4 of which are the spatial coordinates of the node, and then the dense layer is doing all the heavy lifting for classification. If that is the case, then the above GCN might generalize well to neuron classification tasks that are spatially seperable but do terribly when expected to learn classification paradigms that are spatially similar but geometrically differentiated. In order to test whether a GCN can adapt to the latter situation, it should be trained and tested on a dataset for which the neurons are all in the same region of the brain and are differentiable only by branching patterns or other more geometric features. It is worth noting, however, that the classifier above was trained and tested previously with only a single convolutional layer and showed far worse accuracy, which may be an indication that the convolutional layers are doing something useful beyond simply averaging the spatial location of the nodes.

### Conclusion
The results obtained above are promising, especially within the context of neuron classification tasks. However, these results were obained using data with synthetic labels and it is unclear how well they might generalize to real world classification tasks. Additionally, the method used here may struggle when faced with classification tasks that can only be differentiated geometrically. Still, when considered as preliminary results for the efficacy of GNNs for classification of neuron morphologies represented using AGs, the results here seem to warrant further investigation. One might start by comparing the classification accuracy for this dataset with the synthetic labels using more traditional classification methods to the method used here.