In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import src.fitting as fitting
import src.multielec_utils as mutils
import statsmodels.api as sm
import jax
import jax.numpy as jnp
import multiprocessing as mp
from itertools import product

%matplotlib inline
%load_ext autoreload
%autoreload 2

# jax.config.update('jax_platform_name', 'cpu')

In [None]:
def activation_probs(x, w):
    # w : site weights, n x d
    # x : current levels, c x d
    site_activations = jnp.dot(w, jnp.transpose(x)) # dimensions: n x c
    p_sites = jax.nn.sigmoid(site_activations) # dimensions : n x c
    p = 1 - jnp.prod(1 - p_sites, 0)  # dimensions: c

    return p

In [None]:
# Path definitions
ANALYSIS_BASE = "/Volumes/Analysis"
MATFILE_BASE = "/Volumes/Scratch/Users/praful/triplet_gsort_matfiles_20220420"
gsort_path = None
gsort_path_1elec = "/Volumes/Scratch/Users/praful/single_gsort_v2_30um_periphery-affinity_cosine"

dataset = "2020-10-06-7"
estim = "data003/data003-all"
estim_1elec = "data001"
wnoise = "kilosort_data000/data000"
electrical_path = os.path.join(ANALYSIS_BASE, dataset, estim)
vis_datapath = os.path.join(ANALYSIS_BASE, dataset, wnoise)

p = 1

X_expt_orig = mutils.get_stim_amps_newlv(electrical_path, p)
w_true = jnp.array([[-5.98518703, -5.73843676, -1.36037982, -0.05980741],
       [-5.98518703, -2.28047189, -2.93318102, -4.31001908],
       [-5.98518703,  5.39557745,  1.95279497,  1.8031558 ],
       [-5.98518703, -0.25671708,  2.89097144,  3.80746902]])
# w_true = jnp.array([[-5.68501006,  2.44477339,  3.23685565,  2.75812431],
#        [-5.66911426, -2.57285102, -3.49945348, -2.8179713 ]])

X = jnp.array(sm.add_constant(X_expt_orig, has_constant='add'))
p_true = activation_probs(X, w_true) # prob with each current level

In [None]:
fig = plt.figure()
fig.clear()
ax = Axes3D(fig, auto_add_to_figure=False)
fig.add_axes(ax)
plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
plt.xlim(-1.8, 1.8)
plt.ylim(-1.8, 1.8)
ax.set_zlim(-1.8, 1.8)
ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

scat = ax.scatter(X_expt_orig[:, 0], 
            X_expt_orig[:, 1],
            X_expt_orig[:, 2], marker='o', c=p_true, s=20, alpha=0.8, vmin=0, vmax=1)

In [None]:
def sample_spikes(p_true, t):
    p_true, t = np.array(p_true), np.array(t).astype(int)
    
    p_empirical = []
    for i in range(len(p_true)):
        if t[i] == 0:
            p_empirical += [0.5]
        
        else:
            p_empirical += [np.mean(np.random.choice(np.array([0, 1]), 
                                                 p=np.array([1-p_true[i], p_true[i]]), 
                                                 size=t[i]))]
        
    return p_empirical

In [None]:
def neg_log_likelihood(w, x, y, trials, l2_reg=0):
    # x : current levels, c x d
    # w : site weights, n x d
    # y : empirical probability for each current level, c
    # trials: number of trials at each current level, c
    # l2_reg: l2 regularization penalty
    # w = w.reshape(-1, x.shape[-1])  # dimensions: n x d
    
    p_model = activation_probs(x, w) # dimensions: c
    p_model = jnp.clip(p_model, a_min=1e-5, a_max=1-1e-5)

    trials = trials.astype(int)
    
    nll = -jnp.sum(trials * y * jnp.log(p_model) + trials * (1 - y) * jnp.log(1 - p_model))

    penalty = l2_reg/2 * jnp.linalg.norm(w)**2

    return nll + penalty

In [None]:
def optimize_w(x, w, y, trials, l2_reg=0, zero_prob=0.01, step_size=0.0001, n_steps=100, wtol=1e-4):

    m = len(w)
    z = 1 - (1 - zero_prob)**(1/m)

    @jax.jit
    def update(x, w, y, trials, l2_reg):
        grads = jax.grad(neg_log_likelihood)(w, x, y, trials, l2_reg=l2_reg)
        return grads

    losses = []
    prev_w = w
    for step in range(n_steps):
        grad = update(x, w, y, trials, l2_reg)
        w = w - step_size * grad
        losses += [neg_log_likelihood(w, x, y, trials, l2_reg=l2_reg)]
        w = w.at[:, 0].set(jnp.minimum(w[:, 0], np.log(z/(1-z))))

        # print(step, jnp.linalg.norm(w - prev_w) / len(w.ravel()), jnp.linalg.norm(grad) / len(w.ravel()))
        if jnp.linalg.norm(w - prev_w) / len(w.ravel()) <= wtol:
            break
        prev_w = w
        
    return losses, w

In [None]:
def fisher_info(x, w, y, t):
    # x : current levels, c x d
    # w : site weights, n x d
    # y : empirical probability for each current level, c
    # t: number of trials for each current level, c
    
    p_model = jnp.clip(activation_probs(x, w), a_min=1e-5, a_max=1-1e-5) # c
    I_p = jnp.diag(t / (p_model * (1 - p_model)))   # c x c
    J = jax.jacfwd(activation_probs, argnums=1)(x, w).reshape((len(x), w.shape[0]*w.shape[1]))
    I_w = jnp.dot(jnp.dot(J.T, I_p), J) / len(x)
    
    loss = jnp.trace(J @ (jnp.linalg.inv(I_w) @ J.T))
    # sign, logdet = jnp.linalg.slogdet(I_w)
    # loss = -sign * logdet
    return loss

In [None]:
def euclidean_proj_simplex(v, s=1):
    """ Compute the Euclidean projection on a positive simplex
    Solves the optimisation problem (using the algorithm from [1]):
        min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 
    Parameters
    ----------
    v: (n,) numpy array,
       n-dimensional vector to project
    s: int, optional, default: 1,
       radius of the simplex
    Returns
    -------
    w: (n,) numpy array,
       Euclidean projection of v on the simplex
    Notes
    -----
    The complexity of this algorithm is in O(n log(n)) as it involves sorting v.
    Better alternatives exist for high-dimensional sparse vectors (cf. [1])
    However, this implementation still easily scales to millions of dimensions.
    References
    ----------
    [1] Efficient Projections onto the .1-Ball for Learning in High Dimensions
        John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra.
        International Conference on Machine Learning (ICML 2008)
        http://www.cs.berkeley.edu/~jduchi/projects/DuchiSiShCh08.pdf
    """
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # check if we are already on the simplex
    if v.sum() == s and np.alltrue(v >= 0):
        # best projection: itself!
        return v
    # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    # get the number of > 0 components of the optimal solution
    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
    # compute the Lagrange multiplier associated to the simplex constraint
    theta = (cssv[rho] - s) / (rho + 1.0)
    # compute the projection by thresholding v using theta
    w = (v - theta).clip(min=0)
    return w

In [None]:
def optimize_fisher(x, w, y, t_prev, t, reg=0, step_size=0.001, n_steps=100, step_cnt_decrement=5, reltol=-np.inf, T_budget=5000):

    @jax.jit
    def update(x, w, y, t_prev, t):
        fisher_lambda = lambda t, x, w, y, t_prev: fisher_info(x, w, y, t_prev + jnp.absolute(t)) + reg * (jnp.absolute(jnp.sum(jnp.absolute(t)) - T_budget))
        grads = jax.grad(fisher_lambda)(t, x, w, y, t_prev)

        return grads
    
    losses = []
    last_t = t
    for step in range(n_steps):
        grad = update(x, w, y, t_prev, t)
        t = t - step_size * grad
    
        losses += [[fisher_info(x, w, y, t_prev + jnp.absolute(t)), 
                    jnp.sum(jnp.absolute(t)),
                    fisher_info(x, w, y, t_prev + jnp.absolute(t)) + reg * (jnp.absolute(jnp.sum(jnp.absolute(t)) - T_budget))]]

        # curr_loss = fisher_info(x, w, y, t_prev + jnp.absolute(t)) + reg * jnp.sum(jnp.absolute(t))
        # last_loss = fisher_info(x, w, y, t_prev + jnp.absolute(last_t)) + reg * jnp.sum(jnp.absolute(last_t))
        # rel_decrease = jnp.absolute(curr_loss - last_loss) / last_loss
        # scaled_grad = jnp.linalg.norm(grad) / len(t)
        # # print(step, jnp.absolute(curr_loss - last_loss) / last_loss, jnp.linalg.norm(grad) / len(t))
        # if rel_decrease <= reltol:
        #     break

        # last_t = t
        if step % step_cnt_decrement == 0:
            step_size = step_size * 0.95
        
    return np.array(losses), t

In [None]:
def binary_search_fisher(reg_array, low, high, x, w, y, t_prev, t, T_budget, budget_tol=0.1, step_size=0.001, n_steps=100, step_cnt_decrement=5, reltol=-np.inf):
    losses, t = optimize_fisher(x, w, y, t_prev, t, reg=reg_array[high], step_size=step_size, n_steps=n_steps, step_cnt_decrement=step_cnt_decrement, reltol=reltol)
    budget_high = jnp.sum(jnp.round(jnp.absolute(t), 0))
    print(reg_array[high], budget_high)

    if budget_high <= (1 + budget_tol) * T_budget and budget_high >= (1 - budget_tol) * T_budget:
        return losses, t, high

    losses, t = optimize_fisher(x, w, y, t_prev, t, reg=reg_array[low], step_size=step_size, n_steps=n_steps, step_cnt_decrement=step_cnt_decrement, reltol=reltol)
    budget_low = jnp.sum(jnp.round(jnp.absolute(t), 0))
    print(reg_array[low], budget_low)

    if budget_low <= (1 + budget_tol) * T_budget and budget_low >= (1 - budget_tol) * T_budget:
        return losses, t, low

    if budget_high > T_budget and budget_low < T_budget:
        mid = (high + low) // 2

        losses, t = optimize_fisher(x, w, y, t_prev, t, reg=reg_array[mid], step_size=step_size, n_steps=n_steps, step_cnt_decrement=step_cnt_decrement, reltol=reltol)
        budget_mid = jnp.sum(jnp.round(jnp.absolute(t), 0))
        print(reg_array[mid], budget_mid)
        if budget_mid <= (1 + budget_tol) * T_budget and budget_mid >= (1 - budget_tol) * T_budget:
            return losses, t, mid

        elif budget_mid > T_budget:
            return binary_search_fisher(reg_array, low+1, mid-1, x, w, y, t_prev, t, T_budget, budget_tol=budget_tol, step_size=step_size, n_steps=n_steps, 
                                        step_cnt_decrement=step_cnt_decrement, reltol=reltol)
        
        else:
            return binary_search_fisher(reg_array, mid+1, high-1, x, w, y, t_prev, t, T_budget, budget_tol=budget_tol, step_size=step_size, n_steps=n_steps, 
                                        step_cnt_decrement=step_cnt_decrement, reltol=reltol)

    else:
        return -1
    

In [None]:
def get_performance_AL(X, w_meas, p_true):
    probs_pred = activation_probs(X, w_meas)
    RMSE = jnp.sqrt(jnp.sum((probs_pred - p_true)**2) / len(X))

    return RMSE

In [None]:
%matplotlib inline

total_budget = 10000
num_iters = 5
budget = int(total_budget / num_iters)
reg = 10#np.flip(np.logspace(-5, 3, 100000, base=2))
ms = [2, 3, 4, 5]
num_restarts = 100
l2_reg = 0.01
R2_thresh = 0.02

init_size = 200
init_trials = 5

performance_stack = []
performance_stack_random = []
num_samples_stack = []

for restart in range(num_restarts):
    print('Restart', restart + 1)
    # Initialize amplitudes
    init_inds = np.random.choice(len(X), replace=False, size=init_size)

    # Initialize trials
    T_prev = jnp.zeros(len(X_expt_orig))
    T_prev = T_prev.at[init_inds].set(init_trials)
    T_prev_random = jnp.copy(T_prev)

    p_empirical = jnp.array(sample_spikes(p_true, T_prev))
    p_empirical_random = jnp.copy(p_empirical)

    # Initialize weights

    w_inits = []
    for m in ms:
        w_init = jnp.array(np.random.normal(size=(m, X.shape[1])))
        w_inits.append(w_init)

    w_inits_random = w_inits.copy()

    performances = []
    performances_random = []
    num_samples = []

    cnt = 0

    while True:
        # reg = regs[cnt]
        num_samples.append(np.sum(np.absolute(np.array(T_prev)).astype(int)))
        sampled_inds = np.where(np.absolute(np.array(T_prev)).astype(int) > 0)[0]

        # fig = plt.figure()
        # fig.clear()
        # ax = Axes3D(fig, auto_add_to_figure=False)
        # fig.add_axes(ax)
        # plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
        # plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
        # plt.xlim(-1.8, 1.8)
        # plt.ylim(-1.8, 1.8)
        # ax.set_zlim(-1.8, 1.8)
        # ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

        # scat = ax.scatter(X_expt_orig[sampled_inds, 0], 
        #             X_expt_orig[sampled_inds, 1],
        #             X_expt_orig[sampled_inds, 2], marker='o', c=p_empirical[sampled_inds], s=20, alpha=0.8, vmin=0, vmax=1)

        # plt.show()

        sampled_inds_random = np.where(np.absolute(np.array(T_prev_random)).astype(int) > 0)[0]

        # fig = plt.figure()
        # fig.clear()
        # ax = Axes3D(fig, auto_add_to_figure=False)
        # fig.add_axes(ax)
        # plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
        # plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
        # plt.xlim(-1.8, 1.8)
        # plt.ylim(-1.8, 1.8)
        # ax.set_zlim(-1.8, 1.8)
        # ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

        # scat = ax.scatter(X_expt_orig[sampled_inds_random , 0], 
        #             X_expt_orig[sampled_inds_random , 1],
        #             X_expt_orig[sampled_inds_random , 2], marker='o', c=p_empirical_random[sampled_inds_random], s=20, alpha=0.8, vmin=0, vmax=1)

        # plt.show()
        
        ybar = jnp.sum(p_empirical * T_prev) / jnp.sum(T_prev)
        beta_null = jnp.log(ybar / (1 - ybar))
        null_weights = jnp.concatenate((jnp.array([beta_null]), jnp.zeros(X_expt_orig.shape[-1])))
        nll_null = neg_log_likelihood(null_weights[None, :], X, p_empirical, T_prev, l2_reg=l2_reg)

        # weights_BIC = []
        # BICs = []
        # Optimize w
        for i in range(len(ms)):

            losses, w_final = optimize_w(X, w_inits[i], p_empirical, T_prev, l2_reg=0.01, step_size=0.001, n_steps=3500)
            p_pred = activation_probs(X, w_final)

            # BICs.append(len(w_final.ravel()) * jnp.log(len(X)) + 2 * losses[-1])
            # weights_BIC.append(w_final)

            # fig = plt.figure()
            # fig.clear()
            # ax = Axes3D(fig, auto_add_to_figure=False)
            # fig.add_axes(ax)
            # plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            # plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            # plt.xlim(-1.8, 1.8)
            # plt.ylim(-1.8, 1.8)
            # ax.set_zlim(-1.8, 1.8)
            # ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            # scat = ax.scatter(X_expt_orig[:, 0], 
            #             X_expt_orig[:, 1],
            #             X_expt_orig[:, 2], marker='o', c=p_pred, s=20, alpha=0.8, vmin=0, vmax=1)

            # plt.show()

            # plt.figure()
            # plt.plot(losses)
            # plt.axhline(neg_log_likelihood(w_true, X, p_empirical, T_prev), linestyle='--', c='k')
            # plt.show()

            w_inits[i] = w_final

            if i == 0:
                last_R2 = 1 - losses[-1] / nll_null
                last_opt = w_final

            else:
                new_R2 = 1 - losses[-1] / nll_null
                if new_R2 - last_R2 <= R2_thresh:
                    break
                else:
                    last_R2 = new_R2
                    last_opt = w_final

        # w_final = weights_BIC[jnp.argmin(jnp.array(BICs))]
        w_final = last_opt
        performance = get_performance_AL(X, w_final, p_true)
        performances.append(performance)

        # print(jnp.array(BICs), w_final)


        ybar = jnp.sum(p_empirical_random * T_prev_random) / jnp.sum(T_prev_random)
        beta_null = jnp.log(ybar / (1 - ybar))
        null_weights = jnp.concatenate((jnp.array([beta_null]), jnp.zeros(X_expt_orig.shape[-1])))
        nll_null = neg_log_likelihood(null_weights[None, :], X, p_empirical_random, T_prev_random, l2_reg=l2_reg)
        # weights_BIC_random = []
        # BICs_random = []
        # Optimize w
        for i in range(len(ms)):

            losses_random, w_final_random = optimize_w(X, w_inits_random[i], p_empirical_random, T_prev_random, l2_reg=0.01, step_size=0.001, n_steps=3500)
            p_pred_random = activation_probs(X, w_final_random)

            # BICs_random.append(len(w_final_random.ravel()) * jnp.log(len(X)) + 2 * losses_random[-1])
            # weights_BIC_random.append(w_final_random)

            # fig = plt.figure()
            # fig.clear()
            # ax = Axes3D(fig, auto_add_to_figure=False)
            # fig.add_axes(ax)
            # plt.xlabel(r'$I_1$ ($\mu$A)', fontsize=16)
            # plt.ylabel(r'$I_2$ ($\mu$A)', fontsize=16)
            # plt.xlim(-1.8, 1.8)
            # plt.ylim(-1.8, 1.8)
            # ax.set_zlim(-1.8, 1.8)
            # ax.set_zlabel(r'$I_3$ ($\mu$A)', fontsize=16)

            # scat = ax.scatter(X_expt_orig[:, 0], 
            #             X_expt_orig[:, 1],
            #             X_expt_orig[:, 2], marker='o', c=p_pred_random, s=20, alpha=0.8, vmin=0, vmax=1)

            # plt.show()

            # plt.figure()
            # plt.plot(losses_random)
            # plt.axhline(neg_log_likelihood(w_true, X, p_empirical_random, T_prev_random), linestyle='--', c='k')
            # plt.show()

            w_inits_random[i] = w_final_random

            if i == 0:
                last_R2 = 1 - losses_random[-1] / nll_null
                last_opt = w_final_random

            else:
                new_R2 = 1 - losses_random[-1] / nll_null
                if new_R2 - last_R2 <= R2_thresh:
                    break
                else:
                    last_R2 = new_R2
                    last_opt = w_final_random

        # w_final_random = weights_BIC_random[jnp.argmin(jnp.array(BICs_random))]
        w_final_random = last_opt
        performance_random = get_performance_AL(X, w_final_random, p_true)
        performances_random.append(performance_random)

        # print(jnp.array(BICs_random), w_final_random)
        print(performance, performance_random)

        if cnt >= num_iters:
            break

        # explore = performance
        # explore_batch = 0#int(explore * budget) * 2
        # exploit_batch = budget - explore_batch

        T_new_init = jnp.zeros(len(T_prev)) + 1
        losses, t_final = optimize_fisher(X, w_final, p_empirical, T_prev, T_new_init, reg=reg, step_size=0.01, n_steps=1000, T_budget=budget)
        # losses, t_final, reg_ind = binary_search_fisher(regs, 0, len(regs)-1, X, w_final, p_empirical, T_prev, T_new_init, budget, budget_tol=0.2, step_size=0.01, n_steps=1000)
        # print(regs[reg_ind])

        # fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        # axs[0].plot(losses[:, 0])
        # axs[0].set_ylabel('Fisher Loss (A-optimality)')
        # axs[1].plot(losses[:, 1])
        # axs[1].set_ylabel('Total Trials')
        # axs[2].plot(losses[:, 2])
        # axs[2].set_ylabel('Regularized Loss, reg=' + str(reg))

        # fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
        # plt.show()

        # random_draws_explore = np.random.choice(len(X), size=explore_batch)
        # T_new_explore = jnp.array(np.bincount(random_draws_explore, minlength=len(X)))

        T_new = jnp.round(jnp.absolute(t_final), 0)#(t_final + T_new_explore), 0)
        # print(jnp.sum(T_new))
        # plt.figure()
        # plt.plot(T_new)
        # plt.show()

        if jnp.sum(T_new) < budget:
            random_extra = np.random.choice(len(X), size=int(budget - jnp.sum(T_new)))#, p=np.array(jnp.absolute(t_final)))
            T_new_extra = jnp.array(np.bincount(random_extra, minlength=len(X))).astype(int)
            T_new = T_new + T_new_extra
            
            # print(jnp.sum(T_new))
            # plt.figure()
            # plt.plot(T_new)
            # plt.show()

        p_new = jnp.array(sample_spikes(p_true, T_new))

        p_tmp = (p_new * T_new + p_empirical * T_prev) / (T_prev + T_new)
        T_tmp = T_prev + T_new
        p_tmp = p_tmp.at[jnp.isnan(p_tmp)].set(0.5)

        p_empirical = p_tmp
        T_prev = T_tmp
        # print(jnp.sum(T_tmp))

        random_draws = np.random.choice(len(X), size=int(jnp.sum(T_new)))
        T_new_random = jnp.array(np.bincount(random_draws, minlength=len(X))).astype(int)
        p_new_random = jnp.array(sample_spikes(p_true, T_new_random))
        
        p_tmp_random = (p_new_random * T_new_random + p_empirical_random * T_prev_random) / (T_prev_random + T_new_random)
        T_tmp_random = T_prev_random + T_new_random
        p_tmp_random = p_tmp_random.at[jnp.isnan(p_tmp_random)].set(0.5)

        p_empirical_random = p_tmp_random
        T_prev_random = T_tmp_random

        cnt += 1
    
    performance_stack.append(performances)
    performance_stack_random.append(performances_random)
    num_samples_stack.append(num_samples)

In [None]:
plt.figure(figsize=(10, 8))
plt.errorbar(np.mean(np.array(num_samples_stack), axis=0), np.mean(np.array(performance_stack), 0), 
             yerr=np.std(np.array(performance_stack), axis=0), fmt='o', ls='-', linewidth=4, elinewidth=2, label='Active Learning', c='tab:blue', alpha=0.5)
plt.errorbar(np.mean(np.array(num_samples_stack), axis=0), np.mean(np.array(performance_stack_random), 0), 
             yerr=np.std(np.array(performance_stack_random), axis=0), fmt='o', ls='-', linewidth=4, elinewidth=2, label='Random Baseline', c='tab:orange', alpha=0.5)
# plt.yscale('log')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Number of Trials Sampled', fontsize=24)
plt.ylabel(r'RMSE', fontsize=24)
plt.legend(fontsize=20)
# plt.ylim(0.36, 0.4)