In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1
os.environ["CUDA_VISIBLE_DEVICES"]= '0'

import numpy as np
import matplotlib.pyplot as plt
import multielec_src.fitting as fitting
import multielec_src.multielec_utils as mutils
import multielec_src.old_labview_data_reader as oldlv
from scipy.io import loadmat
import multiprocessing as mp
import statsmodels.api as sm
from copy import deepcopy, copy
import visionloader as vl
from mpl_toolkits.mplot3d import Axes3D
from gsort.gsort_core_scripts import *
from gsort.gsort_data_loader import *
from estim_utils.triplet import get_triplet_index
from itertools import product

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
RUN_TYPE = "triplet"

#gsort args
ELECTRODE_ORDERING = [get_triplet_index(176, 200, 190), get_triplet_index(517, 5, 6), get_triplet_index(111, 121, 120)]
VISUAL_ANALYSIS_BASE = "/Volumes/Analysis"
DATASET = "2020-10-18-0"
VSTIM_DATARUN = "kilosort_data000/data000"
ESTIM_ANALYSIS_BASE = "/Volumes/Scratch/Users/praful/pp_out_v2"
ESTIM_DATARUN = "data003"
estim_analysis_path = os.path.join(ESTIM_ANALYSIS_BASE, DATASET, ESTIM_DATARUN)
NOISE_THRESH = 2

CELL_TYPES = ['parasol', 'midget']
EXCLUDED_TYPES = ['bad', 'dup']

cell_data_dict, cells_to_gsort_dict, mutual_cells, total_cell_to_electrode_list, \
    END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE = \
        load_vision_data_for_gsort(RUN_TYPE, VISUAL_ANALYSIS_BASE, DATASET, VSTIM_DATARUN, \
            noise_thresh=NOISE_THRESH, cell_types = CELL_TYPES, excluded_types=EXCLUDED_TYPES, patterns=np.array(ELECTRODE_ORDERING))

In [3]:
# def generate_noise(signals, factor=0):

#     noise_tensor = np.zeros(signals.shape)
#     for i in range(signals.shape[0]):
#         noise_tensor[i] = np.random.normal(loc=np.zeros(len(NOISE)), scale=NOISE, 
#                                             size=(signals.shape[2], signals.shape[1])).T
    
#     return noise_tensor * factor

In [4]:
amps_list = []
for i in range(len(ELECTRODE_ORDERING)):
    amps_list.append(mutils.get_stim_amps_newlv(estim_analysis_path, ELECTRODE_ORDERING[i]))

amps_scan = np.array(amps_list)
all_cells = np.sort(np.array(list(cell_data_dict.keys())))

In [5]:
try:
    signal_lengths = np.load(f'signal_lengths_{DATASET}.npy')
    trials_mat_true = np.load(f'trials_mat_true_{DATASET}.npy', allow_pickle=True)

except:
    trials_mat_true = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=object)
    signal_lengths = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=int)
    for i in range(len(ELECTRODE_ORDERING)):
        p = ELECTRODE_ORDERING[i]
        for k in range(len(amps_scan[i])):
            print(p, k)
            # if DATASET == "2020-10-06-7":
            #     signal = np.load(os.path.join(estim_analysis_path, f'p{p}',
            #                                   f'p{p}_m{k}.npy'))
            # else:
            try:
                signal = oldlv.get_oldlabview_pp_data(estim_analysis_path, p, k)
                signal_lengths[i][k] = len(signal)
                trials_mat_true[i][k] = np.arange(len(signal), dtype=int)
            except:
                signal_lengths[i][k] = 0
                trials_mat_true[i][k] = []
                print(f"Signal not found for pattern {p} amplitude {k}")

    np.save(f'signal_lengths_{DATASET}.npy', signal_lengths)
    np.save(f'trials_mat_true_{DATASET}.npy', trials_mat_true)

In [6]:
def init_trials_mat(new_T):
    trials_mat = np.zeros(new_T.shape, dtype=object)

    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            trials_mat[i][j] = []
    
    return trials_mat

In [7]:
def update_trials_mat(trials_mat, new_T, lengths):
    assert new_T.shape == lengths.shape
    assert new_T.shape == trials_mat.shape

    for i in range(len(new_T)):
        for j in range(len(new_T[i])):
            if lengths[i][j] == 0:
                continue 
            
            new_inds = np.random.choice(lengths[i][j], size=int(new_T[i][j]), replace=True)
            
            if len(trials_mat[i][j]) == 0:
                trials_mat[i][j] = new_inds
            else:
                trials_mat[i][j] = np.concatenate([trials_mat[i][j], new_inds])

    return trials_mat

In [8]:
def gsort_trials(p, k, inds):
    if len(inds) == 0:
        return np.zeros(len(cells_to_gsort_dict[p])), np.array(cells_to_gsort_dict[p])
    
    # if DATASET == "2020-10-06-7":
    #     signal = np.load(os.path.join(estim_analysis_path, f'p{p}',
    #                                       f'p{p}_m{k}.npy'))
    # else:
    signal = oldlv.get_oldlabview_pp_data(estim_analysis_path, p, k)
    preloaded_data = (cell_data_dict, cells_to_gsort_dict[p], mutual_cells, total_cell_to_electrode_list, END_TIME_LIMIT, START_TIME_LIMIT, CLUSTER_DELAY, NOISE)
    subsampled = signal[inds, :, :]

    return run_pattern_movie_live(subsampled, preloaded_data), np.array(cells_to_gsort_dict[p])

In [9]:
def gsort_trial_allocation(trials_mat, patterns, cells, NUM_THREADS=24):
    assert len(trials_mat) == len(patterns), "Trials matrix first dimension must match number of patterns"

    all_probs = np.zeros((len(cells), trials_mat.shape[0], trials_mat.shape[1]))

    input_list = []
    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            input_list += [(patterns[i], j, trials_mat[i][j])]
            
    pool = mp.Pool(processes=NUM_THREADS)
    results = pool.starmap_async(gsort_trials, input_list)
    mp_output = results.get()
    pool.close()

    cnt = 0
    for i in range(len(trials_mat)):
        for j in range(len(trials_mat[i])):
            cell_inds = np.searchsorted(cells, mp_output[cnt][1])
            all_probs[cell_inds, i, j] = mp_output[cnt][0]
            cnt += 1

    return all_probs

In [10]:
try:
    probs_true = np.load(f'probs_true_{DATASET}.npy')
except:
    probs_true = gsort_trial_allocation(trials_mat_true, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)
    np.save(f'probs_true_{DATASET}.npy', probs_true)

In [11]:
amps_prior = np.array([np.array(np.meshgrid(np.linspace(-4, 4, 41), 
                                np.linspace(-4, 4, 41),
                                np.linspace(-4, 4, 41))).T.reshape(-1,3)] * len(ELECTRODE_ORDERING))
T_prior = np.ones((amps_prior.shape[0], amps_prior.shape[1]), dtype=float) * 20

In [12]:
def get_rows_with_single_nonzero_column(array, col_index):
    # Initialize a mask with all False values
    mask = np.full(array.shape[0], False)

    # Update the mask to True where the selected column is non-zero
    mask[np.where(array[:, col_index] != 0)] = True

    # Update the mask to False where other columns are non-zero
    for i in range(array.shape[1]):
        if i != col_index:
            mask[np.where(array[:, i] != 0)] = False

    # Get the indices where mask is True
    row_indices = np.where(mask)[0]
    return row_indices


In [13]:
def get_combinations_of_indices(array):
    unique_values = np.unique(array)
    indices = {val: np.where(array == val)[0] for val in unique_values}
    combinations = [np.array(comb) for comb in product(*indices.values())]
    return combinations

In [14]:
def remove_shared_rows(a, b):
    """
    Returns a copy of a numpy array 'a' with the rows that are shared in another numpy array 'b' removed.
    """
    dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    a_void = a.view(dtype).reshape(-1)
    b_void = b.view(dtype).reshape(-1)

    mask = np.isin(a_void, b_void, invert=True)

    return a[mask]

In [15]:
def generate_prior_from_triplet(true_params, prior_amps, prior_trials, threshold_max=10.0):
    probs_pred_prior = fitting.sigmoidND_nonlinear(
                                                    sm.add_constant(prior_amps, has_constant='add'), 
                                                    true_params)

    fig = plt.figure()
    fig.clear()
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)
    plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    plt.xlim(-4, 4)
    plt.ylim(-4, 4)
    ax.set_zlim(-4, 4)
    ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    scat = ax.scatter(prior_amps[:, 0], 
                        prior_amps[:, 1],
                        prior_amps[:, 2], marker='o', 
                        c=probs_pred_prior, s=20, alpha=0.8, vmin=0, vmax=1)
    plt.show()

    plt.figure()
    all_params = []
    all_threshs = []
    all_elecs = []
    for i in range(prior_amps.shape[1]):
        row_inds = get_rows_with_single_nonzero_column(prior_amps, i)
        X = prior_amps[row_inds][:, i]
        probs = probs_pred_prior[row_inds]
        T = prior_trials[row_inds]

        ms = [1, 2]
        w_inits = []
        for m in ms:
            w_init = np.array(np.random.normal(size=(m, 2)))
            w_inits.append(w_init)

        params, _, _ = fitting.fit_surface(X.reshape(-1, 1), probs, T, w_inits)
        pred = fitting.sigmoidND_nonlinear(sm.add_constant(X.reshape(-1, 1), has_constant='add'), params)

        for j in range(len(params)):
            if np.absolute(params[j][0]/params[j][1]) < threshold_max:
                params_vec = np.zeros(prior_amps.shape[1] + 1)
                params_vec[0] = params[j][0]
                params_vec[i+1] = params[j][1]
                all_params.append(params_vec)
                all_threshs.append(-params[j][0]/params[j][1])
                all_elecs.append(i)

        print(params)
        plt.scatter(X, probs)
        plt.plot(X, pred)
        plt.ylim(-0.1, 1.1)

    plt.show()

    all_params = np.array(all_params)
    all_threshs = np.array(all_threshs)
    all_elecs = np.array(all_elecs)

    print(all_params)

    positive_inds = np.where(all_threshs > 0)[0]
    if len(positive_inds) > 0:
        min_pos_ind = np.where(all_threshs == np.min(all_threshs[positive_inds]))
        other_pos_inds = np.setdiff1d(positive_inds, min_pos_ind)

    negative_inds = np.where(all_threshs < 0)[0]
    if len(negative_inds) > 0:
        min_neg_ind = np.where(all_threshs == np.max(all_threshs[negative_inds]))
        other_neg_inds = np.setdiff1d(negative_inds, min_neg_ind)
    
    vectors = []
    vector_threshs = []
    if len(positive_inds) > 0:
        vector = np.zeros(prior_amps.shape[1] + 1)
        vector[0] = all_params[min_pos_ind, 0]
        vector[all_elecs[min_pos_ind]+1] = all_params[min_pos_ind, all_elecs[min_pos_ind]+1]
        for i in other_pos_inds:
            vector[all_elecs[i]+1] = all_params[i, all_elecs[i]+1]/all_params[i, 0]*all_params[min_pos_ind, 0]
        
        vectors.append(vector)
        vector_threshs.append(np.sum(np.absolute(all_threshs[positive_inds])))

    if len(negative_inds) > 0:
        vector = np.zeros(prior_amps.shape[1] + 1)
        vector[0] = all_params[min_neg_ind, 0]
        vector[all_elecs[min_neg_ind]+1] = all_params[min_neg_ind, all_elecs[min_neg_ind]+1]
        for i in other_neg_inds:
            vector[all_elecs[i]+1] = all_params[i, all_elecs[i]+1]/all_params[i, 0]*all_params[min_neg_ind, 0]
        
        vectors.append(vector)
        vector_threshs.append(np.sum(np.absolute(all_threshs[negative_inds])))

    vector_threshs = np.array(vector_threshs)
    sorted_thresh_inds = np.argsort(vector_threshs)
    all_vectors = np.array(vectors)[sorted_thresh_inds]
    
    # if len(np.unique(all_elecs)) > 1:
    #     mixed_vectors = []
    #     mixed_threshs = []
    #     combinations = get_combinations_of_indices(all_elecs)
    #     print(combinations)
    #     for i in range(len(combinations)):
    #         vector = np.zeros(prior_amps.shape[1]+1)
    #         min_ind = combinations[i][np.argmin(np.absolute(all_threshs[combinations[i]]))]
    #         other_inds = np.setdiff1d(combinations[i], min_ind)

    #         vector[0] = all_params[min_ind, 0]
    #         vector[all_elecs[min_ind]+1] = all_params[min_ind, all_elecs[min_ind]+1]
    #         for k in other_inds:
    #             vector[all_elecs[k]+1] = all_params[k, all_elecs[k]+1]/all_params[k, 0]*all_params[min_ind, 0]

    #         mixed_vectors.append(vector)
    #         mixed_threshs.append(all_threshs[combinations[i]])

    #     mixed_vectors = np.array(mixed_vectors)
    #     mixed_threshs = np.array(mixed_threshs)

    #     sorted_inds = np.argsort(np.sum(np.absolute(mixed_threshs), axis=1))
    #     print(mixed_vectors, mixed_threshs, sorted_inds)

    #     new_sorted_inds = []
    #     while len(sorted_inds) > 0:
    #         new_sorted_inds.append(sorted_inds[0])
            
    #         sign1 = np.sign(mixed_vectors[sorted_inds[0]][1])
    #         sign2 = np.sign(mixed_vectors[sorted_inds[0]][2])
    #         sign3 = np.sign(mixed_vectors[sorted_inds[0]][3])

    #         ind = np.where((np.sign(mixed_vectors[sorted_inds][:, 1]) == -sign1) & (np.sign(mixed_vectors[sorted_inds][:, 2]) == -sign2)
    #                  & (np.sign(mixed_vectors[sorted_inds][:, 3]) == -sign3))[0]
            
    #         if len(ind) != 0:
    #             new_sorted_inds.append(sorted_inds[ind[0]])
    #             sorted_inds = np.delete(sorted_inds, ind[0])
            
    #         sorted_inds = np.delete(sorted_inds, 0)

    #     new_sorted_inds = np.array(new_sorted_inds)
    #     print(mixed_vectors, mixed_threshs, new_sorted_inds)
    #     print(mixed_vectors[new_sorted_inds])
        
    #     mixed_vectors_sorted = np.copy(mixed_vectors[new_sorted_inds])
    #     all_vectors = np.vstack((all_vectors, remove_shared_rows(mixed_vectors_sorted, all_vectors)))           




    # two_site_pos_vectors = []
    # if len(positive_inds) > 1:
    #     for i in positive_inds[np.argsort(np.absolute(all_threshs[positive_inds]))]:
    #         for j in np.setdiff1d(positive_inds, i):
    #             vector = np.zeros(prior_amps.shape[1] + 1)
    #             vector[0] = all_params[i, 0]
    #             vector[all_elecs[i]+1] = all_params[i, all_elecs[i]+1]
    #             vector[all_elecs[j]+1] = all_params[j, all_elecs[j]+1]/all_params[j, 0]*all_params[i, 0]
    #             two_site_pos_vectors.append(vector)

    # two_site_neg_vectors = []
    # if len(negative_inds) > 1:
    #     for i in negative_inds[np.argsort(np.absolute(all_threshs[negative_inds]))]:
    #         for j in np.setdiff1d(negative_inds, i):
    #             vector = np.zeros(prior_amps.shape[1] + 1)
    #             vector[0] = all_params[i, 0]
    #             vector[all_elecs[i]+1] = all_params[i, all_elecs[i]+1]
    #             vector[all_elecs[j]+1] = all_params[j, all_elecs[j]+1]/all_params[j, 0]*all_params[i, 0]
    #             two_site_neg_vectors.append(vector)

    # for element1, element2 in zip(two_site_pos_vectors, two_site_neg_vectors):
    #     vectors.append(element1)
    #     vectors.append(element2)

    # # If the lists have different lengths, you can include the remaining elements from the longer list
    # remaining_elements = two_site_pos_vectors[len(two_site_neg_vectors):] if len(two_site_pos_vectors) > len(two_site_neg_vectors) else two_site_neg_vectors[len(two_site_pos_vectors):]
    # vectors.extend(remaining_elements)

    vectors = []
    for i in np.argsort(np.absolute(all_threshs)):
        vector = np.zeros(prior_amps.shape[1] + 1)
        vector[0] = all_params[i, 0]
        vector[all_elecs[i]+1] = all_params[i, all_elecs[i]+1]
        vectors.append(vector)

    all_vectors = np.vstack((all_vectors, remove_shared_rows(np.array(vectors), all_vectors)))  

    print(all_vectors)
    return all_vectors

In [16]:
def generate_cov(prior_mean, bias_var=50, slope_var=50, pcorr=0):

    prior_cov = np.zeros((len(prior_mean.flatten()),
                          len(prior_mean.flatten())))

    # print(prior_mean)

    cov_cnt1 = 0
    cov_cnt2 = 0
    for i in range(len(prior_mean)):
        sub_cov = np.ones((prior_mean.shape[1], prior_mean.shape[1]))
        sub_cov = sub_cov * np.sqrt(bias_var * slope_var) * pcorr
        np.fill_diagonal(sub_cov, slope_var)
        sub_cov[0][0] = bias_var
        # print(sub_cov)

        prior_cov[cov_cnt1:cov_cnt1+prior_mean.shape[1],cov_cnt2:cov_cnt2+prior_mean.shape[1]] = sub_cov
        cov_cnt1 += prior_mean.shape[1]
        cov_cnt2 += prior_mean.shape[1]

    return prior_cov


In [17]:
ms = [1, 2, 3, 4, 5, 6]
spont_limit = 0.2
min_active_inds = 200
regmap = 0
# bad_trials = 5

probs_pred_true = np.zeros(probs_true.shape)

T_prev = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=float)
init_trials = 10
init_amps = 1000

for k in range(len(T_prev)):
    init_inds = np.random.choice(np.arange(len(T_prev[k]), dtype=int), size=init_amps,
                                replace=False)
    T_prev[k][init_inds] = init_trials

trials_mat_prev = init_trials_mat(T_prev)
trials_mat_prev = update_trials_mat(trials_mat_prev, T_prev, signal_lengths)

probs_empirical = gsort_trial_allocation(trials_mat_prev, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)
priors_array = np.zeros((probs_empirical.shape[0], probs_empirical.shape[1]), dtype=object)
params_true_array = np.zeros((probs_empirical.shape[0], probs_empirical.shape[1]), dtype=object)
for i in range(len(probs_true)):
    for j in range(len(probs_true[i])):

        if len(np.where(probs_true[i][j] > spont_limit)[0]) >= min_active_inds:

            # if all_cells[i] != 85 or j+1 != 3:
            #     continue

            print(all_cells[i], j+1)
            clean_inds = mutils.triplet_cleaning(amps_scan[j], probs_true[i][j], signal_lengths[j],
                                                 return_inds=True)
            above_spont = np.where(probs_true[i][j][clean_inds] >= spont_limit)[0]
            if len(above_spont) >= min_active_inds:

                dirty_inds = np.setdiff1d(np.arange(len(amps_scan[j]), dtype=int),
                                          clean_inds)
                probs_true[i][j][dirty_inds] = 0

                w_inits = []
                for m in ms:
                    w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
                    w_inits.append(w_init)
                
                X, probs, T = fitting.disambiguate_fitting(amps_scan[j],
                                                           probs_true[i][j],
                                                           signal_lengths[j],
                                                           w_inits)

                params_true, _, _ = fitting.fit_surface(X, probs, T,
                                                        w_inits)
                print(params_true)
                probs_pred_true[i][j] = fitting.sigmoidND_nonlinear(
                                                            sm.add_constant(amps_scan[j], has_constant='add'), 
                                                            params_true)
                params_true_array[i][j] = params_true
    
                plt.close('all')

                fig = plt.figure(0)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][:, 0], 
                                    amps_scan[j][:, 1],
                                    amps_scan[j][:, 2], marker='o', 
                                    c=probs_true[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                fig = plt.figure(1)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(X[:, 0], 
                                    X[:, 1],
                                    X[:, 2], marker='o', 
                                    c=probs, s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                fig = plt.figure(2)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][:, 0], 
                                    amps_scan[j][:, 1],
                                    amps_scan[j][:, 2], marker='o', 
                                    c=probs_pred_true[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                all_vectors = generate_prior_from_triplet(params_true, amps_prior[j], T_prior[j])
                priors = []
                for m in ms:
                    if m > len(all_vectors):
                        diff_stack = np.zeros((m - len(all_vectors), all_vectors.shape[1]))
                        diff_stack[:, 0] = -50
                        prior_mean = np.vstack((all_vectors, diff_stack))
                    else:
                        prior_mean = all_vectors[:m, :]

                    prior_cov = generate_cov(prior_mean)

                    prior_probs = fitting.sigmoidND_nonlinear(sm.add_constant(amps_prior[j], has_constant='add'), 
                                                                prior_mean)

                    print(prior_mean)
                    fig = plt.figure()
                    fig.clear()
                    ax = Axes3D(fig, auto_add_to_figure=False)
                    fig.add_axes(ax)
                    plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                    plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                    plt.xlim(-4, 4)
                    plt.ylim(-4, 4)
                    ax.set_zlim(-4, 4)
                    ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                    scat = ax.scatter(amps_prior[j][:, 0], 
                                        amps_prior[j][:, 1],
                                        amps_prior[j][:, 2], marker='o', 
                                        c=prior_probs, s=20, alpha=0.8, vmin=0, vmax=1)
                    plt.show()

                    priors.append((prior_mean.flatten(), prior_cov))

                priors_array[i][j] = priors
                
                sampled_inds = np.where(T_prev[j] > 0)[0]
                # unsampled_inds = np.where(T_prev[j] == 0)[0]

                # probs_empirical[i][j][unsampled_inds] = 0
                # T_prev[j][unsampled_inds] = bad_trials

                T = T_prev[j][sampled_inds]
                probs = probs_empirical[i][j][sampled_inds]
                X = amps_scan[j][sampled_inds]

                clean_inds = mutils.triplet_cleaning(X, probs, T, return_inds=True)
                dirty_inds = np.setdiff1d(np.arange(len(X), dtype=int),
                                          clean_inds)
                probs[dirty_inds] = 0
                X, probs, T = fitting.disambiguate_fitting(X, probs, T, w_inits)
                                                            # priors=priors_array[i][j])

                params_sub_MLE, _, _ = fitting.fit_surface(X, probs, T, w_inits)
                params_sub_MAP, _, _ = fitting.fit_surface(X, probs, T, w_inits, reg_method='MAP',
                                                           reg=(regmap, priors_array[i][j]), opt_verbose=True)

                print(params_sub_MLE, params_sub_MAP)

                probs_sub_MLE = fitting.sigmoidND_nonlinear(
                                                            sm.add_constant(amps_scan[j], has_constant='add'), 
                                                            params_sub_MLE)

                probs_sub_MAP = fitting.sigmoidND_nonlinear(
                                                            sm.add_constant(amps_scan[j], has_constant='add'), 
                                                            params_sub_MAP)

                fig = plt.figure(5)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][:, 0], 
                                    amps_scan[j][:, 1],
                                    amps_scan[j][:, 2], marker='o', 
                                    c=probs_sub_MLE, s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                fig = plt.figure(6)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(amps_scan[j][:, 0], 
                                    amps_scan[j][:, 1],
                                    amps_scan[j][:, 2], marker='o', 
                                    c=probs_sub_MAP, s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                fig = plt.figure(7)
                fig.clear()
                ax = Axes3D(fig, auto_add_to_figure=False)
                fig.add_axes(ax)
                plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
                plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
                plt.xlim(-2, 2)
                plt.ylim(-2, 2)
                ax.set_zlim(-2, 2)
                ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

                scat = ax.scatter(X[:, 0], 
                                    X[:, 1],
                                    X[:, 2], marker='o', 
                                    c=probs, s=20, alpha=0.8, vmin=0, vmax=1)
                plt.show()

                input()

if not os.path.isfile(f'params_true_{DATASET}.npy'):
    np.save(f'params_true_{DATASET}.npy', params_true_array)

KeyboardInterrupt: 

In [None]:
def get_performance_array(true_probs, curr_probs, spont_limit=spont_limit, min_inds=min_active_inds):
    
    error = 0
    cnt = 0
    for i in range(len(true_probs)):
        for j in range(len(true_probs[i])):
            if len(np.where(true_probs[i][j] > spont_limit)[0]) >= min_inds:

                error += np.sqrt(np.sum((curr_probs[i][j] - true_probs[i][j])**2) / len(true_probs[i][j]))
                cnt += 1

    return error / cnt

In [None]:
def get_performance_array_CE(true_probs, true_trials, curr_probs, spont_limit=spont_limit, min_inds=min_active_inds):
    
    error = 0
    cnt = 0
    for i in range(len(true_probs)):
        for j in range(len(true_probs[i])):
            if len(np.where(true_probs[i][j] > spont_limit)[0]) >= min_inds:
                error += -((true_trials[j] * (true_probs[i][j] * np.log(curr_probs[i][j])
                                    + (1 - true_probs[i][j]) * np.log(1 - curr_probs[i][j])))
                                     / np.sum(true_trials[j]))
                cnt += 1

    return error / cnt

In [None]:
# total_budget = 50000
num_iters = 5

T_prev = np.zeros((amps_scan.shape[0], amps_scan.shape[1]), dtype=float)
budget = T_prev.shape[0] * T_prev.shape[1] * 0.5 #int(total_budget / num_iters)
reg = None # 20, 50
T_step_size = 0.05 # 0.05, 0.01
T_n_steps = 5000    # 5000

init_trials = 10
init_amps = 1000
disambiguate = False
verbose = True
R2_cutoff = 0
prob_low = 1 / init_trials
min_inds = 50
exploit_factor = 0.75

for i in range(len(T_prev)):
    init_inds = np.random.choice(np.arange(len(T_prev[i]), dtype=int), size=init_amps,
                                 replace=False)
    T_prev[i][init_inds] = init_trials

# T_prev[:, np.where(np.count_nonzero(amps_scan[0], axis=1) <= 1)[0]] = init_trials
T_prev_uniform = deepcopy(T_prev)

trials_mat_prev = init_trials_mat(T_prev)
trials_mat_prev = update_trials_mat(trials_mat_prev, T_prev, signal_lengths)
trials_mat_prev_uniform = deepcopy(trials_mat_prev)

probs_empirical = gsort_trial_allocation(trials_mat_prev, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)
probs_empirical_uniform = deepcopy(probs_empirical)

performances = []
performances_uniform = []
num_samples = []
num_samples_uniform = []

iter_cnt = 0
while True:
    if iter_cnt == 0:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=min_inds,
                                        priors_array=priors_array,
                                        regmap=regmap,
                                        exploit_factor=exploit_factor)

        performance = get_performance_array(probs_pred_true, probs_curr)
        performance_uniform = performance

        w_inits_array_uniform = deepcopy(w_inits_array)
        
    else:
        T_new, w_inits_array, t_final, probs_curr, params_curr = fitting.fisher_sampling_1elec(
                                        probs_empirical, 
                                        T_prev, amps_scan,
                                        T_step_size=T_step_size,
                                        T_n_steps=T_n_steps,
                                        verbose=verbose, budget=budget, ms=ms, reg=reg,
                                        return_probs=True,
                                        # t_final=t_final,
                                        w_inits_array=w_inits_array,
                                        disambiguate=disambiguate,
                                        R2_cutoff=R2_cutoff,
                                        min_prob=prob_low,
                                        min_inds=min_inds,
                                        priors_array=priors_array,
                                        regmap=regmap,
                                        exploit_factor=exploit_factor)
        
        performance = get_performance_array(probs_pred_true, probs_curr)

        input_list_uniform = fitting.generate_input_list(probs_empirical_uniform, amps_scan, 
                                                            T_prev_uniform, w_inits_array_uniform, prob_low,
                                                            priors_array=priors_array, regmap=regmap,
                                                            disambiguate=disambiguate, min_inds=min_inds)

        pool = mp.Pool(processes=24)
        results_uniform = pool.starmap_async(fitting.fit_surface, input_list_uniform)
        mp_output_uniform = results_uniform.get()
        pool.close()

        params_curr_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        w_inits_array_uniform = np.zeros((probs_empirical_uniform.shape[0], probs_empirical_uniform.shape[1]), dtype=object)
        probs_curr_uniform = np.zeros(probs_empirical_uniform.shape)

        cnt = 0
        for i in range(len(probs_empirical_uniform)):
            for j in range(len(probs_empirical_uniform[i])):
                params_curr_uniform[i][j] = mp_output_uniform[cnt][0]
                w_inits_array_uniform[i][j] = mp_output_uniform[cnt][1]
                
                probs_curr_uniform[i][j] = fitting.sigmoidND_nonlinear(
                                        sm.add_constant(amps_scan[j], has_constant='add'), 
                                        params_curr_uniform[i][j])

                cnt += 1

        performance_uniform = get_performance_array(probs_pred_true, probs_curr_uniform)
    
    # for i in range(len(probs_pred_true)):
    #     for j in range(len(probs_pred_true[i])):
    #         if len(np.where(probs_pred_true[i][j] > spont_limit)[0]) >= min_active_inds:

    #             sampled_inds = np.where(T_prev[j] > 0)[0]
    #             print(all_cells[i], j+1)
                
    #             fig = plt.figure(0)
    #             fig.clear()
    #             ax = Axes3D(fig, auto_add_to_figure=False)
    #             fig.add_axes(ax)
    #             plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    #             plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    #             plt.xlim(-2, 2)
    #             plt.ylim(-2, 2)
    #             ax.set_zlim(-2, 2)
    #             ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    #             scat = ax.scatter(amps_scan[j][:, 0], 
    #                             amps_scan[j][:, 1],
    #                             amps_scan[j][:, 2], marker='o', 
    #                             c=probs_pred_true[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
    #             plt.show()

    #             fig = plt.figure(1)
    #             fig.clear()
    #             ax = Axes3D(fig, auto_add_to_figure=False)
    #             fig.add_axes(ax)
    #             plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    #             plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    #             plt.xlim(-2, 2)
    #             plt.ylim(-2, 2)
    #             ax.set_zlim(-2, 2)
    #             ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    #             scat = ax.scatter(amps_scan[j][:, 0], 
    #                             amps_scan[j][:, 1],
    #                             amps_scan[j][:, 2], marker='o', 
    #                             c=probs_curr[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
    #             plt.show()
                
    #             error_cell = np.sqrt(np.sum((probs_curr[i][j] - probs_pred_true[i][j])**2) / len(probs_pred_true[i][j]))
    #             print(f'Error: {error_cell}')

    #             fig = plt.figure(2)
    #             fig.clear()
    #             ax = Axes3D(fig, auto_add_to_figure=False)
    #             fig.add_axes(ax)
    #             plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    #             plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    #             plt.xlim(-2, 2)
    #             plt.ylim(-2, 2)
    #             ax.set_zlim(-2, 2)
    #             ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    #             scat = ax.scatter(amps_scan[j][sampled_inds, 0], 
    #                             amps_scan[j][sampled_inds, 1],
    #                             amps_scan[j][sampled_inds, 2], marker='o', 
    #                             c=probs_empirical[i][j][sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)
    #             plt.show()

    #             clean_inds = mutils.triplet_cleaning(amps_scan[j][sampled_inds],
    #                                                           probs_empirical[i][j][sampled_inds],
    #                                                           T_prev[j][sampled_inds],
    #                                                           return_inds=True)
                
    #             dirty_inds = np.setdiff1d(np.arange(len(sampled_inds), dtype=int),
    #                                       clean_inds)
                
    #             probs_empirical[i][j][sampled_inds[dirty_inds]] = 0

    #             X_clean, p_clean, T_clean = fitting.disambiguate_fitting(amps_scan[j][sampled_inds],
    #                                                                      probs_empirical[i][j][sampled_inds],
    #                                                                      T_prev[j][sampled_inds], 
    #                                                                      w_inits_array[i][j])
    #                                                                     #  priors=priors_array[i][j])

    #             params_MLE, _, _ = fitting.fit_surface(X_clean, p_clean, T_clean, w_inits_array[i][j])
    #             probs_MLE = fitting.sigmoidND_nonlinear(sm.add_constant(amps_scan[j], has_constant='add'), 
    #                                                         params_MLE)

    #             fig = plt.figure(3)
    #             fig.clear()
    #             ax = Axes3D(fig, auto_add_to_figure=False)
    #             fig.add_axes(ax)
    #             plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    #             plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    #             plt.xlim(-2, 2)
    #             plt.ylim(-2, 2)
    #             ax.set_zlim(-2, 2)
    #             ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    #             scat = ax.scatter(X_clean[:, 0], 
    #                             X_clean[:, 1],
    #                             X_clean[:, 2], marker='o', 
    #                             c=p_clean, s=20, alpha=0.8, vmin=0, vmax=1)
    #             plt.show()

    #             fig = plt.figure(4)
    #             fig.clear()
    #             ax = Axes3D(fig, auto_add_to_figure=False)
    #             fig.add_axes(ax)
    #             plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
    #             plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
    #             plt.xlim(-2, 2)
    #             plt.ylim(-2, 2)
    #             ax.set_zlim(-2, 2)
    #             ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

    #             scat = ax.scatter(amps_scan[j][:, 0], 
    #                             amps_scan[j][:, 1],
    #                             amps_scan[j][:, 2], marker='o', 
    #                             c=probs_MLE, s=20, alpha=0.8, vmin=0, vmax=1)
    #             plt.show()

    #             input()

    print(performance, performance_uniform)
    
    performances.append(performance)
    performances_uniform.append(performance_uniform)
    
    num_samples.append(np.sum(T_prev))
    num_samples_uniform.append(np.sum(T_prev_uniform))

    iter_cnt += 1

    if iter_cnt > num_iters:
        break

    T_prev = T_new + T_prev
    trials_mat_prev = update_trials_mat(trials_mat_prev, T_new, signal_lengths)
    probs_empirical = gsort_trial_allocation(trials_mat_prev, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)

    plt.figure()
    plt.plot(T_prev.flatten())
    plt.show()

    random_extra = np.random.choice(len(T_new.flatten()), size=int(np.sum(T_new)), replace=True)
    T_new_uniform = np.array(np.bincount(random_extra, minlength=len(T_new.flatten())).astype(int).reshape(T_new.shape), dtype=float)
    # T_new_uniform = np.ones_like(T_prev_uniform, dtype=float)
    T_prev_uniform = T_prev_uniform + T_new_uniform
    trials_mat_prev_uniform = update_trials_mat(trials_mat_prev_uniform, T_new_uniform, signal_lengths)
    probs_empirical_uniform = gsort_trial_allocation(trials_mat_prev_uniform, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)

In [None]:
baseline_trials = 20
T_prev_baseline = np.ones_like(T_prev, dtype=float) * baseline_trials

trials_mat_prev_baseline = init_trials_mat(T_prev_baseline)
trials_mat_prev_baseline = update_trials_mat(trials_mat_prev_baseline, T_prev_baseline, signal_lengths)

probs_empirical_baseline = gsort_trial_allocation(trials_mat_prev_baseline, ELECTRODE_ORDERING, all_cells, NUM_THREADS=48)

w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
for i in range(len(w_inits_array_baseline)):
    for j in range(len(w_inits_array_baseline[i])):
        w_inits = []

        for m in ms:
            w_init = np.array(np.random.normal(size=(m, amps_scan[j].shape[1]+1)))
            w_inits.append(w_init)

        w_inits_array_baseline[i][j] = w_inits

input_list_baseline = fitting.generate_input_list(probs_empirical_baseline, amps_scan, T_prev_baseline, 
                                                    w_inits_array_baseline, 1 / baseline_trials,
                                                    priors_array=priors_array, regmap=regmap,
                                                    disambiguate=disambiguate, min_inds=min_inds)

pool = mp.Pool(processes=24)
results_baseline = pool.starmap_async(fitting.fit_surface, input_list_baseline)
mp_output_baseline = results_baseline.get()
pool.close()

params_curr_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
w_inits_array_baseline = np.zeros((probs_empirical_baseline.shape[0], probs_empirical_baseline.shape[1]), dtype=object)
probs_curr_baseline = np.zeros(probs_empirical_baseline.shape)

cnt = 0
for i in range(len(probs_empirical_baseline)):
    for j in range(len(probs_empirical_baseline[i])):
        params_curr_baseline[i][j] = mp_output_baseline[cnt][0]
        w_inits_array_baseline[i][j] = mp_output_baseline[cnt][1]
        
        probs_curr_baseline[i][j] = fitting.sigmoidND_nonlinear(
                                sm.add_constant(amps_scan[j], has_constant='add'), 
                                params_curr_baseline[i][j])

        cnt += 1

performance_baseline = get_performance_array(probs_pred_true, probs_curr_baseline)

In [None]:
plt.figure(figsize=(10, 8))

plt.plot(np.array(num_samples)/T_prev.shape[0]/T_prev.shape[1], performances, linewidth=4, c='tab:blue', label='Optimal Sampling')
plt.plot(np.array(num_samples_uniform)/T_prev_uniform.shape[0]/T_prev_uniform.shape[1], performances_uniform, linewidth=4, c='tab:red', label='Random Sampling')

plt.axhline(performance_baseline, c='k', linestyle='--', linewidth=2, label=f'{baseline_trials} Uniform Trials')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Average Trials per Current Level', fontsize=24)
plt.ylabel(r'RMSE', fontsize=24)
plt.legend(fontsize=20)

# plt.savefig(f'triplet_AL_full_{dataset}.png', dpi=300)

In [None]:
for i in range(len(probs_pred_true)):
    for j in range(len(probs_pred_true[i])):
        if len(np.where(probs_pred_true[i][j] > spont_limit)[0]) >= min_active_inds:

            sampled_inds = np.where(T_prev[j] > 0)[0]
            print(all_cells[i], j+1)
            
            fig = plt.figure(0)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            ax.set_zlim(-2, 2)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][:, 0], 
                            amps_scan[j][:, 1],
                            amps_scan[j][:, 2], marker='o', 
                            c=probs_pred_true[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            fig = plt.figure(1)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            ax.set_zlim(-2, 2)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][:, 0], 
                            amps_scan[j][:, 1],
                            amps_scan[j][:, 2], marker='o', 
                            c=probs_curr[i][j], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()
            
            error_cell = np.sqrt(np.sum((probs_curr[i][j] - probs_pred_true[i][j])**2) / len(probs_pred_true[i][j]))
            print(f'Error: {error_cell}')

            fig = plt.figure(2)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            ax.set_zlim(-2, 2)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][sampled_inds, 0], 
                            amps_scan[j][sampled_inds, 1],
                            amps_scan[j][sampled_inds, 2], marker='o', 
                            c=probs_empirical[i][j][sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            clean_inds = mutils.triplet_cleaning(amps_scan[j][sampled_inds],
                                                            probs_empirical[i][j][sampled_inds],
                                                            T_prev[j][sampled_inds],
                                                            return_inds=True)
            
            dirty_inds = np.setdiff1d(np.arange(len(sampled_inds), dtype=int),
                                        clean_inds)
            
            probs_empirical[i][j][sampled_inds[dirty_inds]] = 0

            X_clean, p_clean, T_clean = fitting.disambiguate_fitting(amps_scan[j][sampled_inds],
                                                                        probs_empirical[i][j][sampled_inds],
                                                                        T_prev[j][sampled_inds], 
                                                                        w_inits_array[i][j])
                                                                    #  priors=priors_array[i][j])

            params_MLE, _, _ = fitting.fit_surface(X_clean, p_clean, T_clean, w_inits_array[i][j])
            probs_MLE = fitting.sigmoidND_nonlinear(sm.add_constant(amps_scan[j], has_constant='add'), 
                                                        params_MLE)

            fig = plt.figure(3)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            ax.set_zlim(-2, 2)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(X_clean[:, 0], 
                            X_clean[:, 1],
                            X_clean[:, 2], marker='o', 
                            c=p_clean, s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            fig = plt.figure(4)
            fig.clear()
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            ax.set_zlim(-2, 2)
            ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            scat = ax.scatter(amps_scan[j][:, 0], 
                            amps_scan[j][:, 1],
                            amps_scan[j][:, 2], marker='o', 
                            c=probs_MLE, s=20, alpha=0.8, vmin=0, vmax=1)
            plt.show()

            input()