In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1
os.environ["CUDA_VISIBLE_DEVICES"]= '3'

import numpy as np
import matplotlib.pyplot as plt
import multielec_src.fitting as fitting
import multielec_src.multielec_utils as mutils
import multielec_src.old_labview_data_reader as oldlv
from scipy.io import loadmat
import multiprocessing as mp
import statsmodels.api as sm
from copy import deepcopy, copy
import visionloader as vl
from mpl_toolkits.mplot3d import Axes3D
from gsort.gsort_core_scripts import *
from gsort.gsort_data_loader import *
from itertools import product

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
RUN_TYPE = "single"

#gsort args
VISUAL_ANALYSIS_BASE = "/Volumes/Analysis"
DATASET = "2020-10-18-0"
VSTIM_DATARUN = "kilosort_data000/data000"
ESTIM_ANALYSIS_BASE = "/Volumes/Analysis"
ESTIM_DATARUN = "data001"
estim_analysis_path = os.path.join(ESTIM_ANALYSIS_BASE, DATASET, ESTIM_DATARUN)
NOISE_THRESH = 2
BAD_ELECS_519 = np.array([1, 130, 259, 260, 389, 390, 519], dtype=int)
patterns = np.setdiff1d(np.arange(2, 519, dtype=int), BAD_ELECS_519)

CELL_TYPES = ['parasol']
EXCLUDED_TYPES = ['bad', 'dup']

cell_data_dict, cells_to_gsort_dict, mutual_cells, total_cell_to_electrode_list, \
    END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE = \
        load_vision_data_for_gsort(RUN_TYPE, VISUAL_ANALYSIS_BASE, DATASET, VSTIM_DATARUN, \
            noise_thresh=NOISE_THRESH, cell_types = CELL_TYPES, excluded_types=EXCLUDED_TYPES, patterns=patterns)

In [3]:
# def generate_noise(signals, factor=0):

#     noise_tensor = np.zeros(signals.shape)
#     for i in range(signals.shape[0]):
#         noise_tensor[i] = np.random.normal(loc=np.zeros(len(NOISE)), scale=NOISE, 
#                                             size=(signals.shape[2], signals.shape[1])).T
    
#     return noise_tensor * factor

In [4]:
Ivals = np.array([0.10053543, 0.11310236, 0.11938583, 0.13195276, 0.14451969,                        
                       0.16337008, 0.17593701, 0.1947874 , 0.2136378 , 0.23877165,
                       0.25762205, 0.2780315 , 0.30330709, 0.35385827, 0.37913386,
                       0.42968504, 0.45496063, 0.50551181, 0.55606299, 0.60661417,
                       0.68244094, 0.73299213, 0.8088189 , 0.88464567, 0.98574803,
                       1.10433071, 1.20472441, 1.30511811, 1.40551181, 1.60629921,
                       1.70669291, 1.90748031, 2.10826772, 2.30905512, 2.50984252,
                       2.81102362, 3.11220472, 3.41338583, 3.71456693])

In [5]:
# TODO explicitly remove rows with only one nonzero element

amps_list = []
for i in range(len(patterns)):
    amps_list.append(Ivals.reshape(-1, 1))

amps_scan = np.array(amps_list)
all_cells = np.sort(np.array(list(cell_data_dict.keys())))

In [6]:
try:
    signal_lengths = np.load(f'signal_lengths_{DATASET}_1elec.npy')
    trials_mat_true = np.load(f'trials_mat_true_{DATASET}_1elec.npy', allow_pickle=True)

except:
    trials_mat_true = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=object)
    signal_lengths = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=int)
    for i in range(len(patterns)):
        p = patterns[i]
        for k in range(len(amps_scan[i])):
            print(p, k)

            try:
                signal = oldlv.get_oldlabview_pp_data(estim_analysis_path, p, k)
                signal_lengths[i][k] = len(signal)
                trials_mat_true[i][k] = np.arange(len(signal), dtype=int)
            except:
                signal_lengths[i][k] = 0
                trials_mat_true[i][k] = []
                print(f"Signal not found for pattern {p} amplitude {k}")

    np.save(f'signal_lengths_{DATASET}_1elec.npy', signal_lengths)
    np.save(f'trials_mat_true_{DATASET}_1elec.npy', trials_mat_true)

In [7]:
def init_trials_mat(new_T):
    trials_mat = np.zeros(new_T.shape, dtype=object)

    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            trials_mat[i][j] = []
    
    return trials_mat

In [8]:
def update_trials_mat(trials_mat_, new_T, lengths):
    trials_mat = deepcopy(trials_mat_)
    assert new_T.shape == lengths.shape
    assert new_T.shape == trials_mat.shape

    for i in range(len(new_T)):
        for j in range(len(new_T[i])):
            if lengths[i][j] == 0:
                continue 
            
            new_inds = np.random.choice(lengths[i][j], size=int(new_T[i][j]), replace=True)
            
            if len(trials_mat[i][j]) == 0:
                trials_mat[i][j] = new_inds
            else:
                trials_mat[i][j] = np.concatenate([trials_mat[i][j], new_inds])

    return trials_mat

In [9]:
def gsort_trials(p, k, inds):
    if len(inds) == 0:
        return np.zeros(len(cells_to_gsort_dict[p])), np.array(cells_to_gsort_dict[p])

    signal = oldlv.get_oldlabview_pp_data(estim_analysis_path, p, k)
    preloaded_data = (cell_data_dict, cells_to_gsort_dict[p], mutual_cells, total_cell_to_electrode_list, END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE)
    subsampled = signal[inds, :, :]

    return run_pattern_movie_live(subsampled, preloaded_data), np.array(cells_to_gsort_dict[p])

In [10]:
def gsort_trial_allocation(trials_mat, patterns, cells, NUM_THREADS=24):
    assert len(trials_mat) == len(patterns), "Trials matrix first dimension must match number of patterns"

    all_probs = np.zeros((len(cells), trials_mat.shape[0], trials_mat.shape[1]))

    input_list = []
    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            input_list += [(patterns[i], j, trials_mat[i][j])]
            
    pool = mp.Pool(processes=NUM_THREADS)
    results = pool.starmap_async(gsort_trials, input_list)
    mp_output = results.get()
    pool.close()

    cnt = 0
    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            cell_inds = np.searchsorted(cells, mp_output[cnt][1])
            all_probs[cell_inds, i, j] = mp_output[cnt][0]
            cnt += 1

    return all_probs

In [11]:
ms = [1]
spont_limit = 0.2
zero_prob = 0.01
slope_bound = 10
noise_limit = 0.1
min_int_inds = 3

try:
    probs_true = np.load(f'probs_true_{DATASET}_1elec.npy')
except:
    probs_true = gsort_trial_allocation(trials_mat_true, patterns, all_cells, NUM_THREADS=48)
    np.save(f'probs_true_{DATASET}_1elec.npy', probs_true)

probs_pred_true = np.zeros(probs_true.shape)
params_true_array = np.zeros((probs_true.shape[0], probs_true.shape[1]), dtype=object)
for i in range(len(probs_true)):
    for j in range(len(probs_true[i])):
        if np.amax(probs_true[i][j]) > spont_limit:

            print(all_cells[i], patterns[j])
            good_inds = np.where((probs_true[i][j] > noise_limit) &
                                 (probs_true[i][j] < 1 - noise_limit))[0]
            if len(good_inds) < min_int_inds:
                continue

            X, probs, T = deepcopy(amps_scan[j][good_inds]), deepcopy(probs_true[i][j][good_inds]), deepcopy(signal_lengths[j][good_inds])
            X, probs, T = fitting.get_monotone_probs_and_amps(X, probs, T)
            if len(X) < min_int_inds:
                continue

            w_inits = []
            for m in ms:
                w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
                z = 1 - (1 - zero_prob)**(1/len(w_init))
                w_init[:, 0] = np.clip(w_init[:, 0], None, np.log(z/(1-z)))
                w_init[:, 1:] = np.clip(w_init[:, 1:], -slope_bound, slope_bound)
                w_inits.append(w_init)

            params_true, _, _ = fitting.fit_surface(X, probs, T,
                                                    w_inits)
            print(params_true)
            probs_pred_true[i][j] = fitting.sigmoidND_nonlinear(
                                                        sm.add_constant(amps_scan[j], has_constant='add'), 
                                                        params_true)
            params_true_array[i][j] = params_true

            # plt.close('all')

            # fig = plt.figure(0)
            # fig.clear()
            # plt.scatter(amps_scan[j][:, 0], probs_true[i][j])
            # plt.ylim(-0.1, 1.1)
            # plt.xlim(-0.1, 4.1)
            # plt.xlabel(r'Current $\mu$A')
            # plt.ylabel('Activation Probability')
            # plt.show()

            # fig = plt.figure(0)
            # fig.clear()
            # plt.scatter(X[:, 0], probs)
            # plt.plot(amps_scan[j][:, 0], probs_pred_true[i][j])
            # plt.ylim(-0.1, 1.1)
            # plt.xlim(-0.1, 4.1)
            # plt.xlabel(r'Current $\mu$A')
            # plt.ylabel('Activation Probability')
            # plt.show()

            # input()

1 2
[[-9.00314507  6.15971845]]
1 9
1 508
1 512
[[-4.59511985  1.96092537]]
1 515
2 2
2 3
2 9
[[-6.18028207  2.51848533]]
2 408
2 414
2 507
2 508
2 512
3 2
3 3
3 504
3 508
3 512
3 515
6 3
[[-8.0573527   2.30742655]]
6 5
6 6
[[-12.44486449  10.        ]]
6 10
[[-16.10581783   8.53067618]]
6 15
[[-11.70092952   6.52846535]]
6 19
10 4
[[-4.59511985  4.15081124]]
10 5
11 2
[[-12.11671948   8.04613587]]
11 5
[[-5.39441263  3.0838513 ]]
11 6
[[-7.41521284  3.78338772]]
11 12
11 15
[[-9.68781113  3.48002408]]
11 16
[[-8.66022047  2.65458328]]
11 21
[[-8.55262619  2.90633326]]
11 517
12 3
12 5
[[-15.71753287   8.37124449]]
12 6
12 10
12 12
12 15
12 512
12 517
13 480
[[-5.56028034  1.69206694]]
13 481
[[-5.89741675  1.64623531]]
13 490
13 491
14 5
14 7
[[-13.83058587   6.78208206]]
14 8
[[-12.51005753   6.39043983]]
14 11
[[-4.59511985  1.03308469]]
14 12
16 4
16 11
[[-10.61490298   8.586288  ]]
16 16
16 20
[[-4.59511985  3.90563045]]
16 25
16 29
17 16
17 20
17 25
17 29
[[-6.96010345  4.6617601

In [None]:
np.save(f'params_true_{DATASET}_1elec.npy', params_true_array)

In [None]:
def get_performance_array(true_probs, curr_probs, spont_limit=spont_limit):
    
    error = 0
    cnt = 0
    for i in range(len(true_probs)):
        for j in range(len(true_probs[i])):
            if np.amax(true_probs[i][j]) > spont_limit:

                error += np.sqrt(np.sum((curr_probs[i][j] - true_probs[i][j])**2) / len(true_probs[i][j]))
                cnt += 1

    return error / cnt

In [None]:
num_iters = 5

T_prev = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=float)
budget = T_prev.shape[0] * T_prev.shape[1] * 0.5 #int(total_budget / num_iters)
reg = None # 20, 50
T_step_size = 0.1 # 0.05, 0.01
T_n_steps = 1000    # 5000

init_trials = 10
init_amps = 20
disambiguate = False
verbose = True
R2_cutoff = 0
prob_low = 1 / init_trials
exploit_factor = 0.5

for i in range(len(T_prev)):
    # init_inds = np.random.choice(np.arange(len(T_prev[i]), dtype=int), size=init_amps,
    #                              replace=False)
    # T_prev[i][init_inds] = init_trials
    T_prev[i][::2] = init_trials

T_prev_uniform = deepcopy(T_prev)

trials_mat_prev = init_trials_mat(T_prev)
trials_mat_prev = update_trials_mat(trials_mat_prev, T_prev, signal_lengths)
trials_mat_prev_uniform = deepcopy(trials_mat_prev)

probs_empirical = gsort_trial_allocation(trials_mat_prev, patterns, all_cells, NUM_THREADS=48)
probs_empirical_uniform = deepcopy(probs_empirical)

performances = []
performances_uniform = []
num_samples = []
num_samples_uniform = []

iter_cnt = 0
while True:
    if iter_cnt == 0:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=0,
                                        min_clean_inds=0,
                                        exploit_factor=exploit_factor,
                                        single_elec=True)

        performance = get_performance_array(probs_pred_true, probs_curr)
        performance_uniform = performance

        w_inits_array_uniform = deepcopy(w_inits_array)
        
    else:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        # t_final=t_final,
                                        w_inits_array=w_inits_array,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=0,
                                        min_clean_inds=0,
                                        exploit_factor=exploit_factor,
                                        single_elec=True)
        
        performance = get_performance_array(probs_pred_true, probs_curr)

        input_list_uniform = fitting.generate_input_list(probs_empirical_uniform, amps_scan, 
                                                            T_prev_uniform, w_inits_array_uniform, prob_low,
                                                            disambiguate=disambiguate)

        pool = mp.Pool(processes=24)
        results_uniform = pool.starmap_async(fitting.fit_surface, input_list_uniform)
        mp_output_uniform = results_uniform.get()
        pool.close()

        params_curr_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        w_inits_array_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        probs_curr_uniform = np.zeros(probs_empirical_uniform.shape)

        cnt = 0
        for i in range(len(probs_empirical_uniform)):
            for j in range(len(probs_empirical_uniform[i])):
                params_curr_uniform[i][j] = mp_output_uniform[cnt][0]
                w_inits_array_uniform[i][j] = mp_output_uniform[cnt][1]
                
                probs_curr_uniform[i][j] = fitting.sigmoidND_nonlinear(
                                        sm.add_constant(amps_scan[j], has_constant='add'), 
                                        params_curr_uniform[i][j])

                cnt += 1

        performance_uniform = get_performance_array(probs_pred_true, probs_curr_uniform)
    
    # try:
    #     for i in range(len(probs_pred_true)):
    #         for j in range(len(probs_pred_true[i])):
    #             if np.amax(probs_pred_true[i][j]) > spont_limit:

    #                 sampled_inds = np.where(T_prev[j] > 0)[0]
    #                 print(all_cells[i], j+1)
                    
    #                 fig = plt.figure(0)
    #                 plt.plot(amps_scan[j][:, 0], probs_curr[i][j])
    #                 plt.scatter(amps_scan[j][sampled_inds][:, 0], probs_empirical[i][j][sampled_inds],
    #                             s=T_prev[j][sampled_inds])
    #                 plt.plot(amps_scan[j][:, 0], probs_pred_true[i][j])
    #                 plt.scatter(amps_scan[j], probs_true[i][j])
    #                 plt.ylim(-0.1, 1.1)
    #                 plt.xlim(-0.1, 4.1)
    #                 plt.xlabel(r'Current ($\mu$A)')
    #                 plt.ylabel('Activation Probability')
    #                 plt.show()
                    
    #                 error_cell = np.sqrt(np.sum((probs_curr[i][j] - probs_pred_true[i][j])**2) / len(probs_pred_true[i][j]))
    #                 print(f'Error: {error_cell}')

    #                 X, probs, T = (deepcopy(amps_scan[j][sampled_inds]), deepcopy(probs_empirical[i][j][sampled_inds]), deepcopy(T_prev[j][sampled_inds]))
    #                 probs = fitting.disambiguate_sigmoid(probs, noise_limit=0.1, thr_prob=0.25)

    #                 fig = plt.figure(1)
    #                 plt.scatter(X[:, 0], probs, s=T)
    #                 plt.ylim(-0.1, 1.1)
    #                 plt.xlim(-0.1, 4.1)
    #                 plt.xlabel(r'Current ($\mu$A)')
    #                 plt.ylabel('Activation Probability')
    #                 plt.show()

    #                 input()

    # except KeyboardInterrupt:
    #     pass

    print(performance, performance_uniform)
    
    performances.append(performance)
    performances_uniform.append(performance_uniform)
    
    num_samples.append(np.sum(T_prev))
    num_samples_uniform.append(np.sum(T_prev_uniform))

    iter_cnt += 1

    if iter_cnt > num_iters:
        break

    T_prev = T_new + T_prev
    trials_mat_prev = update_trials_mat(trials_mat_prev, T_new, signal_lengths)
    probs_empirical = gsort_trial_allocation(trials_mat_prev, patterns, all_cells, NUM_THREADS=48)

    plt.figure()
    plt.plot(T_prev.flatten())
    plt.show()

    random_extra = np.random.choice(len(T_new.flatten()), size=int(np.sum(T_new)), replace=True)
    T_new_uniform = np.array(np.bincount(random_extra, minlength=len(T_new.flatten())).astype(int).reshape(T_new.shape), dtype=float)
    # T_new_uniform = np.ones_like(T_prev_uniform, dtype=float)
    T_prev_uniform = T_prev_uniform + T_new_uniform
    trials_mat_prev_uniform = update_trials_mat(trials_mat_prev_uniform, T_new_uniform, signal_lengths)
    probs_empirical_uniform = gsort_trial_allocation(trials_mat_prev_uniform, patterns, all_cells, NUM_THREADS=48)

In [None]:
baseline_trials = 20
T_prev_baseline = np.ones_like(T_prev, dtype=float) * baseline_trials

trials_mat_prev_baseline = init_trials_mat(T_prev_baseline)
trials_mat_prev_baseline = update_trials_mat(trials_mat_prev_baseline, T_prev_baseline, signal_lengths)

probs_empirical_baseline = gsort_trial_allocation(trials_mat_prev_baseline, patterns, all_cells, NUM_THREADS=48)

w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
for i in range(len(w_inits_array_baseline)):
    for j in range(len(w_inits_array_baseline[i])):
        w_inits = []

        for m in ms:
            w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
            z = 1 - (1 - zero_prob)**(1/len(w_init))
            w_init[:, 0] = np.clip(w_init[:, 0], None, np.log(z/(1-z)))
            w_init[:, 1:] = np.clip(w_init[:, 1:], -slope_bound, slope_bound)
            w_inits.append(w_init)

        w_inits_array_baseline[i][j] = w_inits

input_list_baseline = fitting.generate_input_list(probs_empirical_baseline, amps_scan, T_prev_baseline, 
                                                    w_inits_array_baseline, 1 / baseline_trials,
                                                    disambiguate=disambiguate)

pool = mp.Pool(processes=24)
results_baseline = pool.starmap_async(fitting.fit_surface, input_list_baseline)
mp_output_baseline = results_baseline.get()
pool.close()

params_curr_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
probs_curr_baseline = np.zeros(probs_empirical_baseline.shape)

cnt = 0
for i in range(len(probs_empirical_baseline)):
    for j in range(len(probs_empirical_baseline[i])):
        params_curr_baseline[i][j] = mp_output_baseline[cnt][0]
        w_inits_array_baseline[i][j] = mp_output_baseline[cnt][1]
        
        probs_curr_baseline[i][j] = fitting.sigmoidND_nonlinear(
                                sm.add_constant(amps_scan[j], has_constant='add'), 
                                params_curr_baseline[i][j])

        cnt += 1

performance_baseline = get_performance_array(probs_pred_true, probs_curr_baseline)

In [None]:
plt.figure(figsize=(10, 8))

plt.plot(np.array(num_samples)/T_prev.shape[0]/T_prev.shape[1], performances, linewidth=4, c='tab:blue', label='Optimal Sampling')
plt.plot(np.array(num_samples_uniform)/T_prev_uniform.shape[0]/T_prev_uniform.shape[1], performances_uniform, linewidth=4, c='tab:red', label='Random Sampling')

plt.axhline(performance_baseline, c='k', linestyle='--', linewidth=2, label=f'{baseline_trials} Uniform Trials')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Average Trials per Current Level', fontsize=24)
plt.ylabel(r'RMSE', fontsize=24)
plt.legend(fontsize=20)

# plt.savefig(f'triplet_AL_full_{dataset}.png', dpi=300)

In [None]:
for i in range(len(probs_pred_true)):
    for j in range(len(probs_pred_true[i])):
        if np.amax(probs_pred_true[i][j]) > spont_limit:

            sampled_inds = np.where(T_prev[j] > 0)[0]
            print(all_cells[i], j+1)
            
            fig = plt.figure(0)
            plt.plot(amps_scan[j][:, 0], probs_curr[i][j])
            plt.scatter(amps_scan[j][sampled_inds][:, 0], probs_empirical[i][j][sampled_inds],
                        s=T_prev[j][sampled_inds])
            plt.plot(amps_scan[j][:, 0], probs_pred_true[i][j])
            plt.scatter(amps_scan[j], probs_true[i][j])
            plt.ylim(-0.1, 1.1)
            plt.xlim(-0.1, 4.1)
            plt.xlabel(r'Current ($\mu$A)')
            plt.ylabel('Activation Probability')
            plt.show()
            
            error_cell = np.sqrt(np.sum((probs_curr[i][j] - probs_pred_true[i][j])**2) / len(probs_pred_true[i][j]))
            print(f'Error: {error_cell}')

            X, probs, T = (deepcopy(amps_scan[j][sampled_inds]), deepcopy(probs_empirical[i][j][sampled_inds]), deepcopy(T_prev[j][sampled_inds]))
            X, probs, T = fitting.get_monotone_probs_and_amps(X, probs, T)

            fig = plt.figure(1)
            plt.scatter(X[:, 0], probs, s=T)
            plt.ylim(-0.1, 1.1)
            plt.xlim(-0.1, 4.1)
            plt.xlabel(r'Current ($\mu$A)')
            plt.ylabel('Activation Probability')
            plt.show()

            input()