In [None]:
from sklearn.metrics import mean_squared_error
from function import Kernel_sobo, choose_lam_r_quantile
import numpy as np
import scipy.stats as stats
import tqdm
# plot image of mse for fixed r and varing sample size
import matplotlib.pyplot as plt
np.random.seed(10)


def scale_kernel(alpha, x, y):
    """define the spike kernel"""
    return 2*sum(np.cos(2*np.pi*k*(x-y))/((k+1)**alpha) for k in range(100))+1


#specify kernel function
def kernel_spike(alpha, x, y):
    """define the spike kernel"""
    n = x.shape[0]
    m = y.shape[0]
    K = np.zeros([n,m])
    for i in range(n):
        for j in range(m):
            K[i,j] = scale_kernel(alpha, x[i], y[j])
    return K


# underlying mean regression function
def f_0(x):
    """define the mean regression function for 1-dimensional KRR (Example S1 in supplementary material)"""
    return scale_kernel(3.5,x,0)*np.sin(1*x)

sd = 2
tau = 0.5
f_true = f_0

def generate_data(n, f):
    """generate data from the mean regression function f (f_0 or f_1)"""
    x_train=np.sort(np.random.rand(n))
    y_train=f(x_train)+np.random.normal(0, sd, n) - stats.norm.ppf(tau, loc=0, scale=sd)
    return x_train,y_train



alpha_list = [2, 4, 6, 8, 10]
n = 300
iter_num = 50
mse_list = np.zeros([iter_num, len(alpha_list), 2])
mse_mean = np.zeros([len(alpha_list), 2])
mse_var = np.zeros([len(alpha_list), 2])

for i in range(len(alpha_list)):
    alpha = alpha_list[i]
    for j in tqdm.tqdm(range(iter_num)):
        x_train, y_train = generate_data(n,  f_true)
        y_true = f_true(x_train)
        K = kernel_spike(alpha, x_train, x_train)
        mse_list[j, i, 0] = choose_lam_r_quantile(K, y_train, y_true, truncation=True, tau=tau, loss_type="excess_risk")
        mse_list[j, i, 1] = choose_lam_r_quantile(K, y_train, y_true, truncation=False, tau=tau, loss_type="excess_risk")
    mse_mean[i, 0] = np.mean(mse_list[:, i, 0])
    mse_var[i, 0] = np.std(mse_list[:, i, 0])
    mse_mean[i, 1] = np.mean(mse_list[:, i, 1])
    mse_var[i, 1] = np.std(mse_list[:, i, 1])
    print("alpha=", alpha_list[i], ",truncated mean mse=", format(mse_mean[i, 0], '.3f'),  ",full mean mse=", format(mse_mean[i, 1], '.3f'))
    print("alpha=", alpha_list[i], ",truncated std mse=", format(mse_var[i, 0], '.3f'),  ",full var mse=", format(mse_var[i, 1], '.3f'))
print("hyperparameters is", "n =", n,  "alpha_list = ", alpha_list, "iter_num = ", iter_num, "tau = ", tau, "sd = ", sd)
plt.plot(alpha_list, mse_mean[:, 0], label="truncated mean")
plt.plot(alpha_list, mse_mean[:, 1], label="full")
plt.fill_between(alpha_list, mse_mean[:, 0]-mse_var[:, 0], mse_mean[:, 0]+mse_var[:, 0], alpha=0.2)
plt.fill_between(alpha_list, mse_mean[:, 1]-mse_var[:, 1], mse_mean[:, 1]+mse_var[:, 1], alpha=0.2)
plt.legend()
plt.xlabel('alpha')
plt.ylabel('excess risk')
plt.title('KQR, tau=0.5')
plt.show()

