In [None]:
%matplotlib notebook
import numpy as np
import pandas as pd
from scipy.optimize import minimize
import random as rd
import matplotlib.pyplot as plt
import LFPy
import brainsignals.neural_simulations as ns
from brainsignals.neural_simulations import return_hay_cell
import neuron
np.random.seed(5678)

In [None]:
def insert_synaptic_input(idx, cell, syn_scale):

    synapse_parameters = {'e': 0., # reversal potential
                          'weight': 0.002 * syn_scale, #  synapse weight
                          'record_current': True, # record synapse current
                          'syntype': 'Exp2Syn',
                          'tau1': 1, #Time constant, rise
                          'tau2': 3, #Time constant, decay
                          }
    synapse_parameters['idx'] = idx
    synapse = LFPy.Synapse(cell, **synapse_parameters)
    synapse.set_spike_times(np.array([1.]))
    return synapse, cell

In [None]:
tstop = 10
# Time window to extract spike from:
t0 = 3
t1 = 8

dt = 2**-6
cell = return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
ns.point_axon_down(cell)
syn, cell = insert_synaptic_input(0, cell, 20)
cell.simulate(rec_imem=True, rec_vmem=True)
t0_idx = np.argmin(np.abs(cell.tvec - t0))
t1_idx = np.argmin(np.abs(cell.tvec - t1))

cell.vmem = cell.vmem[:, t0_idx:t1_idx]
cell.imem = cell.imem[:, t0_idx:t1_idx]
cell.tvec = cell.tvec[t0_idx:t1_idx] - cell.tvec[t0_idx]
fig = plt.figure()
plt.plot(cell.tvec, cell.vmem[0])
fig.show()

In [None]:
# minimum value for p_near
cdm = LFPy.CurrentDipoleMoment(cell)
P = cdm.get_transformation_matrix() @ cell.imem
print(np.max(np.abs(P)))

p_df = pd.DataFrame(P.T, columns = ['p_x', 'p_y', 'p_z'])
p_df_abs = p_df.abs()
p_df['abs_sum'] = p_df_abs.sum(axis=1)
max_sum = p_df['abs_sum'].idxmax()
p_df.loc[max_sum]

In [None]:
# generate disk electrode

def generate_disk_points(N, p0, n, disk_radius):
    
    x = rd.uniform(0,1)
    y = rd.uniform(0,1)
    z = rd.uniform(0,1)
    a = np.array([x,y,z])
    alfa = np.cross(N,(N+a))/np.linalg.norm(N+a)
    beta = np.cross(N,alfa)
    alfa_n = alfa/np.linalg.norm(alfa)
    beta_n = beta/np.linalg.norm(beta)
    
    i = 0
    C = np.zeros(shape=(n, 3))
    r_max = 0
    
    while i<n:
        u = rd.uniform(-disk_radius,disk_radius)
        v = rd.uniform(-disk_radius,disk_radius)
        if (u**2+v**2) < disk_radius**2:
            C[i] = u*alfa_n + v*beta_n + p0
            i += 1
            if (u**2+v**2) > r_max**2:
                r_max = np.sqrt(u**2+v**2)
    return C

In [None]:
def gen_random_disk_set(n_elecs, disk_radius, n, x0, x1):
    N_vecs = np.array([(1,0,0), (0,1,0), (0,0,1), (-1,0,0), (0,-1,0), (0,0,-1)])

    k= 0
    first = True
    for i in range(n_elecs):
        k = np.random.randint(6)
        N = N_vecs[k]
        c0 = np.random.uniform(x0, x1, size=3)
        el_coords1 = generate_disk_points(N, c0, n, disk_radius)
        if first:
            elec_coords=el_coords1
            first = False
        else:
            newrow = el_coords1
            elec_coords = np.vstack([elec_coords, newrow])
        k+=1
    return elec_coords

In [None]:
def gen_disk_set(n_elecs_set, disk_radius, set_width, n, x0, x1):
    
    normal_vec_set = np.array([(1,0,0),(-1,0,0), (0,1,0), (0,-1,0)])
    set_radius = set_width/2
    k = 0
    # calc coordinates for disk electrode sets
    first = True
    for i in range(n_elecs_set):
        pos_elec_set = np.random.uniform(x0, x1, size=3)
        for j in range(4):
            center = pos_elec_set + set_radius*normal_vec_set[j]
            N = normal_vec_set[j]
            el_coords1 = generate_disk_points(N, center, n, disk_radius)
            if first:
                elec_coords=el_coords1
                first = False
            else:
                
                newrow = el_coords1
                elec_coords = np.vstack([elec_coords, newrow])
            k+=1
            
    return elec_coords

In [None]:
# generate cylinder electrode 
# =======================================================================
# Input parameters:
# =======================================================================
# start_z       : starting point of cylinder height   
# end_z         : ending point of cylinder height 
# mid_x         : midpoint x-axis
# mid_y         : midpoint y-axis
# radius        : cylinder radius
# num_points    : number of points to constitute the cylinder
# test_scale    : parameter for testing, default as 1
# =======================================================================
# Outputs       : C, a numpy array of coordinates with shape (num_points,3) 

def gen_sylinder_points(start_z, end_z, mid_x, mid_y, radius, num_points, test_scale):
    
    z_ = np.random.uniform(test_scale*start_z, test_scale*end_z, num_points)
    theta = np.random.uniform(0, 2*np.pi, num_points)
    x_ = mid_x + radius*test_scale*np.cos(theta)
    y_ = mid_y + radius*test_scale*np.sin(theta)
    C = np.array([x_, y_, z_])
    
    return C.T

In [None]:
# gen set of disk set elecs with set posititions, also return center locs of disks
def g_d_s(disk_radius, set_radius, n_points, n_elecs_set, center_pos):
    normal_vec_set = np.array([(1,0,0),(-1,0,0), (0,1,0), (0,-1,0)])
    centers = []
    k = 0
    first = True
    for i in range(n_elecs_set):
        pos_elec_set = center_pos[i]
        for j in range(4):
            center = pos_elec_set + set_radius*normal_vec_set[j]
            N = normal_vec_set[j]
            el_coords1 = generate_disk_points(N, center, n_points, disk_radius)
            centers.append(center)
            if first:
                elec_coords=el_coords1
                first = False
            else:
                newrow = el_coords1
                elec_coords = np.vstack([elec_coords, newrow])
                
            k+=1
            
    return elec_coords, centers

In [None]:
# gen set of sylinder elecs with set posititions
def g_s_s(n_elecs, n_points, syl_height, syl_width, center_locs, test_scale):
    
    k = 0
    first = True
    for i in range(n_elecs):
        mid_z = center_locs[i][2]
        start_z = mid_z - (syl_height/2)
        end_z = start_z + syl_height
        mid_x = center_locs[i][0]
        mid_y = center_locs[i][1]
        el_coords1 = gen_sylinder_points(start_z, end_z, mid_x, mid_y, (syl_width/2), n_points, test_scale)
        if first:
            elec_coords=el_coords1
            first = False
        else:
            newrow = el_coords1
            elec_coords = np.vstack([elec_coords, newrow])
        k+=1
        
    return elec_coords

In [None]:
# plot set of spatially extended electrodes 

def plot_elecs(elec_coords, dipole_pos):
    X = []
    Y = []
    Z = []
    for i in range(len(elec_coords)):
        X.append(elec_coords[i][0])
        Y.append(elec_coords[i][1])
        Z.append(elec_coords[i][2])

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect(aspect='auto')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.scatter(X,Y,Z)
    ax.scatter(dipole_pos[0], dipole_pos[1], dipole_pos[2])
    
    return ax

In [None]:
def plot_res_two(elec_coords, dipole_pos_1, dipole_pos_2, opt, filename, title):
    X = []
    Y = []
    Z = []
    for i in range(len(elec_coords)):
        X.append(elec_coords[i][0])
        Y.append(elec_coords[i][1])
        Z.append(elec_coords[i][2])

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect(aspect='auto')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.scatter(X,Y,Z)
    ax.scatter(dipole_pos_1[0], dipole_pos_1[1], dipole_pos_1[2])
    ax.scatter(dipole_pos_2[0], dipole_pos_2[1], dipole_pos_2[2])
    ax.scatter(opt[0], opt[1], opt[2], color = 'red')
    fig.suptitle(title, fontsize=10)
    fig.savefig(f'{filename}.png', bbox_inches='tight')
    
    return ax

In [None]:
def plot_res(elec_coords, dipole_pos, opt):
    X = []
    Y = []
    Z = []
    for i in range(len(elec_coords)):
        X.append(elec_coords[i][0])
        Y.append(elec_coords[i][1])
        Z.append(elec_coords[i][2])

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect(aspect='auto')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.scatter(X,Y,Z)
    ax.scatter(dipole_pos[0], dipole_pos[1], dipole_pos[2])
    ax.scatter(opt[0], opt[1], opt[2], color = 'red')
    #fig.suptitle(title, fontsize=10)
    #fig.savefig(filename, bbox_inches='tight')
    return ax

In [None]:
def V_e_avg(V_e, n_points, n_elecs):
    V_e_avg = np.zeros(shape=(n_elecs, ))
    
    for i in range(n_elecs):
        V_e_avg[i] = np.sum(V_e[n_points*(i):n_points*(i+1)])/n_points
    
    return V_e_avg

In [None]:
# method for calculating extracellular potential through dipole approximation
# ===========================================================================
# Input parameters:
# ===========================================================================
# x           :  3D coordinates of dipole
# elec_locs   :  coordinates for points constituting the electrode surfaces
# p           :  dipole moment
# n_points    :  number of points on each electrode surface
# n_elecs     :  number of electrode surfaces
# ===========================================================================
# Outputs     : average extracellular potential for each electrode surface
# ===========================================================================

sigma = 0.3
def dipole_potential(x, elec_locs, p, n_points, n_elecs):
    r_ = elec_locs - x
    V_e = r_ @ p.T / (4 * np.pi * sigma * np.linalg.norm(r_, axis=1) ** 3)
    
    avg = V_e_avg(V_e, n_points, n_elecs)
    
    return avg

In [None]:
def dipole_potential_min(test_pos, elec_locs, V_e_measured, p, n_points, n_elecs):
    V_e = dipole_potential(test_pos, elec_locs, p, n_points, n_elecs)
    res = (V_e-V_e_measured) /((np.sqrt(np.abs(V_e_measured*V_e))))    # Normalizing seemed to help
    
    return np.sum(res**2)

In [None]:
num_trials = 20
box_width = 20000

x0 = -box_width / 2
x1 = box_width / 2

p_near = np.array([196.481, -78.868, 795.242])
p_far = np.array([1., 1., 1.]) * 1e8

set_radius = 1000
n_points_disk = 1000
n_points_point = 1
n_elecs_set = 5
n_elecs_tot = n_elecs_set*4
elec_set_center_positions = np.array([(4000, -4000, -2000), (-4000, 4000, 1500), (4000, 4000, 4000),
                                      (-4000, -4000, 2000), (0, -4000, -4000)])

dks, pts = g_d_s(10, set_radius, n_points_disk, n_elecs_set, elec_set_center_positions)
dipole_near_pos = pts[18]+np.array([0, 20, 0])
dipole_near_pos_2 = pts[10]+np.array([0, 20, 0])
dipole_near_pos_3 = pts[1]+np.array([-20, 0, 0])

dipole_far_pos = np.array([8000, 8500, -8500])
disk_radii = np.linspace(0.1, 300, num=22)

disk_errors_tot = np.zeros(shape=(num_trials, len(disk_radii)))
disk_fails_tot = []

for i, disk_radius in enumerate(disk_radii):
    disk_errors = np.zeros(num_trials)
    disk_fails = []
    disks, points = g_d_s(disk_radius, set_radius, n_points_disk, n_elecs_set, elec_set_center_positions)
    V_e_measured_near = dipole_potential(dipole_near_pos, disks, p_near, n_points_disk, n_elecs_tot)
    V_e_measured_near_2 = dipole_potential(dipole_near_pos_2, disks, p_near, n_points_disk, n_elecs_tot)
    V_e_measured_near_3 = dipole_potential(dipole_near_pos_3, disks, p_near, n_points_disk, n_elecs_tot)
    V_e_measured_far = dipole_potential(dipole_far_pos, disks, p_far, n_points_disk, n_elecs_tot)
    V_e_measured = V_e_measured_near + V_e_measured_far + V_e_measured_near_2 + V_e_measured_near_3
        
    
    for trial_idx in range(num_trials):
        
        retry = True
        disk_counter = 0
        while retry:
            dipole_pos0 = np.random.uniform(x0, x1, size=(3))
            disk_opt = minimize(dipole_potential_min, dipole_pos0,
                        options = {'maxiter': 50000., 'maxls': 100, 'gtol': 1e-12},
                        method='L-BFGS-B',
                        args =(disks, V_e_measured, p_far, n_points_disk, n_elecs_tot), 
                        bounds=((x0*1.2, x1*1.2), (x0*1.2, x1*1.2), (x0*1.2, x1*1.2)), tol=1e-12
                        )
            V_e_best_guess_disk = dipole_potential(disk_opt.x, disks, p_far, n_points_disk, n_elecs_tot)
            rel_error_disk = np.mean(np.abs((V_e_measured_far - V_e_best_guess_disk) / V_e_measured_far))
            
            if (rel_error_disk < 0.05) or (disk_counter > 100):
                # If it doesn't work well, we just retry with a different initial guess
                retry = False

            disk_counter += 1
        disk_errors_tot[trial_idx, i] = np.sqrt(np.sum((dipole_far_pos - disk_opt.x)**2)) 
        
        if disk_errors_tot[trial_idx, i] > 10:
            disk_fails.append([disks, dipole_far_pos, disk_errors_tot[trial_idx, i], disk_opt.x, V_e_best_guess_disk, 
                          V_e_measured_far, dipole_pos0, disk_opt, disk_radius])
    
        disk_fails_tot.append(disk_fails)  

In [None]:
np.savetxt('disk_errors_bio_noise_rerun_orig_radii.txt', disk_errors_tot)

In [None]:
fig = plt.figure()
plt.title('Errors in locating distant neuron with biological noise present')
plt.xlabel('electrode radius (µm)')
plt.ylabel('log10 localization error (µm)')
plt.legend()
plt.yscale('linear') 
for i in range(num_trials):
    plt.scatter(disk_radii,np.log10(disk_errors_tot[i]), s=10)
fig_name = 'bio_noise_errors_r_300'
fig_folder = 'b_noise_rerun'
fig.savefig(f'C:\\Users\\SunRe_Admin\\OneDrive\\Skrivebord\\master\\{fig_folder}\\{fig_name}', bbox_inches='tight')

In [None]:
d = {}
for i, radius in enumerate(disk_radii):
    d[f'errors_radius_{radius}'] = []
    for j in range(len(disk_errors_tot)):
        d[f'errors_radius_{radius}'].append(disk_errors_tot[j][i])
df = pd.DataFrame(d)
df.describe()