In [3]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy.io import loadmat
import visionloader as vl
import statsmodels.api as sm
import os

# Load in triplet stim elecResps and amplitudes

In [4]:
dataset = '2021-05-27-4/data003'
analysis_base = '/Volumes/Analysis/'
electrical_path = os.path.join(analysis_base, dataset)

In [6]:
vis_datapath = '/Volumes/Analysis/2021-05-27-0/kilosort_data001/data001'
vis_datarun = 'data001'
vcd = vl.load_vision_data(vis_datapath, vis_datarun,
                          include_neurons=True,
                          include_ei=True,
                          include_params=True,
                          include_noise=True)

coords = vcd.get_electrode_map()

In [7]:
p = 24
n = 3352

In [None]:
gsort_base = '/Volumes/Scratch/Users/praful/triplet_gsort_v2_30um_periphery-affinity/' + dataset + 'p' + str(p) + '/'

triplet_elecs = (loadmat(electrical_path + 'pattern_files/p' + str(p) + '.mat', 
                         squeeze_me=True, struct_as_record=False)['patternStruct'].stimElecs)

amplitudes = (loadmat(electrical_path + 'pattern_files/p' + str(p) + '.mat', 
                         squeeze_me=True, struct_as_record=False)['patternStruct'].amplitudes)

# REMOVE THIS FOR ALL OTHER DATASETS OTHER THAN 2021-02-13-6
# amplitudes = amplitudes[:5999]
############################################################

num_pts = len(amplitudes)

target_spike_probs = np.zeros(num_pts)
target_valid = np.zeros(num_pts)
residuals = []
for k in range(num_pts):
    print(k)
    with open(filepath + 'gsort_tri_v2_n'+str(n)+'_p'+str(p)+'_k'+str(k)+'.pkl', 'rb') as f:
        prob_dict = pickle.load(f)
        target_spike_probs[k] = prob_dict['prob']
#         target_valid[k] = prob_dict['valid_graph']

    with open(filepath + 'residual_tri_v2_n'+str(n)+'_p'+str(p)+'_k'+str(k)+'.pkl', 'rb') as f:
        residual = pickle.load(f)
        mean_residual = np.mean(residual['res'], axis=0)
        residuals.append(mean_residual)

residuals = np.array(residuals)
# target_spike_probs = np.array(target_spike_probs)
# target_valid = np.array(target_valid)

In [None]:
xcoords = coords[:, 0]
ycoords = coords[:, 1]

fig,ax = plt.subplots(1,1,figsize=(10,8))

plt.xticks(fontsize=22)
plt.yticks(fontsize=22)

for elec in triplet_elecs:
    txt = ax.text(xcoords[elec-1],ycoords[elec-1],str(elec),color='k',fontsize=22)
    ax.scatter(xcoords[elec-1],
               ycoords[elec-1],marker='*',c='k', s=300)

In [None]:
triplet_elecs

In [None]:
import smart_sigmoid_triplet as sst
import estim_utils_pkv as utils

In [None]:
# current_levels = np.array([-1.78125, -1.59375, -1.40625, -1.21875, -1.03125, -0.84375,
#        -0.65625, -0.46875, -0.28125, -0.09375,  0.09375,  0.28125,
#         0.46875,  0.65625,  0.84375,  1.03125,  1.21875,  1.40625,
#         1.59375,  1.78125])
current_levels = np.unique(amplitudes[:, 0])

In [None]:
plt.figure(0, figsize=(10, 8))
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

noise_thr = 0.2
split_size = 3
free_el = 1
level2 = 18
level1 = 5

fixed_els = np.setdiff1d(np.arange(3), free_el)

fixed_inds = np.where((amplitudes[:, fixed_els[0]] == current_levels[level1]) &
                          (amplitudes[:, fixed_els[1]] == current_levels[level2]))[0]
fixed_inds = fixed_inds[np.argsort(amplitudes[fixed_inds, free_el])]

low_inds = np.where(target_spike_probs[fixed_inds] < noise_thr)[0]
high_inds = np.setdiff1d(np.arange(len(current_levels)), low_inds)
high_split = np.split(high_inds, np.where(np.diff(high_inds) > split_size)[0] + 1)
target_spike_probs_mono = np.copy(target_spike_probs)
xsigmoid = np.linspace(-2, 2, 100)

try: 
    if len(high_split[0]) == 0:
        pass

    elif len(high_split) == 1:
        if len(high_split[0]) == 1 and high_split[0][0] != 0 and high_split[0][0] != len(fixed_inds):
            neg_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                 residuals[fixed_inds[high_split[0][0]]-1], ord='fro')
            pos_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                 residuals[fixed_inds[high_split[0][0]]+1], ord='fro')
            if neg_side_norm > pos_side_norm:
                target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0
            else:
                target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1
                target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0
            
            print(pos_side_norm, neg_side_norm)
            
        else:
            linreg = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                            target_spike_probs[fixed_inds[high_split[0]]])
            if linreg.coef_ < 0:
                target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0

            elif linreg.coef_ > 0:
                target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0
                target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1

            ypred = amplitudes[fixed_inds[high_split[0]], free_el] * linreg.coef_[0] + linreg.intercept_
            plt.plot(amplitudes[fixed_inds[high_split[0]], free_el], ypred, c='tab:blue')
        
        _, popt, _ = utils.fit_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                     target_spike_probs_mono[fixed_inds], 
                                       param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el])])
        sig = utils.sigmoid(xsigmoid, popt[0], popt[1])

        plt.plot(xsigmoid, sig, c='tab:orange')

    elif len(high_split) == 2:
        linreg1 = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                        target_spike_probs[fixed_inds[high_split[0]]])

#         if linreg1.coef_ < 0:
        target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
        target_spike_probs_mono[fixed_inds[high_split[1][-1]+1:]] = 1

        ypred1 = amplitudes[fixed_inds[high_split[0]], free_el] * linreg1.coef_[0] + linreg1.intercept_
        plt.plot(amplitudes[fixed_inds[high_split[0]], free_el], ypred1, c='tab:blue')

        linreg2 = LinearRegression().fit(amplitudes[fixed_inds[high_split[1]], free_el].reshape(-1, 1), 
                                        target_spike_probs[fixed_inds[high_split[1]]])

#         if linreg2.coef_ > 0:
#             target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
#             target_spike_probs_mono[fixed_inds[high_split[1][-1]+1:]] = 1

        ypred2 = amplitudes[fixed_inds[high_split[1]], free_el] * linreg2.coef_[0] + linreg2.intercept_
        plt.plot(amplitudes[fixed_inds[high_split[1]], free_el], ypred2, c='tab:blue')

        _, popt, _ = utils.fit_double_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                     target_spike_probs_mono[fixed_inds],
                     param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el]), 1, 
                                  np.mean(amplitudes[fixed_inds[high_split[1]], free_el])])
        double_sig = utils.double_sigmoid(xsigmoid, popt[0], popt[1], popt[2], popt[3])

        plt.plot(xsigmoid, double_sig, c='tab:orange')

    else:
        raise ValueError('Too many splits.')

    plt.scatter(amplitudes[fixed_inds][:,free_el], target_spike_probs[fixed_inds],
                label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
              + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]), c='tab:blue')
    plt.scatter(amplitudes[fixed_inds][:,free_el], target_spike_probs_mono[fixed_inds],
                label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
              + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]), c='tab:orange')
    
    plt.ylabel('Activation Probability', fontsize=22)
    
    
except:
    plt.scatter(amplitudes[fixed_inds][:,free_el], target_spike_probs[fixed_inds],
                label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
              + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]), c='tab:blue')
    plt.plot(amplitudes[fixed_inds][:,free_el], target_spike_probs_mono[fixed_inds],
                label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
              + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]), c='tab:orange')
    
plt.xlabel(r'$I_' + str(free_el + 1) + '$', fontsize=22)
plt.legend(fontsize=16)

In [None]:
target_spike_probs_mono[fixed_inds]

In [None]:
popt

In [None]:
high_split

In [None]:
fig = plt.figure(0, figsize=(15, 12))
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

noise_thr = 0.2
split_size = 3
free_el = 1
level2 = 4

currentRange = range(20)
xsigmoid = np.linspace(-2, 2, 100)
activation_curve_3d = np.zeros((len(currentRange), len(xsigmoid)))
ax = fig.add_subplot(111)
NUM_COLORS = len(currentRange)
cm = plt.get_cmap('gist_rainbow')
ax.set_prop_cycle('color', [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)])

cnt = 0
for level1 in currentRange:
    fixed_els = np.setdiff1d(np.arange(3), free_el)

    fixed_inds = np.where((amplitudes[:, fixed_els[0]] == current_levels[level1]) &
                              (amplitudes[:, fixed_els[1]] == current_levels[level2]))[0]
    fixed_inds = fixed_inds[np.argsort(amplitudes[fixed_inds, free_el])]

    low_inds = np.where(target_spike_probs[fixed_inds] < noise_thr)[0]
    high_inds = np.setdiff1d(np.arange(len(current_levels)), low_inds)
    high_split = np.split(high_inds, np.where(np.diff(high_inds) > split_size)[0] + 1)
    target_spike_probs_mono = np.copy(target_spike_probs)
    

    try: 
        if len(high_split[0]) == 0:
            pass

        elif len(high_split) == 1:
            if len(high_split[0]) == 1 and high_split[0][0] != 0 and high_split[0][0] != len(fixed_inds):
                neg_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                     residuals[fixed_inds[high_split[0][0]]-1], ord='fro')
                pos_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                     residuals[fixed_inds[high_split[0][0]]+1], ord='fro')
                if neg_side_norm > pos_side_norm:
                    target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                    target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0
                else:
                    target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1
                    target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0

                print(pos_side_norm, neg_side_norm)

            else:
                linreg = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                                target_spike_probs[fixed_inds[high_split[0]]])
                if linreg.coef_ < 0:
                    target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                    target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0

                elif linreg.coef_ > 0:
                    target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0
                    target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1

            _, popt, _ = utils.fit_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                         target_spike_probs_mono[fixed_inds], 
                                           param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el])])
            sig = utils.sigmoid(xsigmoid, popt[0], popt[1])
            activation_curve_3d[cnt] = sig
            
            ax.scatter(amplitudes[fixed_inds][:, free_el], target_spike_probs_mono[fixed_inds],
                       label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
                  + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))
            ax.plot(xsigmoid, sig)

        elif len(high_split) == 2:
            linreg1 = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                            target_spike_probs[fixed_inds[high_split[0]]])

            target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
            target_spike_probs_mono[fixed_inds[high_split[1][-1]+1:]] = 1

            linreg2 = LinearRegression().fit(amplitudes[fixed_inds[high_split[1]], free_el].reshape(-1, 1), 
                                            target_spike_probs[fixed_inds[high_split[1]]])
            _, popt, _ = utils.fit_double_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                         target_spike_probs_mono[fixed_inds],
                         param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el]), 1, 
                                      np.mean(amplitudes[fixed_inds[high_split[1]], free_el])])
            double_sig = utils.double_sigmoid(xsigmoid, popt[0], popt[1], popt[2], popt[3])
            activation_curve_3d[cnt] = double_sig

            ax.scatter(amplitudes[fixed_inds][:, free_el], target_spike_probs_mono[fixed_inds],
                       label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
                  + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))
            ax.plot(xsigmoid, double_sig)

        else:
            raise ValueError('Too many splits.')

#         ax.scatter(amplitudes[fixed_inds][:,free_el], target_spike_probs[fixed_inds],
#                     label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
#                   + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))

    except:
        pass
#         interp = np.interp(xsigmoid, current_levels, target_spike_probs_mono[fixed_inds])
#         ax.plot(xsigmoid, interp)
#         activation_curve_3d[cnt] = interp
#         plt.scatter(amplitudes[fixed_inds][:,free_el], target_spike_probs[fixed_inds],
#                     label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
#                   + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))
#         plt.plot(amplitudes[fixed_inds][:,free_el], target_spike_probs_mono[fixed_inds],
#                     label=r'$I_' + str(fixed_els[0] + 1) + '$ = ' + str(current_levels[level1]) 
#                   + ', ' + '$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))

    cnt += 1
        
plt.ylabel('Activation Probability', fontsize=22) 
plt.xlabel(r'$I_' + str(free_el + 1) + '$', fontsize=22)
plt.legend(fontsize=16)

In [None]:
from mpl_toolkits import mplot3d
%matplotlib notebook
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation

In [None]:
I, S = np.meshgrid(current_levels[currentRange], xsigmoid)

In [None]:
fig = plt.figure(figsize=(10, 8))
ax = plt.axes(projection="3d")
ax.plot_surface(I, S, activation_curve_3d.T)
plt.xlabel(r'$I_' + str(fixed_els[0] + 1) + '$')
plt.ylabel(r'$I_' + str(free_el + 1) + '$')
ax.set_zlabel('Activation Probability')
ax.set_title(r'$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))

In [None]:
def get_act_curve(free_el, target_prob=0.5, noise_thr=0.2, split_size=3):
    currentRange = range(3, 17)
    xsigmoid = np.linspace(-2, 2, 100)
    activation_curve_3d = np.zeros((len(currentRange), len(currentRange), len(xsigmoid)))
    thresholds = []

    for i in range(len(currentRange)):
        level2 = currentRange[i]
        for j in range(len(currentRange)):
            level1 = currentRange[j]
            fixed_els = np.setdiff1d(np.arange(3), free_el)

            fixed_inds = np.where((amplitudes[:, fixed_els[0]] == current_levels[level1]) &
                                      (amplitudes[:, fixed_els[1]] == current_levels[level2]))[0]
            fixed_inds = fixed_inds[np.argsort(amplitudes[fixed_inds, free_el])]

            low_inds = np.where(target_spike_probs[fixed_inds] < noise_thr)[0]
            high_inds = np.setdiff1d(np.arange(len(current_levels)), low_inds)
            high_split = np.split(high_inds, np.where(np.diff(high_inds) > split_size)[0] + 1)
            target_spike_probs_mono = np.copy(target_spike_probs)

            try: 
                if len(high_split[0]) == 0:
                    pass

                elif len(high_split) == 1:
                    if len(high_split[0]) == 1 and high_split[0][0] != 0 and high_split[0][0] != len(fixed_inds):
                        neg_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                             residuals[fixed_inds[high_split[0][0]]-1], ord='fro')
                        pos_side_norm = np.linalg.norm(residuals[fixed_inds[high_split[0][0]]] - 
                                             residuals[fixed_inds[high_split[0][0]]+1], ord='fro')
                        if neg_side_norm > pos_side_norm:
                            target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                            target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0
                        else:
                            target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1
                            target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0

                    else:
                        linreg = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                                        target_spike_probs[fixed_inds[high_split[0]]])
                        if linreg.coef_ < 0:
                            target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                            target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 0

                        elif linreg.coef_ > 0:
                            target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 0
                            target_spike_probs_mono[fixed_inds[high_split[0][-1]+1:]] = 1

                    _, popt, _ = utils.fit_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                                 target_spike_probs_mono[fixed_inds], 
                                                   param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el])])
                    sig = utils.sigmoid(xsigmoid, popt[0], popt[1])
                    activation_curve_3d[i][j] = sig

                elif len(high_split) == 2:
                    linreg1 = LinearRegression().fit(amplitudes[fixed_inds[high_split[0]], free_el].reshape(-1, 1), 
                                                    target_spike_probs[fixed_inds[high_split[0]]])

                    target_spike_probs_mono[fixed_inds[:high_split[0][0]]] = 1
                    target_spike_probs_mono[fixed_inds[high_split[1][-1]+1:]] = 1

                    linreg2 = LinearRegression().fit(amplitudes[fixed_inds[high_split[1]], free_el].reshape(-1, 1), 
                                                    target_spike_probs[fixed_inds[high_split[1]]])
                    _, popt, _ = utils.fit_double_sigmoid(amplitudes[fixed_inds][:, free_el], 
                                                                 target_spike_probs_mono[fixed_inds],
                                 param_guess=[1, np.mean(amplitudes[fixed_inds[high_split[0]], free_el]), 1, 
                                              np.mean(amplitudes[fixed_inds[high_split[1]], free_el])])
                    double_sig = utils.double_sigmoid(xsigmoid, popt[0], popt[1], popt[2], popt[3])
                    activation_curve_3d[i][j] = double_sig

                else:
                    raise ValueError('Too many splits.')


            except:
#                 try:
#                     interp = np.interp(xsigmoid, current_levels, target_spike_probs_mono[fixed_inds])
#                     activation_curve_3d[i][j] = interp
#                 except:
#                     pass
                pass
        
            idx = np.argwhere(np.diff(np.sign(activation_curve_3d[i][j] - target_prob))).flatten()
            if len(idx) >= 1:
                for ind in idx:
                    threshold = np.array([xsigmoid[ind], current_levels[level1], current_levels[level2]])
                    print(threshold)
                    thresholds.append(threshold)

    return activation_curve_3d, np.array(thresholds)

In [None]:
full_act_curve, thresholds = get_act_curve(free_el, target_prob=0.5)

In [None]:
order = np.concatenate((np.array([free_el]), fixed_els))

In [None]:
thresholds_sorted = np.column_stack((thresholds[:, np.where(order == 0)[0][0]], 
                                     thresholds[:, np.where(order == 1)[0][0]],
                                     thresholds[:, np.where(order == 2)[0][0]]))

In [None]:
thresholds_sorted

In [None]:
X, y = sm.add_constant(thresholds_sorted[:, :2]), thresholds_sorted[:, 2]

In [None]:
clf = LinearRegression(fit_intercept=False).fit(X, y)

In [None]:
X @ clf.coef_

In [None]:
fig = plt.figure(figsize=(10, 8))
ax = plt.axes(projection="3d")
ax.scatter(thresholds_sorted[:, 0], thresholds_sorted[:, 1], thresholds_sorted[:, 2])
# ax.scatter(thresholds2_sorted[:, 0], thresholds2_sorted[:, 1], thresholds2_sorted[:, 2])
plt.xlabel(r'$I_1$')
plt.ylabel(r'$I_2$')
ax.set_zlabel(r'$I_3$')
# plt.ylabel(r'$I_' + str(free_el + 1) + '$')
# ax.set_zlabel('Activation Probability')
# ax.set_title(r'$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[level2]))

In [None]:
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from matplotlib.animation import PillowWriter 

In [None]:
# Create a figure and a 3D Axes
fig = plt.figure()
ax = Axes3D(fig)
plt.xlabel(r'$I_1$')
plt.ylabel(r'$I_2$')
ax.set_zlabel(r'$I_3$')

# Create an init function and the animate functions.
# Both are explained in the tutorial. Since we are changing
# the the elevation and azimuth and no objects are really
# changed on the plot we don't have to return anything from
# the init and animate function. (return value is explained
# in the tutorial.
def init():
    ax.scatter(thresholds_sorted[:, 0], 
               thresholds_sorted[:, 1],
               thresholds_sorted[:, 2], marker='o', s=20, c="tab:blue", alpha=0.6)
    
#     ax.scatter(thresholds2_sorted[:, 0], 
#                thresholds2_sorted[:, 1],
#                thresholds2_sorted[:, 2], marker='o', s=20, c="tab:orange", alpha=0.6)
    
    return fig,

def animate(i):
    ax.view_init(elev=10., azim=i)
    return fig,

# Animate
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=360, interval=20, blit=True)
# Save
# anim.save('/Volumes/Lab/Users/praful/thresh_surface_comp_p.gif', writer='imagemagick', fps=30)
plt.show()

In [None]:
fig = plt.figure()
ax = fig.gca(projection= "3d")


def update(frame,fig):
    fig.axes[0].clear()
    if len(fig.axes[0].collections) != 0:
        fig.axes[0].collections = []
        surf = fig.axes[0].plot_surface(I, S, full_act_curve[frame].T)
        ax.set_xlabel(r'$I_' + str(fixed_els[0] + 1) + '$')
        ax.set_ylabel(r'$I_' + str(free_el + 1) + '$')
        ax.set_zlabel('Activation Probability')
        ax.set_title(r'$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[currentRange[frame]]))
    else:
        surf = fig.axes[0].plot_surface(I, S, full_act_curve[frame].T)
        ax.set_xlabel(r'$I_' + str(fixed_els[0] + 1) + '$')
        ax.set_ylabel(r'$I_' + str(free_el + 1) + '$')
        ax.set_zlabel('Activation Probability')
        ax.set_title(r'$I_' + str(fixed_els[1] + 1) + '$ = ' + str(current_levels[currentRange[frame]]))
    fig.canvas.draw()
    return surf,

ani = FuncAnimation(fig,update,fargs=[fig],frames = len(currentRange), blit = True, repeat=True)
fig.show()

ani.save('/Volumes/Lab/Users/praful/test.gif', writer='imagemagick', fps=10)