In [None]:
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib notebook
%matplotlib widget

import sys
sys.path.append('/home/dtward/data/csh_data/emlddmm')
import emlddmm
import csv
from skimage.measure import marching_cubes
from glob import glob
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from os.path import split,join,splitext

import pandas as pd
from scipy.stats import multivariate_normal
import os

In [None]:
ontology_name = '/nafs/dtward/dong/upenn_atlas/atlas_info_KimRef_FPbasedLabel_v2.7.csv'
seg_name = '/nafs/dtward/dong/upenn_atlas/UPenn_labels_reoriented_origin.vtk'

### Load UPenn ontology + Generate lists of descendents

In [None]:
xS,S,_,_ = emlddmm.read_data(seg_name)

In [None]:
# Transforms xS into a set of coordinates
XS = np.stack(np.meshgrid(*xS,indexing='ij'),-1)

In [None]:
# Center of mass of region 2294 (Definition of the 1st moment of the indicator function for label 2294)
# np.sum(XS*(S[0,...,None]==2294))/np.sum(S==2294)

In [None]:
parent_column = 7 # 8 for allen, 7 for yongsoo
label_column = 0 # 0 for both
shortname_column = 2# 3 for allen, 2 for yongsoo
longname_column = 1# 2 for allen, 1 for yongsoo
ontology = dict()
with open(ontology_name) as f:
    csvreader = csv.reader(f, delimiter=',', quotechar='"')
    count = 0
    for row in csvreader:        
        if count == 0:
            headers = row
            print(headers)
        else:
            if not row[parent_column]:
                parent = -1
            else:
                parent = int(row[parent_column])
            ontology[int(row[label_column])] = (row[shortname_column],row[longname_column],parent)
        count += 1


In [None]:
# we need to find all the descendants of a given label
# first we'll get children
children = dict()
for o in ontology:
    parent = ontology[o][-1]
    if parent not in children:
        children[parent] = []
    children[parent].append(o)

In [None]:
# now we go from children to descendents
descendents = dict(children)
for o in descendents:
    for child in descendents[o]:
        if child in descendents: # if I don't do this i get a key error 0
            descendents[o].extend(descendents[child])
descendents[0] = []

In [None]:
descendents_and_self = dict(descendents)
for o in ontology:
    if o not in descendents_and_self:
        descendents_and_self[o] = [o]
    else:
        descendents_and_self[o].append(o)

In [None]:
ontology

## Generate Boolean masks for CP, CPr (+CPre), CPi, CPc (+CPce)

In [None]:
# find all the descendents of Caudoputemen- rostral
# and caudoputemen rostral- extreme

In [None]:
# these are all the structures in the CP
cp = list(descendents_and_self[672])
Scp = np.zeros_like(S)
for l in cp:
    Scp = np.logical_or(Scp,S==l)

In [None]:
# these are the structures one level below CP, with extreme's merged in
rostral = list(descendents_and_self[2376]) # rostral extreme
rostral.extend(list(descendents_and_self[2491])) # rostral 
rostral = list(dict.fromkeys(rostral))

intermediate = list(descendents_and_self[2492]) # intermediate

caudal = list(descendents_and_self[2496]) # caudal
caudal_ = list(descendents_and_self[2495]) # caudal extreme
caudal.extend(caudal_)

In [None]:
Srostral = np.zeros_like(S)
for l in rostral:
    Srostral = np.logical_or(Srostral, S==l)

In [None]:
Sintermediate = np.zeros_like(S)
for l in intermediate:
    Sintermediate = np.logical_or(Sintermediate, S==l)

In [None]:
Scaudal = np.zeros_like(S)
for l in caudal:
    Scaudal = np.logical_or(Scaudal, S==l)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.sum(Srostral>0,axis=(0,2,3)),label='r')
ax.plot(np.sum(Sintermediate>0,axis=(0,2,3)),label='i')
ax.plot(np.sum(Scaudal>0,axis=(0,2,3)),label='c')
ax.legend()

In [None]:
# TODO
# add one more level down the tree

def generate_region_mask(S, regionID):
    self_and_children = list(descendents_and_self[regionID])
    region_mask = np.zeros_like(S)
    for i in self_and_children:
        region_mask = np.logical_or(region_mask, S==i)
    return region_mask

# === Striatum Dorsal Region (STRd) ===

# CP regions
Scpr_m   = generate_region_mask(S,2294)
Scpr_imd = generate_region_mask(S,2295)
Scpr_imv = generate_region_mask(S,2296)
Scpr_l   = generate_region_mask(S,2497)
Scpi_dm  = generate_region_mask(S,2498)
Scpi_vm  = generate_region_mask(S,2500)
Scpi_dl  = generate_region_mask(S,2499)
Scpi_vl  = generate_region_mask(S,2501)
Scpc_d   = generate_region_mask(S,2493)
Scpc_i   = generate_region_mask(S,2494)
Scpc_v   = generate_region_mask(S,2490)

# Non-CP regions
S_lss = generate_region_mask(S,2001)
S_ast = generate_region_mask(S,2050)

# === Striatum Ventral Region (STRv) ===
S_acb  = generate_region_mask(S,56)  # Accumbens nucleus
S_ipac = generate_region_mask(S,998) # Interstitial nucleus of the posterior limb of the anterior commissure
S_tu   = generate_region_mask(S,754) # Olfactory tubercle

# === (10/3/23) Striatum adjacent regions ===
S_lsx = generate_region_mask(S,275)  # Lateral septal complex
S_samy = generate_region_mask(S,278) # Striatum-like amygdalar nuclei
S_pal = generate_region_mask(S,803)  # Pallidum

# === (10/4/23) Striatum adjacent regions ===
S_lv =   generate_region_mask(S,81)  # Lateral ventricle
S_cc =   generate_region_mask(S,776) # Corpus callosum
S_cl =   generate_region_mask(S,583) # Claustrum
S_en =   generate_region_mask(S,942) # Endopiriform nucleus
S_pir =  generate_region_mask(S,961) # Piriform nucleus
S_ic =   generate_region_mask(S,6)   # Internal capsule
S_mfbc = generate_region_mask(S,768) # Cerebrum related (fiber tracts)

In [None]:
print(f'N (Scp): {np.sum(Scp):,}')
print(f'N (Scpr): {np.sum(Srostral):,}')
print(f'N (Scpi): {np.sum(Sintermediate):,}')
print(f'N (Scpc): {np.sum(Scaudal):,}')

print(f'N (Scpr_m): {np.sum(Scpr_m):,}')
print(f'N (Scpr_imd): {np.sum(Scpr_imd):,}')
print(f'N (Scpr_imv): {np.sum(Scpr_imv):,}')
print(f'N (Scpr_l): {np.sum(Scpr_l):,}')

print(f'N (Scpi_dm): {np.sum(Scpi_dm):,}')
print(f'N (Scpi_vm): {np.sum(Scpi_vm):,}')
print(f'N (Scpi_dl): {np.sum(Scpi_dl):,}')
print(f'N (Scpi_vl): {np.sum(Scpi_vl):,}')

print(f'N (Scpc_d): {np.sum(Scpc_d):,}')
print(f'N (Scpc_i): {np.sum(Scpc_i):,}')
print(f'N (Scpc_v): {np.sum(Scpc_v):,}')

print(f'N (S_lss): {np.sum(S_lss):,}')
print(f'N (S_ast): {np.sum(S_ast):,}')

print(f'N (S_acb): {np.sum(S_acb):,}')
print(f'N (S_ipac): {np.sum(S_ipac):,}')
print(f'N (S_tu): {np.sum(S_tu):,}')

print(f'N (S_lsx): {np.sum(S_lsx):,}')
print(f'N (S_samy): {np.sum(S_samy):,}')
print(f'N (S_pal): {np.sum(S_pal):,}')

print(f'N (S_lv): {np.sum(S_lv):,}')
print(f'N (S_cc): {np.sum(S_cc):,}')
print(f'N (S_cl): {np.sum(S_cl):,}')
print(f'N (S_en): {np.sum(S_en):,}')
print(f'N (S_pir): {np.sum(S_pir):,}')
print(f'N (S_ic): {np.sum(S_ic):,}')
print(f'N (S_mfbc): {np.sum(S_mfbc):,}')


In [None]:
# I think what I'd like to do is assign a gaussian to each region
# then give the neurons a distribution based on the Gaussian
# to do this I should load a set of neurons
# and also start visualizing

# what I think would make sense is to take all the cp structures
# blur them a lot
# then assign probabilities
# we also want to look at the neurons though

In [None]:
# i want to start by visualizing
# we need to load swc files
# and we need to contour the surfaces
d = [x[1] - x[0] for x in xS]


In [None]:
down = 16
xSd,Scpd = emlddmm.downmode(xS,Scp[0],down=[down,down,down])

In [None]:
dd = [x[1] - x[0] for x in xSd]

In [None]:
verts,faces,normals,values = marching_cubes(Scpd,level=0.5,spacing=dd)
verts = verts + np.array([x[0] for x in xSd])

## Load SWC files and Display 2x2 figure

In [None]:
# swcdir = '../swc_out_v08' # this is tme07
# swcdir = '../dragonfly_tme09-1/swc_out_v08'
# swcdir = '/home/abenneck/dragonfly_work/dragonfly_outputs/TME08-1/dragonfly_joint_outputs'

brain = 'TME08-1'
swcdir = f'/home/abenneck/dragonfly_work/dragonfly_outputs/{brain}/dragonfly_joint_outputs'
files = glob(join(swcdir,'*.swc'))
files = [f for f in files if 'permuted' not in f]
files

In [None]:
x = []
for file in files:
    with open(file)  as f:
        for i,line in enumerate(f):    
            print(line)
            if 'Tward' in line:                
                continue
            else:
                if ',' in line:
                    lim = ','
                else:
                    lim = ''
                coords = [float(c) for c in line.split(lim)[2:5]]                
                    
                x.append(coords)
                break
    
            
x = np.array(x)            
            

### Generate + Save figure

In [None]:
s = 5
alpha = 0.25
lw = 0.25


mesh = Poly3DCollection(verts[faces],ec=[0.0,0.0,0.0,0.1],lw=lw,alpha=alpha,)
fig = plt.figure()
ax = fig.add_subplot(2,2,1,projection='3d')
ax.add_collection3d(mesh)
# set limits uniform
vertsmin = np.min(verts,0)
vertsmax = np.max(verts,0)
vertsc = vertsmin*0.5 + vertsmax*0.5
vertsd = vertsmax-vertsmin
vertsd = np.max(vertsd)
lim = vertsc[None] + np.array([-1,1])[...,None]/2*vertsd
ax.set_xlim(lim[:,0])
ax.set_ylim(lim[:,1])
ax.set_zlim(lim[:,2])
ax.set_xlabel('x0')
ax.set_ylabel('x1')
ax.set_zlabel('x2')
ax.scatter(x[:,0],x[:,1],x[:,2],s=s)


ax = fig.add_subplot(2,2,2,projection='3d')
mesh = Poly3DCollection(verts[faces],ec=[0.0,0.0,0.0,0.1],lw=lw,alpha=alpha,)
ax.view_init(0,90)
ax.add_collection3d(mesh)
# set limits uniform
vertsmin = np.min(verts,0)
vertsmax = np.max(verts,0)
vertsc = vertsmin*0.5 + vertsmax*0.5
vertsd = vertsmax-vertsmin
vertsd = np.max(vertsd)
lim = vertsc[None] + np.array([-1,1])[...,None]/2*vertsd
ax.set_xlim(lim[:,0])
ax.set_ylim(lim[:,1])
ax.set_zlim(lim[:,2])
ax.set_xlabel('x0')
ax.set_ylabel('x1')
ax.set_zlabel('x2')
ax.scatter(x[:,0],x[:,1],x[:,2],s=s)



ax = fig.add_subplot(2,2,3,projection='3d')
mesh = Poly3DCollection(verts[faces],ec=[0.0,0.0,0.0,0.1],lw=lw,alpha=alpha,)
ax.view_init(0,0)
ax.add_collection3d(mesh)
# set limits uniform
vertsmin = np.min(verts,0)
vertsmax = np.max(verts,0)
vertsc = vertsmin*0.5 + vertsmax*0.5
vertsd = vertsmax-vertsmin
vertsd = np.max(vertsd)
lim = vertsc[None] + np.array([-1,1])[...,None]/2*vertsd
ax.set_xlim(lim[:,0])
ax.set_ylim(lim[:,1])
ax.set_zlim(lim[:,2])
ax.set_xlabel('x0')
ax.set_ylabel('x1')
ax.set_zlabel('x2')
ax.scatter(x[:,0],x[:,1],x[:,2],s=s)



ax = fig.add_subplot(2,2,4,projection='3d')
mesh = Poly3DCollection(verts[faces],ec=[0.0,0.0,0.0,0.1],lw=lw,alpha=alpha,)
ax.view_init(90,0)
ax.add_collection3d(mesh)
# set limits uniform
vertsmin = np.min(verts,0)
vertsmax = np.max(verts,0)
vertsc = vertsmin*0.5 + vertsmax*0.5
vertsd = vertsmax-vertsmin
vertsd = np.max(vertsd)
lim = vertsc[None] + np.array([-1,1])[...,None]/2*vertsd
ax.set_xlim(lim[:,0])
ax.set_ylim(lim[:,1])
ax.set_zlim(lim[:,2])
ax.set_xlabel('x0')
ax.set_ylabel('x1')
ax.set_zlabel('x2')
ax.scatter(x[:,0],x[:,1],x[:,2],s=s)

# fig.suptitle(swcdir.split('/')[-2])
# fig.savefig('CP_SWC_QC_figure_'+swcdir.split('/')[-2]+'.jpg')

fig.suptitle(f'{brain}')
# fig.savefig(join(f'/home/abenneck/dragonfly_work/dragonfly_outputs/{brain}/',f'CP_SWC_QC_figure_{brain}.jpg'))

In [None]:
# raise Exception(f'End of QC figure generation for {brain}')

In [None]:
# so what is a plan going forward
# I'd like to take every blob as a gaussian
# the height is its volume
# the mean is its mean
# the covariance is its covariance
# the only trouble here is the left right issue
# another posiblility is to just blur the labels
# this may get rid of small structures though
# but I could give less blur to the small structures

In [None]:
# first going forward
# 1. make a figure like this to get a sense of the uncertainty for each brain
# 2. make a probabilitistic version of the CP structures.
# I'd like to model each structure as a gaussian blob (ellispoids with soft boundaries)
# this needs three parameters
# the mean (a 3 element vector)
# the covariance (3x3 symmetric matrix)
# and the height/amplitude of the gaussian (one positive number)

# The height (amplitude) should be the number of voxels in the structure
# the mean, is going to be the first moment of the segmentation
# the covariance, is the second central moment of the segmentation.

### Compute + Save Gaussian parameters for CP

In [None]:
def generate_gaussian_param(subregion_mask, hemi, fname):
    if hemi == 'R':
        LR_indicator = (xS[1][:,None]>=0) # Right hemisphere
    else:
        LR_indicator = (xS[1][:,None]<0)  # Left hemisphere
    
    # ===== Compute Gaussian parameters =====
    
    # number of voxels in 'subregion'
    Ncp = np.sum(subregion_mask*LR_indicator)
    # print(f'number of voxels {Ncp}')
    
    # note xS tells us the location of each voxel
    # let's compute the first moment
    mucp = [np.sum(xS[0][:,None,None]*subregion_mask*LR_indicator)/Ncp, np.sum(xS[1][:,None]*subregion_mask*LR_indicator)/Ncp, np.sum(xS[2][:]*subregion_mask*LR_indicator)/Ncp]
    # print(f'mu {mucp}')
    
    # now calculate the covariance matrix
    covcp01 = np.sum((xS[0][:,None,None] - mucp[0])*(xS[1][:,None] - mucp[1])*subregion_mask*LR_indicator)/Ncp # here row 0 column 1
    covcp02 = np.sum((xS[0][:,None,None] - mucp[0])*(xS[2][None,:] - mucp[2])*subregion_mask*LR_indicator)/Ncp # here row 0 column 2
    covcp12 = np.sum((xS[1][None,:,None] - mucp[1])*(xS[2][None,:] - mucp[2])*subregion_mask*LR_indicator)/Ncp # here row 1 column 2
    
    covcp00 = np.sum(((xS[0][:,None,None] - mucp[0])**2)*subregion_mask*LR_indicator)/Ncp # here row 0 column 0
    covcp11 = np.sum(((xS[1][None,:,None] - mucp[1])**2)*subregion_mask*LR_indicator)/Ncp # here row 1 column 1
    covcp22 = np.sum(((xS[2][None,None,:] - mucp[2])**2)*subregion_mask*LR_indicator)/Ncp # here row 2 column 2
    
    covcp = [[covcp00,covcp01,covcp02],[covcp01,covcp11,covcp12],[covcp02,covcp12,covcp22]]
    
    # print('cov:')
    # print(f'{covcp[0]}')
    # print(f'{covcp[1]}')
    # print(f'{covcp[2]}')
    
    # Save Gaussian parameters in.npz file with keys [height, mu, cov]
    save_path = join('/home/abenneck/dragonfly_work/gaussian_parameters/',fname)
    np.savez(save_path,height=[Ncp],mu=mucp,cov=covcp)
    print(f'Saved {fname}\n')

genParam = False
if genParam:
    for hemi in ['L','R']:
        generate_gaussian_param(Scp, hemi,f'cp_param_{hemi}.npz')
        generate_gaussian_param(Srostral, hemi,f'cpr_param_{hemi}.npz')
        generate_gaussian_param(Sintermediate, hemi,f'cpi_param_{hemi}.npz')
        generate_gaussian_param(Scaudal, hemi,f'cpc_param_{hemi}.npz')
        
        generate_gaussian_param(Scpr_m, hemi,f'cpr_m_param_{hemi}.npz')
        generate_gaussian_param(Scpr_imd, hemi,f'cpr_imd_param_{hemi}.npz')
        generate_gaussian_param(Scpr_imv, hemi,f'cpr_imv_param_{hemi}.npz')
        generate_gaussian_param(Scpr_l, hemi,f'cpr_l_param_{hemi}.npz')
        
        generate_gaussian_param(Scpi_dm, hemi,f'cpi_dm_param_{hemi}.npz')
        generate_gaussian_param(Scpi_vm, hemi,f'cpi_vm_param_{hemi}.npz')
        generate_gaussian_param(Scpi_dl, hemi,f'cpi_dl_param_{hemi}.npz')
        generate_gaussian_param(Scpi_vl, hemi,f'cpi_vl_param_{hemi}.npz')
        
        generate_gaussian_param(Scpc_d, hemi,f'cpc_d_param_{hemi}.npz')
        generate_gaussian_param(Scpc_i, hemi,f'cpc_i_param_{hemi}.npz')
        generate_gaussian_param(Scpc_v, hemi,f'cpc_v_param_{hemi}.npz')

        generate_gaussian_param(S_lss, hemi, f'lss_param_{hemi}.npz')
        generate_gaussian_param(S_ast, hemi, f'ast_param_{hemi}.npz')
        
        generate_gaussian_param(S_acb, hemi, f'acb_param_{hemi}.npz')
        generate_gaussian_param(S_ipac, hemi, f'ipac_param_{hemi}.npz')
        generate_gaussian_param(S_tu, hemi, f'tu_param_{hemi}.npz')
        
        generate_gaussian_param(S_lsx, hemi, f'lsx_param_{hemi}.npz')
        generate_gaussian_param(S_samy, hemi, f'samy_param_{hemi}.npz')
        generate_gaussian_param(S_pal, hemi, f'pal_param_{hemi}.npz')
        
        generate_gaussian_param(S_lv, hemi,f'lv_param_{hemi}.npz')
        generate_gaussian_param(S_cc, hemi,f'cc_param_{hemi}.npz')
        generate_gaussian_param(S_cl, hemi,f'cl_param_{hemi}.npz')
        generate_gaussian_param(S_en, hemi,f'en_param_{hemi}.npz')
        generate_gaussian_param(S_pir, hemi,f'pir_param_{hemi}.npz')
        generate_gaussian_param(S_ic, hemi,f'ic_param_{hemi}.npz')
        generate_gaussian_param(S_mfbc, hemi,f'mfbc_param_{hemi}.npz')

### Load Gaussian parameters given hemi and subregion

In [None]:
# hemi = 'L'
hemi = 'R'
subregion = ''

fname = f'cpc_v_param_{hemi}.npz'
fpath = join(f'/home/abenneck/dragonfly_work/gaussian_parameters/',fname)
out = np.load(fpath)

h = out['height'][0]
mu = out['mu']
cov = out['cov']

print(f'Displating parameters for {fname};\n')
print(f'height: {h}')
print(f'mu: {mu}')
print(f'cov: {cov}')

In [None]:
# once we have parameters for each gaussian
# then we can evalute them at every point in space
# so for a given cell body, we can find a distribution over this number of regions
# for each level of the tree, you can get a distribution over all the structures in that level
# start with R-I-C level
# TODO, come up with some measure of concordance
#   that is related to nick's labels, but allows some slop

In [None]:
# we don't really believe the uncertaint is given by the voxel size (we're not getting 1 voxel accurate registration)
# we believe it is related to the anatomy

### Compute distance transform for each mask

In [None]:
from scipy.ndimage import distance_transform_edt
import time

def generateDistanceTransform(xS, mask, hemi, fname, save_path = '', scale_factor = 10):
    start = time.time()
    if hemi == 'R':
        LR_indicator = (xS[1][:,None]>=0) # Right hemisphere
    else:
        LR_indicator = (xS[1][:,None]<0)  # Left hemisphere

    if save_path == '':
        save_path = '/home/abenneck/dragonfly_work/distance_parameters/'

    transformedMask = np.exp(-distance_transform_edt(1-mask*LR_indicator)/scale_factor)
    # transformedMask[transformedMask == 1.0] = 0.0
    transformedMask = transformedMask / np.sum(transformedMask)
    
    # Save distance transform parameters in .npz file with keys [mask]
    save_path = join(save_path,fname)
    np.savez(save_path, mask = transformedMask.astype(np.float32))
    print(f'Saved {fname} after {time.time() - start : .2f} s')
    
    return transformedMask

saveDistanceParam = False
if saveDistanceParam:
    for hemi in ['L','R']:
        generateDistanceTransform(xS, Scp, hemi, f'cp_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Srostral, hemi, f'cpr_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Sintermediate, hemi, f'cpi_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scaudal, hemi, f'cpc_dist_param_{hemi}.npz')
    
        generateDistanceTransform(xS, Scpr_m, hemi, f'cpr_m_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpr_imd, hemi, f'cpr_imd_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpr_imv, hemi, f'cpr_imv_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpr_l, hemi, f'cpr_l_dist_param_{hemi}.npz')
    
        generateDistanceTransform(xS, Scpi_dm, hemi, f'cpi_dm_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpi_vm, hemi, f'cpi_vm_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpi_dl, hemi, f'cpi_dl_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpi_vl, hemi, f'cpi_vl_dist_param_{hemi}.npz')
    
        generateDistanceTransform(xS, Scpc_d, hemi, f'cpc_d_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpc_i, hemi, f'cpc_i_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, Scpc_v, hemi, f'cpc_v_dist_param_{hemi}.npz')

        generateDistanceTransform(xS, S_lss, hemi, f'lss_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_ast, hemi, f'ast_dist_param_{hemi}.npz')
        
        generateDistanceTransform(xS, S_acb, hemi, f'acb_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_ipac, hemi, f'ipac_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_tu, hemi, f'tu_dist_param_{hemi}.npz')
        
        generateDistanceTransform(xS, S_lsx, hemi, f'lsx_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_samy, hemi, f'samy_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_pal, hemi, f'pal_dist_param_{hemi}.npz')

        generateDistanceTransform(xS, S_lv, hemi, f'lv_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_cc, hemi, f'cc_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_cl, hemi, f'cl_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_en, hemi, f'en_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_pir, hemi, f'pir_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_ic, hemi, f'ic_dist_param_{hemi}.npz')
        generateDistanceTransform(xS, S_mfbc, hemi, f'mfbc_dist_param_{hemi}.npz')

### Load dist param

In [None]:
out = np.load('/home/abenneck/dragonfly_work/distance_parameters/cp_dist_param_L.npz')
out = out['mask']

x, y, z = 100, 100, -100

print(out[0, x, y, z])

### Preprocessing helper functions

In [None]:
def L1_filter_gauss(fileName):
    if len(fileName) <= 15 and fileName != '.ipynb_checkpoints' and 'cp' in fileName:
        return True
    else:
        return False

def L2_filter_gauss(fileName):
    if len(fileName) > 15 and fileName != '.ipynb_checkpoints' and 'cp' in fileName:
        return True
    else:
        return False

def non_CP_filter_gauss(fileName):
    if not L2_filter_gauss(fileName) and not L1_filter_gauss(fileName) and fileName != '.ipynb_checkpoints':
        return True
    else:
        return False

def L1_filter_dist(fileName):
    if len(fileName) <= 21 and fileName != '.ipynb_checkpoints' and 'cp' in fileName:
        return True
    else:
        return False

def L2_filter_dist(fileName):
    if len(fileName) > 21 and fileName != '.ipynb_checkpoints' and 'cp' in fileName:
        return True
    else:
        return False

def non_CP_filter_dist(fileName):
    if not L2_filter_dist(fileName) and not L1_filter_dist(fileName) and fileName != '.ipynb_checkpoints':
        return True
    else:
        return False

def get_coord_list(region_id, XS, S):
    self_and_children = list(descendents_and_self[region_id])
    allCoords = list()
    for idNum in self_and_children:
        regionCoords = XS[S[0]==idNum]
        if len(regionCoords) != 0:
            for coord in regionCoords:
                allCoords.append(coord)
    return allCoords

### Probability computation helper functions

In [None]:
# Distance with vectorization
def dist_from_cp(allCoords, soma_location, useForLoop = False):
    if useForLoop:
        minDist = -1
        for c in allCoords:
            distance = math.dist(c,soma_location)
            if minDist == -1 or minDist > distance:
                minDist = distance
        return minDist        
    else:
        # out = allCoords-soma_location => [allCoords[0]-soma_location, allCoords[1]-soma_location, ...]
        # out = (out)**2                => Square all terms within out
        # out = np.sum(out, axis=1)     => [out[0][0]+out[0][1]+out[0][2], out[1][0]+out[1][1]+out[1][2], ...]
        # out = np.sqrt(out)            => sqrt all terms within out
        # out = np.min(out)             => min of all distances
        return np.min(np.sqrt(np.sum((allCoords-soma_location)**2, axis=1)))

# Round neuron coord to nearest atlas coord
def roundCoord(sub_xS, coord):
    if coord < np.min(sub_xS):
        coord = np.min(sub_xS)
    elif coord > np.max(sub_xS):
        coord = np.max(sub_xS)
    else:
        coord_rounded = 20*round(coord/20) # Round coordinate to nearest 20
        if coord < 0:
            if coord - coord_rounded >= 0:
                coord = coord_rounded + 10
            else:
                coord = coord_rounded - 10
        else:
            if coord - coord_rounded >= 0:
                coord = coord_rounded + 10
            else:
                coord = coord_rounded - 10
    return coord

def append_nonCP_regions(allProb):
    
    allProb.insert(allProb.columns.get_loc('CP_R')+1,'ACB0_L',allProb['ACB_L'])
    allProb.insert(allProb.columns.get_loc('ACB0_L')+1,'ACB0_R',allProb['ACB_R'])
    
    allProb.insert(allProb.columns.get_loc('CPr_R')+1,'ACB1_L',allProb['ACB_L'])
    allProb.insert(allProb.columns.get_loc('ACB1_L')+1,'ACB1_R',allProb['ACB_R'])
    
    allProb.insert(allProb.columns.get_loc('CPr_m_R')+1,'ACB2_L',allProb['ACB_L'])
    allProb.insert(allProb.columns.get_loc('ACB2_L')+1,'ACB2_R',allProb['ACB_R'])
    
    return allProb    

# Absolute probabilities => (Normalize probabilities at multiple levels) => Conditional probabilities
def normalize_prob(allProb):    
    L0_labels = ['CP_L','CP_R','ACB0_L','ACB0_R']
    allProb[L0_labels] = allProb[L0_labels].div(np.sum(allProb[L0_labels], axis=1), axis=0)

    L1_labels = ['CPc_L','CPc_R','CPi_L','CPi_R','CPr_L','CPr_R','ACB1_L','ACB1_R']
    allProb[L1_labels] = allProb[L1_labels].div(np.sum(allProb[L1_labels], axis=1), axis=0)
    
    L2_labels = ['CPc_d_L','CPc_d_R','CPc_i_L','CPc_i_R','CPc_v_L','CPc_v_R','CPi_dl_L','CPi_dl_R','CPi_dm_L','CPi_dm_R','CPi_vl_L','CPi_vl_R','CPi_vm_L','CPi_vm_R','CPr_imd_L','CPr_imd_R','CPr_imv_L','CPr_imv_R','CPr_l_L','CPr_l_R','CPr_m_L','CPr_m_R','ACB2_L','ACB2_R']
    allProb[L2_labels] = allProb[L2_labels].div(np.sum(allProb[L2_labels], axis=1), axis=0)

    nonCP_labels = ['ACB_L','ACB_R','AST_L','AST_R','CC_L','CC_R','CL_L','CL_R','EN_L','EN_R','IC_L','IC_R','IPAC_L','IPAC_R','LSS_L','LSS_R','LSX_L','LSX_R','LV_L','LV_R','MFBC_L','MFBC_R','PAL_L','PAL_R','PIR_L','PIR_R','SAMY_L','SAMY_R','TU_L','TU_R']
    allProb[nonCP_labels] = allProb[nonCP_labels].div(np.sum(allProb[nonCP_labels], axis=1), axis=0)

    return allProb

### Postprocessing helper functions

In [None]:
# Filter to remove the .ipynb_checkpoints file from lists
def brain_filter(dirName):
    if dirName == '.ipynb_checkpoints' or dirName == 'TME20-1':
        return False
    else:
        return True

# Return a pd.DataFrame (1 x len (df)) of the column name in df containing the corresponding value in vals for each row of df
def xlookup(vals, df, cName = ''):
    out = list()
    for row in df.iterrows():
        idx, row = row
        val = vals[idx]
        colName = row[row == val].keys()[0]
        out.append(colName)
    return pd.DataFrame(data=out, columns = [cName])

# Return a pd.Series (1 x len(df)) of the nth largest value in each row of df
def nmax(df, n):
    out = list()
    for row in df.iterrows():
        idx, row = row
        nth_largest = sorted(row, reverse=True)[n-1]
        out.append(nth_largest)
    return pd.Series(data=out)

def compute_top_n_maximums(data, all_labels, n=3):
    
    for i, labels in enumerate(all_labels):
        L0_df = data[labels]
        
        L0_P1 = L0_df.max(axis=1)
        L0_1 = xlookup(L0_P1, L0_df)
    
        if i == 0:
            leftmost_col_name = 'fname'
        else:
            leftmost_col_name = f'P(L{i-1},3)'
        
        data.insert(data.columns.get_loc(leftmost_col_name)+1,f'L{i} (1)',L0_1)
        data.insert(data.columns.get_loc(f'L{i} (1)')+1,f'P(L{i},1)',L0_P1)
        
        L0_P2 = nmax(L0_df, 2)
        L0_2 = xlookup(L0_P2, L0_df)
        data.insert(data.columns.get_loc(f'P(L{i},1)')+1,f'L{i} (2)',L0_2)
        data.insert(data.columns.get_loc(f'L{i} (2)')+1,f'P(L{i},2)',L0_P2)
        
        L0_P3 = nmax(L0_df,3)
        L0_3 = xlookup(L0_P3, L0_df)
        data.insert(data.columns.get_loc(f'P(L{i},2)')+1,f'L{i} (3)',L0_3)
        data.insert(data.columns.get_loc(f'L{i} (3)')+1,f'P(L{i},3)',L0_P3)

    return data

## Generate probability distributions, assign labels to neurons, and save csv files locally

In [None]:
# brains = sorted(filter(brain_filter, os.listdir('/home/abenneck/dragonfly_work/dragonfly_outputs/')))
brains = ['TME07-1', 'TME09-1', 'TME10-1', 'TME10-3', 'TME12-1']
# brains = ['TME08-1']

allRegionCoords = get_coord_list(672, XS, S) # 672 is regionID for CP
allRegionCoords = np.asarray(allRegionCoords)
uniform_prior = 1/100

print('Loading regional amplitudes')
start = time.time()
# Create list of priors from presaved parameters
paramDir = '/home/abenneck/dragonfly_work/gaussian_parameters'
L1_files = sorted(filter(L1_filter_gauss,os.listdir(paramDir)))
L2_files = sorted(filter(L2_filter_gauss,os.listdir(paramDir)))
nonCP_files = sorted(filter(non_CP_filter_gauss,os.listdir(paramDir)))
allPriors = list()
for fileList in [L1_files, L2_files, nonCP_files]:
    for file in fileList:
        data = np.load(os.path.join(paramDir,file))
        norm_factor = data['height'].item()
        allPriors.append(norm_factor)
print(f'Finished loading amp in {time.time() - start} s\n')

print('Loading distance transform distributions')
start = time.time()
# Create list of distributions from presaved distance transforms
# paramDir = '/home/abenneck/dragonfly_work/distance_parameters'
paramDir = '/home/abenneck/dragonfly_work/distance_parameters'
L1_files = sorted(filter(L1_filter_dist,os.listdir(paramDir)))
L2_files = sorted(filter(L2_filter_dist,os.listdir(paramDir)))
nonCP_files = sorted(filter(non_CP_filter_dist,os.listdir(paramDir)))
allTransforms = list()
for fileList in [L1_files, L2_files, nonCP_files]:
    for file in fileList:
        data = np.load(os.path.join(paramDir,file))
        probDist = data['mask']
        allTransforms.append(probDist)
print(f'Finished loading distributions in {time.time() - start} s\n')

# if True:
#     raise Exception('Rewrite some of the below code')

print('Generating probabilities')
start = time.time()
# Generate probabilities
for brain in brains:
    # Define relevant directories
    neuronDir = f'/home/abenneck/dragonfly_work/dragonfly_outputs/{brain}/dragonfly_joint_outputs/'

    # Note: The ACB#_x columns are inserted after computing absolute probabilities, but before normalization
    L0_col = ['mouse ID', 'slice', 'hemi', 'neuron ID','dist (CP)','fname','CP_L','CP_R']
    L1_col = ['CPc_L','CPc_R','CPi_L','CPi_R','CPr_L','CPr_R',]
    L2_col = ['CPc_d_L','CPc_d_R','CPc_i_L','CPc_i_R','CPc_v_L','CPc_v_R','CPi_dl_L','CPi_dl_R','CPi_dm_L','CPi_dm_R','CPi_vl_L','CPi_vl_R','CPi_vm_L','CPi_vm_R','CPr_imd_L','CPr_imd_R','CPr_imv_L','CPr_imv_R','CPr_l_L','CPr_l_R','CPr_m_L','CPr_m_R']
    nonCP_col = ['ACB_L','ACB_R','AST_L','AST_R','CC_L','CC_R','CL_L','CL_R','EN_L','EN_R','IC_L','IC_R','IPAC_L','IPAC_R','LSS_L','LSS_R','LSX_L','LSX_R','LV_L','LV_R','MFBC_L','MFBC_R','PAL_L','PAL_R','PIR_L','PIR_R','SAMY_L','SAMY_R','TU_L','TU_R']
    allCol = np.concatenate([L0_col, L1_col, L2_col, nonCP_col])
    allProb = pd.DataFrame(columns=allCol)
    
    # Generate regional probabilities for every neuron in neuronDir
    row_idx = 0
    for i, file in enumerate(sorted(os.listdir(neuronDir))):
        if "mapped.swc" in file:
            # Load soma coordinates
            data = pd.read_csv(os.path.join(neuronDir,file))
            x = float(data.columns[2])
            y = float(data.columns[3])
            z = float(data.columns[4])
            soma_location = [x,y,z]

            # Get idx for soma location in xS
            x_ind = np.where(xS[0] == roundCoord(xS[0], x))[0].item()
            y_ind = np.where(xS[1] == roundCoord(xS[1], y))[0].item() 
            z_ind = np.where(xS[2] == roundCoord(xS[2], z))[0].item()

            # Get metadata from filename
            mouseID = file.split('_')[1]
            sliceID = file.split('_')[4][:-1]
            hemiID = file.split('_')[4][-1]
            neuronID = file.split('_')[0]
    
            # Compute probabilities from distance transform probabilities
            cpDist = dist_from_cp(allRegionCoords,soma_location)
            neuronProb = [mouseID, sliceID, hemiID, neuronID, cpDist, file]
            for j, prior in enumerate(allPriors):
                neuronProb.append(uniform_prior*allTransforms[j][0, x_ind, y_ind, z_ind])

            allProb.loc[row_idx] = neuronProb
            row_idx+=1
            print(f'{i}/{len(sorted(os.listdir(neuronDir)))} Finished {file}\n')

    # Insert ACB at each level
    allProb = append_nonCP_regions(allProb)

    # Normalize probabilities at each level
    allProb = normalize_prob(allProb)

    # Identify top 3 regions for each neuron and insert columns into allProb
    L0_labels = ['CP_L','CP_R','ACB0_L','ACB0_R']
    L1_labels = ['CPc_L','CPc_R','CPi_L','CPi_R','CPr_L','CPr_R','ACB1_L','ACB1_R']
    L2_labels = ['CPc_d_L','CPc_d_R','CPc_i_L','CPc_i_R','CPc_v_L','CPc_v_R','CPi_dl_L','CPi_dl_R','CPi_dm_L','CPi_dm_R','CPi_vl_L','CPi_vl_R','CPi_vm_L','CPi_vm_R','CPr_imd_L','CPr_imd_R','CPr_imv_L','CPr_imv_R','CPr_l_L','CPr_l_R','CPr_m_L','CPr_m_R','ACB2_L','ACB2_R']
    all_labels = [L0_labels, L1_labels, L2_labels]

    allProb = compute_top_n_maximums(allProb, all_labels)
    
    # Save probability distributions as one .csv file in tempDir
    tempDir = f'/home/abenneck/dragonfly_work/{brain}_neuron_region_prob.csv'   
    allProb.to_csv(tempDir, index=False)

    print(f'Finished {brain} in {time.time() - start} s with an average of {(time.time() - start) / len(allProb)} s / neuron\n')

In [None]:
raise Exception('End of postprocessing')

In [None]:
data = pd.read_csv('/home/abenneck/dragonfly_work/TME08-1_neuron_region_prob.csv')

In [None]:
# # Create list of multivariate Gaussian RVs from presaved parameters
# allDist = list()
# for fileList in [L1_files, L2_files]:
#     for file in fileList:
#         data = np.load(os.path.join(paramDir,file))
#         norm_factor = data['height'].item()
#         mu = data['mu']
#         cov = data['cov']
#         regionDist = multivariate_normal(mean=mu, cov=cov, allow_singular=False)
#         # regionDist = 
#         allDist.append([regionDist,norm_factor])

In [None]:
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_columns', 10)