In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import networkx as nx
import plotly.offline
import plotly.graph_objects as go
from scipy.stats import ks_2samp
from sbemdb import SBEMDB
from cleandb import clean_db, clean_db_uct
import os
import scipy.io
from math import sqrt
#ignore warnings
np.warnings.filterwarnings('ignore')

#Visualization of coherence values, hence clean_db is necessary
db = SBEMDB() # connect to DB
db = clean_db(db) # must be clean_db here because it is based on coherence values
x,y,z = db.segments(444)

ModuleNotFoundError: No module named 'sbemdb'

In [2]:
(xx, yy, zz, pretid, posttid, synid, prenid, postnid) = db.synapses('pre.tid!=444 and post.tid=444',extended = True)

In [3]:
# Read in the color scheme
path_color = os.path.join('..','data','cohcolor.csv')
color_map = pd.read_csv(path_color)
#print(color_map)

f = open(path_color)
lines = f.readlines()
f.close()

# Read in phase and mag values
path_color = os.path.join('..', 'data', 'roi_phase_93.mat')
phase_mat = scipy.io.loadmat(path_color)
path_alpha = os.path.join('..', 'data', 'roi_mag_93.mat')
mag_mat = scipy.io.loadmat(path_alpha)

col_names = ['roi','Phase','Alpha']
data_values = pd.DataFrame(columns = col_names)
mag_values = mag_mat['roi_mag'][0];

for el, val in enumerate(phase_mat['roi_phase']):
    data_values.loc[el,'roi'] = el + 1
    data_values.loc[el,'Phase'] = round(val[0]+np.pi,3) #convert to range [0,2pi]
    #data_values.loc[el,'Alpha'] = sqrt(mag_values[el]);
    data_values.loc[el,'Alpha'] = (mag_values[el]);
    
#print(phase_mat)
print(data_values)

     roi  Phase     Alpha
0      1  3.054  0.853725
1      2  3.085  0.867127
2      3  5.494  0.734109
3      4  1.299  0.751747
4      5  5.932  0.818183
5      6  6.096  0.638267
6      7   5.88  0.716588
7      8  5.217  0.688959
8      9  3.541  0.750446
9     10  5.808  0.646029
10    11  4.558  0.201835
11    12  2.277  0.256543
12    13  0.276  0.491388
13    14  3.268  0.689449
14    15  5.884   0.63235
15    16  0.991  0.265862
16    17  2.684  0.921983
17    18  4.842  0.757642
18    19  3.006  0.865632
19    20  3.417  0.421367
20    21  5.605  0.818301
21    22  5.173   0.47355
22    23  5.791  0.456681
23    24  1.893  0.674163
24    25  5.061  0.196861
25    26  0.138  0.899309
26    27   6.09  0.570121
27    28  6.196  0.848481
28    29  0.109  0.763524
29    30  6.283  0.636438
..   ...    ...       ...
220  221  5.374  0.620381
221  222  4.693  0.799145
222  223  2.827  0.899291
223  224  3.783  0.595257
224  225  3.802  0.857773
225  226  2.577  0.879061
226  227  3.

In [4]:
def is_roi(roi):
    return data_values['roi']==roi;

def to_8bit_rgb(rgb):
    return int(round(rgb*255));

#get tid, convert it to roi, with roi get the phase
#then with phase get the associated color in rgb
#then convert to 8bit rgb range

# rgb(211,211,211)

def alpha_grey_convert(r,g,b,alpha=0):
    if(alpha>=1) : return (r,g,b);
    if(alpha<=0) : return (128,128,128);
    #if(alpha<=0) : return (105,105,105);
    k = alpha;
    n = 128*(1-alpha);
    return(int(k*r+n),int(k*g+n),int(k*b+n))

def get_color(db,tid):
    roi = db.mapping.sbem2roi[tid];
    phase = data_values.iloc[roi-1].Phase 
    alpha = data_values.iloc[roi-1].Alpha
    #match phase to phi in the color_map variable and select the respective r,g,b values 
    rgb1 = color_map['r'].loc[color_map.phi == phase] #then get the color code for that phase from the color_mat --> needs rewriting as well
    rgb2 = color_map['g'].loc[color_map.phi == phase]
    rgb3 = color_map['b'].loc[color_map.phi == phase]
    
    #use alpha_grey_convert
    rgb1,rgb2,rgb3 = alpha_grey_convert(to_8bit_rgb(rgb1),to_8bit_rgb(rgb2),to_8bit_rgb(rgb3),alpha)
    
    return f'rgb({rgb1},{rgb2},{rgb3})';

def get_color_bar(phase):
    # match phase to phi in the color_map variable and select the respective r,g,b values
    rgb1 = color_map['r'].loc[
        color_map.phi == phase]  # then get the color code for that phase from the color_mat --> needs rewriting as well
    rgb2 = color_map['g'].loc[color_map.phi == phase]
    rgb3 = color_map['b'].loc[color_map.phi == phase]
    #print(f'rgb({to_8bit_rgb(rgb1)},{to_8bit_rgb(rgb2)},{to_8bit_rgb(rgb3)})')
    #except:
    #    print("   ")
    return f'rgb({to_8bit_rgb(rgb1)},{to_8bit_rgb(rgb2)},{to_8bit_rgb(rgb3)})';


In [5]:
synapses = {}
sids = {}
for i in range(len(xx)):
    if synid[i] in sids:
        continue;
    tid = pretid[i]
    if tid not in synapses: synapses[tid] = []
    synapses[tid].append([xx[i], yy[i], zz[i]])
    sids[synid[i]] = True;

In [7]:
color_names = ['0°', '-90°', '±180°', '90°', '0°']
color_vals = np.linspace(1,9,5)

def plot_graph(synapses, min_synapses=1):
    each_tree_scatter = []
    soma = []
    i = 0
    for tid in synapses:
        try:
            #print(i, tid)

            if len(synapses[tid]) < min_synapses: continue
            xs = []
            ys = []
            zs = []
            cs = []
            for coords in synapses[tid]:
                xs.append(coords[0])
                ys.append(coords[1])
                zs.append(coords[2])
            if (i == 1):
                colorscale = []
                #color_names = [1, 2, 3, 4, 5, 6]
                for i in range(1, len(lines)-1):
                    #print(i)
                    try:
                        t = lines[i].split(",")
                        color = get_color_bar(float(t[0]))
                        # r = int(float(t[1])*255)
                        # g = int(float(t[2])*255)
                        # b = int(float(t[3])*255)
                        colorscale.append(color)
                    except:
                        continue




                each_tree_scatter.append(go.Scatter3d(x=xs, y=ys, z=zs,
                                                      mode='markers',
                                                      surfacecolor=get_color(db, tid),
                                                      hoverinfo='text',
                                                      hovertext=f'{tid}',
                                                      name=f'{tid:4} - {db.mapping.sbem2can[tid]}',

                                                      marker=dict(
                                                          color=get_color(db, tid),
                                                          # set color to an array/list of desired value
                                                          size=3,
                                                          opacity=0.8,
                                                          colorbar=dict(
                                                              title="Phase of coherence", 
                                                              x=1.18, 
                                                              y=0.5, 
                                                              tickvals=color_vals,
                                                              ticktext=color_names
                                                          ),
                                                          colorscale = colorscale,

                                                      )))
            else:
                each_tree_scatter.append(go.Scatter3d(x=xs, y=ys, z=zs,
                                                      mode='markers',
                                                      surfacecolor=get_color(db, tid),
                                                      hoverinfo='text',
                                                      hovertext=f'{tid}',
                                                      name=f'{tid:4} - {db.mapping.sbem2can[tid]}',

                                                      marker=dict(
                                                          color=get_color(db, tid),
                                                          # set color to an array/list of desired value
                                                          size=4.5,
                                                          opacity=1
                                                      )
                                        ))

                
                
            i+=1
        except:
            continue
    nodes = db.nodexyz(f'tid==444 and typ==1')  # (x, y, z, nid)
    soma.append(go.Scatter3d(x=nodes[0], y=nodes[1], z=nodes[2],
                                         name='Soma',
                                         hoverinfo=[],
                                         hovertext=[],
                                         mode='markers',
                                         marker=dict(
                                             color="rgb(0,0,0)",
                                             size=6,  # set color to an array/list of desired values
                                             opacity=1
                                         )))
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                       mode='lines',
                                       hoverinfo='text',
                                       hovertext=[],
                                       opacity=0.7,
                                       name='DE3',
                                       marker=dict(
                                           color=1,
                                           size=6,  # set color to an array/list of desired values
                                           colorscale='Viridis',  # choose a colorscale
                                           opacity=1
                                       )),
                          ] + each_tree_scatter + soma)

    fig.update_layout(
        scene=dict(
            xaxis=dict(nticks=5, range=[100, 300], showbackground=False, showticklabels=False, title=''),
            yaxis=dict(nticks=13, range=[150, 750], showbackground=False, showticklabels=False, title=''),
            zaxis=dict(nticks=6, range=[100, 350], showbackground=False, showticklabels=False, title=''),
            aspectmode='data'
        )
    )
    

    plotly.offline.plot(fig, filename='synapses_coherencePhase_LocalBend.html', auto_open=True)
    fig.show()
    
plot_graph(synapses);