In [2]:
import numpy as np
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
from core.Skeleton import Skeleton

from core.MicronsCAVE import CAVE

class AxonModel(Skeleton):

    @staticmethod
    def create_axon_score_dict(axon_models, branches, collapsed=True):
        score_dict = {}
        for branch in tqdm(branches):
            cell_seq = branch.cell_id_sequence['collapsed'] if collapsed else branch.cell_id_sequence['raw']
            pos_seq = branch.syn_pos_sequence['collapsed'] if collapsed else branch.syn_pos_sequence['raw']
            branch_scores = []
            for cell_id in cell_seq:
                scores = axon_models[cell_id].gmm.score_samples(pos_seq)
                branch_scores.append(scores)
            score_dict[branch.branch_id] = branch_scores
        
        return score_dict

    def __init__(self, cell_info, syn_group, syn_k=8, soma_k=8, twig_length=4, single_syn_std=5):
        self.single_syn_std = single_syn_std
        super().__init__(cell_info, syn_group, syn_k, soma_k)
        self.smooth(twig_length, prune_unknown=False)
        self.fit_gmm()

    def fit_gmm(self, min_path_length=4):
        paths = self.get_paths(smoothed=True, duplicate_tail=True)
        all_positions = []
        all_means = []
        all_precisions = []
        for path in paths:
            path_positions = np.array([self.smooth_mst.nodes[node]['pos'] for node in path])
            mean = np.mean(path_positions, axis=0)
            if len(path_positions) >= min_path_length:
                precision = np.linalg.inv(np.cov(path_positions.T))
            elif len(path_positions) == 1:
                precision = np.diag(1/(np.array(3*[self.single_syn_std]))**2)
            else:
                precision = np.diag(1/np.var(path_positions, axis=0))
            
            precision[np.isinf(precision)] = 1/(self.single_syn_std**2)
            all_positions.append(path_positions)
            all_means.append(mean)
            all_precisions.append(precision)

        all_positions = np.concatenate(all_positions, axis=0)
        gmm = GaussianMixture(n_components=len(all_means), covariance_type='full', 
                means_init=np.array(all_means),
                precisions_init=np.array(all_precisions))
        gmm.fit(all_positions)
        self.gmm = gmm

In [10]:
import pandas as pd

exc_cells = pd.read_csv('data/exc_cells.csv')
client = CAVE()
syn_table = client.download_output_synapses(int(exc_cells.loc[2000].pt_root_id), cell_df=exc_cells)
print(len(syn_table))
syn_table.head()

Using deprecated pyarrow serialization method, please upgrade CAVEClient>=5.9.0 with pip install --upgrade caveclient


49


Unnamed: 0,id,pre_pt_root_id,post_pt_root_id,size,cell_type_pre,cell_type_post,ctr_pt_x,ctr_pt_y,ctr_pt_z
0,33919160,864691136296739611,864691135677443844,11640,23P,Unknown,421.296,601.576,797.2
1,40159760,864691136296739611,864691136296739611,184,23P,23P,429.288,563.52,715.56
2,18024958,864691136296739611,864691135117914536,3356,23P,Unknown,360.6,616.12,779.68
3,53728319,864691136296739611,864691136118416408,4492,23P,5P-ET,469.48,623.096,772.28
4,43435552,864691136296739611,864691134744365116,5496,23P,Unknown,450.808,615.152,782.28


In [11]:
cell_info = exc_cells[exc_cells['pt_root_id']==exc_cells.loc[2000].pt_root_id]

axn = AxonModel(cell_info, syn_table, twig_length=4, single_syn_std=5)

In [12]:
axn.gmm

In [13]:
axn.get_paths(smoothed=True, duplicate_tail=True)

{(4, 3, 5, 10, 45, 47, 37),
 (16,
  31,
  42,
  28,
  0,
  12,
  7,
  40,
  21,
  22,
  11,
  34,
  29,
  25,
  9,
  13,
  19,
  35,
  15,
  18,
  46),
 (30, 27, 32, 43, 41, 6, 48, 20, 2, 26),
 (30, 38, 24, 14, 17, 1),
 (36, 44, 39, 23, 8, 33, 30)}