In [24]:
from caveclient import CAVEclient
import numpy as np
import pandas as pd
from time import sleep
from requests.exceptions import HTTPError

class CAVE:

    @staticmethod
    def get_cell_type(pt_root_id, cell_df):
        try:
            return cell_df[cell_df['pt_root_id'] == pt_root_id]['cell_type'].to_list()[0]
        except KeyError:
            return "Unknown"
        except IndexError:
            return "Unknown"
        
    @staticmethod
    def rescale_position(position):
        scale_vector = np.array([4/1000, 4/1000, 40/1000])
        return position * scale_vector

    def __init__(self, version='v343'):
        self.version = version
        self.client = CAVEclient(f'minnie65_public_{version}')

    def download_cells(self, filter_dict):
        cell_df = self.client.materialize.query_table('aibs_soma_nuc_metamodel_preds_v117',
                                                      filter_in_dict = filter_dict,
                                                      select_columns=['pt_root_id', 'cell_type', 'pt_position'])
        
        cell_df['pt_position'] = cell_df['pt_position'].apply(CAVE.rescale_position)
        position_df = cell_df['pt_position'].apply(pd.Series)
        position_df.columns = ['pt_x', 'pt_y', 'pt_z']

        cell_df = cell_df.drop('pt_position', axis=1)
        cell_df = pd.concat([cell_df, position_df], axis=1)

        # remove rows with identical pt_root_id
        cell_df = cell_df.drop_duplicates(subset=['pt_root_id'])
        return cell_df

    def download_excitatory_cells(self):
        filter_dict = {'cell_type': ['23P', '4P', '5P-IT', '5P-ET', '5P-NP', '6P-IT', '6P-CT']}
        return self.download_cells(filter_dict)

    def download_inhibitory_cells(self):
        filter_dict = {'cell_type': ['BC', 'MC', 'BPC', 'NGC']}
        return self.download_cells(filter_dict)

    def download_synapses(self, filter_dict, cell_df=None):

        syn_df = self.client.materialize.query_table('synapses_pni_2',
                                                    filter_in_dict = filter_dict,
                                                    select_columns=['id', 'pre_pt_root_id', 'post_pt_root_id', 'ctr_pt_position', 'size'])
        
        if cell_df is not None:
            syn_df['cell_type_pre'] = syn_df['pre_pt_root_id'].apply(lambda x: CAVE.get_cell_type(x, cell_df))
            syn_df['cell_type_post'] = syn_df['post_pt_root_id'].apply(lambda x: CAVE.get_cell_type(x, cell_df))
        else:
            syn_df['cell_type_pre'] = 'Unknown'
            syn_df['cell_type_post'] = 'Unknown'
        
        syn_df['ctr_pt_position'] = syn_df['ctr_pt_position'].apply(CAVE.rescale_position)
        position_df = syn_df['ctr_pt_position'].apply(pd.Series)
        position_df.columns = ['ctr_pt_x', 'ctr_pt_y', 'ctr_pt_z']

        syn_df = syn_df.drop('ctr_pt_position', axis=1)
        syn_df = pd.concat([syn_df, position_df], axis=1)

        return syn_df
    
    def download_input_synapses(self, post_pt_root_ids, cell_df=None):
        if type(post_pt_root_ids) == int:
            post_pt_root_ids = [post_pt_root_ids]

        filter_dict = {'post_pt_root_id': post_pt_root_ids}
        syn_df = self.download_synapses(filter_dict, cell_df)

        if len(syn_df) >= 500000:
            chunk_1 = post_pt_root_ids[:len(post_pt_root_ids)//2]
            filter_dict_1 = {'post_pt_root_id': chunk_1}
            chunk_2 = post_pt_root_ids[len(post_pt_root_ids)//2:]
            filter_dict_2 = {'post_pt_root_id': chunk_2}
            syn_df_1 = self.download_synapses(filter_dict_1, cell_df)
            syn_df_2 = self.download_synapses(filter_dict_2, cell_df)
            syn_df = pd.concat([syn_df_1, syn_df_2], axis=0)

        return syn_df

    def download_output_synapses(self, pre_pt_root_ids, cell_df=None):
        if type(pre_pt_root_ids) == int:
            pre_pt_root_ids = [pre_pt_root_ids]

        filter_dict = {'pre_pt_root_id': pre_pt_root_ids}
        syn_df = self.download_synapses(filter_dict, cell_df)

        if len(syn_df) >= 500000:
            chunk_1 = pre_pt_root_ids[:len(pre_pt_root_ids)//2]
            filter_dict_1 = {'pre_pt_root_id': chunk_1}
            chunk_2 = pre_pt_root_ids[len(pre_pt_root_ids)//2:]
            filter_dict_2 = {'pre_pt_root_id': chunk_2}
            syn_df_1 = self.download_synapses(filter_dict_1, cell_df)
            syn_df_2 = self.download_synapses(filter_dict_2, cell_df)
            syn_df = pd.concat([syn_df_1, syn_df_2], axis=0)
        
        return syn_df

    def download_input_synapses_list(self, post_pt_root_ids, cell_df=None, timeout=600, chunk_size=150):
        num_chunks = int(np.ceil((len(post_pt_root_ids))/chunk_size))
        for chunk in range(num_chunks):
            chunk_ids = post_pt_root_ids[chunk*chunk_size:(chunk+1)*chunk_size]
            try:
                syn_df = self.download_input_synapses(chunk_ids, cell_df)
            except HTTPError:
                print(f"Chunk {chunk} failed, retrying")
                sleep(timeout)
                syn_df = self.download_input_synapses(chunk_ids, cell_df)
            yield syn_df, chunk_ids

    def download_output_synapses_list(self, pre_pt_root_ids, cell_df=None, timeout=600, chunk_size=750):
        num_chunks = int(np.ceil((len(pre_pt_root_ids))/chunk_size))
        for chunk in range(num_chunks):
            chunk_ids = pre_pt_root_ids[chunk*chunk_size:(chunk+1)*chunk_size]
            try:
                syn_df = self.download_output_synapses(chunk_ids, cell_df)
            except HTTPError:
                print(f"Chunk {chunk} failed, retrying")
                sleep(timeout)
                syn_df = self.download_output_synapses(chunk_ids, cell_df)
            yield syn_df, chunk_ids

In [25]:
client = CAVE()

In [3]:
inh_cells = client.download_inhibitory_cells()

Using deprecated pyarrow serialization method, please upgrade CAVEClient>=5.9.0 with pip install --upgrade caveclient


In [4]:
inh_cells.head()

Unnamed: 0,pt_root_id,cell_type,pt_x,pt_y,pt_z
0,864691135207734905,NGC,324.032,432.96,679.8
1,864691135758479438,NGC,309.568,421.12,706.0
2,864691135535098473,NGC,293.248,416.96,853.36
3,864691136143741847,MC,315.776,441.472,844.24
4,864691136745551332,MC,314.432,476.736,681.08


In [5]:
# check if any rows have the same pt_root_id
len(inh_cells) == len(inh_cells.drop_duplicates(subset=['pt_root_id']))

True

In [6]:
exc_cells = client.download_excitatory_cells()

Using deprecated pyarrow serialization method, please upgrade CAVEClient>=5.9.0 with pip install --upgrade caveclient


In [9]:
exc_cells.head()

Unnamed: 0,pt_root_id,cell_type,pt_x,pt_y,pt_z
0,864691135639004475,23P,284.544,442.112,808.8
1,864691135771677771,23P,290.304,434.624,811.64
2,864691135864089470,23P,318.528,485.824,670.16
3,864691135560505569,23P,320.512,496.0,662.52
4,864691136315868311,23P,320.576,507.712,664.88


In [10]:
# check if any rows have the same pt_root_id
len(exc_cells) == len(exc_cells.drop_duplicates(subset=['pt_root_id']))

True

In [11]:
inh_cells.to_csv('./data/inh_cells.csv', index=False)
exc_cells.to_csv('./data/exc_cells.csv', index=False)

In [12]:
inp_synapses = client.download_input_synapses(int(exc_cells.loc[100]['pt_root_id']), exc_cells)

Using deprecated pyarrow serialization method, please upgrade CAVEClient>=5.9.0 with pip install --upgrade caveclient


In [13]:
inp_synapses.head()

Unnamed: 0,id,pre_pt_root_id,post_pt_root_id,size,cell_type_pre,cell_type_post,ctr_pt_x,ctr_pt_y,ctr_pt_z
0,16569774,864691135315207470,864691136175025414,7420,Unknown,23P,368.84,458.104,725.8
1,31643589,864691135325253678,864691136175025414,572,Unknown,23P,415.096,450.32,766.76
2,16781211,864691135653107859,864691136175025414,2876,Unknown,23P,361.336,470.976,742.8
3,44934067,864691136122666321,864691136175025414,1808,Unknown,23P,450.768,461.84,775.36
4,9634201,864691134997035473,864691136175025414,1216,Unknown,23P,337.52,424.92,711.04


In [14]:
len(inp_synapses)

2118

In [16]:
inp_synapses[inp_synapses['cell_type_pre'] != 'Unknown'].head()

Unnamed: 0,id,pre_pt_root_id,post_pt_root_id,size,cell_type_pre,cell_type_post,ctr_pt_x,ctr_pt_y,ctr_pt_z
490,36476187,864691135991510209,864691136175025414,2328,23P,23P,429.464,430.808,778.52
717,17077104,864691136175025414,864691136175025414,632,23P,23P,359.936,466.608,749.12
741,37674975,864691135462736285,864691136175025414,1272,4P,23P,437.736,491.976,762.92
755,16396792,864691135396676641,864691136175025414,444,23P,23P,369.496,449.976,745.92
1002,44691599,864691135855971374,864691136175025414,2840,4P,23P,445.624,453.168,731.24


In [18]:
CAVE.get_cell_type(int(exc_cells.loc[0]['pt_root_id']), exc_cells)

'23P'

In [19]:
exc_cells[exc_cells['pt_root_id'] == int(exc_cells.loc[0]['pt_root_id'])]['cell_type'].item()

'23P'