In [1]:
import numpy as np
import torch
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
from data_grunwald import get_data, get_X_show


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set seed
seed = 0
np.random.seed(seed)

# get data
y = 0.0
d_x = 101  # number of fourier features
n_train = 100
n_test = 10000
if_misspecified = False
X_train, Y_train, X_train_orig = get_data(
    y=y, d_x=d_x, n_data=n_train, if_misspecified=if_misspecified
)
X_test, Y_test, X_test_orig = get_data(
    y=y, d_x=d_x, n_data=n_test, if_misspecified=if_misspecified
)


In [3]:
X_test.shape


(10000, 101)

In [8]:
def get_log_risks_lambs(var_prior):
    """var_prior: prior variance"""
    log_risks_lambs = []
    lambs = np.linspace(0.1, 3.0, 20)
    for lamb in lambs:
        print(lamb)
        print()

        # settings
        alpha = 1 / var_prior # prior recision
        beta_orig = 40.0  # gaussian likelihood precision before absorbing temperature
        beta = (
            beta_orig * lamb
        )  # gaussian likelihood precision after absorbing temperature
        
        # compute posterior dist., see bishop eq.3.53 and 3.54
        S_N_inv = (
            alpha * np.identity(d_x) + beta * X_train.T @ X_train
        )  # posterior variance
        S_N = np.linalg.inv(S_N_inv)  # posterior precision
        m_N = beta * S_N @ X_train.T @ Y_train  # posterior mean

        # computre posterior prediction dist., see bishop eq.3.58 and 3.59
        post_pred_mean = X_test @ m_N
        post_pred_cov = np.identity(n_test) / beta + X_test @ S_N @ X_test.T
        post_pred_std = np.sqrt(np.diag(post_pred_cov))
        post_pred = Normal(torch.tensor(post_pred_mean), torch.tensor(post_pred_std))

        # get log-risk
        log_risk = -post_pred.log_prob(torch.tensor(Y_test)).mean().item()
        log_risks_lambs.append(log_risk)

    return log_risks_lambs


In [9]:
log_risks_lambs_100 = get_log_risks_lambs(100)
log_risks_lambs_001 = get_log_risks_lambs(0.01)


0.1

0.25263157894736843

0.4052631578947369

0.5578947368421052

0.7105263157894737

0.8631578947368421

1.0157894736842106

1.168421052631579

1.3210526315789475

1.473684210526316

1.6263157894736844

1.7789473684210528

1.931578947368421

2.0842105263157893

2.236842105263158

2.3894736842105266

2.542105263157895

2.694736842105263

2.8473684210526318

3.0

0.1

0.25263157894736843

0.4052631578947369

0.5578947368421052

0.7105263157894737

0.8631578947368421

1.0157894736842106

1.168421052631579

1.3210526315789475

1.473684210526316

1.6263157894736844

1.7789473684210528

1.931578947368421

2.0842105263157893

2.236842105263158

2.3894736842105266

2.542105263157895

2.694736842105263

2.8473684210526318

3.0



In [1]:
fig = plt.figure(figsize=(6, 5))
plt.grid()
lambs = np.linspace(0.1, 3.0, 20)
plt.plot(lambs, log_risks_lambs_001, "-*", label="var_prior=0.01")
plt.plot(lambs, log_risks_lambs_100, "_.", label="var_prior=100.0")
plt.xlabel("lambda")
plt.ylabel("true log-risk")
plt.title("grunwald, N=100, d_x=101, var_likelihood=0.025")
plt.legend()
plt.savefig("grunwald_result.jpg", bbox_inches="tight")
plt.show()


NameError: name 'plt' is not defined