In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy.io import loadmat
import visionloader as vl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.decomposition import PCA
import statsmodels.api as sm
import os
import src.fitting as fitting
import src.multielec_utils as mutils
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
from scipy.optimize import minimize
from sklearn.cluster import SpectralClustering

%load_ext autoreload
%autoreload 2
%matplotlib ipympl

# Load in triplet stim elecResps and amplitudes

In [None]:
ANALYSIS_BASE = "/Volumes/Analysis"
gsort_path = "/Volumes/Scratch/Users/praful/triplet_gsort_v2_30um_periphery-affinity_cosine"

In [None]:
dataset = "2020-10-18-5"
estim = "data006/data006-all_v2"
electrical_path = os.path.join(ANALYSIS_BASE, dataset, estim)

In [None]:
wnoise = "kilosort_data002/data002"
vis_datapath = os.path.join(ANALYSIS_BASE, dataset, wnoise)
vis_datarun = wnoise.split('/')[-1]
vcd = vl.load_vision_data(
    vis_datapath,
    vis_datarun,
    include_neurons=True,
    include_ei=True,
    include_params=True,
    include_noise=True,
)

coords = vcd.get_electrode_map()

In [None]:
patterns = np.array([2])
n = 220

p_thr = 0.08

In [None]:
all_elecs = []
neg_inds_total = 0
pos_inds_total = 0

for i in range(len(patterns)):
    p = patterns[i]
    print("Pattern " + str(p))
    filepath = os.path.join(gsort_path, 
                            dataset, estim, wnoise, "p" + str(p))

    triplet_elecs = mutils.get_stim_elecs_newlv(electrical_path, p)
    amplitudes = mutils.get_stim_amps_newlv(electrical_path, p)

    num_pts = len(amplitudes)

    triplet_probs = np.zeros(num_pts)
    for k in range(num_pts):
        with open(os.path.join(filepath, "gsort_tri_v2_n" + str(n) + "_p" + str(p) + "_k" + str(k) + ".pkl"), "rb") as f:
            prob_dict = pickle.load(f)
            triplet_probs[k] = prob_dict["cosine_prob"][0]

    neg_inds = np.where((np.all(amplitudes < 0, axis=1)) & (triplet_probs > p_thr))[0]
    pos_inds = np.where((np.all(amplitudes > 0, axis=1)) & (triplet_probs > p_thr))[0]
    good_inds = np.where(triplet_probs > p_thr)[0]
    all_elecs.append(triplet_elecs)

    pos_inds_total += len(pos_inds)
    neg_inds_total += len(neg_inds)

In [None]:
len(good_inds)

In [None]:
fig = plt.figure()
fig.clear(0)
ax = Axes3D(fig)
plt.xlabel(r'$I_1$')
plt.ylabel(r'$I_2$')
ax.set_zlabel(r'$I_3$')

scat = ax.scatter(amplitudes[:, 0][good_inds], 
            amplitudes[:, 1][good_inds],
            amplitudes[:, 2][good_inds], marker='o', s=20, c=triplet_probs[good_inds], alpha=0.8)

clb = plt.colorbar(scat)
plt.show()

In [None]:
clustering = SpectralClustering(n_clusters=2, assign_labels='discretize', affinity='nearest_neighbors').fit(amplitudes[good_inds])

In [None]:
fig = plt.figure(1)
fig.clear()
ax = Axes3D(fig)
plt.xlabel(r'$I_1$')
plt.ylabel(r'$I_2$')
ax.set_zlabel(r'$I_3$')

scat = ax.scatter(amplitudes[:, 0][good_inds], 
            amplitudes[:, 1][good_inds],
            amplitudes[:, 2][good_inds], marker='o', s=20, c=clustering.labels_, alpha=0.8)

plt.show()

In [None]:
# # Create a figure and a 3D Axes
# fig = plt.figure(1)
# ax = Axes3D(fig)
# plt.xlabel(r'$I_1$')
# plt.ylabel(r'$I_2$')
# ax.set_zlabel(r'$I_3$')

# # Create an init function and the animate functions.
# # Both are explained in the tutorial. Since we are changing
# # the the elevation and azimuth and no objects are really
# # changed on the plot we don't have to return anything from
# # the init and animate function. (return value is explained
# # in the tutorial.
# def init():
#     ax.scatter(amplitudes[:, 0][good_inds], 
#                amplitudes[:, 1][good_inds],
#                amplitudes[:, 2][good_inds], marker='o', s=20, c=triplet_probs[good_inds], alpha=0.8)
    
#     return fig,

# def animate(i):
#     ax.view_init(elev=10., azim=i)
#     return fig,

# # Animate
# anim = animation.FuncAnimation(fig, animate, init_func=init,
#                                frames=360, interval=20, blit=True)
# # Save
# # anim.save('/Volumes/Lab/Users/praful/thresh_surface_comp_p.gif', writer='imagemagick', fps=30)
# plt.show()

In [None]:
pos_inds_total, neg_inds_total

In [None]:
all_elecs_array = np.unique(np.array(all_elecs))

In [None]:
all_elecs_array

In [None]:
X_pos = []
y_pos = []

X_neg = []
y_neg = []

trials_pos = []
trials_neg = []

In [None]:
for i in range(len(patterns)):
    p = patterns[i]
    print("Pattern " + str(p))
    filepath = os.path.join(gsort_path, 
                            dataset, estim, wnoise, "p" + str(p))

    triplet_elecs = mutils.get_stim_elecs_newlv(electrical_path, p)
    amplitudes = mutils.get_stim_amps_newlv(electrical_path, p)

    num_pts = len(amplitudes)

    triplet_probs = np.zeros(num_pts)
    triplet_trials = np.zeros(num_pts, dtype=int)
    for k in range(num_pts):
        with open(os.path.join(filepath, "gsort_tri_v2_n" + str(n) + "_p" + str(p) + "_k" + str(k) + ".pkl"), "rb") as f:
            prob_dict = pickle.load(f)
            triplet_probs[k] = prob_dict["cosine_prob"][0]
            triplet_trials[k] = prob_dict["num_trials"]

    neg_inds = np.where((np.all(amplitudes < 0, axis=1)) & (triplet_probs > p_thr))[0]
    pos_inds = np.where((np.all(amplitudes > 0, axis=1)) & (triplet_probs > p_thr))[0]
    
    elec_inds = np.searchsorted(all_elecs_array, triplet_elecs)

    y = triplet_probs[pos_inds]
    X = amplitudes[pos_inds]
    trials = triplet_trials[pos_inds]
    X_pos_p = np.zeros((len(pos_inds), len(all_elecs_array)))
    X_pos_p[:, elec_inds] = X
    X_pos.append(X_pos_p)
    y_pos.append(y)
    trials_pos.append(trials)

    y = triplet_probs[neg_inds]
    X = amplitudes[neg_inds]
    trials = triplet_trials[neg_inds]
    X_neg_p = np.zeros((len(neg_inds), len(all_elecs_array)))
    X_neg_p[:, elec_inds] = X
    X_neg.append(X_neg_p)
    y_neg.append(y)
    trials_neg.append(trials)

    print(all_elecs_array)
    print(triplet_elecs)
    print(elec_inds)
    print(amplitudes[neg_inds])
    print(X_neg_p)
    print(y)
    print(trials)

In [None]:
all_amps_neg = np.vstack(X_neg)
all_probs_neg = np.hstack(y_neg)
trials_neg = np.hstack(trials_neg)

In [None]:
all_amps_pos = np.vstack(X_pos)
all_probs_pos = np.hstack(y_pos)
trials_pos = np.hstack(trials_pos)

In [None]:
all_amps_pos

In [None]:
all_amps_pos.shape, all_probs_pos.shape, trials_pos.shape

In [None]:
test_size = 0.2
train_amps_pos, test_amps_pos, train_probs_pos, test_probs_pos, train_trials_pos, _ = train_test_split(all_amps_pos, all_probs_pos, trials_pos, test_size=test_size)
train_amps_neg, test_amps_neg, train_probs_neg, test_probs_neg, train_trials_neg, _ = train_test_split(all_amps_neg, all_probs_neg, trials_neg, test_size=test_size)

In [None]:
degree = 4
interaction = True
multi_X, multi_y = fitting.convertToBinaryClassifier(
    train_probs_neg, train_trials_neg, train_amps_neg, degree, interaction
)

In [None]:
multi_X.shape, multi_y.shape

In [None]:
mu = np.concatenate((np.array([1]), np.zeros(multi_X.shape[-1] - 1)))
multi_results = minimize(fitting.negLL, x0=mu, args=(multi_X, multi_y, False, "none"))
multi_weights = multi_results.x

In [None]:
multi_weights

In [None]:
if interaction:
    poly = PolynomialFeatures(degree)
    test_X = poly.fit_transform(test_amps_neg)
    train_X = poly.fit_transform(train_amps_neg)

else:
    test_X = fitting.noInteractionPoly(test_amps_neg, degree)
    train_X = fitting.noInteractionPoly(train_amps_neg, degree)

test_y = test_probs_neg
train_y = train_probs_neg

plt.figure(2)
plt.figure(2).clear()
plt.scatter(train_X @ multi_weights, train_y, label='Train')
plt.scatter(test_X @ multi_weights, test_y, label='Test')
plt.xlabel(r'$w^Tx$', fontsize=16)
plt.ylabel('Activation Probability', fontsize=16)
sigmoid_x = np.linspace(-4, 4, 100)
plt.xlim(-5, 5)
plt.plot(sigmoid_x, 1 / (1 + np.exp(-sigmoid_x)))
plt.legend(fontsize=14)
plt.show()

In [None]:
degree = 4
interaction = True
multi_X, multi_y = fitting.convertToBinaryClassifier(
    train_probs_pos, train_trials_pos, train_amps_pos, degree, interaction
)

In [None]:
multi_X.shape, multi_y.shape

In [None]:
mu = np.concatenate((np.array([1]), np.zeros(multi_X.shape[-1] - 1)))
multi_results = minimize(fitting.negLL, x0=mu, args=(multi_X, multi_y, False, "none"))
multi_weights = multi_results.x

In [None]:
multi_weights

In [None]:
if interaction:
    poly = PolynomialFeatures(degree)
    test_X = poly.fit_transform(test_amps_pos)
    train_X = poly.fit_transform(train_amps_pos)
    
else:
    test_X = fitting.noInteractionPoly(test_amps_pos, degree)
    train_X = fitting.noInteractionPoly(train_amps_pos, degree)

test_y = test_probs_pos
train_y = train_probs_pos

plt.figure(3)
plt.figure(3).clear()
plt.scatter(train_X @ multi_weights, train_y, label='Train')
plt.scatter(test_X @ multi_weights, test_y, label='Test')
plt.xlabel(r'$w^Tx$', fontsize=16)
plt.ylabel('Activation Probability', fontsize=16)
sigmoid_x = np.linspace(-4, 4, 100)
plt.plot(sigmoid_x, 1 / (1 + np.exp(-sigmoid_x)))
plt.legend(fontsize=14)
plt.show()