In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy.io import loadmat
import visionloader as vl
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.cluster import KMeans, SpectralClustering
import os
import src.fitting as fitting
import src.multielec_utils as mutils
from mpl_toolkits.mplot3d import Axes3D
import cvxpy as cp
from scipy.optimize import minimize
import statsmodels.api as sm

%load_ext autoreload
%autoreload 2
%matplotlib notebook

# Load in triplet stim elecResps and amplitudes

In [2]:
ANALYSIS_BASE = "/Volumes/Analysis"
gsort_path = "/Volumes/Scratch/Users/praful/triplet_gsort_v2_30um_raphe-affinity_cosine"

In [3]:
dataset = "2020-10-18-0"
estim = "data003/data003-all"
electrical_path = os.path.join(ANALYSIS_BASE, dataset, estim)

In [4]:
wnoise = "kilosort_data000/data000"
vis_datapath = os.path.join(ANALYSIS_BASE, dataset, wnoise)
vis_datarun = wnoise.split('/')[-1]
vcd = vl.load_vision_data(
    vis_datapath,
    vis_datarun,
    include_neurons=True,
    include_ei=True,
    include_params=True,
    include_noise=True,
)

coords = vcd.get_electrode_map()

In [5]:
degree = 1
l2_reg = 0.0
clusters_t = 15
clusters_nt = 15
negll_thr = 0.94

In [6]:
p = 2
targets = np.array([12])
nontargets = np.array([6, 11])

p_thr = 2/19
p_upper = 1
random_state = 0

In [None]:
def getWeights(p, n, degree, l2_reg, n_clusters, nll_thr, points_per_cluster=50, selec_vec=np.zeros(3), show_clusters=False):
    poly = PolynomialFeatures(degree)
    filepath = os.path.join(gsort_path, 
                            dataset, estim, wnoise, "p" + str(p))

    amplitudes = mutils.get_stim_amps_newlv(electrical_path, p)
    num_pts = len(amplitudes)

    triplet_probs = np.zeros(num_pts)
    trials = np.zeros(num_pts, dtype=int)
    for k in range(num_pts):
        with open(os.path.join(filepath, "gsort_tri_v2_n" + str(n) + "_p" + str(p) + "_k" + str(k) + ".pkl"), "rb") as f:
            prob_dict = pickle.load(f)
            triplet_probs[k] = prob_dict["cosine_prob"][0]
            trials[k] = prob_dict["num_trials"]

    good_inds = np.where((triplet_probs > p_thr) & (triplet_probs < p_upper))[0]

    y = triplet_probs[good_inds]
    X = amplitudes[good_inds]

    fig = plt.figure()
    fig.clear()
    ax = Axes3D(fig)
    plt.xlabel(r'$I_1$')
    plt.ylabel(r'$I_2$')
    ax.set_zlabel(r'$I_3$')

    scat = ax.scatter(amplitudes[:, 0][good_inds], 
                amplitudes[:, 1][good_inds],
                amplitudes[:, 2][good_inds], marker='o', s=20, c=triplet_probs[good_inds], alpha=0.8)
    if np.any(selec_vec != 0):
        ax.scatter(selec_vec[0], selec_vec[1], selec_vec[2], c='tab:red', marker='*', s=100)

    clb = plt.colorbar(scat)
    plt.show()

    clustering = SpectralClustering(n_clusters=n_clusters, 
                                    assign_labels='discretize',
                                    affinity='nearest_neighbors', 
                                    random_state=random_state).fit(X)

    pp_weights = []
    for i in range(n_clusters):
        inds = np.where(clustering.labels_ == i)[0]

        if show_clusters:
            fig = plt.figure()
            fig.clear()
            ax = Axes3D(fig)
            plt.xlabel(r'$I_1$')
            plt.ylabel(r'$I_2$')
            plt.xlim(-1.8, 1.8)
            plt.ylim(-1.8, 1.8)
            ax.set_zlim(-1.8, 1.8)
            ax.set_zlabel(r'$I_3$')

            scat = ax.scatter(X[inds, 0], 
                        X[inds, 1],
                        X[inds, 2], marker='o', s=20, c=y[inds], alpha=0.8)

            clb = plt.colorbar(scat)
            plt.show()

        X_cluster = X[inds]
        y_cluster = y[inds]
        trials_cluster = trials[good_inds][inds]
        
        X_bin, y_bin = fitting.convertToBinaryClassifier(y_cluster, trials_cluster, X_cluster)
        
        results = minimize(fitting.negLL, x0=np.array([-1, 1, 1, 1]), args=(X_bin, y_bin, False, 'none'))
        print(results.x)

        OLS_X = poly.fit_transform(X_cluster)
        OLS_y = np.log(y_cluster / (1 - y_cluster))

        OLS_w = (np.linalg.inv(l2_reg * np.eye(OLS_X.shape[-1]) + OLS_X.T @ OLS_X) @ OLS_X.T) @ OLS_y
        print(OLS_w)
        
        nll_MLE = fitting.negLL(results.x, X_bin, y_bin, False, 'none')
        nll_OLS = fitting.negLL(OLS_w, X_bin, y_bin, False, 'none')
        
        print(nll_MLE)
        print(nll_OLS)
        
        if show_clusters:
            plt.figure()
            plt.scatter(fitting.fsigmoid(sm.add_constant(X_cluster), OLS_w), y_cluster)
            plt.scatter(fitting.fsigmoid(sm.add_constant(X_cluster), results.x), y_cluster)
            plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100), c='k', linestyle='--')

#         if nll_OLS < nll_thr and OLS_w[0] < 0:
#             pp_weights.append(OLS_w)
        if nll_MLE < nll_thr and results.x[0] < 0 and len(inds) > points_per_cluster:
            pp_weights.append(results.x)

    return np.array(pp_weights)

In [None]:
nt_weights = []
for nt in nontargets:
    print(vcd.get_cell_type_for_cell(nt))
    nt_weights.append(getWeights(p, nt, degree, l2_reg, clusters_nt, negll_thr, show_clusters=True))

nt_weights = np.vstack((nt_weights))

In [None]:
nt_weights

In [None]:
t_weights = []
for t in targets:
    print(vcd.get_cell_type_for_cell(t))
    t_weights.append(getWeights(p, t, degree, l2_reg, clusters_t, negll_thr, show_clusters=True))

t_weights = np.vstack((t_weights))

In [None]:
t_weights

In [None]:
def compute_pt(T):
    return 1/(np.exp(T)+1)

def compute_pn(T):
    return np.exp(T)/(np.exp(T)+1)

In [None]:
I_max = 1.78125
selec_vals = np.zeros(len(t_weights))
x_vals = np.zeros((len(t_weights), 3))
for i in range(len(t_weights)):
    print(i)
    T_t = cp.Variable(1)
    T_n = cp.Variable(1)
    x = cp.Variable(3)

    constraints = [
        -t_weights[i, 0] - t_weights[i,1:] @ x  <= T_t,
        nt_weights[:, 0] + nt_weights[:, 1:] @ x <= T_n,
        x <= I_max,
        -I_max <= x, 
    ]

    objective_function = cp.exp(T_n) + cp.exp(T_t)

    objective = cp.Minimize(objective_function)
    prob = cp.Problem(objective, constraints)
    opt_val  = prob.solve()
    print(f"Optimal value {opt_val}")

    print(f"optimal p_t {compute_pt(T_t.value)[0]}")
    print(f"optimal p_n {compute_pn(T_n.value)[0]}")
    print(f"optimal T_t {T_t.value[0]}")
    print(f"optimal T_n {T_n.value[0]}")
    print(x.value)

    selec_vals[i] = compute_pt(T_t.value)[0] * (1 - compute_pn(T_n.value)[0])
    x_vals[i] = x.value

In [None]:
selec_vals

In [None]:
selec_vals[np.argsort(selec_vals)]

In [None]:
x_vals[np.argsort(selec_vals)]

In [None]:
selec_vec = x_vals[np.argsort(selec_vals)[-1]]

In [None]:
for nt in nontargets:
    getWeights(p, nt, degree, l2_reg, clusters_nt, negll_thr, selec_vec=selec_vec, show_clusters=False)

In [None]:
for t in targets:
    getWeights(p, t, degree, l2_reg, clusters_t, negll_thr, selec_vec=selec_vec, show_clusters=False)