In [2]:
import itertools as it
import warnings

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

warnings.filterwarnings("ignore")
import pandas as pd

In [3]:
Xs = []
Ys = []
spaces = []
for i in range(1, 9):
    Xs.append(np.load(f"initial_data/function_{i}/initial_inputs.npy"))
    Y = np.load(f"initial_data/function_{i}/initial_outputs.npy")
    if i == 2:
        Y = np.where(Y <= 0, 1e-10, Y)
        Y = np.log(Y)
    Ys.append(Y)

# Read the feedback data and stack it with the input data
df = pd.read_csv('feedback_data/605_data.csv')
df = df.drop(columns=['Unnamed: 0', 'timestamp', 'student_id'], axis=1)
for index, col in enumerate(df.columns):
    if 'output' in col:
        y_array = np.array(df[col], dtype=float)
        Ys[index-8] = np.hstack((y_array, Ys[index-8]))
        continue
    series = []
    for i in range(len(df[col])):
        series.append(df[col][i].replace("[", "").replace("]", "").split())
    series_array = np.array(series, dtype=float)
    Xs[index] = np.vstack((series_array, Xs[index]))
    


In [None]:
# Carry out a UCB based predictions for all the functions
dims = [2, 2, 3, 4, 4, 5, 6, 8]
kernel = 1.0 * RBF(1.0)
for i, dim in enumerate(dims):
    alpha = 1e-10
    if i == 1:
        # Since the description for function states that the observations are noisy
        # TODO: is the alpha value correct ????
        alpha = 0.1**2
    gpr = GaussianProcessRegressor(kernel=kernel, alpha=alpha)
    gpr.fit(Xs[i], Ys[i])
    x1 = np.linspace(0.0, 1, 15)
    X_grid = np.fromiter(it.chain(*it.product(x1, repeat=dim)), dtype=float).reshape(-1,dim)
    # print(X_grid)
    mean, std = gpr.predict(X_grid, return_std = True)
    # print(mean)
    ucb = mean + 1.96 * std
    idx_max = np.argmax(ucb)
    next_query = X_grid[idx_max]
    next_qery_string = "-".join([f"{x:.6f}" for x in next_query])
    print(f"The next query for the function {i+1} is {next_qery_string}")
    

The next query for the function 1 is 0.642857-0.642857
The next query for the function 2 is 0.000000-0.000000
The next query for the function 3 is 1.000000-0.000000-0.000000
The next query for the function 4 is 0.000000-0.000000-0.000000-0.071429
The next query for the function 5 is 1.000000-1.000000-1.000000-1.000000
The next query for the function 6 is 0.000000-0.000000-0.000000-0.000000-0.071429
The next query for the function 7 is 0.000000-0.000000-0.000000-0.000000-0.000000-0.000000
