In [None]:
def initialise_gp(kernel, mean, dataset):
    prior = gpx.gps.Prior(kernel=kernel, mean_function=mean)
    likelihood = gpx.likelihoods.Gaussian(num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64))
    posterior = prior * likelihood
    return posterior

# Define gp
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
posterior = initialise_gp(kernel, mean, p53_gpjax_dataset)

# Define marginal log likelihood
mll = jit(gpx.objectives.ConjugateMLL(negative=True))

In [None]:
objective = gpx.objectives.ConjugateMLL(negative=True)

opt_posterior, history = gpx.fit_scipy(
    model=posterior,
    objective=objective,
    train_data=p53_gpjax_dataset,  # Ensure this is your GPJax Dataset instance
    max_iters=1000,  # You can adjust this number as necessary
    verbose=True,  # Set to True or False based on whether you want optimization details printed
)

In [None]:
# training model
import optax
from tqdm import tqdm
from gpjax.objectives import ConjugateMLL as c_mll
from jax import jit, value_and_grad

neg_log_likelihood = lambda params: -c_mll(posterior, p53_gpjax_dataset)

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)


objective_and_grad = jit(value_and_grad(neg_log_likelihood))


num_iterations = 100
for i in tqdm(range(num_iterations)):
    loss, grads = objective_and_grad(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    if i % 10 == 0:  # print loss every 10 iterations
        print(f"Iteration {i}: NLL = {loss}")