In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1
os.environ["CUDA_VISIBLE_DEVICES"]= '2'

import numpy as np
import matplotlib.pyplot as plt
import multielec_src.fitting as fitting
import multielec_src.multielec_utils as mutils
import multielec_src.old_labview_data_reader as oldlv
from scipy.io import loadmat
import multiprocessing as mp
import statsmodels.api as sm
from copy import deepcopy, copy
import visionloader as vl
from mpl_toolkits.mplot3d import Axes3D
from gsort.gsort_core_scripts import *
from gsort.gsort_data_loader import *
from estim_utils.triplet import get_triplet_index

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
RUN_TYPE = "triplet"

#gsort args
ELECTRODE_ORDERING = [get_triplet_index(23, 32, 35), get_triplet_index(490, 491, 481), get_triplet_index(488, 489, 479)]
VISUAL_ANALYSIS_BASE = "/Volumes/Analysis"
DATASET = "2023-06-14-1"
VSTIM_DATARUN = "streamed/data000"
ESTIM_ANALYSIS_BASE = "/Volumes/Analysis"
ESTIM_DATARUN = "data003"
estim_analysis_path = os.path.join(ESTIM_ANALYSIS_BASE, DATASET, ESTIM_DATARUN)
NOISE_THRESH = 6

CELL_TYPES = ['nc']
EXCLUDED_TYPES = ['bad', 'dup']

cell_data_dict, cells_to_gsort_dict, mutual_cells, total_cell_to_electrode_list, \
    END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE = \
        load_vision_data_for_gsort(RUN_TYPE, VISUAL_ANALYSIS_BASE, DATASET, VSTIM_DATARUN, \
            noise_thresh=NOISE_THRESH, cell_types = CELL_TYPES, excluded_types=EXCLUDED_TYPES, patterns=np.array(ELECTRODE_ORDERING))

In [None]:
def generate_noise(signals, factor=0):

    noise_tensor = np.zeros(signals.shape)
    for i in range(signals.shape[0]):
        noise_tensor[i] = np.random.normal(loc=np.zeros(len(NOISE)), scale=NOISE, 
                                            size=(signals.shape[2], signals.shape[1])).T
    
    return noise_tensor * factor

In [None]:
def gsort_subsampled(p, k, num_trials):
    if num_trials == 0:
        return np.zeros(len(cells_to_gsort_dict[p])), np.array(cells_to_gsort_dict[p])
        
    signal = oldlv.get_oldlabview_pp_data(estim_analysis_path, p, k)
    preloaded_data = (cell_data_dict, cells_to_gsort_dict[p], mutual_cells, total_cell_to_electrode_list, END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE)

    subsampled = signal[np.random.choice(len(signal), size=num_trials, replace=True)]
    noise_tensor = generate_noise(subsampled)
    subsampled = subsampled + noise_tensor

    return run_pattern_movie_live(subsampled, preloaded_data), np.array(cells_to_gsort_dict[p])

In [None]:
def gsort_trial_allocation(trials, patterns, cells, NUM_THREADS=24):
    assert len(trials) == len(patterns), "Trials matrix first dimension must match number of patterns"

    all_probs = np.zeros((len(cells), trials.shape[0], trials.shape[1]))

    input_list = []
    for i in range(len(trials)):
        for j in range(len(trials[i])):
            input_list += [(patterns[i], j, int(trials[i][j]))]
            
    pool = mp.Pool(processes=NUM_THREADS)
    results = pool.starmap_async(gsort_subsampled, input_list)
    mp_output = results.get()
    pool.close()

    cnt = 0
    for i in range(len(trials)):
        for j in range(len(trials[i])):
            cell_inds = np.searchsorted(cells, mp_output[cnt][1])
            all_probs[cell_inds, i, j] = mp_output[cnt][0]
            cnt += 1

    return all_probs

In [None]:
amps_list = []
for i in range(len(ELECTRODE_ORDERING)):
    amps_list.append(mutils.get_stim_amps_newlv(estim_analysis_path, ELECTRODE_ORDERING[i]))

amps_scan = np.array(amps_list)
all_cells = np.sort(np.array(list(cell_data_dict.keys())))

In [None]:
# total_budget = 50000
num_iters = 5

T_prev = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=float)
budget = T_prev.shape[0] * T_prev.shape[1] * 0.5 #int(total_budget / num_iters)
reg = None # 20, 50
T_step_size = 0.05 # 0.05, 0.01
T_n_steps = 5000    

init_trials = 10
init_amps = 1000
ms = [1, 2, 3, 4]
disambiguate = False
verbose = True
R2_cutoff = 0
prob_low = 1 / init_trials
min_inds = 50

for i in range(len(T_prev)):
    init_inds = np.random.choice(np.arange(len(T_prev[i]), dtype=int), size=init_amps,
                                 replace=False)
    T_prev[i][init_inds] = init_trials

T_prev_uniform = deepcopy(T_prev)

probs_empirical = gsort_trial_allocation(T_prev, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)
probs_empirical_uniform = deepcopy(probs_empirical)

# # CHECK DATASET IF NEEDED

# for i in range(len(probs_empirical)):
#     for j in range(len(probs_empirical[i])):

#         print(all_cells[i], j+1)
#         sampled_inds = np.where(T_prev[j] > 0)[0]

#         fig = plt.figure(1)
#         fig.clear()
#         ax = Axes3D(fig, auto_add_to_figure=False)
#         fig.add_axes(ax)
#         plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
#         plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
#         plt.xlim(-2, 2)
#         plt.ylim(-2, 2)
#         ax.set_zlim(-2, 2)
#         ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

#         scat = ax.scatter(amps_scan[j][sampled_inds, 0], 
#                             amps_scan[j][sampled_inds, 1],
#                             amps_scan[j][sampled_inds, 2], marker='o', 
#                             c=probs_empirical[i][j][sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)
#         plt.show()
        
#         w_inits = []

#         for m in ms:
#             w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
#             w_inits.append(w_init)
        
#         X_clean, p_clean, T_clean = fitting.disambiguate_fitting(amps_scan[j][sampled_inds],
#                                                                  probs_empirical[i][j][sampled_inds],
#                                                                  T_prev[j][sampled_inds],
#                                                                  w_inits)

#         fig = plt.figure(2)
#         fig.clear()
#         ax = Axes3D(fig, auto_add_to_figure=False)
#         fig.add_axes(ax)
#         plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
#         plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
#         plt.xlim(-2, 2)
#         plt.ylim(-2, 2)
#         ax.set_zlim(-2, 2)
#         ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

#         scat = ax.scatter(X_clean[:, 0], 
#                             X_clean[:, 1],
#                             X_clean[:, 2], marker='o', 
#                             c=p_clean, s=20, alpha=0.8, vmin=0, vmax=1)
#         plt.show()

#         input()

# performances = []
# performances_uniform = []
num_samples = []
num_samples_uniform = []

iter_cnt = 0
while True:
    if iter_cnt == 0:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=min_inds)

        # performance = get_performance_array(params_true, probs_curr, probs_true_scan)
        # performance_uniform = performance

        w_inits_array_uniform = deepcopy(w_inits_array)
        
    else:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        # t_final=t_final,
                                        w_inits_array=w_inits_array,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=min_inds)
        
        # performance = get_performance_array(params_true, probs_curr, probs_true_scan)

        input_list_uniform = fitting.generate_input_list(probs_empirical_uniform, amps_scan, 
                                                            T_prev_uniform, w_inits_array_uniform, prob_low,
                                                            disambiguate=disambiguate, min_inds=min_inds)

        pool = mp.Pool(processes=24)
        results_uniform = pool.starmap_async(fitting.fit_surface, input_list_uniform)
        mp_output_uniform = results_uniform.get()
        pool.close()

        params_curr_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        w_inits_array_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        probs_curr_uniform = np.zeros(probs_empirical_uniform.shape)

        cnt = 0
        for i in range(len(probs_empirical_uniform)):
            for j in range(len(probs_empirical_uniform[i])):
                params_curr_uniform[i][j] = mp_output_uniform[cnt][0]
                w_inits_array_uniform[i][j] = mp_output_uniform[cnt][1]
                
                probs_curr_uniform[i][j] = fitting.sigmoidND_nonlinear(
                                        sm.add_constant(amps_scan[j], has_constant='add'), 
                                        params_curr_uniform[i][j])

                cnt += 1

        # performance_uniform = get_performance_array(params_true, probs_curr_uniform, probs_true_scan)
    
    # print(performance, performance_uniform)
    
    # performances.append(performance)
    # performances_uniform.append(performance_uniform)

    for i in range(len(probs_empirical)):
        for j in range(len(probs_empirical[i])):
            if ~np.all(params_curr[i][j][:, 0] == -np.inf):

                print(all_cells[i], j+1)
                sampled_inds = np.where(T_prev[j] > 0)[0]

                print(params_curr[i][j])
                fig = plt.figure(0)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-1.8, 1.8)
                plt.ylim(-1.8, 1.8)
                ax.set_zlim(-1.8, 1.8)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][:, 0], 
                                amps_scan[j][:, 1],
                                amps_scan[j][:, 2], marker='o', 
                                c=probs_curr[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                fig = plt.figure(1)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][sampled_inds, 0], 
                                    amps_scan[j][sampled_inds, 1],
                                    amps_scan[j][sampled_inds, 2], marker='o', 
                                    c=probs_empirical[i][j][sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                X_clean, p_clean, T_clean = fitting.disambiguate_fitting(amps_scan[j][sampled_inds],
                                                                        probs_empirical[i][j][sampled_inds],
                                                                        T_prev[j][sampled_inds],
                                                                        w_inits_array[i][j])

                fig = plt.figure(2)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(X_clean[:, 0], 
                                    X_clean[:, 1],
                                    X_clean[:, 2], marker='o', 
                                    c=p_clean, s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                input()
    
    num_samples.append(np.sum(T_prev))
    num_samples_uniform.append(np.sum(T_prev_uniform))

    iter_cnt += 1

    if iter_cnt > num_iters:
        break

    T_prev = T_new + T_prev
    probs_empirical = gsort_trial_allocation(T_prev, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)
    
    # random_extra = np.random.choice(len(T_new.flatten()), size=int(np.sum(T_new)), replace=True)
    # T_new_uniform = np.array(np.bincount(random_extra, minlength=len(T_new.flatten())).astype(int).reshape(T_new.shape), dtype=float)
    T_new_uniform = np.ones_like(T_prev_uniform, dtype=float)
    T_prev_uniform = T_prev_uniform + T_new_uniform
    probs_empirical_uniform = gsort_trial_allocation(T_prev_uniform, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)

    plt.figure()
    plt.plot(T_prev.flatten())
    plt.show()

In [None]:
def get_performance_array(true_params, curr_probs, true_probs):
    
    error = 0
    cnt = 0
    for i in range(len(true_params)):
        for j in range(len(true_params[i])):
            if ~np.all(true_params[i][j][:, 0] == -np.inf):
                error += np.sqrt(np.sum((curr_probs[i][j] - true_probs[i][j])**2) / len(true_probs[i][j]))
                cnt += 1

    return error / cnt

In [None]:
baseline_trials = 20
T_prev_baseline = np.ones_like(T_prev, dtype=float) * baseline_trials

probs_empirical_baseline = sample_spikes_array(probs_true_scan, T_prev_baseline, NUM_THREADS=24)

w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
for i in range(len(w_inits_array_baseline)):
    for j in range(len(w_inits_array_baseline[i])):
        w_inits = []

        for m in ms:
            w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
            w_inits.append(w_init)

        w_inits_array_baseline[i][j] = w_inits

input_list_baseline = fitting.generate_input_list(probs_empirical_baseline, amps_scan, T_prev_baseline, w_inits_array_baseline, 1 / baseline_trials,
                                                    disambiguate=disambiguate, min_inds=min_inds)

pool = mp.Pool(processes=24)
results_baseline = pool.starmap_async(fitting.fit_surface, input_list_baseline)
mp_output_baseline = results_baseline.get()
pool.close()

params_curr_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
probs_curr_baseline = np.zeros(probs_empirical_baseline.shape)

cnt = 0
for i in range(len(probs_empirical_baseline)):
    for j in range(len(probs_empirical_baseline[i])):
        params_curr_baseline[i][j] = mp_output_baseline[cnt][0]
        w_inits_array_baseline[i][j] = mp_output_baseline[cnt][1]
        
        probs_curr_baseline[i][j] = fitting.sigmoidND_nonlinear(
                                sm.add_constant(amps_scan[j], has_constant='add'), 
                                params_curr_baseline[i][j])

        cnt += 1

performance_baseline = get_performance_array(params_true, probs_curr_baseline, probs_true_scan)

In [None]:
plt.figure(figsize=(10, 8))

plt.plot(np.array(num_samples)/T_prev.shape[0]/T_prev.shape[1], performances, linewidth=4, c='tab:blue', label='Optimal Sampling')
plt.plot(np.array(num_samples_uniform)/T_prev_uniform.shape[0]/T_prev_uniform.shape[1], performances_uniform, linewidth=4, c='tab:red', label='Random Sampling')

plt.axhline(performance_baseline, c='k', linestyle='--', linewidth=2, label=f'{baseline_trials} Uniform Trials')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Average Trials per Current Level', fontsize=24)
plt.ylabel(r'RMSE', fontsize=24)
plt.legend(fontsize=20)

# plt.savefig(f'triplet_AL_full_{dataset}.png', dpi=300)

In [None]:
# CHECK DATASET IF NEEDED
from mpl_toolkits.mplot3d import Axes3D

for i in range(len(all_probs)):
    for j in range(len(all_probs[i])):
        if ~np.all(params_true[i][j][:, 0] == -np.inf):

            print(cells[i], j+1)
            print(params_true[i][j])
            print(len(np.where(all_probs[i][j][clean_inds_array[i][j]] >= spont_limit)[0]))

            sampled_inds = np.where(T_prev[j] > 0)[0]
            
            fig = plt.figure(0)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-1.8, 1.8)
            plt.ylim(-1.8, 1.8)
            ax.set_zlim(-1.8, 1.8)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][:, 0], 
                              amps_scan[j][:, 1],
                              amps_scan[j][:, 2], marker='o', 
                              c=probs_true_scan[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()


            print(params_curr[i][j])
            fig = plt.figure(1)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-1.8, 1.8)
            plt.ylim(-1.8, 1.8)
            ax.set_zlim(-1.8, 1.8)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][:, 0], 
                              amps_scan[j][:, 1],
                              amps_scan[j][:, 2], marker='o', 
                              c=probs_curr[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            fig = plt.figure(1)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-1.8, 1.8)
            plt.ylim(-1.8, 1.8)
            ax.set_zlim(-1.8, 1.8)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][sampled_inds, 0], 
                              amps_scan[j][sampled_inds, 1],
                              amps_scan[j][sampled_inds, 2], marker='o', 
                              c=probs_empirical[i][j][sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            input()